diff --git a/patchscopes/code/README.md b/patchscopes/code/README.md
new file mode 100644
index 00000000..8df91386
--- /dev/null
+++ b/patchscopes/code/README.md
@@ -0,0 +1,50 @@
+## 🩺 Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models
+
+
+### Overview
+We propose a framework that decodes specific information from a representation within an LLM by “patching” it into the inference pass on a different prompt that has been designed to encourage the extraction of that information. A "Patchscope" is a configuration of our framework that can be viewed as an inspection tool geared towards a particular objective.
+
+For example, this figure shows a simple Patchscope for decoding what is encoded in the representation of "CEO" in the source prompt (left). We patch a target prompt (right) comprised of few-shot demonstrations of token repetitions, which encourages decoding the token identity given a hidden representation.
+
+[**[Paper]**](https://arxiv.org/abs/2401.06102) [**[Project Website]**](https://pair-code.github.io/interpretability/patchscopes/)
+
+
+
+### 💾 Download textual data
+The script is provided [**here**](download_the_pile_text_data.py). Use the following command to run it:
+```python
+python3 download_the_pile_text_data.py
+```
+
+### 🦙 For using Vicuna-13B
+Run the following command for using the Vicuna 13b model (see also details [here](https://huggingface.co/CarperAI/stable-vicuna-13b-delta)):
+```python
+python3 apply_delta.py --base meta-llama/Llama-2-13b-hf --target ./stable-vicuna-13b --delta CarperAI/stable-vicuna-13b-delta
+```
+
+### 🧪 Experiments
+
+#### (1) Next Token Prediction
+The main code used appears [here](next_token_prediction.ipynb).
+#### (2) Attribute Extraction
+For this experiment, you should download the `preprocessed_data` directory.
+The main code used appears [here](attribute_extraction.ipynb).
+#### (3) Entity Processing
+The main code used appears [here](entity_processing.ipynb). The dataset is available for downloading [here](https://github.com/AlexTMallen/adaptive-retrieval/blob/main/data/popQA.tsv).
+#### (4) Cross-model Patching
+The main code used appears [here](patch_cross_model.ipynb).
+#### (5) Self-Correction in Multi-Hop Reasoning
+For this experiment, you should download the `preprocessed_data` directory.
+The main code used appears [here](multihop-CoT.ipynb). The code provided supports the Vicuna-13B model.
+
+### 📙 BibTeX
+```bibtex
+@misc{ghandeharioun2024patchscopes,
+ title={Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models},
+ author={Ghandeharioun, Asma and Caciularu, Avi and Pearce, Adam and Dixon, Lucas and Geva, Mor},
+ year={2024},
+ eprint={2401.06102},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+```
diff --git a/patchscopes/code/apply_delta.py b/patchscopes/code/apply_delta.py
new file mode 100644
index 00000000..127ea0af
--- /dev/null
+++ b/patchscopes/code/apply_delta.py
@@ -0,0 +1,51 @@
+"""
+Usage:
+python3 apply_delta.py --base /path/to/model_weights/llama-13b --target stable-vicuna-13b --delta pvduy/stable-vicuna-13b-delta
+
+The code was adopted from https://github.com/GanjinZero/RRHF/blob/main/apply_delta.py
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+
+def apply_delta(base_model_path, target_model_path, delta_path):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading delta")
+ delta = AutoModelForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
+
+ DEFAULT_PAD_TOKEN = "[PAD]"
+ base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)
+ num_new_tokens = base_tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))
+
+ base.resize_token_embeddings(len(base_tokenizer))
+ input_embeddings = base.get_input_embeddings().weight.data
+ output_embeddings = base.get_output_embeddings().weight.data
+ input_embeddings[-num_new_tokens:] = 0
+ output_embeddings[-num_new_tokens:] = 0
+
+ print("Applying delta")
+ for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
+ assert name in delta.state_dict()
+ param.data += delta.state_dict()[name]
+
+ print("Saving target model")
+ base.save_pretrained(target_model_path)
+ delta_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/patchscopes/code/attribute_extraction.ipynb b/patchscopes/code/attribute_extraction.ipynb
new file mode 100644
index 00000000..8acf0916
--- /dev/null
+++ b/patchscopes/code/attribute_extraction.ipynb
@@ -0,0 +1,668 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GMJYfysaREkb"
+ },
+ "source": [
+ "# Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mdEmY4rDQ3ik",
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "from ast import literal_eval\n",
+ "import functools\n",
+ "import json\n",
+ "import os\n",
+ "import random\n",
+ "import shutil\n",
+ "\n",
+ "# Scienfitic packages\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "import datasets\n",
+ "from torch import cuda\n",
+ "torch.set_grad_enabled(False)\n",
+ "\n",
+ "# Visuals\n",
+ "from matplotlib import pyplot as plt\n",
+ "import seaborn as sns\n",
+ "sns.set(context=\"notebook\",\n",
+ " rc={\"font.size\":16,\n",
+ " \"axes.titlesize\":16,\n",
+ " \"axes.labelsize\":16,\n",
+ " \"xtick.labelsize\": 16.0,\n",
+ " \"ytick.labelsize\": 16.0,\n",
+ " \"legend.fontsize\": 16.0})\n",
+ "palette_ = sns.color_palette(\"Set1\")\n",
+ "palette = palette_[2:5] + palette_[7:]\n",
+ "sns.set_theme(style='whitegrid')\n",
+ "\n",
+ "# Utilities\n",
+ "\n",
+ "from general_utils import (\n",
+ " ModelAndTokenizer,\n",
+ " make_inputs,\n",
+ " decode_tokens,\n",
+ " find_token_range,\n",
+ " predict_from_input,\n",
+ ")\n",
+ "\n",
+ "from patchscopes_utils import *\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "tqdm.pandas()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-iVlmvjRahV6"
+ },
+ "outputs": [],
+ "source": [
+ "model_to_hook = {\n",
+ " \"EleutherAI/pythia-6.9b\": set_hs_patch_hooks_neox,\n",
+ " \"EleutherAI/pythia-12b\": set_hs_patch_hooks_neox,\n",
+ " \"meta-llama/Llama-2-13b-hf\": set_hs_patch_hooks_llama,\n",
+ " \"lmsys/vicuna-7b-v1.5\": set_hs_patch_hooks_llama,\n",
+ " \"./stable-vicuna-13b\": set_hs_patch_hooks_llama,\n",
+ " \"CarperAI/stable-vicuna-13b-delta\": set_hs_patch_hooks_llama,\n",
+ " \"EleutherAI/gpt-j-6b\": set_hs_patch_hooks_gptj\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MJu_u30hA9dd"
+ },
+ "outputs": [],
+ "source": [
+ "# Load model\n",
+ "\n",
+ "# 0-shot with GPT-J\n",
+ "model_name = \"gpt-j-6B\"\n",
+ "sos_tok = False\n",
+ "\n",
+ "if \"13b\" in model_name or \"12b\" in model_name:\n",
+ " torch_dtype = torch.float16\n",
+ "else:\n",
+ " torch_dtype = None\n",
+ "\n",
+ "my_device = torch.device(\"cuda:0\")\n",
+ "\n",
+ "mt = ModelAndTokenizer(\n",
+ " model_name,\n",
+ " low_cpu_mem_usage=False,\n",
+ " torch_dtype=torch_dtype,\n",
+ " device=my_device,\n",
+ ")\n",
+ "mt.set_hs_patch_hooks = model_to_hook[model_name]\n",
+ "mt.model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Ly4N9cT7ahV7"
+ },
+ "outputs": [],
+ "source": [
+ "def run_experiment(task_type, task_name, data_dir, output_dir, batch_size=512, n_samples=-1,\n",
+ " save_output=True, replace=False, only_correct=False, is_icl=True):\n",
+ " fdir_out = f\"{output_dir}/{task_type}\"\n",
+ " fname_out = f\"{fdir_out}/{task_name}_only_correct_{only_correct}.pkl\"\n",
+ " if not replace and os.path.exists(fname_out):\n",
+ " print(f\"File {fname_out} exists. Skipping...\")\n",
+ " return\n",
+ " print(f\"Running experiment on {task_type}/{task_name}...\")\n",
+ " df = pd.read_pickle(f\"{data_dir}/{task_type}/{task_name}.pkl\")\n",
+ " if only_correct:\n",
+ " df = df[df[\"is_correct_baseline\"]].reset_index(drop=True)\n",
+ " # Dropping empty prompt sources. This is an artifact of saving and reloading inputs\n",
+ " df = df[df[\"prompt_source\"]!=\"\"].reset_index(drop=True)\n",
+ " # Dropping prompt sources with \\n. pandas read_pickle is not able to handle them properly and drops the rest of the input.\n",
+ " df = df[~df[\"prompt_source\"].str.contains('\\n')].reset_index(drop=True)\n",
+ " # After manual inspection, this example seems to have tokenization issues. Dropping.\n",
+ " if task_name == \"star_constellation\":\n",
+ " df = df[~df[\"prompt_source\"].str.contains(\"service\")].reset_index(drop=True)\n",
+ " elif task_name == \"object_superclass\":\n",
+ " df = df[~df[\"prompt_source\"].str.contains(\"Swainson ’ s hawk and the prairie\")].reset_index(drop=True)\n",
+ " print(f\"\\tNumber of samples: {len(df)}\")\n",
+ "\n",
+ " # BATCHED\n",
+ " batch = []\n",
+ " for _, row in tqdm.tqdm(df.iterrows()):\n",
+ " for layer_source in range(mt.num_layers-1):\n",
+ " for layer_target in range(mt.num_layers-1):\n",
+ " item = dict(row)\n",
+ " item.update({\n",
+ " \"layer_source\": layer_source,\n",
+ " \"layer_target\": layer_target,\n",
+ " })\n",
+ " batch.append(item)\n",
+ " experiment_df = pd.DataFrame.from_records(batch)\n",
+ "\n",
+ " if n_samples > 0 and n_samples 0 and n_samples 0 and n_samples 1:
+ return [decode_tokens(tokenizer, row) for row in token_array]
+ return [tokenizer.decode([t]) for t in token_array]
+
+
+def find_token_range(tokenizer, token_array, substring):
+ """Find the tokens corresponding to the given substring in token_array."""
+ toks = decode_tokens(tokenizer, token_array)
+ whole_string = "".join(toks)
+ char_loc = whole_string.index(substring)
+ loc = 0
+ tok_start, tok_end = None, None
+ for i, t in enumerate(toks):
+ loc += len(t)
+ if tok_start is None and loc > char_loc:
+ tok_start = i
+ if tok_end is None and loc >= char_loc + len(substring):
+ tok_end = i + 1
+ break
+ return (tok_start, tok_end)
+
+
+def predict_from_input(model, inp):
+ out = model(**inp)["logits"]
+ probs = torch.softmax(out[:, -1], dim=1)
+ p, preds = torch.max(probs, dim=1)
+ return preds, p
+
+
+def set_requires_grad(requires_grad, *models):
+ for model in models:
+ if isinstance(model, torch.nn.Module):
+ for param in model.parameters():
+ param.requires_grad = requires_grad
+ elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
+ model.requires_grad = requires_grad
+ else:
+ assert False, "unknown type %r" % type(model)
diff --git a/patchscopes/code/images/patchscopes.png b/patchscopes/code/images/patchscopes.png
new file mode 100644
index 00000000..3be77cf8
Binary files /dev/null and b/patchscopes/code/images/patchscopes.png differ
diff --git a/patchscopes/code/multihop-CoT.ipynb b/patchscopes/code/multihop-CoT.ipynb
new file mode 100644
index 00000000..3853d4c9
--- /dev/null
+++ b/patchscopes/code/multihop-CoT.ipynb
@@ -0,0 +1,1146 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GMJYfysaREkb"
+ },
+ "source": [
+ "# Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mdEmY4rDQ3ik",
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "from ast import literal_eval\n",
+ "import functools\n",
+ "import json\n",
+ "import os\n",
+ "import random\n",
+ "import shutil\n",
+ "\n",
+ "# Scienfitic packages\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "import torch\n",
+ "import datasets\n",
+ "from torch import cuda\n",
+ "torch.set_grad_enabled(False)\n",
+ "\n",
+ "# Visuals\n",
+ "from matplotlib import pyplot as plt\n",
+ "import seaborn as sns\n",
+ "sns.set(context=\"notebook\",\n",
+ " rc={\"font.size\":16,\n",
+ " \"axes.titlesize\":16,\n",
+ " \"axes.labelsize\":16,\n",
+ " \"xtick.labelsize\": 16.0,\n",
+ " \"ytick.labelsize\": 16.0,\n",
+ " \"legend.fontsize\": 16.0})\n",
+ "palette_ = sns.color_palette(\"Set1\")\n",
+ "palette = palette_[2:5] + palette_[7:]\n",
+ "sns.set_theme(style='whitegrid')\n",
+ "\n",
+ "# Utilities\n",
+ "\n",
+ "from general_utils import (\n",
+ " ModelAndTokenizer,\n",
+ " make_inputs,\n",
+ " decode_tokens,\n",
+ " find_token_range,\n",
+ " predict_from_input,\n",
+ ")\n",
+ "\n",
+ "from patchscopes_utils import *\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "tqdm.pandas()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "oURHfJrzap1H"
+ },
+ "outputs": [],
+ "source": [
+ "# Load model\n",
+ "\n",
+ "model_name = \"vicuna-13b-v1.1\"\n",
+ "sos_tok = False\n",
+ "\n",
+ "if \"13b\" in model_name or \"12b\" in model_name:\n",
+ " torch_dtype = torch.float16\n",
+ "else:\n",
+ " torch_dtype = None\n",
+ "\n",
+ "my_device = torch.device(\"cuda:2\")\n",
+ "\n",
+ "mt = ModelAndTokenizer(\n",
+ " model_name,\n",
+ " low_cpu_mem_usage=False,\n",
+ " torch_dtype=torch_dtype,\n",
+ " device=my_device,\n",
+ ")\n",
+ "mt.set_hs_patch_hooks = set_hs_patch_hooks_llama_batch\n",
+ "mt.model.eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AAtiySLToTY7"
+ },
+ "source": [
+ "# MultiHop reasoning experiments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "PfhXcB54ap1I"
+ },
+ "outputs": [],
+ "source": [
+ "def generate_baseline_multihop(\n",
+ " mt, df, batch_size=256, max_gen_len=10,\n",
+ "):\n",
+ " def _generate_baseline_single_batch(batch_df):\n",
+ " batch_size = len(batch_df)\n",
+ " cases = [(\"baseline_hop2\", \"hop2\"),\n",
+ " (\"baseline_hop3\", \"hop3\"),\n",
+ " (\"baseline_multihop3\", \"hop3\"),\n",
+ " ]\n",
+ " results = {}\n",
+ " for target_col, object_col in cases:\n",
+ "\n",
+ " target_baseline_batch = np.array(batch_df[target_col])\n",
+ " object_batch = np.array(batch_df[object_col])\n",
+ "\n",
+ "\n",
+ " # Step 0: run the the model on target prompt baseline (having the subject token in input rather than patched)\n",
+ " # The goal of this step is to calculate whether the model works correctly by default, and to calculate surprisal\n",
+ " inp_target_baseline = make_inputs(mt.tokenizer, target_baseline_batch, mt.device)\n",
+ " seq_len_target_baseline = len(inp_target_baseline[\"input_ids\"][0])\n",
+ " output_target_baseline_toks = mt.model.generate(\n",
+ " inp_target_baseline[\"input_ids\"],\n",
+ " max_length=seq_len_target_baseline + max_gen_len,\n",
+ " pad_token_id=mt.model.generation_config.eos_token_id,\n",
+ " )[:, seq_len_target_baseline:]\n",
+ " generations_baseline = decode_tokens(mt.tokenizer, output_target_baseline_toks)\n",
+ " generations_baseline_txt = np.array([\" \".join(sample_gen) for sample_gen in generations_baseline])\n",
+ "\n",
+ "\n",
+ " is_correct_baseline = np.array([\n",
+ " (object_batch[i] in generations_baseline_txt[i] or\n",
+ " object_batch[i].replace(\" \", \"\") in generations_baseline_txt[i].replace(\" \", \"\"))\n",
+ " for i in range(batch_size)\n",
+ " ])\n",
+ " results.update(\n",
+ " {\n",
+ " f\"generations_{target_col}\": generations_baseline_txt,\n",
+ " f\"is_correct_{target_col}\": is_correct_baseline,\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " return results\n",
+ "\n",
+ " results = {}\n",
+ " n_batches = len(df) // batch_size\n",
+ " if len(df)%batch_size !=0:\n",
+ " n_batches +=1\n",
+ " for i in tqdm.tqdm(range(n_batches)):\n",
+ " cur_df = df.iloc[batch_size * i : batch_size * (i + 1)]\n",
+ " batch_results = _generate_baseline_single_batch(cur_df)\n",
+ " for key, value in batch_results.items():\n",
+ " if key in results:\n",
+ " results[key] = np.concatenate((results[key], value))\n",
+ " else:\n",
+ " results[key] = value\n",
+ "\n",
+ " return results\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "X3YQSI-MoTY8"
+ },
+ "source": [
+ "# Experiment 1: Multihop Product Company CEO tuples\n",
+ "\n",
+ "This is a subset made only from (product, company) and (company, CEO) tuples from the LRE dataset.\n",
+ "We only picked 3 (company, CEO) tuples, and 15 (product, company) tuples for each that the model is more likely to know the answer to.\n",
+ "\n",
+ "This is an exploratory experiment. There is a more complete experiment later in the colab.\n",
+ "Hop 1: Product\n",
+ "Hop 2: company\n",
+ "Hop 3: CEO"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fW9dio43ap1J"
+ },
+ "outputs": [],
+ "source": [
+ "multihop_samples = {\n",
+ " (\"Satya Nadella\", \"Microsoft\"): [\"WinDbg\", \".NET Framework\", \"Internet Explorer\", \"MS-DOS\", \"Office Open XML\",\n",
+ " \"TypeScript\", \"Bing Maps Platform\", \"Outlook Express\", \"PowerShell\", \"Windows 95\",\n",
+ " \"Xbox 360\", \"Zune\", \"Visual Basic Script\", \"Virtual Hard Disk\", \"Robocopy\",\n",
+ " ],\n",
+ " (\"Tim Cook\", \"Apple\"): [\"Siri\", \"App Store\", \"CarPlay\", \"MacBook Air\", \"Xcode\",\n",
+ " \"macOS\", \"iWork\", \"Safari\", \"QuickTime\", \"TextEdit\",\n",
+ " \"WebKit\", \"QuickDraw\", \"Time Machine (macOS)\", \"MessagePad\", \"Macbook Pro\",\n",
+ " ],\n",
+ " (\"Sundar Pichai\", \"Google\"): [\"Chromecast\", \"Chromebook\", \"Wear OS\", \"G Suite\", \"Picasa\",\n",
+ " \"WebP Lossless\", \"General Transit Feed Specification Lossless\", \"Cloud Spanner\", \"Android TV\", \"Android Runtime\",\n",
+ " \"Android Jelly Bean\", \"Android Auto\", \"App Inventor\", \"Chromebook Pixel\", \"Project Ara\",\n",
+ " ]\n",
+ "}\n",
+ "\n",
+ "def generate_multihop_data_ceo(fdir_out=\"./outputs/factual\", batch_size=512, max_gen_len=20, replace=False):\n",
+ " if not os.path.exists(fdir_out):\n",
+ " os.makedirs(fdir_out)\n",
+ " fname_out = \"multihop_product_company_ceo\"\n",
+ " if not replace and os.path.exists(f\"{fdir_out}/{fname_out}.pkl\"):\n",
+ " print(f\"File {fdir_out}/{fname_out}.pkl exists. Skipping generation. Reading file...\")\n",
+ " df = pd.read_pickle(os.path.join(fdir_out, f\"{fname_out}.pkl\"))\n",
+ " return df\n",
+ " prompt_source_template = \"{} was created by\"\n",
+ " prompt_target_template = \"Who is the current CEO of {}\"\n",
+ " sample_id = 0\n",
+ "\n",
+ " print(\"Step 1: Prepare dataset...\")\n",
+ " records = []\n",
+ "\n",
+ " for key, value in multihop_samples.items():\n",
+ " hop3, hop2 = key\n",
+ " for hop1 in value:\n",
+ " # hop1: Product\n",
+ " # hop2: Company\n",
+ " # hop3: CEO\n",
+ " records.append({\n",
+ " \"sample_id\": sample_id,\n",
+ " \"prompt_source\": prompt_source_template.replace(\"{}\", hop1),\n",
+ " \"position_source\": -1, # always doing next token prediction\n",
+ " \"prompt_target\": prompt_target_template,\n",
+ " \"position_target\": -1,\n",
+ "\n",
+ " \"baseline_hop2\": f\"{hop1} was created by\", # hop2\n",
+ " \"baseline_hop3\": f\"Who is the current CEO of {hop2}\", # hop3\n",
+ " \"baseline_multihop3\": f\"Who is the current CEO of the company that created {hop1}\", # hop3\n",
+ "\n",
+ " \"hop1\": hop1,\n",
+ " \"hop2\": hop2,\n",
+ " \"hop3\": hop3,\n",
+ " })\n",
+ " sample_id +=1\n",
+ "\n",
+ " # Step 2: Compute baseline generations\n",
+ " print(\"Step 2: Compute baseline generations...\")\n",
+ " df = pd.DataFrame.from_records(records)\n",
+ " eval_results = generate_baseline_multihop(mt, df, batch_size=batch_size, max_gen_len=max_gen_len)\n",
+ " for key, value in eval_results.items():\n",
+ " df[key] = list(value)\n",
+ "\n",
+ " df.to_csv(os.path.join(fdir_out, f\"{fname_out}.tsv\"), sep=\"\\t\")\n",
+ " df.to_pickle(os.path.join(fdir_out, f\"{fname_out}.pkl\"))\n",
+ " return df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0Ax2wLmxap1J"
+ },
+ "outputs": [],
+ "source": [
+ "multihop_df = generate_multihop_data_ceo(batch_size=128, max_gen_len=20)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0uklW1xTap1K"
+ },
+ "outputs": [],
+ "source": [
+ "def evaluate_attriburte_exraction_batch_multihop(\n",
+ " mt, df, batch_size=256, max_gen_len=10, transform=None\n",
+ "):\n",
+ " def _evaluate_attriburte_exraction_single_batch(batch_df):\n",
+ " batch_size = len(batch_df)\n",
+ " prompt_source_batch = np.array(batch_df[\"prompt_source\"])\n",
+ " prompt_target_batch = np.array(batch_df[\"prompt_target\"])\n",
+ " layer_source_batch = np.array(batch_df[\"layer_source\"])\n",
+ " layer_target_batch = np.array(batch_df[\"layer_target\"])\n",
+ " position_source_batch = np.array(batch_df[\"position_source\"])\n",
+ " position_target_batch = np.array(batch_df[\"position_target\"])\n",
+ "\n",
+ " object_batch = np.array(batch_df[\"hop3\"])\n",
+ "\n",
+ "\n",
+ " # Adjust position_target to be absolute rather than relative\n",
+ " inp_target = make_inputs(mt.tokenizer, prompt_target_batch, mt.device)\n",
+ " for i in range(batch_size):\n",
+ " if position_target_batch[i] < 0:\n",
+ " position_target_batch[i] += len(inp_target[\"input_ids\"][i])\n",
+ "\n",
+ " # Step 1: run the the model on source without patching and get the hidden representations.\n",
+ " inp_source = make_inputs(mt.tokenizer, prompt_source_batch, mt.device)\n",
+ " output_orig = mt.model(**inp_source, output_hidden_states=True)\n",
+ "\n",
+ " # hidden_states size (n_layers, n_sample, seq_len, hidden_dim)\n",
+ " hidden_rep = [\n",
+ " output_orig.hidden_states[layer_source_batch[i] + 1][i][\n",
+ " position_source_batch[i]\n",
+ " ]\n",
+ " for i in range(batch_size)\n",
+ " ]\n",
+ " if transform is not None:\n",
+ " for i in range(batch_size):\n",
+ " hidden_rep[i] = transform(hidden_rep[i])\n",
+ "\n",
+ " # Step 2: do second run on target prompt, while patching the input hidden state.\n",
+ " hs_patch_config = [\n",
+ " {\n",
+ " \"batch_idx\": i,\n",
+ " \"layer_target\": layer_target_batch[i],\n",
+ " \"position_target\": position_target_batch[i],\n",
+ " \"hidden_rep\": hidden_rep[i],\n",
+ " \"skip_final_ln\": (\n",
+ " layer_source_batch[i]\n",
+ " == layer_target_batch[i]\n",
+ " == mt.num_layers - 1\n",
+ " ),\n",
+ " }\n",
+ " for i in range(batch_size)\n",
+ " ]\n",
+ " patch_hooks = mt.set_hs_patch_hooks(\n",
+ " mt.model, hs_patch_config, patch_input=False, generation_mode=True\n",
+ " )\n",
+ "\n",
+ " output = mt.model(**inp_target)\n",
+ "\n",
+ " # NOTE: inputs are left padded,\n",
+ " # and sequence length is the same across batch\n",
+ " seq_len = len(inp_target[\"input_ids\"][0])\n",
+ " output_toks = mt.model.generate(\n",
+ " inp_target[\"input_ids\"],\n",
+ " max_length=seq_len + max_gen_len,\n",
+ " pad_token_id=mt.model.generation_config.eos_token_id,\n",
+ " )[:, seq_len:]\n",
+ " generations_patched = decode_tokens(mt.tokenizer, output_toks)\n",
+ " generations_patched_txt = np.array([\n",
+ " \" \".join(generations_patched[i])\n",
+ " for i in range(batch_size)\n",
+ " ])\n",
+ " is_correct_patched = np.array([\n",
+ " (object_batch[i] in generations_patched_txt[i]\n",
+ " or object_batch[i].replace(\" \", \"\") in generations_patched_txt[i].replace(\" \", \"\"))\n",
+ " for i in range(batch_size)\n",
+ " ])\n",
+ "\n",
+ " # remove patching hooks\n",
+ " remove_hooks(patch_hooks)\n",
+ "\n",
+ " cpu_hidden_rep = np.array([hidden_rep[i].detach().cpu().numpy() for i in range(batch_size)])\n",
+ "\n",
+ " results = {\n",
+ " \"generations_patched\": generations_patched,\n",
+ " \"is_correct_patched\": is_correct_patched,\n",
+ " \"hidden_rep\": cpu_hidden_rep,\n",
+ "\n",
+ " }\n",
+ "\n",
+ " return results\n",
+ "\n",
+ " results = {}\n",
+ " n_batches = len(df) // batch_size\n",
+ " if len(df)%batch_size !=0:\n",
+ " n_batches +=1\n",
+ " for i in tqdm.tqdm(range(len(df) // batch_size)):\n",
+ " cur_df = df.iloc[batch_size * i : batch_size * (i + 1)]\n",
+ " batch_results = _evaluate_attriburte_exraction_single_batch(cur_df)\n",
+ " for key, value in batch_results.items():\n",
+ " if key in results:\n",
+ " results[key] = np.concatenate((results[key], value))\n",
+ " else:\n",
+ " results[key] = value\n",
+ "\n",
+ " return results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "iWL0cllWap1K"
+ },
+ "outputs": [],
+ "source": [
+ "def run_experiment(fname_in, fdir_out, fname_out = \"multihop\", batch_size=512, n_samples=-1,\n",
+ " save_output=True, replace=False):\n",
+ " print(f\"Running experiment on {fname_in}...\")\n",
+ " if not replace and os.path.exists(f\"{fdir_out}/{fname_out}.pkl\"):\n",
+ " print(f\"File {fdir_out}/{fname_out}.pkl exists. Skipping generation. Reading file...\")\n",
+ " results_df = pd.read_pickle(f\"{fdir_out}/{fname_out}.pkl\")\n",
+ " return results_df\n",
+ " df = pd.read_pickle(f\"{fname_in}\")\n",
+ " print(f\"\\tNumber of samples: {len(df)}\")\n",
+ "\n",
+ " # BATCHED\n",
+ " batch = []\n",
+ " for layer_source in tqdm.tqdm(range(mt.num_layers)):\n",
+ " for layer_target in range(mt.num_layers):\n",
+ " for _, row in df.iterrows():\n",
+ " item = dict(row)\n",
+ " item.update({\n",
+ " \"layer_source\": layer_source,\n",
+ " \"layer_target\": layer_target,\n",
+ " })\n",
+ " batch.append(item)\n",
+ " experiment_df = pd.DataFrame.from_records(batch)\n",
+ "\n",
+ " if n_samples > 0 and n_samples cat\\n1135 -> 1135\\nhello -> hello\\n?\"\n",
+ "inp_target = make_inputs(mt.tokenizer, [prompt_target], device=mt.model.device)\n",
+ "\n",
+ "data = {}\n",
+ "for sentence, split in tqdm(sentences):\n",
+ " inp = make_inputs(mt.tokenizer, [sentence], device=mt.model.device)\n",
+ " if sos_tok:\n",
+ " start_pos = 1\n",
+ " else:\n",
+ " start_pos = 0\n",
+ " position = random.randint(start_pos, len(inp['input_ids'][0]) - 2)\n",
+ "\n",
+ " if (sentence, position, split, \"source\") not in data:\n",
+ " output = mt.model(**inp, output_hidden_states = True)\n",
+ " _, answer_t = torch.max(torch.softmax(output.logits[0, -1, :], dim=0), dim=0)\n",
+ " data[(sentence, position, split, \"source\")] = [\n",
+ " output[\"hidden_states\"][layer+1][0][position].detach().cpu().numpy()\n",
+ " for layer in range(mt.num_layers)\n",
+ " ]\n",
+ "\n",
+ " inp_target['input_ids'][0][-1] = answer_t\n",
+ " output = mt.model(**inp_target, output_hidden_states = True)\n",
+ " data[(sentence, position, split, \"target\")] = [\n",
+ " output[\"hidden_states\"][layer+1][0][-1].detach().cpu().numpy()\n",
+ " for layer in range(mt.num_layers)\n",
+ " ]\n",
+ "\n",
+ "df = pd.Series(data).reset_index()\n",
+ "df.columns = ['full_text', 'position', 'data_split', 'prompt', 'hidden_rep']\n",
+ "\n",
+ "df.to_pickle(model_name+\"_pile_trn_val.pkl\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "kEemdkGOGdLd"
+ },
+ "outputs": [],
+ "source": [
+ "# Pad and unpad \n",
+ "\n",
+ "pad = lambda x: np.hstack([x, np.ones((x.shape[0], 1))])\n",
+ "unpad = lambda x: x[:,:-1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "R_TIyUtoGdLd",
+ "outputId": "a9d67b1c-3ba9-4e8b-e3d1-e45b8fdbd6c9"
+ },
+ "outputs": [],
+ "source": [
+ "# Across layer mappings\n",
+ "\n",
+ "output_dir = f'{model_name}_mappings_pile'\n",
+ "if not os.path.exists(output_dir):\n",
+ " os.makedirs(output_dir)\n",
+ "\n",
+ "df_trn = pd.DataFrame(df[df['data_split'] == 'train']['hidden_rep'].to_list(),\n",
+ " columns=[layer for layer in range(mt.num_layers)])\n",
+ "\n",
+ "target_layer = mt.num_layers - 1\n",
+ "Y = np.array(\n",
+ " df_trn[target_layer].values.tolist()\n",
+ ")\n",
+ "\n",
+ "mappings = []\n",
+ "for layer in range(mt.num_layers):\n",
+ " X = np.array(\n",
+ " df_trn[layer].values.tolist()\n",
+ " )\n",
+ "\n",
+ " # Solve the least squares problem X * A = Y\n",
+ " # to find our transformation matrix A\n",
+ " A, res, rank, s = np.linalg.lstsq(pad(X), pad(Y))\n",
+ " transform = lambda x: unpad(pad(x) @ A)\n",
+ "\n",
+ " mappings.append(A)\n",
+ " with open(f'{output_dir}/mapping_{layer}-{target_layer}.npy', 'wb') as fd:\n",
+ " np.save(fd, A)\n",
+ "\n",
+ " print(layer, \"max error on train:\", np.abs(Y - transform(X)).max())\n",
+ "\n",
+ "shutil.make_archive(output_dir, 'zip', output_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BoEZwZLEGdLd",
+ "outputId": "afbbe9a9-596a-4e8d-8689-96b1a2758a7a"
+ },
+ "outputs": [],
+ "source": [
+ "# Prompt-id mappings\n",
+ "\n",
+ "output_dir = f'{model_name}_mappings_pile_prompt-id'\n",
+ "if not os.path.exists(output_dir):\n",
+ " os.makedirs(output_dir)\n",
+ "\n",
+ "df_trn_src = pd.DataFrame(df[(df['data_split'] == 'train') & (df['prompt'] == 'source')]['hidden_rep'].to_list(),\n",
+ " columns=[layer for layer in range(mt.num_layers)])\n",
+ "df_trn_tgt = pd.DataFrame(df[(df['data_split'] == 'train') & (df['prompt'] == 'target')]['hidden_rep'].to_list(),\n",
+ " columns=[layer for layer in range(mt.num_layers)])\n",
+ "\n",
+ "mappings = []\n",
+ "for layer in range(mt.num_layers):\n",
+ " X = np.array(\n",
+ " df_trn_src[layer].values.tolist()\n",
+ " )\n",
+ " Y = np.array(\n",
+ " df_trn_tgt[layer].values.tolist()\n",
+ " )\n",
+ "\n",
+ " # Solve the least squares problem X * A = Y\n",
+ " # to find our transformation matrix A\n",
+ " A, res, rank, s = np.linalg.lstsq(pad(X), pad(Y))\n",
+ " transform = lambda x: unpad(pad(x) @ A)\n",
+ "\n",
+ " mappings.append(A)\n",
+ " with open(f'{output_dir}/mapping_{layer}.npy', 'wb') as fd:\n",
+ " np.save(fd, A)\n",
+ "\n",
+ " print(layer, \"max error on train:\", np.abs(Y - transform(X)).max())\n",
+ "\n",
+ "shutil.make_archive(output_dir, 'zip', output_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7UhNsl2HGdLd"
+ },
+ "outputs": [],
+ "source": [
+ "mappings = []\n",
+ "for layer in tqdm(range(mt.num_layers)):\n",
+ " with open(f'{model_name}_mappings_pile/mapping_{layer}-{mt.num_layers-1}.npy', 'rb') as fd:\n",
+ " A = np.load(fd)\n",
+ " mappings.append(A)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DH8FA6WsGdLe"
+ },
+ "outputs": [],
+ "source": [
+ "# Evaluate linear mappings on the validation set of WikiText\n",
+ "device = mt.model.device\n",
+ "target_layer = mt.num_layers - 1\n",
+ "\n",
+ "records = []\n",
+ "for layer in tqdm(range(mt.num_layers)):\n",
+ " A = mappings[layer]\n",
+ " transform = lambda x: torch.tensor(\n",
+ " np.squeeze(\n",
+ " unpad(np.dot(\n",
+ " pad(np.expand_dims(x.detach().cpu().numpy(), 0)),\n",
+ " A\n",
+ " ))\n",
+ " )\n",
+ " ).to(device)\n",
+ "\n",
+ " for idx, row in df[df['data_split'] == 'validation'].iterrows():\n",
+ " prompt = row['full_text']\n",
+ " position = row['position']\n",
+ " prec_1, surprisal = evaluate_patch_next_token_prediction(\n",
+ " mt, prompt, prompt, layer, target_layer,\n",
+ " position, position, position_prediction=position, transform=transform)\n",
+ "\n",
+ " records.append({'layer': layer, 'prec_1': prec_1, 'surprisal': surprisal})\n",
+ "\n",
+ "\n",
+ "results = pd.DataFrame.from_records(records)\n",
+ "results.to_csv(f'{model_name}_mappings_pile_eval.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "dNVyKXLAGdLe",
+ "outputId": "9d8a9374-5bd5-4ad8-cadf-18831f3c8846"
+ },
+ "outputs": [],
+ "source": [
+ "# Evaluate identity mapping on the validation set of WikiText\n",
+ "\n",
+ "target_layer = mt.num_layers - 1\n",
+ "\n",
+ "records = []\n",
+ "for layer in tqdm(range(mt.num_layers)):\n",
+ " for idx, row in df[df['data_split'] == 'validation'].iterrows():\n",
+ " prompt = row['full_text']\n",
+ " position = row['position']\n",
+ " prec_1, surprisal = evaluate_patch_next_token_prediction(\n",
+ " mt, prompt, prompt, layer, target_layer,\n",
+ " position, position, position_prediction=position)\n",
+ "\n",
+ " records.append({'layer': layer, 'prec_1': prec_1, 'surprisal': surprisal})\n",
+ "\n",
+ "results = pd.DataFrame.from_records(records)\n",
+ "results.to_csv(f'{model_name}_identity_pile_eval.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yuKvnceYGdLe"
+ },
+ "outputs": [],
+ "source": [
+ "# Evaluate the ID prompt on the validation set of WikiText (with/without mappings)\n",
+ "\n",
+ "device = mt.model.device\n",
+ "\n",
+ "prompt_target = \"cat -> cat\\n1135 -> 1135\\nhello -> hello\\n?\"\n",
+ "position_target = -1\n",
+ "apply_mappings = True\n",
+ "\n",
+ "records = []\n",
+ "for layer in tqdm(range(mt.num_layers)):\n",
+ " if apply_mappings:\n",
+ " A = mappings[layer]\n",
+ " transform = lambda x: torch.tensor(\n",
+ " np.squeeze(\n",
+ " unpad(np.dot(\n",
+ " pad(np.expand_dims(x.detach().cpu().numpy(), 0)),\n",
+ " A\n",
+ " ))\n",
+ " )\n",
+ " ).to(device)\n",
+ " else:\n",
+ " transform = None\n",
+ "\n",
+ " for idx, row in df[df['data_split'] == 'validation'].iterrows():\n",
+ " if 'prompt' in row and row['prompt'] == 'target':\n",
+ " continue\n",
+ " prompt_source = row['full_text']\n",
+ " position_source = row['position']\n",
+ " prec_1, surprisal = evaluate_patch_next_token_prediction(\n",
+ " mt, prompt_source, prompt_target, layer, layer,\n",
+ " position_source, position_target, position_prediction=position_target, transform=transform)\n",
+ "\n",
+ " records.append({'layer': layer, 'prec_1': prec_1, 'surprisal': surprisal})\n",
+ "\n",
+ "results = pd.DataFrame.from_records(records)\n",
+ "if apply_mappings:\n",
+ " results.to_csv(f'{model_name}_prompt-id-mapping_pile_eval.csv')\n",
+ "else:\n",
+ " results.to_csv(f'{model_name}_prompt-id_pile_eval.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "u4v3LlWQGdLe",
+ "outputId": "e4bb6dee-10ba-4792-bf80-54299bcf66a5"
+ },
+ "outputs": [],
+ "source": [
+ "results1 = pd.read_csv(f'{model_name}_identity_pile_eval.csv')\n",
+ "results1[\"variant\"] = \"identity\"\n",
+ "results2 = pd.read_csv(f'{model_name}_mappings_pile_eval.csv')\n",
+ "results2[\"variant\"] = \"affine mapping\"\n",
+ "results3 = pd.read_csv(f'{model_name}_prompt-id_pile_eval.csv')\n",
+ "results3[\"variant\"] = \"prompt id\"\n",
+ "\n",
+ "results = pd.concat([results1, results2, results3], ignore_index=True)\n",
+ "\n",
+ "for metric in ['prec_1', 'surprisal']:\n",
+ " ax = sns.lineplot(data=results, x='layer', y=metric, hue=\"variant\")\n",
+ " ax.set_title(model_name.strip('./'))\n",
+ " ax.legend_.set_title('')\n",
+ " plt.show()\n",
+ " plt.clf()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.12"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/patchscopes/code/patch_cross_model.ipynb b/patchscopes/code/patch_cross_model.ipynb
new file mode 100644
index 00000000..f3ffac93
--- /dev/null
+++ b/patchscopes/code/patch_cross_model.ipynb
@@ -0,0 +1,498 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GMJYfysaREkb"
+ },
+ "source": [
+ "# Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "mdEmY4rDQ3ik",
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "from ast import literal_eval\n",
+ "import functools\n",
+ "import json\n",
+ "import os\n",
+ "import random\n",
+ "import shutil\n",
+ "\n",
+ "# Scienfitic packages\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "import datasets\n",
+ "from torch import cuda\n",
+ "torch.set_grad_enabled(False)\n",
+ "\n",
+ "# Visuals\n",
+ "from matplotlib import pyplot as plt\n",
+ "import seaborn as sns\n",
+ "sns.set(context=\"notebook\",\n",
+ " rc={\"font.size\":16,\n",
+ " \"axes.titlesize\":16,\n",
+ " \"axes.labelsize\":16,\n",
+ " \"xtick.labelsize\": 16.0,\n",
+ " \"ytick.labelsize\": 16.0,\n",
+ " \"legend.fontsize\": 16.0})\n",
+ "palette_ = sns.color_palette(\"Set1\")\n",
+ "palette = palette_[2:5] + palette_[7:]\n",
+ "sns.set_theme(style='whitegrid')\n",
+ "\n",
+ "# Utilities\n",
+ "\n",
+ "from general_utils import (\n",
+ " ModelAndTokenizer,\n",
+ " make_inputs,\n",
+ " decode_tokens,\n",
+ " find_token_range,\n",
+ " predict_from_input,\n",
+ ")\n",
+ "\n",
+ "from patchscopes_utils import *\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "tqdm.pandas()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_to_hook = {\n",
+ " \"EleutherAI/pythia-6.9b\": set_hs_patch_hooks_neox,\n",
+ " \"EleutherAI/pythia-12b\": set_hs_patch_hooks_neox,\n",
+ " \"meta-llama/Llama-2-13b-hf\": set_hs_patch_hooks_llama,\n",
+ " \"lmsys/vicuna-7b-v1.5\": set_hs_patch_hooks_llama,\n",
+ " \"./stable-vicuna-13b\": set_hs_patch_hooks_llama,\n",
+ " \"CarperAI/stable-vicuna-13b-delta\": set_hs_patch_hooks_llama,\n",
+ " \"EleutherAI/gpt-j-6b\": set_hs_patch_hooks_gptj\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "referenced_widgets": [
+ "4479f16b9a544b79bb8790693701d8de"
+ ]
+ },
+ "id": "fKGGJO3GQ3in",
+ "outputId": "aed82adb-d542-4de6-ade7-c2a4f7aadcc6"
+ },
+ "outputs": [],
+ "source": [
+ "# Load model 1\n",
+ "\n",
+ "model_name_1 = \"lmsys/vicuna-7b-v1.5\"\n",
+ "sos_tok_1 = False\n",
+ "\n",
+ "if \"13b\" in model_name_1 or \"12b\" in model_name_1:\n",
+ " torch_dtype = torch.float16\n",
+ "else:\n",
+ " torch_dtype = None\n",
+ "\n",
+ "mt_1 = ModelAndTokenizer(\n",
+ " model_name_1,\n",
+ " low_cpu_mem_usage=False,\n",
+ " torch_dtype=torch_dtype,\n",
+ " device=\"cuda:1\"\n",
+ ")\n",
+ "mt_1.set_hs_patch_hooks = model_to_hook[model_name_1]\n",
+ "mt_1.model.eval()\n",
+ "mt_1.model.to(mt_1.device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load model 2\n",
+ "\n",
+ "model_name_2 = \"./stable-vicuna-13b\"\n",
+ "model_name_2_ = model_name_2.strip('./')\n",
+ "sos_tok_2 = False\n",
+ "\n",
+ "if \"13b\" in model_name_2 or \"12b\" in model_name_2:\n",
+ " torch_dtype = torch.float16\n",
+ "else:\n",
+ " torch_dtype = None\n",
+ "\n",
+ "mt_2 = ModelAndTokenizer(\n",
+ " model_name_2,\n",
+ " low_cpu_mem_usage=False,\n",
+ " torch_dtype=torch_dtype,\n",
+ " device=\"cuda:0\"\n",
+ ")\n",
+ "mt_2.set_hs_patch_hooks = model_to_hook[model_name_2]\n",
+ "mt_2.model.eval()\n",
+ "mt_2.model.to(mt_2.device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Next token prediction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pile_dataset = datasets.load_from_disk('./the_pile_deduplicated')\n",
+ "pile_dataset = pile_dataset.shuffle(seed=42)\n",
+ "print(len(pile_dataset))\n",
+ "\n",
+ "trn_n = 100000\n",
+ "val_n = 2000\n",
+ "pile_trn = pile_dataset['text'][:trn_n]\n",
+ "pile_val = pile_dataset['text'][trn_n:trn_n+val_n]\n",
+ "sentences = [(x, 'train') for x in pile_trn] + [(x, 'validation') for x in pile_val]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "max_len = 256\n",
+ "\n",
+ "data = {}\n",
+ "for sentence, split in tqdm(sentences):\n",
+ " \n",
+ " inp_1_ = make_inputs(mt_1.tokenizer, [sentence], device=mt_1.device)\n",
+ " inp_2_ = make_inputs(mt_2.tokenizer, [sentence], device=mt_2.device)\n",
+ " position = None\n",
+ " k = 0\n",
+ " while k<10:\n",
+ " position_tmp = random.randint(\n",
+ " 0, min(max_len - 1, \n",
+ " len(inp_1_['input_ids'][0]) - 1, \n",
+ " len(inp_2_['input_ids'][0]) - 1)\n",
+ " )\n",
+ " # cut the tokenized input at the sampled position and turn it back into a string.\n",
+ " # add some buffer at the end such that the tokenization is not modified around the sampled position.\n",
+ " prefix_1 = mt_1.tokenizer.decode(inp_1_['input_ids'][0][:position_tmp + int(sos_tok_1) + 5])\n",
+ " prefix_2 = mt_2.tokenizer.decode(inp_2_['input_ids'][0][:position_tmp + int(sos_tok_2) + 5])\n",
+ " \n",
+ " # check that the selected position corresponds to the same part of the string by \n",
+ " # comparing the prefixes until the sampled position. also make sure that this re-tokenization\n",
+ " # does not shift the sampled position off the sequence length.\n",
+ " inp_1 = make_inputs(mt_1.tokenizer, [prefix_1], device=mt_1.device)\n",
+ " inp_2 = make_inputs(mt_2.tokenizer, [prefix_2], device=mt_2.device)\n",
+ " if prefix_1 == prefix_2 and position_tmp < min(len(inp_1['input_ids'][0]), \n",
+ " len(inp_2['input_ids'][0])):\n",
+ " position = position_tmp\n",
+ " break\n",
+ " k += 1\n",
+ " if position is None:\n",
+ " continue\n",
+ " \n",
+ " for mt, model_name, inp, sos_tok in zip(\n",
+ " [mt_1, mt_2],\n",
+ " [model_name_1, model_name_2],\n",
+ " [inp_1, inp_2],\n",
+ " [sos_tok_1, sos_tok_2]\n",
+ " ):\n",
+ " position_ = position + int(sos_tok)\n",
+ " if (prefix_1, position_, split, model_name) not in data:\n",
+ " output = mt.model(**inp, output_hidden_states = True)\n",
+ "\n",
+ " data[(prefix_1, position_, split, model_name)] = [\n",
+ " output[\"hidden_states\"][layer+1][0][position_].detach().cpu().numpy()\n",
+ " for layer in range(mt.num_layers)\n",
+ " ]\n",
+ "\n",
+ "df = pd.Series(data).reset_index()\n",
+ "df.columns = ['full_text', 'position', 'data_split', 'model_name', 'hidden_rep'] \n",
+ "\n",
+ "for model_name in [model_name_1, model_name_2]:\n",
+ " df[df['model_name'] == model_name].to_pickle(f\"{model_name}_pile_trn_val.pkl\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Pad and unpad \n",
+ "\n",
+ "pad = lambda x: np.hstack([x, np.ones((x.shape[0], 1))])\n",
+ "unpad = lambda x: x[:,:-1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "layer_sources = [l for l in range(0, mt_1.num_layers, 5)]\n",
+ "layer_targets = [l for l in range(0, mt_2.num_layers, 5)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "output_dir = f'{model_name_1}_{model_name_2_}_mappings_pile'\n",
+ "if not os.path.exists(output_dir):\n",
+ " os.makedirs(output_dir)\n",
+ " \n",
+ "df_trn_1 = pd.DataFrame(df[(df['data_split'] == 'train') & \n",
+ " (df['model_name'] == model_name_1)]['hidden_rep'].to_list(), \n",
+ " columns=[layer for layer in range(mt_1.num_layers)])\n",
+ "df_trn_2 = pd.DataFrame(df[(df['data_split'] == 'train') & \n",
+ " (df['model_name'] == model_name_2)]['hidden_rep'].to_list(), \n",
+ " columns=[layer for layer in range(mt_2.num_layers)])\n",
+ "\n",
+ "layer_sources = [l for l in range(0, mt_1.num_layers, 5)]\n",
+ "layer_targets = [l for l in range(0, mt_2.num_layers, 5)]\n",
+ "\n",
+ "mappings = {}\n",
+ "for layer_source in tqdm(layer_sources):\n",
+ " for layer_target in layer_targets:\n",
+ " X = np.array(\n",
+ " df_trn_1[layer_source].values.tolist()\n",
+ " )\n",
+ " Y = np.array(\n",
+ " df_trn_2[layer_target].values.tolist()\n",
+ " )\n",
+ "\n",
+ " # Solve the least squares problem X * A = Y\n",
+ " # to find our transformation matrix A\n",
+ " A, res, rank, s = np.linalg.lstsq(pad(X), pad(Y))\n",
+ " transform = lambda x: unpad(pad(x) @ A)\n",
+ "\n",
+ " mappings[(layer_source, layer_target)] = A\n",
+ " with open(f'{model_name_1}_{model_name_2_}_mappings_pile/mapping_{layer_source}-{layer_target}.npy', 'wb') as fd:\n",
+ " np.save(fd, A)\n",
+ "\n",
+ " print(layer_source, layer_target, \"max error on train:\", np.abs(Y - transform(X)).max())\n",
+ "\n",
+ "shutil.make_archive(output_dir, 'zip', output_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mappings = {}\n",
+ "for layer_source in tqdm(layer_sources):\n",
+ " for layer_target in layer_targets:\n",
+ " with open(f'{model_name_1}_{model_name_2_}_mappings_pile/mapping_{layer_source}-{layer_target}.npy', 'rb') as fd:\n",
+ " A = np.load(fd)\n",
+ " mappings[(layer_source, layer_target)] = A"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Re-organize validation set\n",
+ "\n",
+ "df_val = df[(df['data_split'] == 'validation')].groupby(['full_text', 'data_split']).agg(pd.Series.tolist).reset_index()\n",
+ "cols = ['position', 'model_name', 'hidden_rep']\n",
+ "for col in cols:\n",
+ " df_val[[f'{col}_1', f'{col}_2']] = df_val[col].to_list()\n",
+ "\n",
+ "df_val = df_val[[col for col in df_val.columns if col not in cols]]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Evaluate linear mappings on the validation set of WikiText/a sample from the Pile\n",
+ "\n",
+ "records = []\n",
+ "for layer_source in tqdm(layer_sources):\n",
+ " for layer_target in tqdm(layer_targets):\n",
+ " A = mappings[(layer_source, layer_target)]\n",
+ " transform = lambda x: torch.tensor(\n",
+ " np.squeeze(\n",
+ " unpad(np.dot(\n",
+ " pad(np.expand_dims(x.detach().cpu().numpy(), 0)), \n",
+ " A\n",
+ " ))\n",
+ " )\n",
+ " ).to(mt_2.device)\n",
+ "\n",
+ " for idx, row in df_val.iterrows():\n",
+ " prompt = row['full_text']\n",
+ " position_source = row['position_1']\n",
+ " position_target = row['position_2']\n",
+ " prec_1, surprisal = evaluate_patch_next_token_prediction_x_model(\n",
+ " mt_1, mt_2, prompt, prompt, layer_source, layer_target,\n",
+ " position_source, position_target, position_prediction=position_target, transform=transform)\n",
+ "\n",
+ " records.append({'layer_source': layer_source,\n",
+ " 'layer_target': layer_target,\n",
+ " 'position_source': position_source,\n",
+ " 'position_target': position_target,\n",
+ " 'prec_1': prec_1, \n",
+ " 'surprisal': surprisal})\n",
+ " \n",
+ "\n",
+ "results = pd.DataFrame.from_records(records)\n",
+ "results.to_csv(f'{model_name_1}_{model_name_2_}_mappings_pile_eval.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot the resulted heatmap\n",
+ "metric = 'prec_1'\n",
+ "tmp = results[['layer_source', 'layer_target', metric]].groupby(['layer_source', 'layer_target']).agg(\"mean\").reset_index()\n",
+ "tmp = tmp.pivot(index='layer_source', columns='layer_target', values=metric)\n",
+ "\n",
+ "sns.heatmap(tmp, annot=True, fmt=\".1f\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Evaluate identity mapping on the validation set of WikiText\n",
+ "\n",
+ "records = []\n",
+ "for layer_source in tqdm(layer_sources):\n",
+ " for layer_target in tqdm(layer_targets):\n",
+ " for idx, row in df_val.iterrows():\n",
+ " prompt = row['full_text']\n",
+ " position_source = row['position_1']\n",
+ " position_target = row['position_2']\n",
+ " prec_1, surprisal = evaluate_patch_next_token_prediction_x_model(\n",
+ " mt_1, mt_2, prompt, prompt, layer_source, layer_target,\n",
+ " position_source, position_target, position_prediction=position_target)\n",
+ "\n",
+ " records.append({'layer_source': layer_source,\n",
+ " 'layer_target': layer_target,\n",
+ " 'position_source': position_source,\n",
+ " 'position_target': position_target,\n",
+ " 'prec_1': prec_1, \n",
+ " 'surprisal': surprisal})\n",
+ " \n",
+ "results = pd.DataFrame.from_records(records)\n",
+ "results.to_csv(f'{model_name_1}_{model_name_2_}_identity_pile_eval.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Evaluate the ID prompt on the validation set of WikiText\n",
+ "\n",
+ "prompt_target = \"cat -> cat\\n1135 -> 1135\\nhello -> hello\\n?\"\n",
+ "position_target = -1\n",
+ "\n",
+ "records = []\n",
+ "for layer_source in tqdm(layer_sources):\n",
+ " for layer_target in tqdm(layer_targets):\n",
+ " for idx, row in df_val.iterrows():\n",
+ " prompt_source = row['full_text']\n",
+ " position_source = row['position_1']\n",
+ " prec_1, surprisal = evaluate_patch_next_token_prediction_x_model(\n",
+ " mt_1, mt_2, prompt_source, prompt_target, layer_source, layer_target,\n",
+ " position_source, position_target, position_prediction=position_target, transform=None)\n",
+ "\n",
+ " records.append({'layer_source': layer_source,\n",
+ " 'layer_target': layer_target,\n",
+ " 'position_source': position_source,\n",
+ " 'position_target': position_target,\n",
+ " 'prec_1': prec_1, \n",
+ " 'surprisal': surprisal})\n",
+ " \n",
+ "results = pd.DataFrame.from_records(records)\n",
+ "results.to_csv(f'{model_name_1}_{model_name_2_}_prompt-id_pile_eval.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "results1 = pd.read_csv(f'{model_name_1}_{model_name_2_}_identity_pile_eval.csv')\n",
+ "results1[\"variant\"] = \"identity\"\n",
+ "results2 = pd.read_csv(f'{model_name_1}_{model_name_2_}_mappings_pile_eval.csv')\n",
+ "results2[\"variant\"] = \"affine mapping\"\n",
+ "results3 = pd.read_csv(f'{model_name_1}_{model_name_2_}_prompt-id_pile_eval.csv')\n",
+ "results3[\"variant\"] = \"prompt id\"\n",
+ "\n",
+ "results = pd.concat([results1, results2, results3], ignore_index=True)\n",
+ "\n",
+ "for metric in ['prec_1', 'surprisal']:\n",
+ " ax = sns.lineplot(data=results, x='layer', y=metric, hue=\"variant\")\n",
+ " ax.set_title(f\"{model_name_1.strip('./')} --> {model_name_2_}\")\n",
+ " ax.legend_.set_title('')\n",
+ " plt.show()\n",
+ " plt.clf()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.12"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/patchscopes/code/patchscopes_utils.py b/patchscopes/code/patchscopes_utils.py
new file mode 100644
index 00000000..1973fad7
--- /dev/null
+++ b/patchscopes/code/patchscopes_utils.py
@@ -0,0 +1,1157 @@
+# coding=utf-8
+# Copyright 2024 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+import tqdm
+from general_utils import decode_tokens
+from general_utils import make_inputs
+
+
+# ##############
+#
+# Hooks
+#
+# ##############
+
+
+def set_hs_patch_hooks_neox(
+ model,
+ hs_patch_config,
+ module="hs", # mlp, attn
+ patch_input=False,
+ skip_final_ln=False,
+ generation_mode=False,
+):
+ """Neox patch hooks."""
+ # when using mode.generate() the hidden states in the input are cached after
+ # the first inference pass, and in the next steps the input/output are of
+ # size 1. In these cases we don't need to patch anymore the previous hidden
+ # states from the initial input, because they are cached, but we do need to
+ # handle these cases in this call because this hook wraps the generation call.
+ #
+ # NOTE: To use generation mode, we must patch a position that is not the
+ # first one. This is because in this case we don't know during generation if
+ # we are handling the initial input or a future step and thus don't know if
+ # a patching is needed or not.
+
+ # if generation_mode:
+ # for i in hs_patch_config:
+ # for position_, _ in hs_patch_config[i]:
+ # assert position_ > 0
+
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ def patch_hs(name, position_hs, patch_input, generation_mode):
+ def pre_hook(module, input):
+ # (batch, sequence, hidden_state)
+ input_len = len(input[0][0])
+ if generation_mode and input_len == 1:
+ return
+ for position_, hs_ in position_hs:
+ input[0][0, position_] = hs_
+
+ def post_hook(module, input, output):
+ if "skip_ln" in name:
+ # output: (batch, sequence, hidden_state)
+ output_len = len(output[0])
+ else:
+ # output[0]: (batch, sequence, hidden_state)
+ output_len = len(output[0][0])
+
+ if generation_mode and output_len == 1:
+ return
+ for position_, hs_ in position_hs:
+ if "skip_ln" in name:
+ output[0][position_] = hs_
+ else:
+ output[0][0, position_] = hs_
+
+ if patch_input:
+ return pre_hook
+ else:
+ return post_hook
+
+ hooks = []
+ for i in hs_patch_config:
+ if patch_input:
+ hooks.append(
+ model.gpt_neox.layers[i].register_forward_pre_hook(
+ patch_hs(
+ f"patch_hs_{i}",
+ hs_patch_config[i],
+ patch_input,
+ generation_mode,
+ )
+ )
+ )
+ else:
+ # when patching a last-layer representation to the last layer of the
+ # same model, the final layer norm is not needed because it was already
+ # applied (assuming that the representation for patching was obtained by
+ # setting output_hidden_representations to True).
+ if skip_final_ln and i == len(model.gpt_neox.layers) - 1:
+ hooks.append(
+ model.gpt_neox.final_layer_norm.register_forward_hook(
+ patch_hs(
+ f"patch_hs_{i}_skip_ln",
+ hs_patch_config[i],
+ patch_input,
+ generation_mode,
+ )
+ )
+ )
+ else:
+ hooks.append(
+ model.gpt_neox.layers[i].register_forward_hook(
+ patch_hs(
+ f"patch_hs_{i}",
+ hs_patch_config[i],
+ patch_input,
+ generation_mode,
+ )
+ )
+ )
+
+ return hooks
+
+
+def set_hs_patch_hooks_llama(
+ model,
+ hs_patch_config,
+ module="hs", # mlp, attn
+ patch_input=False,
+ skip_final_ln=False,
+ generation_mode=False,
+):
+ """Llama patch hooks."""
+ # when using mode.generate() the hidden states in the input are cached after
+ # the first inference pass, and in the next steps the input/output are of
+ # size 1. In these cases we don't need to patch anymore the previous hidden
+ # states from the initial input, because they are cached, but we do need to
+ # handle these cases in this call because this hook wraps the generation call.
+ #
+ # NOTE: To use generation mode, we must patch a position that is not the
+ # first one. This is because in this case we don't know during generation if
+ # we are handling the initial input or a future step and thus don't know if
+ # a patching is needed or not.
+
+ # if generation_mode:
+ # for i in hs_patch_config:
+ # for position_, _ in hs_patch_config[i]:
+ # assert position_ > 0
+
+ def patch_hs(name, position_hs, patch_input, generation_mode):
+ def pre_hook(module, input):
+ # (batch, sequence, hidden_state)
+ input_len = len(input[0][0])
+ if generation_mode and input_len == 1:
+ return
+ for position_, hs_ in position_hs:
+ input[0][0, position_] = hs_
+
+ def post_hook(module, input, output):
+ if "skip_ln" in name or "mlp" in name:
+ # output: (batch, sequence, hidden_state)
+ output_len = len(output[0])
+ else:
+ # output[0]: (batch, sequence, hidden_state)
+ output_len = len(output[0][0])
+
+ if generation_mode and output_len == 1:
+ return
+ for position_, hs_ in position_hs:
+ if "skip_ln" in name or "mlp" in name:
+ output[0][position_] = hs_
+ else:
+ output[0][0, position_] = hs_
+
+ if patch_input:
+ return pre_hook
+ else:
+ return post_hook
+
+ hooks = []
+ for i in hs_patch_config:
+ patch_hook = patch_hs(
+ f"patch_{module}_{i}",
+ position_hs=hs_patch_config[i],
+ patch_input=patch_input,
+ generation_mode=generation_mode,
+ )
+ if patch_input:
+ if module == "hs":
+ hooks.append(
+ model.model.layers[i].register_forward_pre_hook(patch_hook)
+ )
+ elif module == "mlp":
+ hooks.append(
+ model.model.layers[i].mlp.register_forward_pre_hook(patch_hook)
+ )
+ elif module == "attn":
+ hooks.append(
+ model.model.layers[i].self_attn.register_forward_pre_hook(
+ patch_hook
+ )
+ )
+ else:
+ raise ValueError("Module %s not supported", module)
+ else:
+ # when patching a last-layer representation to the last layer of the same
+ # model, the final layer norm is not needed because it was already applied
+ # (assuming that the representation for patching was obtained by
+ # setting output_hidden_representations to True).
+ if skip_final_ln and i == len(model.model.layers) - 1 and module == "hs":
+ hooks.append(
+ model.model.norm.register_forward_hook(
+ patch_hs(
+ f"patch_hs_{i}_skip_ln",
+ hs_patch_config[i],
+ patch_input,
+ generation_mode,
+ )
+ )
+ )
+ else:
+ if module == "hs":
+ hooks.append(model.model.layers[i].register_forward_hook(patch_hook))
+ elif module == "mlp":
+ hooks.append(
+ model.model.layers[i].mlp.register_forward_hook(patch_hook)
+ )
+ elif module == "attn":
+ hooks.append(
+ model.model.layers[i].self_attn.register_forward_hook(patch_hook)
+ )
+ else:
+ raise ValueError("Module %s not supported", module)
+
+ return hooks
+
+
+def set_hs_patch_hooks_gptj(
+ model,
+ hs_patch_config,
+ module="hs", # mlp, attn
+ patch_input=False,
+ skip_final_ln=False,
+ generation_mode=False,
+):
+ """GPTJ patch hooks."""
+ # when using mode.generate() the hidden states in the input are cached after
+ # the first inference pass, and in the next steps the input/output are of
+ # size 1. In these cases we don't need to patch anymore the previous hidden
+ # states from the initial input, because they are cached, but we do need
+ # to handle these cases in this call because this hook wraps the generation
+ # call.
+ #
+ # NOTE: To use generation mode, we must patch a position that is not the
+ # first one. This is because in this case we don't know during generation
+ # if we are handling the initial input or a future step and thus don't know
+ # if a patching is needed or not.
+
+ # if generation_mode:
+ # for i in hs_patch_config:
+ # for position_, _ in hs_patch_config[i]:
+ # assert position_ > 0
+
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ def patch_hs(name, position_hs, patch_input, generation_mode):
+ def pre_hook(module, input):
+ # (batch, sequence, hidden_state)
+ input_len = len(input[0][0])
+ if generation_mode and input_len == 1:
+ return
+ for position_, hs_ in position_hs:
+ input[0][0, position_] = hs_
+
+ def post_hook(module, input, output):
+ if "skip_ln" in name:
+ # output: (batch, sequence, hidden_state)
+ output_len = len(output[0])
+ else:
+ # output[0]: (batch, sequence, hidden_state)
+ output_len = len(output[0][0])
+
+ if generation_mode and output_len == 1:
+ return
+ for position_, hs_ in position_hs:
+ if "skip_ln" in name:
+ output[0][position_] = hs_
+ else:
+ output[0][0, position_] = hs_
+
+ if patch_input:
+ return pre_hook
+ else:
+ return post_hook
+
+ hooks = []
+ for i in hs_patch_config:
+ if patch_input:
+ hooks.append(
+ model.transformer.h[i].register_forward_pre_hook(
+ patch_hs(
+ f"patch_hs_{i}",
+ hs_patch_config[i],
+ patch_input,
+ generation_mode,
+ )
+ )
+ )
+ else:
+ # when patching a last-layer representation to the last layer of the same
+ # model, the final layer norm is not needed because it was already applied
+ # (assuming that the representation for patching was obtained by
+ # setting output_hidden_representations to True).
+ if skip_final_ln and i == len(model.transformer.h) - 1:
+ hooks.append(
+ model.transformer.ln_f.register_forward_hook(
+ patch_hs(
+ f"patch_hs_{i}_skip_ln",
+ hs_patch_config[i],
+ patch_input,
+ generation_mode,
+ )
+ )
+ )
+ else:
+ hooks.append(
+ model.transformer.h[i].register_forward_hook(
+ patch_hs(
+ f"patch_hs_{i}",
+ hs_patch_config[i],
+ patch_input,
+ generation_mode,
+ )
+ )
+ )
+
+ return hooks
+
+
+def remove_hooks(hooks):
+ for hook in hooks:
+ hook.remove()
+
+
+# ##############
+#
+# Inspection
+#
+# ##############
+
+
+def inspect(
+ mt,
+ prompt_source,
+ prompt_target,
+ layer_source,
+ layer_target,
+ position_source,
+ position_target,
+ module="hs",
+ generation_mode=False,
+ max_gen_len=20,
+ verbose=False,
+ temperature=None,
+):
+ """Inspection via patching."""
+ # adjust position_target to be absolute rather than relative
+ inp_target = make_inputs(mt.tokenizer, [prompt_target], mt.device)
+ if position_target < 0:
+ position_target = len(inp_target["input_ids"][0]) + position_target
+
+ # first run the the model on prompt_patch and get all hidden states.
+ inp_source = make_inputs(mt.tokenizer, [prompt_source], mt.device)
+ if verbose:
+ print(
+ "prompt_patch:",
+ [mt.tokenizer.decode(x) for x in inp_source["input_ids"][0]],
+ )
+
+ hs_cache_ = []
+ # We manually store intermediate states that the model API does not expose
+ store_hooks = []
+ if module == "mlp":
+
+ def store_mlp_hook(module, input, output):
+ hs_cache_.append(output[0])
+
+ for layer in mt.model.model.layers:
+ store_hooks.append(layer.mlp.register_forward_hook(store_mlp_hook))
+ elif module == "attn":
+
+ def store_attn_hook(module, input, output):
+ hs_cache_.append(output[0].squeeze())
+
+ for layer in mt.model.model.layers:
+ store_hooks.append(layer.self_attn.register_forward_hook(store_attn_hook))
+
+ output = mt.model(**inp_source, output_hidden_states=True)
+ if module == "hs":
+ hs_cache_ = [
+ output["hidden_states"][layer + 1][0] for layer in range(mt.num_layers)
+ ]
+
+ remove_hooks(store_hooks)
+ # now do a second run on prompt, while patching
+ # a specific hidden state from the first run.
+ hs_patch_config = {
+ layer_target: [(
+ position_target,
+ hs_cache_[layer_source][position_source],
+ )]
+ }
+
+ if layer_source == layer_target == mt.num_layers - 1:
+ skip_final_ln = True
+ else:
+ skip_final_ln = False
+ patch_hooks = mt.set_hs_patch_hooks(
+ mt.model,
+ hs_patch_config,
+ module=module,
+ patch_input=False,
+ skip_final_ln=skip_final_ln,
+ generation_mode=True,
+ )
+
+ # Single prediction / generation
+ if verbose:
+ print(
+ "prompt:", [mt.tokenizer.decode(x) for x in inp_source["input_ids"][0]]
+ )
+ print(
+ f"patching position {position_target} with the hidden state from layer"
+ f" {layer_source} at position {position_source}."
+ )
+ if generation_mode:
+ # Checking if should perform temperature sampling, to allow smoother
+ # non-repeating long outputs.
+ if temperature:
+ output_toks = mt.model.generate(
+ inp_target["input_ids"],
+ max_length=len(inp_target["input_ids"][0]) + max_gen_len,
+ pad_token_id=mt.model.generation_config.eos_token_id,
+ temperature=temperature,
+ do_sample=True,
+ top_k=0,
+ )[0][len(inp_target["input_ids"][0]) :]
+ else:
+ output_toks = mt.model.generate(
+ inp_target["input_ids"],
+ max_length=len(inp_target["input_ids"][0]) + max_gen_len,
+ pad_token_id=mt.model.generation_config.eos_token_id,
+ )[0][len(inp_target["input_ids"][0]) :]
+
+ output = mt.tokenizer.decode(output_toks)
+ if verbose:
+ print(
+ "generation with patching: ",
+ [mt.tokenizer.decode(x) for x in output_toks],
+ )
+ else:
+ output = mt.model(**inp_target)
+ answer_prob, answer_t = torch.max(
+ torch.softmax(output.logits[0, -1, :], dim=0), dim=0
+ )
+ output = decode_tokens(mt.tokenizer, [answer_t])[0], round(
+ answer_prob.cpu().item(), 4
+ )
+ if verbose:
+ print("prediction with patching: ", output)
+
+ # remove patching hooks
+ remove_hooks(patch_hooks)
+
+ return output
+
+
+def evaluate_patch_next_token_prediction(
+ mt,
+ prompt_source,
+ prompt_target,
+ layer_source,
+ layer_target,
+ position_source,
+ position_target,
+ module="hs",
+ position_prediction=-1,
+ transform=None,
+):
+ """Evaluate next token prediction."""
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ # adjust position_target to be absolute rather than relative
+ inp_target = make_inputs(mt.tokenizer, [prompt_target], mt.device)
+ if position_target < 0:
+ position_target = len(inp_target["input_ids"][0]) + position_target
+
+ # first run the the model on without patching and get the results.
+ inp_source = make_inputs(mt.tokenizer, [prompt_source], mt.device)
+ output_orig = mt.model(**inp_source, output_hidden_states=True)
+ dist_orig = torch.softmax(output_orig.logits[0, position_source, :], dim=0)
+ _, answer_t_orig = torch.max(dist_orig, dim=0)
+ hidden_rep = output_orig["hidden_states"][layer_source + 1][0][
+ position_source
+ ]
+ if transform is not None:
+ hidden_rep = transform(hidden_rep)
+
+ # now do a second run on prompt, while patching the input hidden state.
+ hs_patch_config = {layer_target: [(position_target, hidden_rep)]}
+ if layer_source == layer_target == mt.num_layers - 1:
+ skip_final_ln = True
+ else:
+ skip_final_ln = False
+ patch_hooks = mt.set_hs_patch_hooks(
+ mt.model,
+ hs_patch_config,
+ module=module,
+ patch_input=False,
+ skip_final_ln=skip_final_ln,
+ generation_mode=True,
+ )
+ output = mt.model(**inp_target)
+ dist = torch.softmax(output.logits[0, position_prediction, :], dim=0)
+ _, answer_t = torch.max(dist, dim=0)
+
+ # remove patching hooks
+ remove_hooks(patch_hooks)
+
+ prec_1 = (answer_t == answer_t_orig).detach().cpu().item()
+ surprisal = -torch.log(dist_orig[answer_t]).detach().cpu().numpy()
+
+ return prec_1, surprisal
+
+
+def evaluate_patch_next_token_prediction_x_model(
+ mt_1,
+ mt_2,
+ prompt_source,
+ prompt_target,
+ layer_source,
+ layer_target,
+ position_source,
+ position_target,
+ module="hs",
+ position_prediction=-1,
+ transform=None,
+):
+ """evaluate next token prediction across models."""
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ # adjust position_target to be absolute rather than relative
+ inp_target = make_inputs(mt_2.tokenizer, [prompt_target], device=mt_2.device)
+ if position_target < 0:
+ position_target = len(inp_target["input_ids"][0]) + position_target
+
+ # first run the the model on without patching and get the results.
+ inp_source = make_inputs(mt_1.tokenizer, [prompt_source], device=mt_1.device)
+ output_orig = mt_1.model(**inp_source, output_hidden_states=True)
+ dist_orig = torch.softmax(output_orig.logits[0, position_source, :], dim=0)
+ _, answer_t_orig = torch.max(dist_orig, dim=0)
+ hidden_rep = output_orig["hidden_states"][layer_source + 1][0][
+ position_source
+ ]
+ if transform is not None:
+ hidden_rep = transform(hidden_rep)
+
+ # now do a second run on prompt, while patching the input hidden state.
+ hs_patch_config = {layer_target: [(position_target, hidden_rep)]}
+ skip_final_ln = False
+ patch_hooks = mt_2.set_hs_patch_hooks(
+ mt_2.model,
+ hs_patch_config,
+ module=module,
+ patch_input=False,
+ skip_final_ln=skip_final_ln,
+ generation_mode=True,
+ )
+ output = mt_2.model(**inp_target)
+ dist = torch.softmax(output.logits[0, position_prediction, :], dim=0)
+ _, answer_t = torch.max(dist, dim=0)
+
+ # remove patching hooks
+ remove_hooks(patch_hooks)
+
+ prec_1 = answer_t.detach().cpu().item() == answer_t_orig.detach().cpu().item()
+ surprisal = -torch.log(dist_orig[answer_t]).detach().cpu().numpy()
+
+ return prec_1, surprisal
+
+
+# Adding support for batched patching. More than 10x speedup
+# Currently only supporting GPT-J
+def set_hs_patch_hooks_gptj_batch(
+ model,
+ hs_patch_config,
+ module="hs",
+ patch_input=False,
+ generation_mode=False,
+):
+ """GPTJ patch hooks - supporting batch."""
+ # when using mode.generate() the hidden states in the input are cached after
+ # the first inference pass, and in the next steps the input/output are of
+ # size 1. In these cases we don't need to patch anymore the previous hidden
+ # states from the initial input, because they are cached, but we do need to
+ # handle these cases in this call because this hook wraps the generation call.
+ #
+ # NOTE: To use generation mode, we must patch a position that is not the
+ # first one. This is because in this case we don't know during generation if
+ # we are handling the initial input or a future step and thus don't know if
+ # a patching is needed or not.
+
+ # if generation_mode:
+ # for i in hs_patch_config:
+ # for position_, _ in hs_patch_config[i]:
+ # assert position_ > 0
+
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ def patch_hs(name, position_hs, patch_input, generation_mode):
+ def pre_hook(module, inp):
+ # (batch, sequence, hidden_state)
+ idx_, position_, hs_ = (
+ position_hs["batch_idx"],
+ position_hs["position_target"],
+ position_hs["hidden_rep"],
+ )
+ input_len = len(inp[0][idx_])
+ if generation_mode and input_len == 1:
+ return
+ inp[0][idx_][position_] = hs_
+
+ def post_hook(module, inp, output):
+ idx_, position_, hs_ = (
+ position_hs["batch_idx"],
+ position_hs["position_target"],
+ position_hs["hidden_rep"],
+ )
+ if "skip_ln" in name:
+ # output: (batch, sequence, hidden_state)
+ output_len = len(output[idx_])
+ if generation_mode and output_len == 1:
+ return
+ output[idx_][position_] = hs_
+ else:
+ # output[0]: (batch, sequence, hidden_state)
+ output_len = len(output[0][idx_])
+ if generation_mode and output_len == 1:
+ return
+ output[0][idx_][position_] = hs_
+
+ if patch_input:
+ return pre_hook
+ else:
+ return post_hook
+
+ hooks = []
+ for item in hs_patch_config:
+ i = item["layer_target"]
+ skip_final_ln = item["skip_final_ln"]
+ if patch_input:
+ hooks.append(
+ model.transformer.h[i].register_forward_pre_hook(
+ patch_hs(f"patch_hs_{i}", item, patch_input, generation_mode)
+ )
+ )
+ else:
+ # when patching a last-layer representation to the last layer of the same
+ # model, the final layer norm is not needed because it was already
+ # applied (assuming that the representation for patching was obtained by
+ # setting output_hidden_representations to True).
+ if skip_final_ln and i == len(model.transformer.h) - 1:
+ hooks.append(
+ model.transformer.ln_f.register_forward_hook(
+ patch_hs(
+ f"patch_hs_{i}_skip_ln",
+ item,
+ patch_input,
+ generation_mode,
+ )
+ )
+ )
+ else:
+ hooks.append(
+ model.transformer.h[i].register_forward_hook(
+ patch_hs(f"patch_hs_{i}", item, patch_input, generation_mode)
+ )
+ )
+
+ return hooks
+
+
+def set_hs_patch_hooks_llama_batch(
+ model,
+ hs_patch_config,
+ module="hs",
+ patch_input=False,
+ generation_mode=False,
+):
+ """LLAMA patch hooks - supporting batch."""
+ # when using mode.generate() the hidden states in the input are cached after
+ # the first inference pass, and in the next steps the input/output are of
+ # size 1. In these cases we don't need to patch anymore the previous hidden
+ # states from the initial input, because they are cached, but we do need to
+ # handle these cases in this call because this hook wraps the generation call.
+ #
+ # NOTE: To use generation mode, we must patch a position that is not the
+ # first one. This is because in this case we don't know during generation if
+ # we are handling the initial input or a future step and thus don't know if
+ # a patching is needed or not.
+
+ # if generation_mode:
+ # for i in hs_patch_config:
+ # for position_, _ in hs_patch_config[i]:
+ # assert position_ > 0
+
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ def patch_hs(name, position_hs, patch_input, generation_mode):
+ def pre_hook(module, inp):
+ # inp[0]: (batch, sequence, hidden_state)
+ idx_, position_, hs_ = (
+ position_hs["batch_idx"],
+ position_hs["position_target"],
+ position_hs["hidden_rep"],
+ )
+ input_len = len(inp[0][idx_])
+ if generation_mode and input_len == 1:
+ return
+ inp[0][idx_][position_] = hs_
+
+ def post_hook(module, inp, output):
+ idx_, position_, hs_ = (
+ position_hs["batch_idx"],
+ position_hs["position_target"],
+ position_hs["hidden_rep"],
+ )
+ if "skip_ln" in name:
+ # output: (batch, sequence, hidden_state)
+ output_len = len(output[idx_])
+ if generation_mode and output_len == 1:
+ return
+ output[idx_][position_] = hs_
+ else:
+ # output[0]: (batch, sequence, hidden_state)
+ output_len = len(output[0][idx_])
+ if generation_mode and output_len == 1:
+ return
+ output[0][idx_][position_] = hs_
+
+ if patch_input:
+ return pre_hook
+ else:
+ return post_hook
+
+ hooks = []
+
+ for item in hs_patch_config:
+ i = item["layer_target"]
+ skip_final_ln = item["skip_final_ln"]
+ if patch_input:
+ hooks.append(
+ model.model.layers[i].register_forward_pre_hook(
+ patch_hs(f"patch_hs_{i}", item, patch_input, generation_mode)
+ )
+ )
+ else:
+ # when patching a last-layer representation to the last layer of the same
+ # model, the final layer norm is not needed because it was already applied
+ # (assuming that the representation for patching was obtained by setting
+ # output_hidden_representations to True).
+ if skip_final_ln and i == len(model.model.layers) - 1:
+ hooks.append(
+ model.model.norm.register_forward_hook(
+ patch_hs(
+ f"patch_hs_{i}_skip_ln", item, patch_input, generation_mode
+ )
+ )
+ )
+ else:
+ hooks.append(
+ model.model.layers[i].register_forward_hook(
+ patch_hs(f"patch_hs_{i}", item, patch_input, generation_mode)
+ )
+ )
+
+ return hooks
+
+
+def evaluate_patch_next_token_prediction_batch(
+ mt, df, batch_size=256, transform=None, module="hs"
+):
+ """Evaluate next token prediction with batch support."""
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ prec_1 = np.zeros(0)
+ surprisal = np.zeros(0)
+ next_token = np.zeros(0)
+ # generations = []
+
+ def _evaluat_single_batch(batch_df):
+ batch_size = len(batch_df)
+ prompt_source_batch = np.array(batch_df["prompt_source"])
+ prompt_target_batch = np.array(batch_df["prompt_target"])
+ layer_source_batch = np.array(batch_df["layer_source"])
+ layer_target_batch = np.array(batch_df["layer_target"])
+ position_source_batch = np.array(batch_df["position_source"])
+ position_target_batch = np.array(batch_df["position_target"])
+ position_prediction_batch = np.ones_like(position_target_batch) * -1
+ # max_gen_len = np.array(batch_df["max_gen_len"])
+
+ # adjust position_target to be absolute rather than relative
+ inp_target = make_inputs(mt.tokenizer, prompt_target_batch, mt.device)
+ for i in range(batch_size):
+ if position_target_batch[i] < 0:
+ position_target_batch[i] += len(inp_target["input_ids"][i])
+
+ # first run the the model on without patching and get the results.
+ inp_source = make_inputs(mt.tokenizer, prompt_source_batch, mt.device)
+ output_orig = mt.model(**inp_source, output_hidden_states=True)
+ dist_orig = torch.softmax(
+ output_orig.logits[
+ np.array(range(batch_size)), position_source_batch, :
+ ],
+ dim=-1,
+ )
+ _, answer_t_orig = torch.max(dist_orig, dim=-1)
+ # hidden_states size (n_layers, n_sample, seq_len, hidden_dim)
+ hidden_rep = [
+ output_orig.hidden_states[layer_source_batch[i] + 1][i][
+ position_source_batch[i]
+ ]
+ for i in range(batch_size)
+ ]
+ if transform is not None:
+ for i in range(batch_size):
+ hidden_rep[i] = transform(hidden_rep[i])
+
+ # now do a second run on prompt, while patching the input hidden state.
+ hs_patch_config = [
+ {
+ "batch_idx": i,
+ "layer_target": layer_target_batch[i],
+ "position_target": position_target_batch[i],
+ "hidden_rep": hidden_rep[i],
+ "skip_final_ln": (
+ layer_source_batch[i]
+ == layer_target_batch[i]
+ == mt.num_layers - 1
+ ),
+ }
+ for i in range(batch_size)
+ ]
+ patch_hooks = mt.set_hs_patch_hooks(
+ mt.model,
+ hs_patch_config,
+ module=module,
+ patch_input=False,
+ generation_mode=False,
+ )
+
+ output = mt.model(**inp_target)
+
+ # # NOTE: inputs are left padded,
+ # # and sequence length is the same across batch
+ # # to support generations of variable lengths,
+ # # first generate with maximum number of tokens needed in the batch
+ # seq_len = len(inp_target["input_ids"][0])
+ # output_toks = mt.model.generate(
+ # inp_target["input_ids"],
+ # max_length=seq_len + max(max_gen_len),
+ # pad_token_id=mt.model.generation_config.eos_token_id,
+ # )[:, seq_len:]
+
+ # # then, we select only the subset of tokens that we need
+ # generations = [
+ # mt.tokenizer.decode(output_toks[i][: max_gen_len[i]])
+ # for i in range(batch_size)
+ # ]
+
+ dist = torch.softmax(
+ output.logits[
+ np.array(range(batch_size)), position_prediction_batch, :
+ ],
+ dim=-1,
+ )
+ _, answer_t = torch.max(dist, dim=-1)
+ next_token = [mt.tokenizer.decode(tok) for tok in answer_t]
+
+ # remove patching hooks
+ remove_hooks(patch_hooks)
+
+ prec_1 = (answer_t == answer_t_orig).detach().cpu().numpy()
+ surprisal = (
+ -torch.log(dist_orig[np.array(range(batch_size)), answer_t])
+ .detach()
+ .cpu()
+ .numpy()
+ )
+
+ return prec_1, surprisal, next_token
+
+ for i in tqdm.tqdm(range(len(df) // batch_size)):
+ cur_df = df.iloc[batch_size * i : batch_size * (i + 1)]
+ batch_prec_1, batch_surprisal, batch_next_token = _evaluat_single_batch(
+ cur_df
+ )
+ prec_1 = np.concatenate((prec_1, batch_prec_1))
+ surprisal = np.concatenate((surprisal, batch_surprisal))
+ next_token = np.concatenate((next_token, batch_next_token))
+
+ return prec_1, surprisal, next_token
+
+
+def inspect_batch(mt, df, batch_size=256, transform=None, module="hs"):
+ """Inspects batch: source/target layer/position could differ within batch."""
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ generations = []
+
+ def _inspect_single_batch(batch_df):
+ batch_size = len(batch_df)
+ prompt_source_batch = np.array(batch_df["prompt_source"])
+ prompt_target_batch = np.array(batch_df["prompt_target"])
+ layer_source_batch = np.array(batch_df["layer_source"])
+ layer_target_batch = np.array(batch_df["layer_target"])
+ position_source_batch = np.array(batch_df["position_source"])
+ position_target_batch = np.array(batch_df["position_target"])
+ max_gen_len = np.array(batch_df["max_gen_len"])
+
+ # adjust position_target to be absolute rather than relative
+ inp_target = make_inputs(mt.tokenizer, prompt_target_batch, mt.device)
+ for i in range(batch_size):
+ if position_target_batch[i] < 0:
+ position_target_batch[i] += len(inp_target["input_ids"][i])
+
+ # first run the the model on without patching and get the results.
+ inp_source = make_inputs(mt.tokenizer, prompt_source_batch, mt.device)
+ output_orig = mt.model(**inp_source, output_hidden_states=True)
+
+ # hidden_states size (n_layers, n_sample, seq_len, hidden_dim)
+ hidden_rep = [
+ output_orig.hidden_states[layer_source_batch[i] + 1][i][
+ position_source_batch[i]
+ ]
+ for i in range(batch_size)
+ ]
+ if transform is not None:
+ for i in range(batch_size):
+ hidden_rep[i] = transform(hidden_rep[i])
+
+ # now do a second run on prompt, while patching the input hidden state.
+ hs_patch_config = [
+ {
+ "batch_idx": i,
+ "layer_target": layer_target_batch[i],
+ "position_target": position_target_batch[i],
+ "hidden_rep": hidden_rep[i],
+ "skip_final_ln": (
+ layer_source_batch[i]
+ == layer_target_batch[i]
+ == mt.num_layers - 1
+ ),
+ }
+ for i in range(batch_size)
+ ]
+ patch_hooks = mt.set_hs_patch_hooks(
+ mt.model,
+ hs_patch_config,
+ module=module,
+ patch_input=False,
+ generation_mode=True,
+ )
+
+ # NOTE: inputs are left padded,
+ # and sequence length is the same across batch
+ # to support generations of variable lengths,
+ # first generate with maximum number of tokens needed in the batch
+ seq_len = len(inp_target["input_ids"][0])
+ output_toks = mt.model.generate(
+ inp_target["input_ids"],
+ max_length=seq_len + max(max_gen_len),
+ pad_token_id=mt.model.generation_config.eos_token_id,
+ )[:, seq_len:]
+
+ # then, we select only the subset of tokens that we need
+ generations = [
+ mt.tokenizer.decode(output_toks[i][: max_gen_len[i]])
+ for i in range(batch_size)
+ ]
+
+ # remove patching hooks
+ remove_hooks(patch_hooks)
+
+ return generations
+
+ for i in tqdm.tqdm(range(1 + len(df) // batch_size)):
+ cur_df = df.iloc[batch_size * i : batch_size * (i + 1)]
+ batch_generations = _inspect_single_batch(cur_df)
+ generations.extend(batch_generations)
+
+ return generations
+
+
+def evaluate_attriburte_exraction_batch(
+ mt,
+ df,
+ batch_size=256,
+ max_gen_len=10,
+ transform=None,
+ is_icl=True,
+ module="hs",
+):
+ """Evaluates attribute extraction with batch support."""
+ # We don't know the exact token position of the
+ # attribute, as it is not necessarily the next token. So, precision and
+ # surprisal may not apply directly.
+
+ if module != "hs":
+ raise ValueError("Module %s not yet supported", module)
+
+ def _evaluate_attriburte_exraction_single_batch(batch_df):
+ batch_size = len(batch_df)
+ prompt_source_batch = np.array(batch_df["prompt_source"])
+ prompt_target_batch = np.array(batch_df["prompt_target"])
+ layer_source_batch = np.array(batch_df["layer_source"])
+ layer_target_batch = np.array(batch_df["layer_target"])
+ position_source_batch = np.array(batch_df["position_source"])
+ position_target_batch = np.array(batch_df["position_target"])
+
+ object_batch = np.array(batch_df["object"])
+
+ # Adjust position_target to be absolute rather than relative
+ inp_target = make_inputs(mt.tokenizer, prompt_target_batch, mt.device)
+ for i in range(batch_size):
+ if position_target_batch[i] < 0:
+ position_target_batch[i] += len(inp_target["input_ids"][i])
+
+ # Step 1: run model on source prompt without patching and get the hidden
+ # representations.
+ inp_source = make_inputs(mt.tokenizer, prompt_source_batch, mt.device)
+ output_orig = mt.model(**inp_source, output_hidden_states=True)
+
+ # hidden_states size (n_layers, n_sample, seq_len, hidden_dim)
+ # hidden_rep = []
+ # for i in range(batch_size):
+ # hidden_rep.append(output_orig.hidden_states[layer_source_batch[i] + 1][i][position_source_batch[i]])
+ hidden_rep = [
+ output_orig.hidden_states[layer_source_batch[i] + 1][i][
+ position_source_batch[i]
+ ]
+ for i in range(batch_size)
+ ]
+ if transform is not None:
+ for i in range(batch_size):
+ hidden_rep[i] = transform(hidden_rep[i])
+
+ # Step 2: Do second run on target prompt, while patching the input
+ # hidden state.
+ hs_patch_config = [
+ {
+ "batch_idx": i,
+ "layer_target": layer_target_batch[i],
+ "position_target": position_target_batch[i],
+ "hidden_rep": hidden_rep[i],
+ "skip_final_ln": (
+ layer_source_batch[i]
+ == layer_target_batch[i]
+ == mt.num_layers - 1
+ ),
+ }
+ for i in range(batch_size)
+ ]
+ patch_hooks = mt.set_hs_patch_hooks(
+ mt.model,
+ hs_patch_config,
+ module=module,
+ patch_input=False,
+ generation_mode=True,
+ )
+
+ # Note that inputs are left padded,
+ # and sequence length is the same across batch
+ seq_len = len(inp_target["input_ids"][0])
+ output_toks = mt.model.generate(
+ inp_target["input_ids"],
+ max_length=seq_len + max_gen_len,
+ pad_token_id=mt.model.generation_config.eos_token_id,
+ )[:, seq_len:]
+ generations_patched = decode_tokens(mt.tokenizer, output_toks)
+ if is_icl:
+ prefix = batch_df["prefix"].iloc[0]
+
+ def _crop_by_prefix(generations, prefix):
+ concatenated_str = " ".join(generations)
+ _pos = concatenated_str.find(prefix)
+ return concatenated_str[:_pos]
+
+ generations_patched_postprocessed = np.array([
+ _crop_by_prefix(generations_patched[i], prefix)
+ for i in range(batch_size)
+ ])
+ else:
+ generations_patched_postprocessed = np.array(
+ [" ".join(generations_patched[i]) for i in range(batch_size)]
+ )
+
+ is_correct_patched = np.array([
+ object_batch[i].replace(" ", "")
+ in generations_patched_postprocessed[i].replace(" ", "")
+ for i in range(batch_size)
+ ])
+
+ # remove patching hooks
+ remove_hooks(patch_hooks)
+
+ cpu_hidden_rep = np.array(
+ [hidden_rep[i].detach().cpu().numpy() for i in range(batch_size)]
+ )
+
+ results = {
+ "generations_patched": generations_patched,
+ "generations_patched_postprocessed": generations_patched_postprocessed,
+ "is_correct_patched": is_correct_patched,
+ "hidden_rep": cpu_hidden_rep,
+ }
+
+ return results
+
+ results = {}
+ n_batches = len(df) // batch_size
+ if len(df) % batch_size != 0:
+ n_batches += 1
+ for i in tqdm(range(len(df) // batch_size)):
+ cur_df = df.iloc[batch_size * i : batch_size * (i + 1)]
+ batch_results = _evaluate_attriburte_exraction_single_batch(cur_df)
+ for key, value in batch_results.items():
+ if key in results:
+ results[key] = np.concatenate((results[key], value))
+ else:
+ results[key] = value
+
+ return results
diff --git a/patchscopes/code/preprocessed_data/commonsense/fruit_inside_color.tsv b/patchscopes/code/preprocessed_data/commonsense/fruit_inside_color.tsv
new file mode 100644
index 00000000..eb39c351
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/commonsense/fruit_inside_color.tsv
@@ -0,0 +1,267 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 On the inside, the color of x -1 On the inside, the color of bananas white bananas [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' brown', '.', '\n', '\n', 'The', ' skin', ' is', ' thick', ' and', ' tough', ',', ' and', ' the'] " is a deep , rich , dark brown .
+
+ The skin is thick and tough , and the" False and 5,662 acres of bananas were destroyed, 6 [' and', ' 5', ',', '662', ' acres', ' of', ' bananas']
+1 0 On the inside, the color of x -1 On the inside, the color of bananas white bananas [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' brown', '.', '\n', '\n', 'The', ' skin', ' is', ' thick', ' and', ' tough', ',', ' and', ' the'] " is a deep , rich , dark brown .
+
+ The skin is thick and tough , and the" False their stolen hoard of bananas from King K. Rool 4 [' their', ' stolen', ' hoard', ' of', ' bananas']
+2 0 On the inside, the color of x -1 On the inside, the color of bananas white bananas [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' brown', '.', '\n', '\n', 'The', ' skin', ' is', ' thick', ' and', ' tough', ',', ' and', ' the'] " is a deep , rich , dark brown .
+
+ The skin is thick and tough , and the" False While collecting bananas on the island's vastly 2 [' While', ' collecting', ' bananas']
+3 0 On the inside, the color of x -1 On the inside, the color of bananas white bananas [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' brown', '.', '\n', '\n', 'The', ' skin', ' is', ' thick', ' and', ' tough', ',', ' and', ' the'] " is a deep , rich , dark brown .
+
+ The skin is thick and tough , and the" False hectares of bananas and 686,893 hectares 2 [' hectares', ' of', ' bananas']
+4 0 On the inside, the color of x -1 On the inside, the color of bananas white bananas [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' brown', '.', '\n', '\n', 'The', ' skin', ' is', ' thick', ' and', ' tough', ',', ' and', ' the'] " is a deep , rich , dark brown .
+
+ The skin is thick and tough , and the" False non-regenerating bananas that add to speed 5 [' non', '-', 'reg', 'ener', 'ating', ' bananas']
+5 1 On the inside, the color of x -1 On the inside, the color of apples white apples [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' beautiful', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' shiny', ',', ' and', ' beautiful'] is a deep , rich , and beautiful red . The outside is a bright , shiny , and beautiful False are French cut apples served in a fry box 3 [' are', ' French', ' cut', ' apples']
+6 1 On the inside, the color of x -1 On the inside, the color of apples white apples [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' beautiful', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' shiny', ',', ' and', ' beautiful'] is a deep , rich , and beautiful red . The outside is a bright , shiny , and beautiful False him three poisoned apples to get rid of 3 [' him', ' three', ' poisoned', ' apples']
+7 1 On the inside, the color of x -1 On the inside, the color of apples white apples [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' beautiful', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' shiny', ',', ' and', ' beautiful'] is a deep , rich , and beautiful red . The outside is a bright , shiny , and beautiful False eating dozens of apples from the forbidden 3 [' eating', ' dozens', ' of', ' apples']
+8 1 On the inside, the color of x -1 On the inside, the color of apples white apples [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' beautiful', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' shiny', ',', ' and', ' beautiful'] is a deep , rich , and beautiful red . The outside is a bright , shiny , and beautiful False as large as apples washed onshore around 3 [' as', ' large', ' as', ' apples']
+9 1 On the inside, the color of x -1 On the inside, the color of apples white apples [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' beautiful', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' shiny', ',', ' and', ' beautiful'] is a deep , rich , and beautiful red . The outside is a bright , shiny , and beautiful False " with her basket of apples on its banner.
+" 4 [' with', ' her', ' basket', ' of', ' apples']
+10 2 On the inside, the color of x -1 On the inside, the color of watermelons red watermelons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' same', ' rich', ',', ' deep', ',', ' and', ' saturated', ' colors', ' are', ' still', ' there', '.'] is a bit more muted , but the same rich , deep , and saturated colors are still there . False peaches, plums and watermelons were cultivated 8 [' pe', 'aches', ',', ' pl', 'ums', ' and', ' water', 'mel', 'ons']
+11 2 On the inside, the color of x -1 On the inside, the color of watermelons red watermelons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' same', ' rich', ',', ' deep', ',', ' and', ' saturated', ' colors', ' are', ' still', ' there', '.'] is a bit more muted , but the same rich , deep , and saturated colors are still there . False spit seeds by eating watermelons — which were present 6 [' spit', ' seeds', ' by', ' eating', ' water', 'mel', 'ons']
+12 2 On the inside, the color of x -1 On the inside, the color of watermelons red watermelons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' same', ' rich', ',', ' deep', ',', ' and', ' saturated', ' colors', ' are', ' still', ' there', '.'] is a bit more muted , but the same rich , deep , and saturated colors are still there . False Yoshi can also eat watermelons that let him 6 [' Yoshi', ' can', ' also', ' eat', ' water', 'mel', 'ons']
+13 2 On the inside, the color of x -1 On the inside, the color of watermelons red watermelons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' same', ' rich', ',', ' deep', ',', ' and', ' saturated', ' colors', ' are', ' still', ' there', '.'] is a bit more muted , but the same rich , deep , and saturated colors are still there . False Yoshi can also eat watermelons that let him 6 [' Yoshi', ' can', ' also', ' eat', ' water', 'mel', 'ons']
+14 2 On the inside, the color of x -1 On the inside, the color of watermelons red watermelons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' same', ' rich', ',', ' deep', ',', ' and', ' saturated', ' colors', ' are', ' still', ' there', '.'] is a bit more muted , but the same rich , deep , and saturated colors are still there . False peaches, plums and watermelons were cultivated 8 [' pe', 'aches', ',', ' pl', 'ums', ' and', ' water', 'mel', 'ons']
+15 3 On the inside, the color of x -1 On the inside, the color of kiwis green kiwis [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' fruit', ' is', ' still', ' very', ' sweet', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more'] is a bit more muted , but the fruit is still very sweet . The flavor is a bit more False of five species of kiwis residing in 6 [' of', ' five', ' species', ' of', ' k', 'iw', 'is']
+16 3 On the inside, the color of x -1 On the inside, the color of kiwis green kiwis [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' fruit', ' is', ' still', ' very', ' sweet', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more'] is a bit more muted , but the fruit is still very sweet . The flavor is a bit more False suitable for mangoes, kiwis and bananas, 7 [' suitable', ' for', ' mango', 'es', ',', ' k', 'iw', 'is']
+17 3 On the inside, the color of x -1 On the inside, the color of kiwis green kiwis [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' fruit', ' is', ' still', ' very', ' sweet', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more'] is a bit more muted , but the fruit is still very sweet . The flavor is a bit more False adult great spotted kiwis are large and 5 [' adult', ' great', ' spotted', ' k', 'iw', 'is']
+18 3 On the inside, the color of x -1 On the inside, the color of kiwis green kiwis [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' fruit', ' is', ' still', ' very', ' sweet', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more'] is a bit more muted , but the fruit is still very sweet . The flavor is a bit more False suitable for mangoes, kiwis and bananas, while 7 [' suitable', ' for', ' mango', 'es', ',', ' k', 'iw', 'is']
+19 3 On the inside, the color of x -1 On the inside, the color of kiwis green kiwis [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' fruit', ' is', ' still', ' very', ' sweet', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more'] is a bit more muted , but the fruit is still very sweet . The flavor is a bit more False species of kiwis residing in New 4 [' species', ' of', ' k', 'iw', 'is']
+20 5 On the inside, the color of x -1 On the inside, the color of eggplants white eggplants [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' smooth', ' and', ' shiny', '.', ' The', ' flesh', ' is', ' white', ' and', ' soft', '.'] is a deep purple , and the skin is smooth and shiny . The flesh is white and soft . True tomatoes, onions, eggplants and also sometimes 6 [' tomatoes', ',', ' onions', ',', ' egg', 'pl', 'ants']
+21 5 On the inside, the color of x -1 On the inside, the color of eggplants white eggplants [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' smooth', ' and', ' shiny', '.', ' The', ' flesh', ' is', ' white', ' and', ' soft', '.'] is a deep purple , and the skin is smooth and shiny . The flesh is white and soft . True 1992 result (but using eggplants instead of dolls), 7 [' 1992', ' result', ' (', 'but', ' using', ' egg', 'pl', 'ants']
+22 5 On the inside, the color of x -1 On the inside, the color of eggplants white eggplants [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' smooth', ' and', ' shiny', '.', ' The', ' flesh', ' is', ' white', ' and', ' soft', '.'] is a deep purple , and the skin is smooth and shiny . The flesh is white and soft . True peppers, tomatoes, onions, eggplants and also sometimes 8 [' peppers', ',', ' tomatoes', ',', ' onions', ',', ' egg', 'pl', 'ants']
+23 5 On the inside, the color of x -1 On the inside, the color of eggplants white eggplants [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' smooth', ' and', ' shiny', '.', ' The', ' flesh', ' is', ' white', ' and', ' soft', '.'] is a deep purple , and the skin is smooth and shiny . The flesh is white and soft . True uprooted two Bt eggplants and more than 7 [' upro', 'oted', ' two', ' B', 't', ' egg', 'pl', 'ants']
+24 5 On the inside, the color of x -1 On the inside, the color of eggplants white eggplants [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' smooth', ' and', ' shiny', '.', ' The', ' flesh', ' is', ' white', ' and', ' soft', '.'] is a deep purple , and the skin is smooth and shiny . The flesh is white and soft . True uprooted two Bt eggplants and more than 100 7 [' upro', 'oted', ' two', ' B', 't', ' egg', 'pl', 'ants']
+25 6 On the inside, the color of x -1 On the inside, the color of zucchinis white zucchinis [' is', ' a', ' deep', ' green', ',', ' but', ' the', ' outside', ' is', ' a', ' pale', ' yellow', '.', ' The', ' z', 'uc', 'ch', 'inis', ' are', ' ready'] is a deep green , but the outside is a pale yellow . The z uc ch inis are ready False Kousa mahshi are zucchinis stuffed with 9 [' K', 'ous', 'a', ' mah', 'shi', ' are', ' z', 'uc', 'ch', 'inis']
+26 7 On the inside, the color of x -1 On the inside, the color of pineapples yellow pineapples [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'The', ' color', ' of', ' the', ' inside', ' of', ' the', ' pineapple', ' is'] " is a little bit different from the outside .
+
+ The color of the inside of the pineapple is" False with a load of pineapples — was raided 6 [' with', ' a', ' load', ' of', ' pine', 'app', 'les']
+27 7 On the inside, the color of x -1 On the inside, the color of pineapples yellow pineapples [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'The', ' color', ' of', ' the', ' inside', ' of', ' the', ' pineapple', ' is'] " is a little bit different from the outside .
+
+ The color of the inside of the pineapple is" False Moringa oleifera, pineapples and rubber. UPLB 10 [' M', 'oring', 'a', ' o', 'le', 'ifer', 'a', ',', ' pine', 'app', 'les']
+28 7 On the inside, the color of x -1 On the inside, the color of pineapples yellow pineapples [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'The', ' color', ' of', ' the', ' inside', ' of', ' the', ' pineapple', ' is'] " is a little bit different from the outside .
+
+ The color of the inside of the pineapple is" False Turumun farmed pineapples and coconuts. Wahid 7 [' Tur', 'um', 'un', ' far', 'med', ' pine', 'app', 'les']
+29 7 On the inside, the color of x -1 On the inside, the color of pineapples yellow pineapples [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'The', ' color', ' of', ' the', ' inside', ' of', ' the', ' pineapple', ' is'] " is a little bit different from the outside .
+
+ The color of the inside of the pineapple is" False with a load of pineapples — was raided by United 6 [' with', ' a', ' load', ' of', ' pine', 'app', 'les']
+30 7 On the inside, the color of x -1 On the inside, the color of pineapples yellow pineapples [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'The', ' color', ' of', ' the', ' inside', ' of', ' the', ' pineapple', ' is'] " is a little bit different from the outside .
+
+ The color of the inside of the pineapple is" False chocolate and pineapples from Mexico and 4 [' chocolate', ' and', ' pine', 'app', 'les']
+31 8 On the inside, the color of x -1 On the inside, the color of mangoes orange mangoes [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'M', 'ango', ' is', ' a', ' fruit', ' that', ' is', ' very', ' popular'] " is a little bit different from the outside .
+
+ M ango is a fruit that is very popular" False a crate of mangoes to Dayal Babu's house. 4 [' a', ' crate', ' of', ' mango', 'es']
+32 8 On the inside, the color of x -1 On the inside, the color of mangoes orange mangoes [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'M', 'ango', ' is', ' a', ' fruit', ' that', ' is', ' very', ' popular'] " is a little bit different from the outside .
+
+ M ango is a fruit that is very popular" False a crate of mangoes to Dayal Babu's house. 4 [' a', ' crate', ' of', ' mango', 'es']
+33 8 On the inside, the color of x -1 On the inside, the color of mangoes orange mangoes [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'M', 'ango', ' is', ' a', ' fruit', ' that', ' is', ' very', ' popular'] " is a little bit different from the outside .
+
+ M ango is a fruit that is very popular" False delivers a crate of mangoes to Dayal Babu's house. 5 [' delivers', ' a', ' crate', ' of', ' mango', 'es']
+34 8 On the inside, the color of x -1 On the inside, the color of mangoes orange mangoes [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'M', 'ango', ' is', ' a', ' fruit', ' that', ' is', ' very', ' popular'] " is a little bit different from the outside .
+
+ M ango is a fruit that is very popular" False delivers a crate of mangoes to Dayal Babu's house. 5 [' delivers', ' a', ' crate', ' of', ' mango', 'es']
+35 8 On the inside, the color of x -1 On the inside, the color of mangoes orange mangoes [' is', ' a', ' little', ' bit', ' different', ' from', ' the', ' outside', '.', '\n', '\n', 'M', 'ango', ' is', ' a', ' fruit', ' that', ' is', ' very', ' popular'] " is a little bit different from the outside .
+
+ M ango is a fruit that is very popular" False rubber, turmeric, mangoes and spices. Other 6 [' rubber', ',', ' tur', 'meric', ',', ' mango', 'es']
+36 9 On the inside, the color of x -1 On the inside, the color of cucumbers white cucumbers "[' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if']" " is a bit more muted , but the flavor is still there .
+
+ I 'm not sure if" False eaten some cucumbers with zest and 3 [' eaten', ' some', ' cuc', 'umbers']
+37 9 On the inside, the color of x -1 On the inside, the color of cucumbers white cucumbers "[' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if']" " is a bit more muted , but the flavor is still there .
+
+ I 'm not sure if" False wheat flour, lentils, cucumbers and cereal grain. Acute 7 [' wheat', ' flour', ',', ' lent', 'ils', ',', ' cuc', 'umbers']
+38 9 On the inside, the color of x -1 On the inside, the color of cucumbers white cucumbers "[' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if']" " is a bit more muted , but the flavor is still there .
+
+ I 'm not sure if" False properties, tomatoes, cucumbers and melons are grown 5 [' properties', ',', ' tomatoes', ',', ' cuc', 'umbers']
+39 9 On the inside, the color of x -1 On the inside, the color of cucumbers white cucumbers "[' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if']" " is a bit more muted , but the flavor is still there .
+
+ I 'm not sure if" False tomatoes and cucumbers combined with olive 3 [' tomatoes', ' and', ' cuc', 'umbers']
+40 9 On the inside, the color of x -1 On the inside, the color of cucumbers white cucumbers "[' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if']" " is a bit more muted , but the flavor is still there .
+
+ I 'm not sure if" False chopped tomatoes and cucumbers dressed in olive oil, 4 [' chopped', ' tomatoes', ' and', ' cuc', 'umbers']
+41 10 On the inside, the color of x -1 On the inside, the color of radishes white radishes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' flavor', ' is', ' a', ' bit', ' like', ' a', ' cross', ' between', ' a', ' rad', 'ish', ' and', ' a'] is a deep purple , and the flavor is a bit like a cross between a rad ish and a False beetroot, cucumbers, radishes and green 7 [' beet', 'root', ',', ' cuc', 'umbers', ',', ' rad', 'ishes']
+42 10 On the inside, the color of x -1 On the inside, the color of radishes white radishes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' flavor', ' is', ' a', ' bit', ' like', ' a', ' cross', ' between', ' a', ' rad', 'ish', ' and', ' a'] is a deep purple , and the flavor is a bit like a cross between a rad ish and a False " Spring or summer radishes ===
+" 4 [' Spring', ' or', ' summer', ' rad', 'ishes']
+43 10 On the inside, the color of x -1 On the inside, the color of radishes white radishes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' flavor', ' is', ' a', ' bit', ' like', ' a', ' cross', ' between', ' a', ' rad', 'ish', ' and', ' a'] is a deep purple , and the flavor is a bit like a cross between a rad ish and a False million tons of radishes are produced 4 [' million', ' tons', ' of', ' rad', 'ishes']
+44 10 On the inside, the color of x -1 On the inside, the color of radishes white radishes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' flavor', ' is', ' a', ' bit', ' like', ' a', ' cross', ' between', ' a', ' rad', 'ish', ' and', ' a'] is a deep purple , and the flavor is a bit like a cross between a rad ish and a False Broadly speaking, radishes can be categorized 5 [' Broad', 'ly', ' speaking', ',', ' rad', 'ishes']
+45 10 On the inside, the color of x -1 On the inside, the color of radishes white radishes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' flavor', ' is', ' a', ' bit', ' like', ' a', ' cross', ' between', ' a', ' rad', 'ish', ' and', ' a'] is a deep purple , and the flavor is a bit like a cross between a rad ish and a False " leeks, garlic, radishes and lettuces.
+" 6 [' le', 'eks', ',', ' garlic', ',', ' rad', 'ishes']
+46 12 On the inside, the color of x -1 On the inside, the color of nectarines yellow nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' The', ' skin', ' is', ' thin', ' and', ' tender', ',', ' and', ' the'] is a deep , rich , almost purple - red . The skin is thin and tender , and the False peaches, grapes and nectarines were grown in greenhouses. 7 [' pe', 'aches', ',', ' grapes', ' and', ' n', 'ectar', 'ines']
+47 12 On the inside, the color of x -1 On the inside, the color of nectarines yellow nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' The', ' skin', ' is', ' thin', ' and', ' tender', ',', ' and', ' the'] is a deep , rich , almost purple - red . The skin is thin and tender , and the False designation), lemons, nectarines and figs are grown 7 [' designation', '),', ' le', 'mons', ',', ' n', 'ectar', 'ines']
+48 12 On the inside, the color of x -1 On the inside, the color of nectarines yellow nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' The', ' skin', ' is', ' thin', ' and', ' tender', ',', ' and', ' the'] is a deep , rich , almost purple - red . The skin is thin and tender , and the False peaches, grapes and nectarines were grown in greenhouses. 7 [' pe', 'aches', ',', ' grapes', ' and', ' n', 'ectar', 'ines']
+49 12 On the inside, the color of x -1 On the inside, the color of nectarines yellow nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' The', ' skin', ' is', ' thin', ' and', ' tender', ',', ' and', ' the'] is a deep , rich , almost purple - red . The skin is thin and tender , and the False peaches, grapes and nectarines were grown in greenhouses. 7 [' pe', 'aches', ',', ' grapes', ' and', ' n', 'ectar', 'ines']
+50 13 On the inside, the color of x -1 On the inside, the color of plums red plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' of', ' a', ' ripe', ' peach', ' and', ' the'] and the texture of a ripe peach . On the outside , the color of a ripe peach and the False prunellé from wild plums (modern-day 7 [' pr', 'un', 'ell', 'é', ' from', ' wild', ' pl', 'ums']
+51 13 On the inside, the color of x -1 On the inside, the color of plums red plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' of', ' a', ' ripe', ' peach', ' and', ' the'] and the texture of a ripe peach . On the outside , the color of a ripe peach and the False " the public ""the plums and no cake"", as" 5 "[' the', ' public', ' ""', 'the', ' pl', 'ums']"
+52 13 On the inside, the color of x -1 On the inside, the color of plums red plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' of', ' a', ' ripe', ' peach', ' and', ' the'] and the texture of a ripe peach . On the outside , the color of a ripe peach and the False longifolia, the green plums of Panda oleosa 7 [' long', 'if', 'olia', ',', ' the', ' green', ' pl', 'ums']
+53 13 On the inside, the color of x -1 On the inside, the color of plums red plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' of', ' a', ' ripe', ' peach', ' and', ' the'] and the texture of a ripe peach . On the outside , the color of a ripe peach and the False " described as ""plums"", many plum cakes" 4 "[' described', ' as', ' ""', 'pl', 'ums']"
+54 13 On the inside, the color of x -1 On the inside, the color of plums red plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' of', ' a', ' ripe', ' peach', ' and', ' the'] and the texture of a ripe peach . On the outside , the color of a ripe peach and the False " responds, ""Just sugar plums and buttercups.""" 6 "[' responds', ',', ' ""', 'Just', ' sugar', ' pl', 'ums']"
+55 15 On the inside, the color of x -1 On the inside, the color of strawberries red strawberries [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' a', ' bright', ' red', '.', ' The', ' skin', ' is', ' thin', ' and', ' the', ' flesh'] is a deep red , and the flesh is a bright red . The skin is thin and the flesh True was getting strawberries on my knees 2 [' was', ' getting', ' strawberries']
+56 15 On the inside, the color of x -1 On the inside, the color of strawberries red strawberries [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' a', ' bright', ' red', '.', ' The', ' skin', ' is', ' thin', ' and', ' the', ' flesh'] is a deep red , and the flesh is a bright red . The skin is thin and the flesh True pears, plums, and strawberries were more common. 7 [' p', 'ears', ',', ' pl', 'ums', ',', ' and', ' strawberries']
+57 15 On the inside, the color of x -1 On the inside, the color of strawberries red strawberries [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' a', ' bright', ' red', '.', ' The', ' skin', ' is', ' thin', ' and', ' the', ' flesh'] is a deep red , and the flesh is a bright red . The skin is thin and the flesh True locally-grown strawberries that it carried. 3 [' locally', '-', 'grown', ' strawberries']
+58 15 On the inside, the color of x -1 On the inside, the color of strawberries red strawberries [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' a', ' bright', ' red', '.', ' The', ' skin', ' is', ' thin', ' and', ' the', ' flesh'] is a deep red , and the flesh is a bright red . The skin is thin and the flesh True as raspberries and strawberries in summer pudding, 5 [' as', ' r', 'asp', 'berries', ' and', ' strawberries']
+59 15 On the inside, the color of x -1 On the inside, the color of strawberries red strawberries [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' a', ' bright', ' red', '.', ' The', ' skin', ' is', ' thin', ' and', ' the', ' flesh'] is a deep red , and the flesh is a bright red . The skin is thin and the flesh True Angevine' pears. The strawberries and melons are also 7 "[' Ange', 'vine', ""'"", ' p', 'ears', '.', ' The', ' strawberries']"
+60 16 On the inside, the color of x -1 On the inside, the color of avocados green avocados [' is', ' a', ' deep', ' green', ',', ' but', ' on', ' the', ' outside', ',', ' they', ' are', ' a', ' bright', ',', ' creamy', ' yellow', '.', '\n', '\n'] " is a deep green , but on the outside , they are a bright , creamy yellow .
+
+" True lemons, grapefruit, and avocados are grown too. This 9 [' le', 'mons', ',', ' grape', 'fruit', ',', ' and', ' av', 'oc', 'ados']
+61 16 On the inside, the color of x -1 On the inside, the color of avocados green avocados [' is', ' a', ' deep', ' green', ',', ' but', ' on', ' the', ' outside', ',', ' they', ' are', ' a', ' bright', ',', ' creamy', ' yellow', '.', '\n', '\n'] " is a deep green , but on the outside , they are a bright , creamy yellow .
+
+" True grapefruit, and avocados are grown too. This 6 [' grape', 'fruit', ',', ' and', ' av', 'oc', 'ados']
+62 16 On the inside, the color of x -1 On the inside, the color of avocados green avocados [' is', ' a', ' deep', ' green', ',', ' but', ' on', ' the', ' outside', ',', ' they', ' are', ' a', ' bright', ',', ' creamy', ' yellow', '.', '\n', '\n'] " is a deep green , but on the outside , they are a bright , creamy yellow .
+
+" True ingredients. The avocados may be mashed, 5 [' ingredients', '.', ' The', ' av', 'oc', 'ados']
+63 16 On the inside, the color of x -1 On the inside, the color of avocados green avocados [' is', ' a', ' deep', ' green', ',', ' but', ' on', ' the', ' outside', ',', ' they', ' are', ' a', ' bright', ',', ' creamy', ' yellow', '.', '\n', '\n'] " is a deep green , but on the outside , they are a bright , creamy yellow .
+
+" True cake using raw avocados and other raw 5 [' cake', ' using', ' raw', ' av', 'oc', 'ados']
+64 16 On the inside, the color of x -1 On the inside, the color of avocados green avocados [' is', ' a', ' deep', ' green', ',', ' but', ' on', ' the', ' outside', ',', ' they', ' are', ' a', ' bright', ',', ' creamy', ' yellow', '.', '\n', '\n'] " is a deep green , but on the outside , they are a bright , creamy yellow .
+
+" True ingredient. Raw avocados may be used in its 5 [' ingredient', '.', ' Raw', ' av', 'oc', 'ados']
+65 17 On the inside, the color of x -1 On the inside, the color of peaches yellow peaches [' and', ' cream', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sea'] and cream , the color of the sky , the color of the sun , the color of the sea False " cherries, pears, peaches and walnuts.
+" 7 [' cher', 'ries', ',', ' p', 'ears', ',', ' pe', 'aches']
+66 17 On the inside, the color of x -1 On the inside, the color of peaches yellow peaches [' and', ' cream', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sea'] and cream , the color of the sky , the color of the sun , the color of the sea False Persian conquest, peaches were added to the 4 [' Persian', ' conquest', ',', ' pe', 'aches']
+67 17 On the inside, the color of x -1 On the inside, the color of peaches yellow peaches [' and', ' cream', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sea'] and cream , the color of the sky , the color of the sun , the color of the sea False Shanghai, 20 % of peaches and 40 % of pears were 6 [' Shanghai', ',', ' 20', ' %', ' of', ' pe', 'aches']
+68 17 On the inside, the color of x -1 On the inside, the color of peaches yellow peaches [' and', ' cream', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sea'] and cream , the color of the sky , the color of the sun , the color of the sea False observed feeding on peaches and apricots. 4 [' observed', ' feeding', ' on', ' pe', 'aches']
+69 17 On the inside, the color of x -1 On the inside, the color of peaches yellow peaches [' and', ' cream', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sea'] and cream , the color of the sky , the color of the sun , the color of the sea False Persian conquest, peaches were added to 4 [' Persian', ' conquest', ',', ' pe', 'aches']
+70 18 On the inside, the color of x -1 On the inside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant', ' red', '.'] is a deep , rich red . On the outside , the color is a bright , vibrant red . True such as grapes and pomegranates have been created 7 [' such', ' as', ' grapes', ' and', ' p', 'ome', 'gran', 'ates']
+71 18 On the inside, the color of x -1 On the inside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant', ' red', '.'] is a deep , rich red . On the outside , the color is a bright , vibrant red . True figs, sweet melons, pomegranates and dates. Surrounding 10 [' fig', 's', ',', ' sweet', ' mel', 'ons', ',', ' p', 'ome', 'gran', 'ates']
+72 18 On the inside, the color of x -1 On the inside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant', ' red', '.'] is a deep , rich red . On the outside , the color is a bright , vibrant red . True sweet melons, pomegranates and dates. 7 [' sweet', ' mel', 'ons', ',', ' p', 'ome', 'gran', 'ates']
+73 18 On the inside, the color of x -1 On the inside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant', ' red', '.'] is a deep , rich red . On the outside , the color is a bright , vibrant red . True " include citrus, almonds, pomegranates and grapes.
+" 8 [' include', ' citrus', ',', ' almonds', ',', ' p', 'ome', 'gran', 'ates']
+74 18 On the inside, the color of x -1 On the inside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant', ' red', '.'] is a deep , rich red . On the outside , the color is a bright , vibrant red . True symbolised by monkeys, pomegranates and nesting birds. 8 [' symbol', 'ised', ' by', ' monkeys', ',', ' p', 'ome', 'gran', 'ates']
+75 19 On the inside, the color of x -1 On the inside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' red', '-', 'orange'] is a deep , rich , dark red . The outside is a bright , vibrant , red - orange True their coffee cherries to specifications 3 [' their', ' coffee', ' cher', 'ries']
+76 19 On the inside, the color of x -1 On the inside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' red', '-', 'orange'] is a deep , rich , dark red . The outside is a bright , vibrant , red - orange True named for the Islay cherries that grew wildly 6 [' named', ' for', ' the', ' Is', 'lay', ' cher', 'ries']
+77 19 On the inside, the color of x -1 On the inside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' red', '-', 'orange'] is a deep , rich , dark red . The outside is a bright , vibrant , red - orange True same name, eating cherries was projected 5 [' same', ' name', ',', ' eating', ' cher', 'ries']
+78 19 On the inside, the color of x -1 On the inside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' red', '-', 'orange'] is a deep , rich , dark red . The outside is a bright , vibrant , red - orange True Lilacs, weeping cherries and hibiscus 5 [' Lil', 'acs', ',', ' weeping', ' cher', 'ries']
+79 19 On the inside, the color of x -1 On the inside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' rich', ',', ' dark', ' red', '.', ' The', ' outside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' red', '-', 'orange'] is a deep , rich , dark red . The outside is a bright , vibrant , red - orange True house. Lilacs, weeping cherries and hibiscus were 7 [' house', '.', ' Lil', 'acs', ',', ' weeping', ' cher', 'ries']
+80 20 On the inside, the color of x -1 On the inside, the color of grapes green grapes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' a', ' light', ' green', '.', ' The', ' flavor', ' is', ' sweet', ',', ' with', ' a'] is a deep purple , and the skin is a light green . The flavor is sweet , with a True noticed that the grapes were different from 3 [' noticed', ' that', ' the', ' grapes']
+81 20 On the inside, the color of x -1 On the inside, the color of grapes green grapes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' a', ' light', ' green', '.', ' The', ' flavor', ' is', ' sweet', ',', ' with', ' a'] is a deep purple , and the skin is a light green . The flavor is sweet , with a True peasants harvesting grapes from a row of 2 [' peasants', ' harvesting', ' grapes']
+82 20 On the inside, the color of x -1 On the inside, the color of grapes green grapes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' a', ' light', ' green', '.', ' The', ' flavor', ' is', ' sweet', ',', ' with', ' a'] is a deep purple , and the skin is a light green . The flavor is sweet , with a True with Rhône grapes like Syrah. 4 [' with', ' Rh', 'ô', 'ne', ' grapes']
+83 20 On the inside, the color of x -1 On the inside, the color of grapes green grapes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' a', ' light', ' green', '.', ' The', ' flavor', ' is', ' sweet', ',', ' with', ' a'] is a deep purple , and the skin is a light green . The flavor is sweet , with a True on wheat, barley and grapes. The Muslims 5 [' on', ' wheat', ',', ' barley', ' and', ' grapes']
+84 20 On the inside, the color of x -1 On the inside, the color of grapes green grapes [' is', ' a', ' deep', ' purple', ',', ' and', ' the', ' skin', ' is', ' a', ' light', ' green', '.', ' The', ' flavor', ' is', ' sweet', ',', ' with', ' a'] is a deep purple , and the skin is a light green . The flavor is sweet , with a True production of table grapes as early as 3 [' production', ' of', ' table', ' grapes']
+85 21 On the inside, the color of x -1 On the inside, the color of blueberries green blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' light', ','] is a deep , rich , and vibrant blue . On the outside , the color is a light , False cranberries and blueberries grow in the 4 [' cran', 'berries', ' and', ' blue', 'berries']
+86 21 On the inside, the color of x -1 On the inside, the color of blueberries green blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' light', ','] is a deep , rich , and vibrant blue . On the outside , the color is a light , False cranberries and blueberries grow in the 4 [' cran', 'berries', ' and', ' blue', 'berries']
+87 21 On the inside, the color of x -1 On the inside, the color of blueberries green blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' light', ','] is a deep , rich , and vibrant blue . On the outside , the color is a light , False valley, bilberries, blueberries and cowberry. 6 [' valley', ',', ' bil', 'berries', ',', ' blue', 'berries']
+88 21 On the inside, the color of x -1 On the inside, the color of blueberries green blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' light', ','] is a deep , rich , and vibrant blue . On the outside , the color is a light , False typically blueberries and black huckleberry. 2 [' typically', ' blue', 'berries']
+89 21 On the inside, the color of x -1 On the inside, the color of blueberries green blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' light', ','] is a deep , rich , and vibrant blue . On the outside , the color is a light , False out picking blueberries some distance from 3 [' out', ' picking', ' blue', 'berries']
+90 22 On the inside, the color of x -1 On the inside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False shading through oranges to a saturated 2 [' shading', ' through', ' oranges']
+91 22 On the inside, the color of x -1 On the inside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False deep reds and oranges for the main 4 [' deep', ' red', 's', ' and', ' oranges']
+92 22 On the inside, the color of x -1 On the inside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False parsnips, lemons, oranges and several 7 [' pars', 'n', 'ips', ',', ' le', 'mons', ',', ' oranges']
+93 22 On the inside, the color of x -1 On the inside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False reds, browns, and oranges are the first attested 7 [' red', 's', ',', ' brown', 's', ',', ' and', ' oranges']
+94 22 On the inside, the color of x -1 On the inside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False parsnips, lemons, oranges and several other 7 [' pars', 'n', 'ips', ',', ' le', 'mons', ',', ' oranges']
+95 23 On the inside, the color of x -1 On the inside, the color of lemons yellow lemons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' overall', ' effect', ' is', ' still', ' very', ' bright', '.', '\n', '\n', 'The', ' color', ' of'] " is a bit more muted , but the overall effect is still very bright .
+
+ The color of" False 16,000 cases of lemons from Palermo, and 6 [' 16', ',', '000', ' cases', ' of', ' le', 'mons']
+96 23 On the inside, the color of x -1 On the inside, the color of lemons yellow lemons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' overall', ' effect', ' is', ' still', ' very', ' bright', '.', '\n', '\n', 'The', ' color', ' of'] " is a bit more muted , but the overall effect is still very bright .
+
+ The color of" False " you who can take lemons and make lemonade.""
+" 5 [' you', ' who', ' can', ' take', ' le', 'mons']
+97 23 On the inside, the color of x -1 On the inside, the color of lemons yellow lemons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' overall', ' effect', ' is', ' still', ' very', ' bright', '.', '\n', '\n', 'The', ' color', ' of'] " is a bit more muted , but the overall effect is still very bright .
+
+ The color of" False the six million lemons that had been sold, 4 [' the', ' six', ' million', ' le', 'mons']
+98 23 On the inside, the color of x -1 On the inside, the color of lemons yellow lemons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' overall', ' effect', ' is', ' still', ' very', ' bright', '.', '\n', '\n', 'The', ' color', ' of'] " is a bit more muted , but the overall effect is still very bright .
+
+ The color of" False such as wine, lemons and olive oil. 5 [' such', ' as', ' wine', ',', ' le', 'mons']
+99 23 On the inside, the color of x -1 On the inside, the color of lemons yellow lemons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' overall', ' effect', ' is', ' still', ' very', ' bright', '.', '\n', '\n', 'The', ' color', ' of'] " is a bit more muted , but the overall effect is still very bright .
+
+ The color of" False plastic juice lemons were sold by Coldcrops. 3 [' plastic', ' juice', ' le', 'mons']
+100 24 On the inside, the color of x -1 On the inside, the color of limes green limes ['cale', ' is', ' a', ' dark', ' brown', '.', '\n', '\n', 'The', ' inside', ' of', ' the', ' tank', ' is', ' a', ' dark', ' brown', '.', '\n', '\n'] "cale is a dark brown .
+
+ The inside of the tank is a dark brown .
+
+" False construct a new limes between the 4 [' construct', ' a', ' new', ' l', 'imes']
+101 24 On the inside, the color of x -1 On the inside, the color of limes green limes ['cale', ' is', ' a', ' dark', ' brown', '.', '\n', '\n', 'The', ' inside', ' of', ' the', ' tank', ' is', ' a', ' dark', ' brown', '.', '\n', '\n'] "cale is a dark brown .
+
+ The inside of the tank is a dark brown .
+
+" False those found on Roman limes sites. The ear 5 [' those', ' found', ' on', ' Roman', ' l', 'imes']
+102 24 On the inside, the color of x -1 On the inside, the color of limes green limes ['cale', ' is', ' a', ' dark', ' brown', '.', '\n', '\n', 'The', ' inside', ' of', ' the', ' tank', ' is', ' a', ' dark', ' brown', '.', '\n', '\n'] "cale is a dark brown .
+
+ The inside of the tank is a dark brown .
+
+" False construct a new limes between the Firth 4 [' construct', ' a', ' new', ' l', 'imes']
+103 24 On the inside, the color of x -1 On the inside, the color of limes green limes ['cale', ' is', ' a', ' dark', ' brown', '.', '\n', '\n', 'The', ' inside', ' of', ' the', ' tank', ' is', ' a', ' dark', ' brown', '.', '\n', '\n'] "cale is a dark brown .
+
+ The inside of the tank is a dark brown .
+
+" False his ships with limes to fend off 4 [' his', ' ships', ' with', ' l', 'imes']
+104 24 On the inside, the color of x -1 On the inside, the color of limes green limes ['cale', ' is', ' a', ' dark', ' brown', '.', '\n', '\n', 'The', ' inside', ' of', ' the', ' tank', ' is', ' a', ' dark', ' brown', '.', '\n', '\n'] "cale is a dark brown .
+
+ The inside of the tank is a dark brown .
+
+" False ha) of grapefruit and limes were also damaged, 7 [' ha', ')', ' of', ' grape', 'fruit', ' and', ' l', 'imes']
+105 26 On the inside, the color of x -1 On the inside, the color of blackberries red blackberries [' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' outside', ' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' inside', ' is', ' a'] is a deep , rich purple . The outside is a deep , rich purple . The inside is a False and, in autumn, blackberries to gather. The 6 [' and', ',', ' in', ' autumn', ',', ' black', 'berries']
+106 26 On the inside, the color of x -1 On the inside, the color of blackberries red blackberries [' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' outside', ' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' inside', ' is', ' a'] is a deep , rich purple . The outside is a deep , rich purple . The inside is a False boxthorn, watsonia and blackberries have displaced 9 [' box', 'th', 'orn', ',', ' w', 'ats', 'onia', ' and', ' black', 'berries']
+107 26 On the inside, the color of x -1 On the inside, the color of blackberries red blackberries [' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' outside', ' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' inside', ' is', ' a'] is a deep , rich purple . The outside is a deep , rich purple . The inside is a False " has a fondness for blackberries and grapes.
+" 6 [' has', ' a', ' fond', 'ness', ' for', ' black', 'berries']
+108 26 On the inside, the color of x -1 On the inside, the color of blackberries red blackberries [' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' outside', ' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' inside', ' is', ' a'] is a deep , rich purple . The outside is a deep , rich purple . The inside is a False boxthorn, watsonia and blackberries have displaced native 9 [' box', 'th', 'orn', ',', ' w', 'ats', 'onia', ' and', ' black', 'berries']
+109 26 On the inside, the color of x -1 On the inside, the color of blackberries red blackberries [' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' outside', ' is', ' a', ' deep', ',', ' rich', ' purple', '.', ' The', ' inside', ' is', ' a'] is a deep , rich purple . The outside is a deep , rich purple . The inside is a False watsonia and blackberries have displaced native 5 [' w', 'ats', 'onia', ' and', ' black', 'berries']
+110 27 On the inside, the color of x -1 On the inside, the color of raspberries red raspberries [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more', ' tart'] is a bit more muted , but the flavor is still there . The flavor is a bit more tart False fruits, such as raspberries and strawberries in 6 [' fruits', ',', ' such', ' as', ' r', 'asp', 'berries']
+111 27 On the inside, the color of x -1 On the inside, the color of raspberries red raspberries [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more', ' tart'] is a bit more muted , but the flavor is still there . The flavor is a bit more tart False largest exporter of raspberries in the world 6 [' largest', ' exp', 'orter', ' of', ' r', 'asp', 'berries']
+112 27 On the inside, the color of x -1 On the inside, the color of raspberries red raspberries [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more', ' tart'] is a bit more muted , but the flavor is still there . The flavor is a bit more tart False " watermills, wild raspberries and buttercups ""like" 7 [' water', 'm', 'ills', ',', ' wild', ' r', 'asp', 'berries']
+113 27 On the inside, the color of x -1 On the inside, the color of raspberries red raspberries [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more', ' tart'] is a bit more muted , but the flavor is still there . The flavor is a bit more tart False exporter of raspberries in the world (as 5 [' exp', 'orter', ' of', ' r', 'asp', 'berries']
+114 27 On the inside, the color of x -1 On the inside, the color of raspberries red raspberries [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', ' The', ' flavor', ' is', ' a', ' bit', ' more', ' tart'] is a bit more muted , but the flavor is still there . The flavor is a bit more tart False other fruits, such as raspberries and strawberries in 7 [' other', ' fruits', ',', ' such', ' as', ' r', 'asp', 'berries']
+115 28 On the inside, the color of x -1 On the inside, the color of papayas orange papayas [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'The', ' pap', 'aya', ' is', ' a'] " is a bit more muted , but the flavor is still there .
+
+ The pap aya is a" False like bananas and papayas and uprooted 5 [' like', ' bananas', ' and', ' pap', 'ay', 'as']
+116 28 On the inside, the color of x -1 On the inside, the color of papayas orange papayas [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'The', ' pap', 'aya', ' is', ' a'] " is a bit more muted , but the flavor is still there .
+
+ The pap aya is a" False Gilliard found papayas growing at 5 [' Gill', 'iard', ' found', ' pap', 'ay', 'as']
+117 28 On the inside, the color of x -1 On the inside, the color of papayas orange papayas [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'The', ' pap', 'aya', ' is', ' a'] " is a bit more muted , but the flavor is still there .
+
+ The pap aya is a" False Gilliard found papayas growing at the 5 [' Gill', 'iard', ' found', ' pap', 'ay', 'as']
+118 28 On the inside, the color of x -1 On the inside, the color of papayas orange papayas [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'The', ' pap', 'aya', ' is', ' a'] " is a bit more muted , but the flavor is still there .
+
+ The pap aya is a" False palms, bananas and papayas are not considered 6 [' palms', ',', ' bananas', ' and', ' pap', 'ay', 'as']
+119 29 On the inside, the color of x -1 On the inside, the color of apricots orange apricots [' and', ' the', ' texture', ' of', ' a', ' peach', '.', '\n', '\n', 'The', ' scent', ' of', ' a', ' garden', 'ia', ' and', ' the', ' taste', ' of', ' a'] " and the texture of a peach .
+
+ The scent of a garden ia and the taste of a" False from a dry sheet of apricots soaked in water, 7 [' from', ' a', ' dry', ' sheet', ' of', ' apr', 'ic', 'ots']
+120 29 On the inside, the color of x -1 On the inside, the color of apricots orange apricots [' and', ' the', ' texture', ' of', ' a', ' peach', '.', '\n', '\n', 'The', ' scent', ' of', ' a', ' garden', 'ia', ' and', ' the', ' taste', ' of', ' a'] " and the texture of a peach .
+
+ The scent of a garden ia and the taste of a" False like dates, apricots or raisins, 5 [' like', ' dates', ',', ' apr', 'ic', 'ots']
+121 29 On the inside, the color of x -1 On the inside, the color of apricots orange apricots [' and', ' the', ' texture', ' of', ' a', ' peach', '.', '\n', '\n', 'The', ' scent', ' of', ' a', ' garden', 'ia', ' and', ' the', ' taste', ' of', ' a'] " and the texture of a peach .
+
+ The scent of a garden ia and the taste of a" False " and abundant"" apricots and mulberries. Although" 5 "[' and', ' abundant', '""', ' apr', 'ic', 'ots']"
+122 29 On the inside, the color of x -1 On the inside, the color of apricots orange apricots [' and', ' the', ' texture', ' of', ' a', ' peach', '.', '\n', '\n', 'The', ' scent', ' of', ' a', ' garden', 'ia', ' and', ' the', ' taste', ' of', ' a'] " and the texture of a peach .
+
+ The scent of a garden ia and the taste of a" False California, where he grew apricots on a ranch. He died 7 [' California', ',', ' where', ' he', ' grew', ' apr', 'ic', 'ots']
+123 29 On the inside, the color of x -1 On the inside, the color of apricots orange apricots [' and', ' the', ' texture', ' of', ' a', ' peach', '.', '\n', '\n', 'The', ' scent', ' of', ' a', ' garden', 'ia', ' and', ' the', ' taste', ' of', ' a'] " and the texture of a peach .
+
+ The scent of a garden ia and the taste of a" False dry sheet of apricots soaked in water, 5 [' dry', ' sheet', ' of', ' apr', 'ic', 'ots']
+124 30 On the inside, the color of x -1 On the inside, the color of tomatoes red tomatoes [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' firm', ' and', ' juicy', '.', '\n', '\n', 'The', ' fruit', ' is', ' harvested', ' when'] " is a deep red , and the flesh is firm and juicy .
+
+ The fruit is harvested when" True stop purchasing tomatoes from farms 2 [' stop', ' purchasing', ' tomatoes']
+125 30 On the inside, the color of x -1 On the inside, the color of tomatoes red tomatoes [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' firm', ' and', ' juicy', '.', '\n', '\n', 'The', ' fruit', ' is', ' harvested', ' when'] " is a deep red , and the flesh is firm and juicy .
+
+ The fruit is harvested when" True potatoes, onions and tomatoes. Some recipes 4 [' potatoes', ',', ' onions', ' and', ' tomatoes']
+126 30 On the inside, the color of x -1 On the inside, the color of tomatoes red tomatoes [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' firm', ' and', ' juicy', '.', '\n', '\n', 'The', ' fruit', ' is', ' harvested', ' when'] " is a deep red , and the flesh is firm and juicy .
+
+ The fruit is harvested when" True in eggs, 3rd in tomatoes, 6th in soybeans, 6 [' in', ' eggs', ',', ' 3', 'rd', ' in', ' tomatoes']
+127 30 On the inside, the color of x -1 On the inside, the color of tomatoes red tomatoes [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' firm', ' and', ' juicy', '.', '\n', '\n', 'The', ' fruit', ' is', ' harvested', ' when'] " is a deep red , and the flesh is firm and juicy .
+
+ The fruit is harvested when" True consisting of spiced tomatoes and onions, 4 [' consisting', ' of', ' sp', 'iced', ' tomatoes']
+128 30 On the inside, the color of x -1 On the inside, the color of tomatoes red tomatoes [' is', ' a', ' deep', ' red', ',', ' and', ' the', ' flesh', ' is', ' firm', ' and', ' juicy', '.', '\n', '\n', 'The', ' fruit', ' is', ' harvested', ' when'] " is a deep red , and the flesh is firm and juicy .
+
+ The fruit is harvested when" True chopped greens, onions, tomatoes and chilies 5 [' chopped', ' greens', ',', ' onions', ',', ' tomatoes']
+129 31 On the inside, the color of x -1 On the inside, the color of bell peppers red bell peppers [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ','] is a deep , rich , and vibrant red . On the outside , the color is a bright , True such as eggplants, bell peppers and tomatoes, 7 [' such', ' as', ' egg', 'pl', 'ants', ',', ' bell', ' peppers']
+130 31 On the inside, the color of x -1 On the inside, the color of bell peppers red bell peppers [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ','] is a deep , rich , and vibrant red . On the outside , the color is a bright , True Onions and bell peppers were long part 4 [' On', 'ions', ' and', ' bell', ' peppers']
+131 31 On the inside, the color of x -1 On the inside, the color of bell peppers red bell peppers [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ','] is a deep , rich , and vibrant red . On the outside , the color is a bright , True of grass, green bell peppers and nettles 5 [' of', ' grass', ',', ' green', ' bell', ' peppers']
+132 31 On the inside, the color of x -1 On the inside, the color of bell peppers red bell peppers [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ','] is a deep , rich , and vibrant red . On the outside , the color is a bright , True such as eggplants, bell peppers and tomatoes, 7 [' such', ' as', ' egg', 'pl', 'ants', ',', ' bell', ' peppers']
+133 31 On the inside, the color of x -1 On the inside, the color of bell peppers red bell peppers [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' red', '.', ' On', ' the', ' outside', ',', ' the', ' color', ' is', ' a', ' bright', ','] is a deep , rich , and vibrant red . On the outside , the color is a bright , True pronounced green bell peppers and vegetal flavors. 3 [' pronounced', ' green', ' bell', ' peppers']
+134 32 On the inside, the color of x -1 On the inside, the color of persimmons orange persimmons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'The', ' flavor', ' of', ' the', ' pers'] " is a bit more muted , but the flavor is still there .
+
+ The flavor of the pers" False him a couple persimmons, one of the most 5 [' him', ' a', ' couple', ' pers', 'imm', 'ons']
+135 32 On the inside, the color of x -1 On the inside, the color of persimmons orange persimmons [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'The', ' flavor', ' of', ' the', ' pers'] " is a bit more muted , but the flavor is still there .
+
+ The flavor of the pers" False him a couple persimmons, one of the most 5 [' him', ' a', ' couple', ' pers', 'imm', 'ons']
+136 33 On the inside, the color of x -1 On the inside, the color of lychees white lychees [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'The', ' ly', 'che', 'e', ' is'] " is a bit more muted , but the flavor is still there .
+
+ The ly che e is" False resources include coffee, lychees and shrimp. Key 6 [' resources', ' include', ' coffee', ',', ' ly', 'che', 'es']
+137 33 On the inside, the color of x -1 On the inside, the color of lychees white lychees [' is', ' a', ' bit', ' more', ' muted', ',', ' but', ' the', ' flavor', ' is', ' still', ' there', '.', '\n', '\n', 'The', ' ly', 'che', 'e', ' is'] " is a bit more muted , but the flavor is still there .
+
+ The ly che e is" False include coffee, lychees and shrimp. 5 [' include', ' coffee', ',', ' ly', 'che', 'es']
+138 34 On the inside, the color of x -1 On the inside, the color of coconuts white coconuts [' is', ' a', ' bit', ' different', ' from', ' the', ' outside', '.', ' The', ' inside', ' is', ' a', ' bit', ' darker', ' and', ' the', ' color', ' is', ' a', ' bit'] is a bit different from the outside . The inside is a bit darker and the color is a bit False trade was coconuts and coir. The 4 [' trade', ' was', ' coc', 'on', 'uts']
+139 34 On the inside, the color of x -1 On the inside, the color of coconuts white coconuts [' is', ' a', ' bit', ' different', ' from', ' the', ' outside', '.', ' The', ' inside', ' is', ' a', ' bit', ' darker', ' and', ' the', ' color', ' is', ' a', ' bit'] is a bit different from the outside . The inside is a bit darker and the color is a bit False injured by falling coconuts and other airborne 5 [' injured', ' by', ' falling', ' coc', 'on', 'uts']
+140 34 On the inside, the color of x -1 On the inside, the color of coconuts white coconuts [' is', ' a', ' bit', ' different', ' from', ' the', ' outside', '.', ' The', ' inside', ' is', ' a', ' bit', ' darker', ' and', ' the', ' color', ' is', ' a', ' bit'] is a bit different from the outside . The inside is a bit darker and the color is a bit False climb trees to pick coconuts, which it then 6 [' climb', ' trees', ' to', ' pick', ' coc', 'on', 'uts']
+141 34 On the inside, the color of x -1 On the inside, the color of coconuts white coconuts [' is', ' a', ' bit', ' different', ' from', ' the', ' outside', '.', ' The', ' inside', ' is', ' a', ' bit', ' darker', ' and', ' the', ' color', ' is', ' a', ' bit'] is a bit different from the outside . The inside is a bit darker and the color is a bit False shell. These include coconuts (Cocos nucifera), 6 [' shell', '.', ' These', ' include', ' coc', 'on', 'uts']
+142 34 On the inside, the color of x -1 On the inside, the color of coconuts white coconuts [' is', ' a', ' bit', ' different', ' from', ' the', ' outside', '.', ' The', ' inside', ' is', ' a', ' bit', ' darker', ' and', ' the', ' color', ' is', ' a', ' bit'] is a bit different from the outside . The inside is a bit darker and the color is a bit False main trade was coconuts and coir. After Antonio 5 [' main', ' trade', ' was', ' coc', 'on', 'uts']
diff --git a/patchscopes/code/preprocessed_data/commonsense/fruit_outside_color.tsv b/patchscopes/code/preprocessed_data/commonsense/fruit_outside_color.tsv
new file mode 100644
index 00000000..de6d74a1
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/commonsense/fruit_outside_color.tsv
@@ -0,0 +1,271 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 On the outside, the color of x -1 On the outside, the color of bananas yellow bananas [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The banana is a fruit that" True effects of a diet of bananas in 1924. This 5 [' effects', ' of', ' a', ' diet', ' of', ' bananas']
+1 0 On the outside, the color of x -1 On the outside, the color of bananas yellow bananas [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The banana is a fruit that" True primarily from selling bananas in a shopping 3 [' primarily', ' from', ' selling', ' bananas']
+2 0 On the outside, the color of x -1 On the outside, the color of bananas yellow bananas [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The banana is a fruit that" True exhibiting a bunch of bananas in a gallery, and they'll 4 [' exhibiting', ' a', ' bunch', ' of', ' bananas']
+3 0 On the outside, the color of x -1 On the outside, the color of bananas yellow bananas [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The banana is a fruit that" True consumerism by tossing bananas at East Germans 4 [' consumer', 'ism', ' by', ' tossing', ' bananas']
+4 0 On the outside, the color of x -1 On the outside, the color of bananas yellow bananas [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The banana is a fruit that" True 155,000 short tons) of bananas and 87,078 tonnes 7 [' 155', ',', '000', ' short', ' tons', ')', ' of', ' bananas']
+5 1 On the outside, the color of x -1 On the outside, the color of apples red apples [' is', ' a', ' deep', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' apple'] " is a deep red , but on the inside , they are a deep purple .
+
+ The apple" True about 75 % of apples and pears were lost. 4 [' about', ' 75', ' %', ' of', ' apples']
+6 1 On the outside, the color of x -1 On the outside, the color of apples red apples [' is', ' a', ' deep', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' apple'] " is a deep red , but on the inside , they are a deep purple .
+
+ The apple" True " data are a blend of apples and oranges""," 5 [' data', ' are', ' a', ' blend', ' of', ' apples']
+7 1 On the outside, the color of x -1 On the outside, the color of apples red apples [' is', ' a', ' deep', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' apple'] " is a deep red , but on the inside , they are a deep purple .
+
+ The apple" True Sabrina, throws apples at a house, and 4 [' Sab', 'rina', ',', ' throws', ' apples']
+8 1 On the outside, the color of x -1 On the outside, the color of apples red apples [' is', ' a', ' deep', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' apple'] " is a deep red , but on the inside , they are a deep purple .
+
+ The apple" True objects such as apples or bottles. Although 3 [' objects', ' such', ' as', ' apples']
+9 1 On the outside, the color of x -1 On the outside, the color of apples red apples [' is', ' a', ' deep', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' apple'] " is a deep red , but on the inside , they are a deep purple .
+
+ The apple" True 1 ['app', 'les']
+10 2 On the outside, the color of x -1 On the outside, the color of watermelons green watermelons "[',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' I']" ", but on the inside , the color of the sky .
+
+ I 'm not sure if I" False helps him carry watermelons to the staff 5 [' helps', ' him', ' carry', ' water', 'mel', 'ons']
+11 2 On the outside, the color of x -1 On the outside, the color of watermelons green watermelons "[',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' I']" ", but on the inside , the color of the sky .
+
+ I 'm not sure if I" False spit seeds by eating watermelons — which were present 6 [' spit', ' seeds', ' by', ' eating', ' water', 'mel', 'ons']
+12 2 On the outside, the color of x -1 On the outside, the color of watermelons green watermelons "[',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' I']" ", but on the inside , the color of the sky .
+
+ I 'm not sure if I" False spit seeds by eating watermelons — which were 6 [' spit', ' seeds', ' by', ' eating', ' water', 'mel', 'ons']
+13 2 On the outside, the color of x -1 On the outside, the color of watermelons green watermelons "[',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' I']" ", but on the inside , the color of the sky .
+
+ I 'm not sure if I" False and helps him carry watermelons to the staff 6 [' and', ' helps', ' him', ' carry', ' water', 'mel', 'ons']
+14 2 On the outside, the color of x -1 On the outside, the color of watermelons green watermelons "[',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' I']" ", but on the inside , the color of the sky .
+
+ I 'm not sure if I" False seeds by eating watermelons — which were 5 [' seeds', ' by', ' eating', ' water', 'mel', 'ons']
+15 3 On the outside, the color of x -1 On the outside, the color of kiwis brown kiwis [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' k', 'iw', 'i', ' fruit', ' is'] " is green , but on the inside , it is yellow .
+
+ The k iw i fruit is" False 2 ['ki', 'w', 'is']
+16 3 On the outside, the color of x -1 On the outside, the color of kiwis brown kiwis [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' k', 'iw', 'i', ' fruit', ' is'] " is green , but on the inside , it is yellow .
+
+ The k iw i fruit is" False great spotted kiwis lived in New 4 [' great', ' spotted', ' k', 'iw', 'is']
+17 3 On the outside, the color of x -1 On the outside, the color of kiwis brown kiwis [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' k', 'iw', 'i', ' fruit', ' is'] " is green , but on the inside , it is yellow .
+
+ The k iw i fruit is" False uses it to protect kiwis from the invasive 6 [' uses', ' it', ' to', ' protect', ' k', 'iw', 'is']
+18 3 On the outside, the color of x -1 On the outside, the color of kiwis brown kiwis [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' k', 'iw', 'i', ' fruit', ' is'] " is green , but on the inside , it is yellow .
+
+ The k iw i fruit is" False million great spotted kiwis lived in New Zealand. 5 [' million', ' great', ' spotted', ' k', 'iw', 'is']
+19 3 On the outside, the color of x -1 On the outside, the color of kiwis brown kiwis [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' k', 'iw', 'i', ' fruit', ' is'] " is green , but on the inside , it is yellow .
+
+ The k iw i fruit is" False great spotted kiwis are large and 4 [' great', ' spotted', ' k', 'iw', 'is']
+20 5 On the outside, the color of x -1 On the outside, the color of eggplants purple eggplants [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', ' The', ' white', ' color', ' is', ' due', ' to'] is a deep purple , but on the inside , they are white . The white color is due to True potatoes, mallow, eggplants and zucchini 7 [' potatoes', ',', ' m', 'allow', ',', ' egg', 'pl', 'ants']
+21 5 On the outside, the color of x -1 On the outside, the color of eggplants purple eggplants [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', ' The', ' white', ' color', ' is', ' due', ' to'] is a deep purple , but on the inside , they are white . The white color is due to True leaves, potatoes, mallow, eggplants and zucchini squash, 9 [' leaves', ',', ' potatoes', ',', ' m', 'allow', ',', ' egg', 'pl', 'ants']
+22 5 On the outside, the color of x -1 On the outside, the color of eggplants purple eggplants [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', ' The', ' white', ' color', ' is', ' due', ' to'] is a deep purple , but on the inside , they are white . The white color is due to True and uprooted two Bt eggplants and more than 8 [' and', ' upro', 'oted', ' two', ' B', 't', ' egg', 'pl', 'ants']
+23 5 On the outside, the color of x -1 On the outside, the color of eggplants purple eggplants [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', ' The', ' white', ' color', ' is', ' due', ' to'] is a deep purple , but on the inside , they are white . The white color is due to True potatoes, mallow, eggplants and zucchini 7 [' potatoes', ',', ' m', 'allow', ',', ' egg', 'pl', 'ants']
+24 5 On the outside, the color of x -1 On the outside, the color of eggplants purple eggplants [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', ' The', ' white', ' color', ' is', ' due', ' to'] is a deep purple , but on the inside , they are white . The white color is due to True potatoes, mallow, eggplants and zucchini squash, 7 [' potatoes', ',', ' m', 'allow', ',', ' egg', 'pl', 'ants']
+25 6 On the outside, the color of x -1 On the outside, the color of zucchinis green zucchinis [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' it', '�', '�', 's', ' yellow', '.', '\n', '\n', 'The', ' z', 'uc', 'chini'] " is green , but on the inside , it � � s yellow .
+
+ The z uc chini" True Kousa mahshi are zucchinis stuffed with 9 [' K', 'ous', 'a', ' mah', 'shi', ' are', ' z', 'uc', 'ch', 'inis']
+26 6 On the outside, the color of x -1 On the outside, the color of zucchinis green zucchinis [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' it', '�', '�', 's', ' yellow', '.', '\n', '\n', 'The', ' z', 'uc', 'chini'] " is green , but on the inside , it � � s yellow .
+
+ The z uc chini" True Kousa mahshi are zucchinis stuffed with 9 [' K', 'ous', 'a', ' mah', 'shi', ' are', ' z', 'uc', 'ch', 'inis']
+27 6 On the outside, the color of x -1 On the outside, the color of zucchinis green zucchinis [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' it', '�', '�', 's', ' yellow', '.', '\n', '\n', 'The', ' z', 'uc', 'chini'] " is green , but on the inside , it � � s yellow .
+
+ The z uc chini" True Kousa mahshi are zucchinis stuffed with 9 [' K', 'ous', 'a', ' mah', 'shi', ' are', ' z', 'uc', 'ch', 'inis']
+28 7 On the outside, the color of x -1 On the outside, the color of pineapples brown pineapples [' and', ' the', ' color', ' of', ' the', ' sky', ' are', ' the', ' same', '.', ' But', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' pine', 'app'] and the color of the sky are the same . But on the inside , the color of pine app False with a load of pineapples — was raided by 6 [' with', ' a', ' load', ' of', ' pine', 'app', 'les']
+29 7 On the outside, the color of x -1 On the outside, the color of pineapples brown pineapples [' and', ' the', ' color', ' of', ' the', ' sky', ' are', ' the', ' same', '.', ' But', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' pine', 'app'] and the color of the sky are the same . But on the inside , the color of pine app False spitroast cooks pineapples for one of the restaurant's 6 [' spit', 'ro', 'ast', ' cooks', ' pine', 'app', 'les']
+30 7 On the outside, the color of x -1 On the outside, the color of pineapples brown pineapples [' and', ' the', ' color', ' of', ' the', ' sky', ' are', ' the', ' same', '.', ' But', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' pine', 'app'] and the color of the sky are the same . But on the inside , the color of pine app False with a load of pineapples — was raided by 6 [' with', ' a', ' load', ' of', ' pine', 'app', 'les']
+31 7 On the outside, the color of x -1 On the outside, the color of pineapples brown pineapples [' and', ' the', ' color', ' of', ' the', ' sky', ' are', ' the', ' same', '.', ' But', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' pine', 'app'] and the color of the sky are the same . But on the inside , the color of pine app False spitroast cooks pineapples for one of the restaurant's 6 [' spit', 'ro', 'ast', ' cooks', ' pine', 'app', 'les']
+32 7 On the outside, the color of x -1 On the outside, the color of pineapples brown pineapples [' and', ' the', ' color', ' of', ' the', ' sky', ' are', ' the', ' same', '.', ' But', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' pine', 'app'] and the color of the sky are the same . But on the inside , the color of pine app False its gilded pineapples is now lost). In 1922, 5 [' its', ' g', 'ilded', ' pine', 'app', 'les']
+33 8 On the outside, the color of x -1 On the outside, the color of mangoes green mangoes [' is', ' a', ' deep', ',', ' rich', ',', ' golden', ' yellow', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant'] is a deep , rich , golden yellow . On the inside , the color is a bright , vibrant False rubber, turmeric, mangoes and spices. Other 6 [' rubber', ',', ' tur', 'meric', ',', ' mango', 'es']
+34 8 On the outside, the color of x -1 On the outside, the color of mangoes green mangoes [' is', ' a', ' deep', ',', ' rich', ',', ' golden', ' yellow', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant'] is a deep , rich , golden yellow . On the inside , the color is a bright , vibrant False (for paan), langra mangoes and khoa (solidified 8 [' (', 'for', ' pa', 'an', '),', ' lang', 'ra', ' mango', 'es']
+35 8 On the outside, the color of x -1 On the outside, the color of mangoes green mangoes [' is', ' a', ' deep', ',', ' rich', ',', ' golden', ' yellow', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant'] is a deep , rich , golden yellow . On the inside , the color is a bright , vibrant False (for paan), langra mangoes and khoa (solidified 8 [' (', 'for', ' pa', 'an', '),', ' lang', 'ra', ' mango', 'es']
+36 8 On the outside, the color of x -1 On the outside, the color of mangoes green mangoes [' is', ' a', ' deep', ',', ' rich', ',', ' golden', ' yellow', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant'] is a deep , rich , golden yellow . On the inside , the color is a bright , vibrant False rainfall damaged mangoes in Maharashtra, 3 [' rainfall', ' damaged', ' mango', 'es']
+37 8 On the outside, the color of x -1 On the outside, the color of mangoes green mangoes [' is', ' a', ' deep', ',', ' rich', ',', ' golden', ' yellow', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright', ',', ' vibrant'] is a deep , rich , golden yellow . On the inside , the color is a bright , vibrant False rubber, turmeric, mangoes and spices. Other major 6 [' rubber', ',', ' tur', 'meric', ',', ' mango', 'es']
+38 9 On the outside, the color of x -1 On the outside, the color of cucumbers green cucumbers [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' cuc', 'umber', ' is', ' a', ' fruit'] " is green , but on the inside , they are white .
+
+ The cuc umber is a fruit" True bulgur, diced tomatoes, cucumbers and is sautéed with 7 [' bul', 'gur', ',', ' diced', ' tomatoes', ',', ' cuc', 'umbers']
+39 9 On the outside, the color of x -1 On the outside, the color of cucumbers green cucumbers [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' cuc', 'umber', ' is', ' a', ' fruit'] " is green , but on the inside , they are white .
+
+ The cuc umber is a fruit" True tomatoes and cucumbers are the main ingredients. 3 [' tomatoes', ' and', ' cuc', 'umbers']
+40 9 On the outside, the color of x -1 On the outside, the color of cucumbers green cucumbers [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' cuc', 'umber', ' is', ' a', ' fruit'] " is green , but on the inside , they are white .
+
+ The cuc umber is a fruit" True diced tomatoes and cucumbers combined with olive 4 [' diced', ' tomatoes', ' and', ' cuc', 'umbers']
+41 9 On the outside, the color of x -1 On the outside, the color of cucumbers green cucumbers [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' cuc', 'umber', ' is', ' a', ' fruit'] " is green , but on the inside , they are white .
+
+ The cuc umber is a fruit" True diced tomatoes and cucumbers combined with 4 [' diced', ' tomatoes', ' and', ' cuc', 'umbers']
+42 9 On the outside, the color of x -1 On the outside, the color of cucumbers green cucumbers [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' cuc', 'umber', ' is', ' a', ' fruit'] " is green , but on the inside , they are white .
+
+ The cuc umber is a fruit" True bulgur, diced tomatoes, cucumbers and is sautéed with 7 [' bul', 'gur', ',', ' diced', ' tomatoes', ',', ' cuc', 'umbers']
+43 10 On the outside, the color of x -1 On the outside, the color of radishes pink radishes [' is', ' a', ' bright', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' rad', 'ish', ' is', ' a'] " is a bright red , but the inside is a deep purple .
+
+ The rad ish is a" False million tons of radishes are produced 4 [' million', ' tons', ' of', ' rad', 'ishes']
+44 10 On the outside, the color of x -1 On the outside, the color of radishes pink radishes [' is', ' a', ' bright', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' rad', 'ish', ' is', ' a'] " is a bright red , but the inside is a deep purple .
+
+ The rad ish is a" False taste somewhat like radishes or cucumber. 4 [' taste', ' somewhat', ' like', ' rad', 'ishes']
+45 10 On the outside, the color of x -1 On the outside, the color of radishes pink radishes [' is', ' a', ' bright', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' rad', 'ish', ' is', ' a'] " is a bright red , but the inside is a deep purple .
+
+ The rad ish is a" False " leeks, garlic, radishes and lettuces.
+" 6 [' le', 'eks', ',', ' garlic', ',', ' rad', 'ishes']
+46 10 On the outside, the color of x -1 On the outside, the color of radishes pink radishes [' is', ' a', ' bright', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' rad', 'ish', ' is', ' a'] " is a bright red , but the inside is a deep purple .
+
+ The rad ish is a" False The seeds of radishes can be pressed to extract 4 [' The', ' seeds', ' of', ' rad', 'ishes']
+47 10 On the outside, the color of x -1 On the outside, the color of radishes pink radishes [' is', ' a', ' bright', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' deep', ' purple', '.', '\n', '\n', 'The', ' rad', 'ish', ' is', ' a'] " is a bright red , but the inside is a deep purple .
+
+ The rad ish is a" False warmer climates, radishes are normally planted 4 [' warmer', ' climates', ',', ' rad', 'ishes']
+48 12 On the outside, the color of x -1 On the outside, the color of nectarines red nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright'] is a deep , rich , almost purple - red . On the inside , the color is a bright True peaches, grapes and nectarines were grown in greenhouses. 7 [' pe', 'aches', ',', ' grapes', ' and', ' n', 'ectar', 'ines']
+49 12 On the outside, the color of x -1 On the outside, the color of nectarines red nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright'] is a deep , rich , almost purple - red . On the inside , the color is a bright True designation), lemons, nectarines and figs are grown 7 [' designation', '),', ' le', 'mons', ',', ' n', 'ectar', 'ines']
+50 12 On the outside, the color of x -1 On the outside, the color of nectarines red nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright'] is a deep , rich , almost purple - red . On the inside , the color is a bright True grapes and nectarines were grown 4 [' grapes', ' and', ' n', 'ectar', 'ines']
+51 12 On the outside, the color of x -1 On the outside, the color of nectarines red nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright'] is a deep , rich , almost purple - red . On the inside , the color is a bright True peaches, grapes and nectarines were grown in greenhouses. 7 [' pe', 'aches', ',', ' grapes', ' and', ' n', 'ectar', 'ines']
+52 12 On the outside, the color of x -1 On the outside, the color of nectarines red nectarines [' is', ' a', ' deep', ',', ' rich', ',', ' almost', ' purple', '-', 'red', '.', ' On', ' the', ' inside', ',', ' the', ' color', ' is', ' a', ' bright'] is a deep , rich , almost purple - red . On the inside , the color is a bright True peaches, grapes and nectarines were grown in 7 [' pe', 'aches', ',', ' grapes', ' and', ' n', 'ectar', 'ines']
+53 13 On the outside, the color of x -1 On the outside, the color of plums purple plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', ',', ' the', ' taste', ' of', ' a', ' ripe', ' peach', ',', ' the', ' smell', ' of', ' a', ' ripe'] and the texture of a ripe peach , the taste of a ripe peach , the smell of a ripe False apple butter), plums and walnuts. Cream 4 [' apple', ' butter', '),', ' pl', 'ums']
+54 13 On the outside, the color of x -1 On the outside, the color of plums purple plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', ',', ' the', ' taste', ' of', ' a', ' ripe', ' peach', ',', ' the', ' smell', ' of', ' a', ' ripe'] and the texture of a ripe peach , the taste of a ripe peach , the smell of a ripe False longifolia, the green plums of Panda oleosa 7 [' long', 'if', 'olia', ',', ' the', ' green', ' pl', 'ums']
+55 13 On the outside, the color of x -1 On the outside, the color of plums purple plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', ',', ' the', ' taste', ' of', ' a', ' ripe', ' peach', ',', ' the', ' smell', ' of', ' a', ' ripe'] and the texture of a ripe peach , the taste of a ripe peach , the smell of a ripe False " described as ""plums"", many plum" 4 "[' described', ' as', ' ""', 'pl', 'ums']"
+56 13 On the outside, the color of x -1 On the outside, the color of plums purple plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', ',', ' the', ' taste', ' of', ' a', ' ripe', ' peach', ',', ' the', ' smell', ' of', ' a', ' ripe'] and the texture of a ripe peach , the taste of a ripe peach , the smell of a ripe False referred to as plums since around 4 [' referred', ' to', ' as', ' pl', 'ums']
+57 13 On the outside, the color of x -1 On the outside, the color of plums purple plums [' and', ' the', ' texture', ' of', ' a', ' ripe', ' peach', ',', ' the', ' taste', ' of', ' a', ' ripe', ' peach', ',', ' the', ' smell', ' of', ' a', ' ripe'] and the texture of a ripe peach , the taste of a ripe peach , the smell of a ripe False " described as ""plums"", many plum cakes" 4 "[' described', ' as', ' ""', 'pl', 'ums']"
+58 16 On the outside, the color of x -1 On the outside, the color of limes green limes ['cale', ' is', ' a', ' dull', ',', ' gray', 'ish', '-', 'brown', '.', ' On', ' the', ' inside', ',', ' it', ' is', ' a', ' bright', ',', ' yellow'] cale is a dull , gray ish - brown . On the inside , it is a bright , yellow False ha) of grapefruit and limes were also damaged, 7 [' ha', ')', ' of', ' grape', 'fruit', ' and', ' l', 'imes']
+59 16 On the outside, the color of x -1 On the outside, the color of limes green limes ['cale', ' is', ' a', ' dull', ',', ' gray', 'ish', '-', 'brown', '.', ' On', ' the', ' inside', ',', ' it', ' is', ' a', ' bright', ',', ' yellow'] cale is a dull , gray ish - brown . On the inside , it is a bright , yellow False stocked his ships with limes to fend off scurvy, 5 [' stocked', ' his', ' ships', ' with', ' l', 'imes']
+60 16 On the outside, the color of x -1 On the outside, the color of limes green limes ['cale', ' is', ' a', ' dull', ',', ' gray', 'ish', '-', 'brown', '.', ' On', ' the', ' inside', ',', ' it', ' is', ' a', ' bright', ',', ' yellow'] cale is a dull , gray ish - brown . On the inside , it is a bright , yellow False recommended using lemons and limes to avoid scurvy, which 6 [' recommended', ' using', ' le', 'mons', ' and', ' l', 'imes']
+61 16 On the outside, the color of x -1 On the outside, the color of limes green limes ['cale', ' is', ' a', ' dull', ',', ' gray', 'ish', '-', 'brown', '.', ' On', ' the', ' inside', ',', ' it', ' is', ' a', ' bright', ',', ' yellow'] cale is a dull , gray ish - brown . On the inside , it is a bright , yellow False construction of a new limes between the Firth 5 [' construction', ' of', ' a', ' new', ' l', 'imes']
+62 16 On the outside, the color of x -1 On the outside, the color of limes green limes ['cale', ' is', ' a', ' dull', ',', ' gray', 'ish', '-', 'brown', '.', ' On', ' the', ' inside', ',', ' it', ' is', ' a', ' bright', ',', ' yellow'] cale is a dull , gray ish - brown . On the inside , it is a bright , yellow False construct a new limes between the Firth 4 [' construct', ' a', ' new', ' l', 'imes']
+63 17 On the outside, the color of x -1 On the outside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False focus on growing oranges for export. Khrushchev 3 [' focus', ' on', ' growing', ' oranges']
+64 17 On the outside, the color of x -1 On the outside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False " pinks, through to oranges and whites.
+" 5 [' p', 'inks', ',', ' through', ' to', ' oranges']
+65 17 On the outside, the color of x -1 On the outside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False In addition to oranges and grapes, the 3 [' In', ' addition', ' to', ' oranges']
+66 17 On the outside, the color of x -1 On the outside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False " poranges, who says,
+" 1 [' por', 'anges']
+67 17 On the outside, the color of x -1 On the outside, the color of oranges orange oranges [' and', ' yell', 'ows', ',', ' the', ' color', ' of', ' the', ' sun', ',', ' the', ' color', ' of', ' the', ' sky', ',', ' the', ' color', ' of', ' the'] and yell ows , the color of the sun , the color of the sky , the color of the False producer of oranges in the world, damage 2 [' producer', ' of', ' oranges']
+68 18 On the outside, the color of x -1 On the outside, the color of peaches pink peaches [' and', ' cream', ',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' blood', '.', '\n', '\n', 'The', ' first', ' time', ' I', ' saw'] " and cream , but on the inside , the color of blood .
+
+ The first time I saw" False " with people. He's all peaches and cream on television.""" 7 "[' with', ' people', '.', ' He', ""'s"", ' all', ' pe', 'aches']"
+69 18 On the outside, the color of x -1 On the outside, the color of peaches pink peaches [' and', ' cream', ',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' blood', '.', '\n', '\n', 'The', ' first', ' time', ' I', ' saw'] " and cream , but on the inside , the color of blood .
+
+ The first time I saw" False Shanghai, 20 % of peaches and 40 % of pears 6 [' Shanghai', ',', ' 20', ' %', ' of', ' pe', 'aches']
+70 18 On the outside, the color of x -1 On the outside, the color of peaches pink peaches [' and', ' cream', ',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' blood', '.', '\n', '\n', 'The', ' first', ' time', ' I', ' saw'] " and cream , but on the inside , the color of blood .
+
+ The first time I saw" False the Chinese acquired peaches from Samarkand, 4 [' the', ' Chinese', ' acquired', ' pe', 'aches']
+71 18 On the outside, the color of x -1 On the outside, the color of peaches pink peaches [' and', ' cream', ',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' blood', '.', '\n', '\n', 'The', ' first', ' time', ' I', ' saw'] " and cream , but on the inside , the color of blood .
+
+ The first time I saw" False Shanghai, 20 % of peaches and 40 % of pears 6 [' Shanghai', ',', ' 20', ' %', ' of', ' pe', 'aches']
+72 18 On the outside, the color of x -1 On the outside, the color of peaches pink peaches [' and', ' cream', ',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' blood', '.', '\n', '\n', 'The', ' first', ' time', ' I', ' saw'] " and cream , but on the inside , the color of blood .
+
+ The first time I saw" False Shanghai, 20 % of peaches and 40 % of 6 [' Shanghai', ',', ' 20', ' %', ' of', ' pe', 'aches']
+73 19 On the outside, the color of x -1 On the outside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' Inside', ',', ' the', ' fruit', ' is', ' a', ' bright', ',', ' deep', ' red', ',', ' with', ' a'] is a deep , rich red . Inside , the fruit is a bright , deep red , with a True coastal plain. Figs, pomegranates and olives also grow 9 [' coastal', ' plain', '.', ' F', 'igs', ',', ' p', 'ome', 'gran', 'ates']
+74 19 On the outside, the color of x -1 On the outside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' Inside', ',', ' the', ' fruit', ' is', ' a', ' bright', ',', ' deep', ' red', ',', ' with', ' a'] is a deep , rich red . Inside , the fruit is a bright , deep red , with a True plain. Figs, pomegranates and olives also grow 8 [' plain', '.', ' F', 'igs', ',', ' p', 'ome', 'gran', 'ates']
+75 19 On the outside, the color of x -1 On the outside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' Inside', ',', ' the', ' fruit', ' is', ' a', ' bright', ',', ' deep', ' red', ',', ' with', ' a'] is a deep , rich red . Inside , the fruit is a bright , deep red , with a True by monkeys, pomegranates and nesting birds. 6 [' by', ' monkeys', ',', ' p', 'ome', 'gran', 'ates']
+76 19 On the outside, the color of x -1 On the outside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' Inside', ',', ' the', ' fruit', ' is', ' a', ' bright', ',', ' deep', ' red', ',', ' with', ' a'] is a deep , rich red . Inside , the fruit is a bright , deep red , with a True symbolised by monkeys, pomegranates and nesting birds. 8 [' symbol', 'ised', ' by', ' monkeys', ',', ' p', 'ome', 'gran', 'ates']
+77 19 On the outside, the color of x -1 On the outside, the color of pomegranates red pomegranates [' is', ' a', ' deep', ',', ' rich', ' red', '.', ' Inside', ',', ' the', ' fruit', ' is', ' a', ' bright', ',', ' deep', ' red', ',', ' with', ' a'] is a deep , rich red . Inside , the fruit is a bright , deep red , with a True such as palm trees, pomegranates and lilies in 8 [' such', ' as', ' palm', ' trees', ',', ' p', 'ome', 'gran', 'ates']
+78 20 On the outside, the color of x -1 On the outside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' dark', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' and', ' juicy', ' red', '.'] is a deep , dark red , but the inside is a bright , vibrant , and juicy red . True tank. The best cherries sink to the 5 [' tank', '.', ' The', ' best', ' cher', 'ries']
+79 20 On the outside, the color of x -1 On the outside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' dark', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' and', ' juicy', ' red', '.'] is a deep , dark red , but the inside is a bright , vibrant , and juicy red . True currants, raisins, cherries conserved, barberries 8 [' cur', 'rants', ',', ' ra', 'is', 'ins', ',', ' cher', 'ries']
+80 20 On the outside, the color of x -1 On the outside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' dark', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' and', ' juicy', ' red', '.'] is a deep , dark red , but the inside is a bright , vibrant , and juicy red . True Lilacs, weeping cherries and hibiscus were 5 [' Lil', 'acs', ',', ' weeping', ' cher', 'ries']
+81 20 On the outside, the color of x -1 On the outside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' dark', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' and', ' juicy', ' red', '.'] is a deep , dark red , but the inside is a bright , vibrant , and juicy red . True " saying: ""Don 't eat cherries with great lords" 8 "[' saying', ':', ' ""', 'Don', "" '"", 't', ' eat', ' cher', 'ries']"
+82 20 On the outside, the color of x -1 On the outside, the color of cherries red cherries [' is', ' a', ' deep', ',', ' dark', ' red', ',', ' but', ' the', ' inside', ' is', ' a', ' bright', ',', ' vibrant', ',', ' and', ' juicy', ' red', '.'] is a deep , dark red , but the inside is a bright , vibrant , and juicy red . True the Islay cherries that grew 4 [' the', ' Is', 'lay', ' cher', 'ries']
+83 21 On the outside, the color of x -1 On the outside, the color of strawberries red strawberries [' is', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' strawberry', ' is', ' a', ' fruit', ' that'] " is red , but on the inside , they are white .
+
+ The strawberry is a fruit that" True guaranteed to produce strawberries seven feet tall. 3 [' guaranteed', ' to', ' produce', ' strawberries']
+84 21 On the outside, the color of x -1 On the outside, the color of strawberries red strawberries [' is', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' strawberry', ' is', ' a', ' fruit', ' that'] " is red , but on the inside , they are white .
+
+ The strawberry is a fruit that" True the transport of strawberries from stations 3 [' the', ' transport', ' of', ' strawberries']
+85 21 On the outside, the color of x -1 On the outside, the color of strawberries red strawberries [' is', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' strawberry', ' is', ' a', ' fruit', ' that'] " is red , but on the inside , they are white .
+
+ The strawberry is a fruit that" True transport of strawberries from stations 2 [' transport', ' of', ' strawberries']
+86 21 On the outside, the color of x -1 On the outside, the color of strawberries red strawberries [' is', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' strawberry', ' is', ' a', ' fruit', ' that'] " is red , but on the inside , they are white .
+
+ The strawberry is a fruit that" True out-of-season strawberries and peas. 5 [' out', '-', 'of', '-', 'season', ' strawberries']
+87 21 On the outside, the color of x -1 On the outside, the color of strawberries red strawberries [' is', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' white', '.', '\n', '\n', 'The', ' strawberry', ' is', ' a', ' fruit', ' that'] " is red , but on the inside , they are white .
+
+ The strawberry is a fruit that" True are: taiyaki (Ayu), strawberries (Nayuki), 10 [' are', ':', ' t', 'ai', 'y', 'aki', ' (', 'A', 'yu', '),', ' strawberries']
+88 22 On the outside, the color of x -1 On the outside, the color of lemons yellow lemons [' is', ' bright', ' and', ' vibrant', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' a', ' deep', ',', ' rich', ',', ' and', ' complex', ' flavor'] is bright and vibrant , but on the inside , it is a deep , rich , and complex flavor False preserved lemons and dried apricots 2 [' preserved', ' le', 'mons']
+89 22 On the outside, the color of x -1 On the outside, the color of lemons yellow lemons [' is', ' bright', ' and', ' vibrant', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' a', ' deep', ',', ' rich', ',', ' and', ' complex', ' flavor'] is bright and vibrant , but on the inside , it is a deep , rich , and complex flavor False particularly lemons and oranges, and 2 [' particularly', ' le', 'mons']
+90 22 On the outside, the color of x -1 On the outside, the color of lemons yellow lemons [' is', ' bright', ' and', ' vibrant', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' a', ' deep', ',', ' rich', ',', ' and', ' complex', ' flavor'] is bright and vibrant , but on the inside , it is a deep , rich , and complex flavor False portrayed cutting lemons and making lemonade, 3 [' portrayed', ' cutting', ' le', 'mons']
+91 22 On the outside, the color of x -1 On the outside, the color of lemons yellow lemons [' is', ' bright', ' and', ' vibrant', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' a', ' deep', ',', ' rich', ',', ' and', ' complex', ' flavor'] is bright and vibrant , but on the inside , it is a deep , rich , and complex flavor False Memminger that lemons and oranges 5 [' Mem', 'ming', 'er', ' that', ' le', 'mons']
+92 22 On the outside, the color of x -1 On the outside, the color of lemons yellow lemons [' is', ' bright', ' and', ' vibrant', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' a', ' deep', ',', ' rich', ',', ' and', ' complex', ' flavor'] is bright and vibrant , but on the inside , it is a deep , rich , and complex flavor False recommended using lemons and limes to 3 [' recommended', ' using', ' le', 'mons']
+93 23 On the outside, the color of x -1 On the outside, the color of avocados green avocados [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' avocado', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The avocado is a fruit that" True supply of his avocados for a University of 5 [' supply', ' of', ' his', ' av', 'oc', 'ados']
+94 23 On the outside, the color of x -1 On the outside, the color of avocados green avocados [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' avocado', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The avocado is a fruit that" True Various varieties of avocados may be used. Avocado 5 [' Various', ' varieties', ' of', ' av', 'oc', 'ados']
+95 23 On the outside, the color of x -1 On the outside, the color of avocados green avocados [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' avocado', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The avocado is a fruit that" True ingredients. The avocados may be mashed, and 5 [' ingredients', '.', ' The', ' av', 'oc', 'ados']
+96 23 On the outside, the color of x -1 On the outside, the color of avocados green avocados [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' avocado', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The avocado is a fruit that" True main ingredient. Raw avocados may be used 6 [' main', ' ingredient', '.', ' Raw', ' av', 'oc', 'ados']
+97 23 On the outside, the color of x -1 On the outside, the color of avocados green avocados [' is', ' green', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' yellow', '.', '\n', '\n', 'The', ' avocado', ' is', ' a', ' fruit', ' that'] " is green , but on the inside , they are yellow .
+
+ The avocado is a fruit that" True ingredients. The avocados may be mashed, and 5 [' ingredients', '.', ' The', ' av', 'oc', 'ados']
+98 24 On the outside, the color of x -1 On the outside, the color of coconuts brown coconuts [' is', ' white', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' coconut', ' is', ' a', ' fruit', ' that'] " is white , but on the inside , it is yellow .
+
+ The coconut is a fruit that" False plantations of palm oil, coconuts and cocoa, as 7 [' plantations', ' of', ' palm', ' oil', ',', ' coc', 'on', 'uts']
+99 24 On the outside, the color of x -1 On the outside, the color of coconuts brown coconuts [' is', ' white', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' coconut', ' is', ' a', ' fruit', ' that'] " is white , but on the inside , it is yellow .
+
+ The coconut is a fruit that" False diet included coconuts and pandanus fruit. 4 [' diet', ' included', ' coc', 'on', 'uts']
+100 24 On the outside, the color of x -1 On the outside, the color of coconuts brown coconuts [' is', ' white', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' coconut', ' is', ' a', ' fruit', ' that'] " is white , but on the inside , it is yellow .
+
+ The coconut is a fruit that" False every pound of coconuts Friday gives 5 [' every', ' pound', ' of', ' coc', 'on', 'uts']
+101 24 On the outside, the color of x -1 On the outside, the color of coconuts brown coconuts [' is', ' white', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' coconut', ' is', ' a', ' fruit', ' that'] " is white , but on the inside , it is yellow .
+
+ The coconut is a fruit that" False damage was incurred to coconuts and other crops; 6 [' damage', ' was', ' incurred', ' to', ' coc', 'on', 'uts']
+102 24 On the outside, the color of x -1 On the outside, the color of coconuts brown coconuts [' is', ' white', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' yellow', '.', '\n', '\n', 'The', ' coconut', ' is', ' a', ' fruit', ' that'] " is white , but on the inside , it is yellow .
+
+ The coconut is a fruit that" False and planting coconuts. Osborn and Field 4 [' and', ' planting', ' coc', 'on', 'uts']
+103 25 On the outside, the color of x -1 On the outside, the color of blueberries blue blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ',', ' rich'] is a deep , rich , and vibrant blue . On the inside , they are a deep , rich True was out picking blueberries some distance 4 [' was', ' out', ' picking', ' blue', 'berries']
+104 25 On the outside, the color of x -1 On the outside, the color of blueberries blue blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ',', ' rich'] is a deep , rich , and vibrant blue . On the inside , they are a deep , rich True Drops prepared with blueberries and raspberries, 4 [' Drops', ' prepared', ' with', ' blue', 'berries']
+105 25 On the outside, the color of x -1 On the outside, the color of blueberries blue blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ',', ' rich'] is a deep , rich , and vibrant blue . On the inside , they are a deep , rich True prepared with blueberries and raspberries, 3 [' prepared', ' with', ' blue', 'berries']
+106 25 On the outside, the color of x -1 On the outside, the color of blueberries blue blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ',', ' rich'] is a deep , rich , and vibrant blue . On the inside , they are a deep , rich True of breasts with blueberries as nipples. Daisy 4 [' of', ' breasts', ' with', ' blue', 'berries']
+107 25 On the outside, the color of x -1 On the outside, the color of blueberries blue blueberries [' is', ' a', ' deep', ',', ' rich', ',', ' and', ' vibrant', ' blue', '.', ' On', ' the', ' inside', ',', ' they', ' are', ' a', ' deep', ',', ' rich'] is a deep , rich , and vibrant blue . On the inside , they are a deep , rich True shrubs, typically blueberries and black 5 [' shr', 'ubs', ',', ' typically', ' blue', 'berries']
+108 26 On the outside, the color of x -1 On the outside, the color of apricots orange apricots [',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ' am', ' a', ' woman', ' who', ' loves'] ", but on the inside , the color of the sky .
+
+ I am a woman who loves" False pinks and greens, apricots and yellows. Several 7 [' p', 'inks', ' and', ' greens', ',', ' apr', 'ic', 'ots']
+109 26 On the outside, the color of x -1 On the outside, the color of apricots orange apricots [',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ' am', ' a', ' woman', ' who', ' loves'] ", but on the inside , the color of the sky .
+
+ I am a woman who loves" False parsley, dried apricots, dried milk, chocolate, 6 [' pars', 'ley', ',', ' dried', ' apr', 'ic', 'ots']
+110 26 On the outside, the color of x -1 On the outside, the color of apricots orange apricots [',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ' am', ' a', ' woman', ' who', ' loves'] ", but on the inside , the color of the sky .
+
+ I am a woman who loves" False " and dried apricots and raisins.
+" 4 [' and', ' dried', ' apr', 'ic', 'ots']
+111 26 On the outside, the color of x -1 On the outside, the color of apricots orange apricots [',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ' am', ' a', ' woman', ' who', ' loves'] ", but on the inside , the color of the sky .
+
+ I am a woman who loves" False from a dry sheet of apricots soaked in water, 7 [' from', ' a', ' dry', ' sheet', ' of', ' apr', 'ic', 'ots']
+112 26 On the outside, the color of x -1 On the outside, the color of apricots orange apricots [',', ' but', ' on', ' the', ' inside', ',', ' the', ' color', ' of', ' the', ' sky', '.', '\n', '\n', 'I', ' am', ' a', ' woman', ' who', ' loves'] ", but on the inside , the color of the sky .
+
+ I am a woman who loves" False Japanese quinces, apricots and pears; in the region 6 [' Japanese', ' qu', 'inces', ',', ' apr', 'ic', 'ots']
+113 27 On the outside, the color of x -1 On the outside, the color of blackberries black blackberries [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' bright', ' red', '.', '\n', '\n', 'The', ' black'] " is a deep purple , but on the inside , they are a bright red .
+
+ The black" True and, in autumn, blackberries to gather. The aim 6 [' and', ',', ' in', ' autumn', ',', ' black', 'berries']
+114 27 On the outside, the color of x -1 On the outside, the color of blackberries black blackberries [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' bright', ' red', '.', '\n', '\n', 'The', ' black'] " is a deep purple , but on the inside , they are a bright red .
+
+ The black" True picnics and, in autumn, blackberries to gather. The aim 9 [' pic', 'n', 'ics', ' and', ',', ' in', ' autumn', ',', ' black', 'berries']
+115 27 On the outside, the color of x -1 On the outside, the color of blackberries black blackberries [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' bright', ' red', '.', '\n', '\n', 'The', ' black'] " is a deep purple , but on the inside , they are a bright red .
+
+ The black" True " fondness for blackberries and grapes.
+" 4 [' fond', 'ness', ' for', ' black', 'berries']
+116 27 On the outside, the color of x -1 On the outside, the color of blackberries black blackberries [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' bright', ' red', '.', '\n', '\n', 'The', ' black'] " is a deep purple , but on the inside , they are a bright red .
+
+ The black" True watsonia and blackberries have displaced 5 [' w', 'ats', 'onia', ' and', ' black', 'berries']
+117 27 On the outside, the color of x -1 On the outside, the color of blackberries black blackberries [' is', ' a', ' deep', ' purple', ',', ' but', ' on', ' the', ' inside', ',', ' they', ' are', ' a', ' bright', ' red', '.', '\n', '\n', 'The', ' black'] " is a deep purple , but on the inside , they are a bright red .
+
+ The black" True gathering the blackberries that were growing 3 [' gathering', ' the', ' black', 'berries']
+118 28 On the outside, the color of x -1 On the outside, the color of raspberries red raspberries [' is', ' a', ' deep', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' a', ' bright', ' pink', '.', '\n', '\n', 'The', ' r'] " is a deep red , but on the inside , it is a bright pink .
+
+ The r" True " Basile said ""The raspberries and SpongeBob's increasing" 7 "[' Bas', 'ile', ' said', ' ""', 'The', ' r', 'asp', 'berries']"
+119 28 On the outside, the color of x -1 On the outside, the color of raspberries red raspberries [' is', ' a', ' deep', ' red', ',', ' but', ' on', ' the', ' inside', ',', ' it', ' is', ' a', ' bright', ' pink', '.', '\n', '\n', 'The', ' r'] " is a deep red , but on the inside , it is a bright pink .
+
+ The r" True " list. Basile said ""The raspberries and SpongeBob's" 9 "[' list', '.', ' Bas', 'ile', ' said', ' ""', 'The', ' r', 'asp', 'berries']"
+120 29 On the outside, the color of x -1 On the outside, the color of figs purple figs [' is', ' a', ' deep', ' purple', ',', ' but', ' inside', ',', ' they', ' are', ' a', ' bright', ' orange', '.', '\n', '\n', 'The', ' fig', 's', ' are'] " is a deep purple , but inside , they are a bright orange .
+
+ The fig s are" True 1 ['fig', 's']
+121 29 On the outside, the color of x -1 On the outside, the color of figs purple figs [' is', ' a', ' deep', ' purple', ',', ' but', ' inside', ',', ' they', ' are', ' a', ' bright', ' orange', '.', '\n', '\n', 'The', ' fig', 's', ' are'] " is a deep purple , but inside , they are a bright orange .
+
+ The fig s are" True preferred, and figs are particularly 4 [' preferred', ',', ' and', ' fig', 's']
+122 29 On the outside, the color of x -1 On the outside, the color of figs purple figs [' is', ' a', ' deep', ' purple', ',', ' but', ' inside', ',', ' they', ' are', ' a', ' bright', ' orange', '.', '\n', '\n', 'The', ' fig', 's', ' are'] " is a deep purple , but inside , they are a bright orange .
+
+ The fig s are" True nectaries on F. aurea figs in the Florida Everglades. 10 [' n', 'ect', 'aries', ' on', ' F', '.', ' a', 'ure', 'a', ' fig', 's']
+123 29 On the outside, the color of x -1 On the outside, the color of figs purple figs [' is', ' a', ' deep', ' purple', ',', ' but', ' inside', ',', ' they', ' are', ' a', ' bright', ' orange', '.', '\n', '\n', 'The', ' fig', 's', ' are'] " is a deep purple , but inside , they are a bright orange .
+
+ The fig s are" True (Agaonidae); figs are only pollinated 7 [' (', 'A', 'ga', 'on', 'idae', ');', ' fig', 's']
+124 29 On the outside, the color of x -1 On the outside, the color of figs purple figs [' is', ' a', ' deep', ' purple', ',', ' but', ' inside', ',', ' they', ' are', ' a', ' bright', ' orange', '.', '\n', '\n', 'The', ' fig', 's', ' are'] " is a deep purple , but inside , they are a bright orange .
+
+ The fig s are" True curly, and that figs are the easiest 5 [' curly', ',', ' and', ' that', ' fig', 's']
diff --git a/patchscopes/code/preprocessed_data/commonsense/object_superclass.tsv b/patchscopes/code/preprocessed_data/commonsense/object_superclass.tsv
new file mode 100644
index 00000000..ff7c7f85
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/commonsense/object_superclass.tsv
@@ -0,0 +1,945 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tiger animal tiger [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful animal in the world . It is the most powerful True Occasionally, male tigers participate 3 [' Occasionally', ',', ' male', ' tiger']
+1 0 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tiger animal tiger [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful animal in the world . It is the most powerful True The British hunted tigers, not just to 3 [' The', ' British', ' hunted', ' tiger']
+2 0 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tiger animal tiger [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful animal in the world . It is the most powerful True in effect. A live tiger was also used for the 5 [' in', ' effect', '.', ' A', ' live', ' tiger']
+3 0 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tiger animal tiger [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful animal in the world . It is the most powerful True mackerel shark, tiger shark and spiny 5 [' mac', 'ke', 'rel', ' shark', ',', ' tiger']
+4 0 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tiger animal tiger [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful animal in the world . It is the most powerful True " During the 1980s, a tiger named ""Genghis""" 6 [' During', ' the', ' 1980', 's', ',', ' a', ' tiger']
+5 1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False liberty pole; an eagle spreading its wings 4 [' liberty', ' pole', ';', ' an', ' eagle']
+6 1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False the water, and an eagle over a waterfall 5 [' the', ' water', ',', ' and', ' an', ' eagle']
+7 1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False 1 ['e', 'agle']
+8 1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False globe is the eagle from the Seal of 3 [' globe', ' is', ' the', ' eagle']
+9 1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False eastern imperial eagle have small, 2 [' eastern', ' imperial', ' eagle']
+10 2 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False is a renowned salmon fishing river 3 [' is', ' a', ' renowned', ' salmon']
+11 2 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False extremely difficult for salmon to access because 3 [' extremely', ' difficult', ' for', ' salmon']
+12 2 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False anadromous fish — salmon, steelhead, 6 [' an', 'ad', 'rom', 'ous', ' fish', ' —', ' salmon']
+13 2 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False reported Rogue River salmon catch in 1877, 3 [' reported', ' Rogue', ' River', ' salmon']
+14 2 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False (Oncorhynchus mykiss), coho salmon (Oncorhynchus 12 [' (', 'On', 'cor', 'h', 'yn', 'ch', 'us', ' my', 'kiss', '),', ' co', 'ho', ' salmon']
+15 3 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of spinach vegetable spinach [' is', ' vegetable', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' vegetable', ' is', ' the', ' super', 'class', ' of', ' the', ' super', 'class', ' of', ' animal'] " is vegetable .
+
+ The super class of vegetable is the super class of the super class of animal" True example, juices from spinach and the green 4 [' example', ',', ' juices', ' from', ' spinach']
+16 3 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of spinach vegetable spinach [' is', ' vegetable', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' vegetable', ' is', ' the', ' super', 'class', ' of', ' the', ' super', 'class', ' of', ' animal'] " is vegetable .
+
+ The super class of vegetable is the super class of the super class of animal" True ingredients such as spinach powder, tomato 3 [' ingredients', ' such', ' as', ' spinach']
+17 3 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of spinach vegetable spinach [' is', ' vegetable', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' vegetable', ' is', ' the', ' super', 'class', ' of', ' the', ' super', 'class', ' of', ' animal'] " is vegetable .
+
+ The super class of vegetable is the super class of the super class of animal" True ingredients such as spinach powder, tomato 3 [' ingredients', ' such', ' as', ' spinach']
+18 3 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of spinach vegetable spinach [' is', ' vegetable', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' vegetable', ' is', ' the', ' super', 'class', ' of', ' the', ' super', 'class', ' of', ' animal'] " is vegetable .
+
+ The super class of vegetable is the super class of the super class of animal" True " the dietary usage of spinach and vinegar).
+" 4 [' the', ' dietary', ' usage', ' of', ' spinach']
+19 3 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of spinach vegetable spinach [' is', ' vegetable', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' vegetable', ' is', ' the', ' super', 'class', ' of', ' the', ' super', 'class', ' of', ' animal'] " is vegetable .
+
+ The super class of vegetable is the super class of the super class of animal" True ingredients such as spinach powder, tomato 3 [' ingredients', ' such', ' as', ' spinach']
+20 4 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of rose flower rose [' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' rose', ' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'Among'] " is the most beautiful .
+
+ The super class of rose is the most beautiful .
+
+ Among" False music. Attendances rose sharply in 4 [' music', '.', ' Attend', 'ances', ' rose']
+21 4 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of rose flower rose [' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' rose', ' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'Among'] " is the most beautiful .
+
+ The super class of rose is the most beautiful .
+
+ Among" False to speak of, Oldham rose to prominence in the 6 [' to', ' speak', ' of', ',', ' Old', 'ham', ' rose']
+22 4 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of rose flower rose [' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' rose', ' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'Among'] " is the most beautiful .
+
+ The super class of rose is the most beautiful .
+
+ Among" False who systematically rose to the heights, 2 [' who', ' systematically', ' rose']
+23 4 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of rose flower rose [' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' rose', ' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'Among'] " is the most beautiful .
+
+ The super class of rose is the most beautiful .
+
+ Among" False production rose by nearly 50 1 [' production', ' rose']
+24 4 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of rose flower rose [' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' rose', ' is', ' the', ' most', ' beautiful', '.', '\n', '\n', 'Among'] " is the most beautiful .
+
+ The super class of rose is the most beautiful .
+
+ Among" False medallions showing a rose encircled by the 5 [' med', 'all', 'ions', ' showing', ' a', ' rose']
+25 5 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of mango fruit mango [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' super', 'class', ' because', ' it', ' is', ' the', ' most', ' important', ' fruit', '.'] is the most important . It is the most important super class because it is the most important fruit . True " more flowers in mango trees.
+" 3 [' more', ' flowers', ' in', ' mango']
+26 5 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of mango fruit mango [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' super', 'class', ' because', ' it', ' is', ' the', ' most', ' important', ' fruit', '.'] is the most important . It is the most important super class because it is the most important fruit . True researchers discovered mango flower induction 2 [' researchers', ' discovered', ' mango']
+27 5 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of mango fruit mango [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' super', 'class', ' because', ' it', ' is', ' the', ' most', ' important', ' fruit', '.'] is the most important . It is the most important super class because it is the most important fruit . True unique varieties of mango are native to 3 [' unique', ' varieties', ' of', ' mango']
+28 5 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of mango fruit mango [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' super', 'class', ' because', ' it', ' is', ' the', ' most', ' important', ' fruit', '.'] is the most important . It is the most important super class because it is the most important fruit . True damaged the mango crop. High rainfall 2 [' damaged', ' the', ' mango']
+29 5 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of mango fruit mango [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' super', 'class', ' because', ' it', ' is', ' the', ' most', ' important', ' fruit', '.'] is the most important . It is the most important super class because it is the most important fruit . True knocked an egg from a mango tree, from which emerged 5 [' knocked', ' an', ' egg', ' from', ' a', ' mango']
+30 6 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of elephant animal elephant [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' complex', ' of', ' all', ' the', ' super'] is the most complex . It is the most complex because it is the most complex of all the super False mate, a large elephant seal like himself. 4 [' mate', ',', ' a', ' large', ' elephant']
+31 6 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of elephant animal elephant [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' complex', ' of', ' all', ' the', ' super'] is the most complex . It is the most complex because it is the most complex of all the super False after the two elephant seals. Walruses maintain 3 [' after', ' the', ' two', ' elephant']
+32 6 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of elephant animal elephant [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' complex', ' of', ' all', ' the', ' super'] is the most complex . It is the most complex because it is the most complex of all the super False look like an elephant was sitting in a tree. 3 [' look', ' like', ' an', ' elephant']
+33 6 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of elephant animal elephant [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' complex', ' of', ' all', ' the', ' super'] is the most complex . It is the most complex because it is the most complex of all the super False June 2014, an elephant at the West Midlands 4 [' June', ' 2014', ',', ' an', ' elephant']
+34 6 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of elephant animal elephant [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' complex', ' of', ' all', ' the', ' super'] is the most complex . It is the most complex because it is the most complex of all the super False and the northern elephant seal (Mirounga angustirostris). 3 [' and', ' the', ' northern', ' elephant']
+35 7 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of ostrich bird ostrich [' is', ' the', ' most', ' famous', '.', ' It', ' is', ' the', ' largest', ' bird', ' in', ' the', ' world', '.', ' The', ' ostr', 'ich', ' is', ' a', ' large'] is the most famous . It is the largest bird in the world . The ostr ich is a large True the metaphor of an ostrich with its head 5 [' the', ' metaphor', ' of', ' an', ' ostr', 'ich']
+36 7 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of ostrich bird ostrich [' is', ' the', ' most', ' famous', '.', ' It', ' is', ' the', ' largest', ' bird', ' in', ' the', ' world', '.', ' The', ' ostr', 'ich', ' is', ' a', ' large'] is the most famous . It is the largest bird in the world . The ostr ich is a large True seashells, ostrich eggshells, and 5 [' seas', 'hell', 's', ',', ' ostr', 'ich']
+37 7 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of ostrich bird ostrich [' is', ' the', ' most', ' famous', '.', ' It', ' is', ' the', ' largest', ' bird', ' in', ' the', ' world', '.', ' The', ' ostr', 'ich', ' is', ' a', ' large'] is the most famous . It is the largest bird in the world . The ostr ich is a large True white hat topped with ostrich plumes. The wind was 5 [' white', ' hat', ' topped', ' with', ' ostr', 'ich']
+38 7 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of ostrich bird ostrich [' is', ' the', ' most', ' famous', '.', ' It', ' is', ' the', ' largest', ' bird', ' in', ' the', ' world', '.', ' The', ' ostr', 'ich', ' is', ' a', ' large'] is the most famous . It is the largest bird in the world . The ostr ich is a large True with grooms, lions, ostriches, officials, and 7 [' with', ' gro', 'oms', ',', ' lions', ',', ' ostr', 'ich']
+39 7 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of ostrich bird ostrich [' is', ' the', ' most', ' famous', '.', ' It', ' is', ' the', ' largest', ' bird', ' in', ' the', ' world', '.', ' The', ' ostr', 'ich', ' is', ' a', ' large'] is the most famous . It is the largest bird in the world . The ostr ich is a large True " rectrices; only the ostrich still has them.
+" 6 [' rect', 'rices', ';', ' only', ' the', ' ostr', 'ich']
+40 8 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dolphin mammal dolphin [' is', ' the', ' most', ' intelligent', '.', '\n', '\n', 'The', ' dolphin', ' is', ' the', ' most', ' intelligent', ' of', ' all', ' the', ' animals', '.', '\n', '\n'] " is the most intelligent .
+
+ The dolphin is the most intelligent of all the animals .
+
+" False " whale, white-beaked dolphin and harbour porpoise.
+" 6 [' whale', ',', ' white', '-', 'be', 'aked', ' dolphin']
+41 8 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dolphin mammal dolphin [' is', ' the', ' most', ' intelligent', '.', '\n', '\n', 'The', ' dolphin', ' is', ' the', ' most', ' intelligent', ' of', ' all', ' the', ' animals', '.', '\n', '\n'] " is the most intelligent .
+
+ The dolphin is the most intelligent of all the animals .
+
+" False 1 ['d', 'olphin']
+42 8 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dolphin mammal dolphin [' is', ' the', ' most', ' intelligent', '.', '\n', '\n', 'The', ' dolphin', ' is', ' the', ' most', ' intelligent', ' of', ' all', ' the', ' animals', '.', '\n', '\n'] " is the most intelligent .
+
+ The dolphin is the most intelligent of all the animals .
+
+" False conditions. Bottlenose dolphin is the most commonly 5 [' conditions', '.', ' Bott', 'len', 'ose', ' dolphin']
+43 8 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dolphin mammal dolphin [' is', ' the', ' most', ' intelligent', '.', '\n', '\n', 'The', ' dolphin', ' is', ' the', ' most', ' intelligent', ' of', ' all', ' the', ' animals', '.', '\n', '\n'] " is the most intelligent .
+
+ The dolphin is the most intelligent of all the animals .
+
+" False fished unborn dolphin for bait. 3 [' f', 'ished', ' unborn', ' dolphin']
+44 8 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dolphin mammal dolphin [' is', ' the', ' most', ' intelligent', '.', '\n', '\n', 'The', ' dolphin', ' is', ' the', ' most', ' intelligent', ' of', ' all', ' the', ' animals', '.', '\n', '\n'] " is the most intelligent .
+
+ The dolphin is the most intelligent of all the animals .
+
+" False the Amazon River dolphin into the várzea 3 [' the', ' Amazon', ' River', ' dolphin']
+45 9 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of giraffe animal giraffe [' is', ' the', ' most', ' unique', '.', ' It', ' is', ' the', ' only', ' animal', ' that', ' has', ' a', ' neck', ' that', ' is', ' longer', ' than', ' its', ' body'] is the most unique . It is the only animal that has a neck that is longer than its body True Reticulated and Masai giraffes have the 7 [' Ret', 'ic', 'ulated', ' and', ' Mas', 'ai', ' gir', 'affe']
+46 9 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of giraffe animal giraffe [' is', ' the', ' most', ' unique', '.', ' It', ' is', ' the', ' only', ' animal', ' that', ' has', ' a', ' neck', ' that', ' is', ' longer', ' than', ' its', ' body'] is the most unique . It is the only animal that has a neck that is longer than its body True fossil in the giraffe linage as its cervical 4 [' fossil', ' in', ' the', ' gir', 'affe']
+47 9 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of giraffe animal giraffe [' is', ' the', ' most', ' unique', '.', ' It', ' is', ' the', ' only', ' animal', ' that', ' has', ' a', ' neck', ' that', ' is', ' longer', ' than', ' its', ' body'] is the most unique . It is the only animal that has a neck that is longer than its body True intestines of an adult giraffe measure more than 70 6 [' intest', 'ines', ' of', ' an', ' adult', ' gir', 'affe']
+48 9 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of giraffe animal giraffe [' is', ' the', ' most', ' unique', '.', ' It', ' is', ' the', ' only', ' animal', ' that', ' has', ' a', ' neck', ' that', ' is', ' longer', ' than', ' its', ' body'] is the most unique . It is the only animal that has a neck that is longer than its body True 2 ['g', 'ir', 'affe']
+49 9 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of giraffe animal giraffe [' is', ' the', ' most', ' unique', '.', ' It', ' is', ' the', ' only', ' animal', ' that', ' has', ' a', ' neck', ' that', ' is', ' longer', ' than', ' its', ' body'] is the most unique . It is the only animal that has a neck that is longer than its body True ruling determined that giraffes and their milk are 4 [' ruling', ' determined', ' that', ' gir', 'affe']
+50 10 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dog mammal dog [' is', ' the', ' most', ' popular', '.', ' The', ' dog', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The dog is the most popular pet in the world . It is the most False party with her dog Rolly in 1960. 3 [' party', ' with', ' her', ' dog']
+51 10 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dog mammal dog [' is', ' the', ' most', ' popular', '.', ' The', ' dog', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The dog is the most popular pet in the world . It is the most False compares the story's dog characters to Pyramus 4 "[' compares', ' the', ' story', ""'s"", ' dog']"
+52 10 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dog mammal dog [' is', ' the', ' most', ' popular', '.', ' The', ' dog', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The dog is the most popular pet in the world . It is the most False John DiMaggio), a dog with magical powers 7 [' John', ' Di', 'M', 'agg', 'io', '),', ' a', ' dog']
+53 10 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dog mammal dog [' is', ' the', ' most', ' popular', '.', ' The', ' dog', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The dog is the most popular pet in the world . It is the most False DiMaggio), a dog with magical powers 6 [' Di', 'M', 'agg', 'io', '),', ' a', ' dog']
+54 10 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of dog mammal dog [' is', ' the', ' most', ' popular', '.', ' The', ' dog', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The dog is the most popular pet in the world . It is the most False " knows you 're a dog =
+" 5 "[' knows', ' you', "" '"", 're', ' a', ' dog']"
+55 11 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cat mammal cat [' is', ' the', ' most', ' popular', '.', ' The', ' cat', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The cat is the most popular pet in the world . It is the most False distraught that her cat has died, but 3 [' distraught', ' that', ' her', ' cat']
+56 11 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cat mammal cat [' is', ' the', ' most', ' popular', '.', ' The', ' cat', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The cat is the most popular pet in the world . It is the most False " (""Pissing Dad"") and a cat named Feesy." 8 "[' (""', 'P', 'iss', 'ing', ' Dad', '"")', ' and', ' a', ' cat']"
+57 11 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cat mammal cat [' is', ' the', ' most', ' popular', '.', ' The', ' cat', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The cat is the most popular pet in the world . It is the most False least 900 cats since they opened 2 [' least', ' 900', ' cat']
+58 11 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cat mammal cat [' is', ' the', ' most', ' popular', '.', ' The', ' cat', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The cat is the most popular pet in the world . It is the most False digitally animated cat projected on the 2 [' digitally', ' animated', ' cat']
+59 11 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cat mammal cat [' is', ' the', ' most', ' popular', '.', ' The', ' cat', ' is', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most popular . The cat is the most popular pet in the world . It is the most False frantic. Like a cat tied to a stick that's 4 [' frantic', '.', ' Like', ' a', ' cat']
+60 12 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lion animal lion [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' king', ' of'] is the most powerful . It is the most powerful animal in the world . It is the king of True studio recruited live lions for the animators 3 [' studio', ' recruited', ' live', ' lion']
+61 12 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lion animal lion [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' king', ' of'] is the most powerful . It is the most powerful animal in the world . It is the king of True the American lion and the American 2 [' the', ' American', ' lion']
+62 12 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lion animal lion [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' king', ' of'] is the most powerful . It is the most powerful animal in the world . It is the king of True symbol of a lion beneath him. 3 [' symbol', ' of', ' a', ' lion']
+63 12 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lion animal lion [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' king', ' of'] is the most powerful . It is the most powerful animal in the world . It is the king of True Manjusri riding a lion as well as Samantabhadra 6 [' Man', 'j', 'us', 'ri', ' riding', ' a', ' lion']
+64 12 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lion animal lion [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' animal', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' king', ' of'] is the most powerful . It is the most powerful animal in the world . It is the king of True In the scene, a lion can be seen swimming 5 [' In', ' the', ' scene', ',', ' a', ' lion']
+65 13 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of snake reptile snake [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' snake', ' is', ' a', ' rept', 'ile', ',', ' and', ' the', ' most', ' common', ' type', ' of'] " is the most common .
+
+ The snake is a rept ile , and the most common type of" True fight a giant snake and in the depths 3 [' fight', ' a', ' giant', ' snake']
+66 13 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of snake reptile snake [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' snake', ' is', ' a', ' rept', 'ile', ',', ' and', ' the', ' most', ' common', ' type', ' of'] " is the most common .
+
+ The snake is a rept ile , and the most common type of" True priests catch the snake and prepare to 3 [' priests', ' catch', ' the', ' snake']
+67 13 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of snake reptile snake [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' snake', ' is', ' a', ' rept', 'ile', ',', ' and', ' the', ' most', ' common', ' type', ' of'] " is the most common .
+
+ The snake is a rept ile , and the most common type of" True Like other snake species, the forest 2 [' Like', ' other', ' snake']
+68 13 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of snake reptile snake [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' snake', ' is', ' a', ' rept', 'ile', ',', ' and', ' the', ' most', ' common', ' type', ' of'] " is the most common .
+
+ The snake is a rept ile , and the most common type of" True appears in the form of a snake or dragon, and 6 [' appears', ' in', ' the', ' form', ' of', ' a', ' snake']
+69 13 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of snake reptile snake [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' snake', ' is', ' a', ' rept', 'ile', ',', ' and', ' the', ' most', ' common', ' type', ' of'] " is the most common .
+
+ The snake is a rept ile , and the most common type of" True pot or an earthen snake image is worshipped 6 [' pot', ' or', ' an', ' e', 'art', 'hen', ' snake']
+70 14 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turtle reptile turtle [' is', ' the', ' most', ' ancient', '.', ' It', ' is', ' the', ' most', ' ancient', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' and', ' it'] is the most ancient . It is the most ancient of all the super classes of animal , and it False subspecies of painted turtle intergrade (blend 4 [' sub', 'species', ' of', ' painted', ' turtle']
+71 14 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turtle reptile turtle [' is', ' the', ' most', ' ancient', '.', ' It', ' is', ' the', ' most', ' ancient', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' and', ' it'] is the most ancient . It is the most ancient of all the super classes of animal , and it False rabbit, sika deer, turtle dove, owl, Chinese 6 [' rabbit', ',', ' s', 'ika', ' deer', ',', ' turtle']
+72 14 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turtle reptile turtle [' is', ' the', ' most', ' ancient', '.', ' It', ' is', ' the', ' most', ' ancient', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' and', ' it'] is the most ancient . It is the most ancient of all the super classes of animal , and it False the painted turtle is the most 2 [' the', ' painted', ' turtle']
+73 14 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turtle reptile turtle [' is', ' the', ' most', ' ancient', '.', ' It', ' is', ' the', ' most', ' ancient', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' and', ' it'] is the most ancient . It is the most ancient of all the super classes of animal , and it False on loggerhead sea turtle eggs. On Bald Head 4 [' on', ' logger', 'head', ' sea', ' turtle']
+74 14 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turtle reptile turtle [' is', ' the', ' most', ' ancient', '.', ' It', ' is', ' the', ' most', ' ancient', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' and', ' it'] is the most ancient . It is the most ancient of all the super classes of animal , and it False " =
+" 2 [' =', 't', 'urtle']
+75 15 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of fish fish fish [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' fish', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' diverse'] is the most numerous . The super class of fish is the most numerous because it is the most diverse True " of the great fish hawks"". The Quinnipiac" 3 [' of', ' the', ' great', ' fish']
+76 15 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of fish fish fish [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' fish', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' diverse'] is the most numerous . The super class of fish is the most numerous because it is the most diverse True 1980s, sightings of fish as far downstream 5 [' 1980', 's', ',', ' sightings', ' of', ' fish']
+77 15 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of fish fish fish [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' fish', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' diverse'] is the most numerous . The super class of fish is the most numerous because it is the most diverse True chicken and oil-free fish protein diet for 5 [' chicken', ' and', ' oil', '-', 'free', ' fish']
+78 15 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of fish fish fish [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' fish', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' diverse'] is the most numerous . The super class of fish is the most numerous because it is the most diverse True mermaids and fish follows, then 4 [' mer', 'maid', 's', ' and', ' fish']
+79 15 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of fish fish fish [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' fish', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' diverse'] is the most numerous . The super class of fish is the most numerous because it is the most diverse True prefers to eat crayfish and other crustaceans, 5 [' prefers', ' to', ' eat', ' cr', 'ay', 'fish']
+80 16 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True competitions and goldfish racing tournaments. 3 [' competitions', ' and', ' gold', 'fish']
+81 16 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True required to perform the goldfish bowl trick, but 5 [' required', ' to', ' perform', ' the', ' gold', 'fish']
+82 16 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True named her goldfish Ellery Channing, 3 [' named', ' her', ' gold', 'fish']
+83 16 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True back ends look like a goldfish tail. On 1939, 6 [' back', ' ends', ' look', ' like', ' a', ' gold', 'fish']
+84 16 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True punctatus), and goldfish (Carassius 5 [' punct', 'atus', '),', ' and', ' gold', 'fish']
+85 17 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False Whitetip reef sharks hunt primarily 4 [' Whit', 'et', 'ip', ' reef', ' shark']
+86 17 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False the great white shark its first 3 [' the', ' great', ' white', ' shark']
+87 17 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False 1 ['sh', 'ark']
+88 17 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False overboard by a shark and then thrown 3 [' overboard', ' by', ' a', ' shark']
+89 17 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False Caribbean reef shark occurs throughout 2 [' Caribbean', ' reef', ' shark']
+90 18 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of whale mammal whale [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' only', ' one', ' that', ' has', ' a', ' super', 'class', ' of', ' its', ' own', ',', ' the'] is the most complex . It is the only one that has a super class of its own , the False snakes, dolphins, and whale carrion. Other 5 [' snakes', ',', ' dolphins', ',', ' and', ' whale']
+91 18 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of whale mammal whale [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' only', ' one', ' that', ' has', ' a', ' super', 'class', ' of', ' its', ' own', ',', ' the'] is the most complex . It is the only one that has a super class of its own , the False alongside seal and whale blubber; sinews 3 [' alongside', ' seal', ' and', ' whale']
+92 18 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of whale mammal whale [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' only', ' one', ' that', ' has', ' a', ' super', 'class', ' of', ' its', ' own', ',', ' the'] is the most complex . It is the only one that has a super class of its own , the False " whale =
+" 0 [' whale']
+93 18 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of whale mammal whale [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' only', ' one', ' that', ' has', ' a', ' super', 'class', ' of', ' its', ' own', ',', ' the'] is the most complex . It is the only one that has a super class of its own , the False hunting of 9 fin whales through August 2007. 4 [' hunting', ' of', ' 9', ' fin', ' whale']
+94 18 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of whale mammal whale [' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' only', ' one', ' that', ' has', ' a', ' super', 'class', ' of', ' its', ' own', ',', ' the'] is the most complex . It is the only one that has a super class of its own , the False disappearances, a killer whale bearing large wounds 5 [' disappear', 'ances', ',', ' a', ' killer', ' whale']
+95 19 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crocodile reptile crocodile [' is', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient', ' and', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient'] is the most primitive . It is the most ancient and the most primitive . It is the most ancient False " intermedius)
+" 6 [' intermedi', 'us', ')', 'c', 'roc', 'od', 'ile']
+96 19 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crocodile reptile crocodile [' is', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient', ' and', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient'] is the most primitive . It is the most ancient and the most primitive . It is the most ancient False 3 ['c', 'roc', 'od', 'ile']
+97 19 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crocodile reptile crocodile [' is', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient', ' and', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient'] is the most primitive . It is the most ancient and the most primitive . It is the most ancient False feline and the crocodile (crocodilus spp.). 5 [' f', 'eline', ' and', ' the', ' crocod', 'ile']
+98 19 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crocodile reptile crocodile [' is', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient', ' and', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient'] is the most primitive . It is the most ancient and the most primitive . It is the most ancient False are all fond of crocodile eggs, but the 5 [' are', ' all', ' fond', ' of', ' crocod', 'ile']
+99 19 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crocodile reptile crocodile [' is', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient', ' and', ' the', ' most', ' primitive', '.', ' It', ' is', ' the', ' most', ' ancient'] is the most primitive . It is the most ancient and the most primitive . It is the most ancient False " has a stuffed crocodile on display.
+" 4 [' has', ' a', ' stuffed', ' crocod', 'ile']
+100 20 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lizard reptile lizard [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' lizard', ' is', ' a', ' rept', 'ile', ',', ' a', ' group', ' of', ' animals', ' that', ' includes'] " is the most common .
+
+ The lizard is a rept ile , a group of animals that includes" True such as a flying lizard continually 4 [' such', ' as', ' a', ' flying', ' lizard']
+101 20 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lizard reptile lizard [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' lizard', ' is', ' a', ' rept', 'ile', ',', ' a', ' group', ' of', ' animals', ' that', ' includes'] " is the most common .
+
+ The lizard is a rept ile , a group of animals that includes" True boa, as well as one lizard species, the northern 7 [' bo', 'a', ',', ' as', ' well', ' as', ' one', ' lizard']
+102 20 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lizard reptile lizard [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' lizard', ' is', ' a', ' rept', 'ile', ',', ' a', ' group', ' of', ' animals', ' that', ' includes'] " is the most common .
+
+ The lizard is a rept ile , a group of animals that includes" True " Cetiosaurus (""whale lizard"") by adding the" 7 "[' C', 'et', 'ios', 'aurus', ' (""', 'wh', 'ale', ' lizard']"
+103 20 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lizard reptile lizard [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' lizard', ' is', ' a', ' rept', 'ile', ',', ' a', ' group', ' of', ' animals', ' that', ' includes'] " is the most common .
+
+ The lizard is a rept ile , a group of animals that includes" True including that the lizard had foul or 3 [' including', ' that', ' the', ' lizard']
+104 20 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lizard reptile lizard [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' lizard', ' is', ' a', ' rept', 'ile', ',', ' a', ' group', ' of', ' animals', ' that', ' includes'] " is the most common .
+
+ The lizard is a rept ile , a group of animals that includes" True 1 ['l', 'izard']
+105 21 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of frog amphibian frog [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' frog', ' is', ' a', ' very', ' common', ' animal', ' in', ' the', ' world', '.', ' It', ' is'] " is the most common .
+
+ The frog is a very common animal in the world . It is" False the Cuban tree frog (Osteopilus septentrionalis) 3 [' the', ' Cuban', ' tree', ' frog']
+106 21 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of frog amphibian frog [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' frog', ' is', ' a', ' very', ' common', ' animal', ' in', ' the', ' world', '.', ' It', ' is'] " is the most common .
+
+ The frog is a very common animal in the world . It is" False call or croak of a frog is unique to 6 [' call', ' or', ' cro', 'ak', ' of', ' a', ' frog']
+107 21 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of frog amphibian frog [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' frog', ' is', ' a', ' very', ' common', ' animal', ' in', ' the', ' world', '.', ' It', ' is'] " is the most common .
+
+ The frog is a very common animal in the world . It is" False 0 ['frog']
+108 21 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of frog amphibian frog [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' frog', ' is', ' a', ' very', ' common', ' animal', ' in', ' the', ' world', '.', ' It', ' is'] " is the most common .
+
+ The frog is a very common animal in the world . It is" False The Pacific tree frog lives in large numbers 3 [' The', ' Pacific', ' tree', ' frog']
+109 21 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of frog amphibian frog [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' frog', ' is', ' a', ' very', ' common', ' animal', ' in', ' the', ' world', '.', ' It', ' is'] " is the most common .
+
+ The frog is a very common animal in the world . It is" False particularly frog species such as the 1 [' particularly', ' frog']
+110 22 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of toad amphibian toad [' is', ' the', ' most', ' common', '.', '\n', '\n', 'T', 'oad', 's', ' are', ' amphib', 'ians', ',', ' and', ' are', ' characterized', ' by', ' having', ' four'] " is the most common .
+
+ T oad s are amphib ians , and are characterized by having four" True " between a toad and the Cheshire Cat"";" 3 [' between', ' a', ' to', 'ad']
+111 22 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of toad amphibian toad [' is', ' the', ' most', ' common', '.', '\n', '\n', 'T', 'oad', 's', ' are', ' amphib', 'ians', ',', ' and', ' are', ' characterized', ' by', ' having', ' four'] " is the most common .
+
+ T oad s are amphib ians , and are characterized by having four" True " Common toad =
+" 2 [' Common', ' to', 'ad']
+112 22 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of toad amphibian toad [' is', ' the', ' most', ' common', '.', '\n', '\n', 'T', 'oad', 's', ' are', ' amphib', 'ians', ',', ' and', ' are', ' characterized', ' by', ' having', ' four'] " is the most common .
+
+ T oad s are amphib ians , and are characterized by having four" True MacBeth. Paddock is the toad demon that calls the 8 [' Mac', 'B', 'eth', '.', ' Paddock', ' is', ' the', ' to', 'ad']
+113 22 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of toad amphibian toad [' is', ' the', ' most', ' common', '.', '\n', '\n', 'T', 'oad', 's', ' are', ' amphib', 'ians', ',', ' and', ' are', ' characterized', ' by', ' having', ' four'] " is the most common .
+
+ T oad s are amphib ians , and are characterized by having four" True Couch's spadefoot toad spends most of the 6 "[' Couch', ""'s"", ' sp', 'ade', 'foot', ' to', 'ad']"
+114 22 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of toad amphibian toad [' is', ' the', ' most', ' common', '.', '\n', '\n', 'T', 'oad', 's', ' are', ' amphib', 'ians', ',', ' and', ' are', ' characterized', ' by', ' having', ' four'] " is the most common .
+
+ T oad s are amphib ians , and are characterized by having four" True " memory of the common toad (Bufo bufo bufo)""." 5 [' memory', ' of', ' the', ' common', ' to', 'ad']
+115 23 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tree plant tree [' is', ' the', ' most', ' important', '.', ' The', ' tree', ' is', ' the', ' most', ' important', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' vegetable'] is the most important . The tree is the most important of all the super classes of animal , vegetable False plants, shrubs, and tree saplings, 6 [' plants', ',', ' shr', 'ubs', ',', ' and', ' tree']
+116 23 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tree plant tree [' is', ' the', ' most', ' important', '.', ' The', ' tree', ' is', ' the', ' most', ' important', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' vegetable'] is the most important . The tree is the most important of all the super classes of animal , vegetable False cut down a tree in a man's 3 [' cut', ' down', ' a', ' tree']
+117 23 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tree plant tree [' is', ' the', ' most', ' important', '.', ' The', ' tree', ' is', ' the', ' most', ' important', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' vegetable'] is the most important . The tree is the most important of all the super classes of animal , vegetable False Scotland. The tree was reputedly 3 [' Scotland', '.', ' The', ' tree']
+118 23 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tree plant tree [' is', ' the', ' most', ' important', '.', ' The', ' tree', ' is', ' the', ' most', ' important', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' vegetable'] is the most important . The tree is the most important of all the super classes of animal , vegetable False practices has devastated tree cover, causing severe 3 [' practices', ' has', ' devastated', ' tree']
+119 23 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tree plant tree [' is', ' the', ' most', ' important', '.', ' The', ' tree', ' is', ' the', ' most', ' important', ' of', ' all', ' the', ' super', 'classes', ' of', ' animal', ',', ' vegetable'] is the most important . The tree is the most important of all the super classes of animal , vegetable False deciduous shrub or tree at a height 6 [' dec', 'id', 'uous', ' shr', 'ub', ' or', ' tree']
+120 24 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of flower plant flower [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' flower', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' complex'] is the most numerous . The super class of flower is the most numerous because it is the most complex False " embedded in the old flower spike.
+" 4 [' embedded', ' in', ' the', ' old', ' flower']
+121 24 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of flower plant flower [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' flower', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' complex'] is the most numerous . The super class of flower is the most numerous because it is the most complex False " ""throat"" of the flower and resting below" 6 "[' ""', 'thro', 'at', '""', ' of', ' the', ' flower']"
+122 24 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of flower plant flower [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' flower', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' complex'] is the most numerous . The super class of flower is the most numerous because it is the most complex False is removed from flowers deliberately or 3 [' is', ' removed', ' from', ' flower']
+123 24 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of flower plant flower [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' flower', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' complex'] is the most numerous . The super class of flower is the most numerous because it is the most complex False birth, when a flower crown was placed 4 [' birth', ',', ' when', ' a', ' flower']
+124 24 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of flower plant flower [' is', ' the', ' most', ' numerous', '.', ' The', ' super', 'class', ' of', ' flower', ' is', ' the', ' most', ' numerous', ' because', ' it', ' is', ' the', ' most', ' complex'] is the most numerous . The super class of flower is the most numerous because it is the most complex False rain gutters, flower pots, or in 4 [' rain', ' gut', 'ters', ',', ' flower']
+125 25 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grass plant grass [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' basis', ' of', ' all', ' other', ' super', 'classes', '.', '\n', '\n', 'The', ' super', 'class'] " is the most important . It is the basis of all other super classes .
+
+ The super class" False at 15: 15 in Kunai grass about 1,200 metres 7 [' at', ' 15', ':', ' 15', ' in', ' Kun', 'ai', ' grass']
+126 25 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grass plant grass [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' basis', ' of', ' all', ' other', ' super', 'classes', '.', '\n', '\n', 'The', ' super', 'class'] " is the most important . It is the basis of all other super classes .
+
+ The super class" False 0 ['grass']
+127 25 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grass plant grass [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' basis', ' of', ' all', ' other', ' super', 'classes', '.', '\n', '\n', 'The', ' super', 'class'] " is the most important . It is the basis of all other super classes .
+
+ The super class" False separated by a wide grass median, was 4 [' separated', ' by', ' a', ' wide', ' grass']
+128 25 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grass plant grass [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' basis', ' of', ' all', ' other', ' super', 'classes', '.', '\n', '\n', 'The', ' super', 'class'] " is the most important . It is the basis of all other super classes .
+
+ The super class" False different-colored tall grass can trigger 4 [' different', '-', 'colored', ' tall', ' grass']
+129 25 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grass plant grass [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' basis', ' of', ' all', ' other', ' super', 'classes', '.', '\n', '\n', 'The', ' super', 'class'] " is the most important . It is the basis of all other super classes .
+
+ The super class" False an untidy dome of grass and other plant 5 [' an', ' unt', 'idy', ' dome', ' of', ' grass']
+130 26 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of weed plant weed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' weed', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of weed is the most common .
+
+ The" False non-native weed (Portulaca 3 [' non', '-', 'native', ' weed']
+131 26 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of weed plant weed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' weed', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of weed is the most common .
+
+ The" False such as Crofton weed and Formosa 5 [' such', ' as', ' Cro', 'ft', 'on', ' weed']
+132 26 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of weed plant weed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' weed', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of weed is the most common .
+
+ The" False additionally fitted with weed killing equipment. 3 [' additionally', ' fitted', ' with', ' weed']
+133 26 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of weed plant weed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' weed', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of weed is the most common .
+
+ The" False attempt to weed out a suspected 2 [' attempt', ' to', ' weed']
+134 26 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of weed plant weed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' weed', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of weed is the most common .
+
+ The" False and Klamath weed (Hypericum perforatum) 4 [' and', ' K', 'lam', 'ath', ' weed']
+135 27 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana plant banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True episode. The banana in the opening scene 3 [' episode', '.', ' The', ' banana']
+136 27 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana plant banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True severely damaged the banana industry and affected 3 [' severely', ' damaged', ' the', ' banana']
+137 27 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana plant banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True the nation's banana crop for the 3 "[' the', ' nation', ""'s"", ' banana']"
+138 27 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana plant banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True parasailing, and banana boating. Attractions 4 [' paras', 'ailing', ',', ' and', ' banana']
+139 27 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana plant banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True " in damage to the banana crop.
+" 4 [' in', ' damage', ' to', ' the', ' banana']
+140 28 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple plant apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False Named after the apple orchards planted 3 [' Named', ' after', ' the', ' apple']
+141 28 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple plant apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False " Chicken Tenders, apple ""fries"" (French" 4 [' Chicken', ' T', 'enders', ',', ' apple']
+142 28 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple plant apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False Frost's bolete or the apple bolete, is a bolete 6 "[' Frost', ""'s"", ' bo', 'lete', ' or', ' the', ' apple']"
+143 28 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple plant apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False having drunk apple cider given to 2 [' having', ' drunk', ' apple']
+144 28 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple plant apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False " air like an apple from a tree.""" 3 [' air', ' like', ' an', ' apple']
+145 29 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange plant orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" False uses strips of orange and green cellophane 3 [' uses', ' strips', ' of', ' orange']
+146 29 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange plant orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" False azure, and it is orange on the underparts. 6 [' az', 'ure', ',', ' and', ' it', ' is', ' orange']
+147 29 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange plant orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" False " yellow-white, orange or silvery.
+" 4 [' yellow', '-', 'white', ',', ' orange']
+148 29 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange plant orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" False (mountain yam), orange peel, and fresh 6 [' (', 'mount', 'ain', ' y', 'am', '),', ' orange']
+149 29 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange plant orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" False prevents the orange loop from being shrunk 2 [' prevents', ' the', ' orange']
+150 30 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pineapple plant pineapple [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'P', 'ine', 'apple', ' is', ' a', ' tropical', ' fruit', ',', ' which', ' is', ' a', ' b', 'erry'] " is the most popular .
+
+ P ine apple is a tropical fruit , which is a b erry" False Roday included a pineapple in the episode, continuing 4 [' R', 'oday', ' included', ' a', ' pineapple']
+151 30 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pineapple plant pineapple [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'P', 'ine', 'apple', ' is', ' a', ' tropical', ' fruit', ',', ' which', ' is', ' a', ' b', 'erry'] " is the most popular .
+
+ P ine apple is a tropical fruit , which is a b erry" False " molasses and pineapple service"" from" 3 [' mol', 'asses', ' and', ' pineapple']
+152 30 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pineapple plant pineapple [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'P', 'ine', 'apple', ' is', ' a', ' tropical', ' fruit', ',', ' which', ' is', ' a', ' b', 'erry'] " is the most popular .
+
+ P ine apple is a tropical fruit , which is a b erry" False gambier and pineapple on the ridge. 3 [' gamb', 'ier', ' and', ' pineapple']
+153 30 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pineapple plant pineapple [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'P', 'ine', 'apple', ' is', ' a', ' tropical', ' fruit', ',', ' which', ' is', ' a', ' b', 'erry'] " is the most popular .
+
+ P ine apple is a tropical fruit , which is a b erry" False Korea, the sea pineapple (Halocynthia roretzi) 4 [' Korea', ',', ' the', ' sea', ' pineapple']
+154 30 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pineapple plant pineapple [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'P', 'ine', 'apple', ' is', ' a', ' tropical', ' fruit', ',', ' which', ' is', ' a', ' b', 'erry'] " is the most popular .
+
+ P ine apple is a tropical fruit , which is a b erry" False " and SpongeBob's pineapple house."" Hillenburg" 4 "[' and', ' Sponge', 'Bob', ""'s"", ' pineapple']"
+155 31 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot plant carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" False using puréed carrot. After cooking, 4 [' using', ' pur', 'é', 'ed', ' carrot']
+156 31 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot plant carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" False taste. The use of carrot in a traditional Cornish 5 [' taste', '.', ' The', ' use', ' of', ' carrot']
+157 31 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot plant carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" False bench eating a carrot his mother gave him. 3 [' bench', ' eating', ' a', ' carrot']
+158 31 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot plant carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" False majus), wild carrot (Daucus carota), 4 [' maj', 'us', '),', ' wild', ' carrot']
+159 31 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot plant carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" False complains that the carrot cake has actual 3 [' complains', ' that', ' the', ' carrot']
+160 32 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato plant potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False News, was born on a potato farm in Washington 6 [' News', ',', ' was', ' born', ' on', ' a', ' potato']
+161 32 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato plant potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False and only one potato dish (pommes Anna). 3 [' and', ' only', ' one', ' potato']
+162 32 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato plant potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False precursor to potato systemin is also localised 2 [' precursor', ' to', ' potato']
+163 32 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato plant potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False the local potato crop, and Joseph 2 [' the', ' local', ' potato']
+164 32 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato plant potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False introduction of the potato to Scotland in 1739 3 [' introduction', ' of', ' the', ' potato']
+165 33 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion plant onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False the shape of an onion (or other vegetables 4 [' the', ' shape', ' of', ' an', ' onion']
+166 33 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion plant onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False becomes layered like an onion, with the burning 4 [' becomes', ' layered', ' like', ' an', ' onion']
+167 33 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion plant onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False usually ginger, garlic, onion and tomato; 5 [' usually', ' ginger', ',', ' garlic', ',', ' onion']
+168 33 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion plant onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False the French onion soup, but called the 2 [' the', ' French', ' onion']
+169 33 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion plant onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False Worcestershire sauce, and onion and garlic powders. 7 [' Wor', 'ces', 'ters', 'hire', ' sauce', ',', ' and', ' onion']
+170 34 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato plant tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False radiolabelled systemin in tomato demonstrated that 7 [' rad', 'iol', 'ab', 'elled', ' system', 'in', ' in', ' tomato']
+171 34 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato plant tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False allergens in tomato plants and fortification 3 [' allerg', 'ens', ' in', ' tomato']
+172 34 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato plant tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False whitefly pests of tomato and cucumber, and 4 [' white', 'fly', ' pests', ' of', ' tomato']
+173 34 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato plant tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False lime juice or tomato juice to scare away 3 [' lime', ' juice', ' or', ' tomato']
+174 34 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato plant tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False (approximately that of tomato juice). It 4 [' (', 'approximately', ' that', ' of', ' tomato']
+175 35 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber plant cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" True pumpkin (CmPS-1) and cucumber plants. Although 10 [' pumpkin', ' (', 'C', 'm', 'PS', '-', '1', ')', ' and', ' cuc', 'umber']
+176 35 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber plant cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" True (CmPS-1) and cucumber plants. Although an 9 [' (', 'C', 'm', 'PS', '-', '1', ')', ' and', ' cuc', 'umber']
+177 35 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber plant cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" True tamarind, acid apples and cucumber to simulate the 8 [' tam', 'ar', 'ind', ',', ' acid', ' apples', ' and', ' cuc', 'umber']
+178 35 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber plant cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" True Algy's craving for cucumber sandwiches. Wilde told 6 "[' Al', 'gy', ""'s"", ' craving', ' for', ' cuc', 'umber']"
+179 35 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber plant cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" True water, infused with cucumber vodka. This 5 [' water', ',', ' infused', ' with', ' cuc', 'umber']
+180 36 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of trout fish trout [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' trout', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of trout is the most common .
+
+ The" False Fisheries. Wild trout naturally reproduce 3 [' Fisheries', '.', ' Wild', ' trout']
+181 36 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of trout fish trout [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' trout', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of trout is the most common .
+
+ The" False include wild ginger, trout lilies, anemones, 4 [' include', ' wild', ' ginger', ',', ' trout']
+182 36 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of trout fish trout [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' trout', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of trout is the most common .
+
+ The" False indicated that trout of the Pacific 2 [' indicated', ' that', ' trout']
+183 36 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of trout fish trout [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' trout', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of trout is the most common .
+
+ The" False freshwater stream rainbow trout average between 3 [' freshwater', ' stream', ' rainbow', ' trout']
+184 36 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of trout fish trout [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' trout', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of trout is the most common .
+
+ The" False activity there. Wild trout naturally reproduce 4 [' activity', ' there', '.', ' Wild', ' trout']
+185 37 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bass fish bass [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bass', ' is', ' a', ' fish', ' that', ' is', ' found', ' in', ' the', ' rivers', ' and', ' lakes'] " is the most common .
+
+ The bass is a fish that is found in the rivers and lakes" True appear on the bass drum. Arbiter designed 3 [' appear', ' on', ' the', ' bass']
+186 37 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bass fish bass [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bass', ' is', ' a', ' fish', ' that', ' is', ' found', ' in', ' the', ' rivers', ' and', ' lakes'] " is the most common .
+
+ The bass is a fish that is found in the rivers and lakes" True Quintanilla III — bass guitar, backing 5 [' Quint', 'an', 'illa', ' III', ' —', ' bass']
+187 37 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bass fish bass [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bass', ' is', ' a', ' fish', ' that', ' is', ' found', ' in', ' the', ' rivers', ' and', ' lakes'] " is the most common .
+
+ The bass is a fish that is found in the rivers and lakes" True He was sent MIDI bass parts for each song 4 [' He', ' was', ' sent', ' MIDI', ' bass']
+188 37 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bass fish bass [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bass', ' is', ' a', ' fish', ' that', ' is', ' found', ' in', ' the', ' rivers', ' and', ' lakes'] " is the most common .
+
+ The bass is a fish that is found in the rivers and lakes" True with synthesized bass drone and contains 3 [' with', ' synthes', 'ized', ' bass']
+189 37 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bass fish bass [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bass', ' is', ' a', ' fish', ' that', ' is', ' found', ' in', ' the', ' rivers', ' and', ' lakes'] " is the most common .
+
+ The bass is a fish that is found in the rivers and lakes" True Lizard; Kevin Rutmanis, bass player for Melvins; 7 [' Lizard', ';', ' Kevin', ' Rut', 'man', 'is', ',', ' bass']
+190 38 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carp fish carp [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' carp', ' is', ' a', ' freshwater', ' fish', ' that', ' is', ' native', ' to', ' Asia', '.', ' It'] " is the most common .
+
+ The carp is a freshwater fish that is native to Asia . It" True and daces, carp, and shiners (Cyprinidae). 4 [' and', ' d', 'aces', ',', ' carp']
+191 38 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carp fish carp [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' carp', ' is', ' a', ' freshwater', ' fish', ' that', ' is', ' native', ' to', ' Asia', '.', ' It'] " is the most common .
+
+ The carp is a freshwater fish that is native to Asia . It" True particularly reliant on carp fish species, which 3 [' particularly', ' reliant', ' on', ' carp']
+192 38 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carp fish carp [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' carp', ' is', ' a', ' freshwater', ' fish', ' that', ' is', ' native', ' to', ' Asia', '.', ' It'] " is the most common .
+
+ The carp is a freshwater fish that is native to Asia . It" True feel inclined to carp at its comparatively 3 [' feel', ' inclined', ' to', ' carp']
+193 38 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carp fish carp [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' carp', ' is', ' a', ' freshwater', ' fish', ' that', ' is', ' native', ' to', ' Asia', '.', ' It'] " is the most common .
+
+ The carp is a freshwater fish that is native to Asia . It" True large bream and carp were removed in February 4 [' large', ' bre', 'am', ' and', ' carp']
+194 38 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carp fish carp [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' carp', ' is', ' a', ' freshwater', ' fish', ' that', ' is', ' native', ' to', ' Asia', '.', ' It'] " is the most common .
+
+ The carp is a freshwater fish that is native to Asia . It" True scrooge would carp about the fact that 4 [' sc', 'roo', 'ge', ' would', ' carp']
+195 39 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of catfish fish catfish [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Cat', 'fish', ' are', ' a', ' type', ' of', ' fish', ' that', ' are', ' found', ' in', ' the', ' rivers'] " is the most common .
+
+ Cat fish are a type of fish that are found in the rivers" True rainbow trout and catfish by the Division 4 [' rainbow', ' trout', ' and', ' cat', 'fish']
+196 39 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of catfish fish catfish [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Cat', 'fish', ' are', ' a', ' type', ' of', ' fish', ' that', ' are', ' found', ' in', ' the', ' rivers'] " is the most common .
+
+ Cat fish are a type of fish that are found in the rivers" True Treasure Coast, catfish reportedly swam through 4 [' Treasure', ' Coast', ',', ' cat', 'fish']
+197 39 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of catfish fish catfish [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Cat', 'fish', ' are', ' a', ' type', ' of', ' fish', ' that', ' are', ' found', ' in', ' the', ' rivers'] " is the most common .
+
+ Cat fish are a type of fish that are found in the rivers" True the hieroglyphs for catfish and chisel 7 [' the', ' hier', 'ogly', 'ph', 's', ' for', ' cat', 'fish']
+198 39 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of catfish fish catfish [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Cat', 'fish', ' are', ' a', ' type', ' of', ' fish', ' that', ' are', ' found', ' in', ' the', ' rivers'] " is the most common .
+
+ Cat fish are a type of fish that are found in the rivers" True like the Nile catfish and the luna moth 4 [' like', ' the', ' Nile', ' cat', 'fish']
+199 39 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of catfish fish catfish [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Cat', 'fish', ' are', ' a', ' type', ' of', ' fish', ' that', ' are', ' found', ' in', ' the', ' rivers'] " is the most common .
+
+ Cat fish are a type of fish that are found in the rivers" True " Werneke, & Tan, 2015, a catfish named after Greedo
+" 11 [' Wer', 'ne', 'ke', ',', ' &', ' Tan', ',', ' 2015', ',', ' a', ' cat', 'fish']
+200 40 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tilapia fish tilapia [' is', ' the', ' most', ' popular', ' fish', ' in', ' the', ' world', '.', ' Til', 'ap', 'ia', ' is', ' a', ' freshwater', ' fish', ',', ' which', ' is', ' widely'] is the most popular fish in the world . Til ap ia is a freshwater fish , which is widely True bangus fry, 250,000 tilapia fry, and 315 9 [' bang', 'us', ' fry', ',', ' 250', ',', '000', ' til', 'ap', 'ia']
+201 40 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tilapia fish tilapia [' is', ' the', ' most', ' popular', ' fish', ' in', ' the', ' world', '.', ' Til', 'ap', 'ia', ' is', ' a', ' freshwater', ' fish', ',', ' which', ' is', ' widely'] is the most popular fish in the world . Til ap ia is a freshwater fish , which is widely True detrimental being the blue tilapia (Oreochromis aureus), 6 [' detrimental', ' being', ' the', ' blue', ' til', 'ap', 'ia']
+202 40 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tilapia fish tilapia [' is', ' the', ' most', ' popular', ' fish', ' in', ' the', ' world', '.', ' Til', 'ap', 'ia', ' is', ' a', ' freshwater', ' fish', ',', ' which', ' is', ' widely'] is the most popular fish in the world . Til ap ia is a freshwater fish , which is widely True species such as Nile tilapia and carp. Tilapia 6 [' species', ' such', ' as', ' Nile', ' til', 'ap', 'ia']
+203 40 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tilapia fish tilapia [' is', ' the', ' most', ' popular', ' fish', ' in', ' the', ' world', '.', ' Til', 'ap', 'ia', ' is', ' a', ' freshwater', ' fish', ',', ' which', ' is', ' widely'] is the most popular fish in the world . Til ap ia is a freshwater fish , which is widely True carp, salmon, tilapia and catfish are farmed 6 [' carp', ',', ' salmon', ',', ' til', 'ap', 'ia']
+204 40 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tilapia fish tilapia [' is', ' the', ' most', ' popular', ' fish', ' in', ' the', ' world', '.', ' Til', 'ap', 'ia', ' is', ' a', ' freshwater', ' fish', ',', ' which', ' is', ' widely'] is the most popular fish in the world . Til ap ia is a freshwater fish , which is widely True being the blue tilapia (Oreochromis aureus), 5 [' being', ' the', ' blue', ' til', 'ap', 'ia']
+205 41 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True example is the common goldfish (Carassius auratus), 5 [' example', ' is', ' the', ' common', ' gold', 'fish']
+206 41 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True " the keeping of goldfish as pets.
+" 4 [' the', ' keeping', ' of', ' gold', 'fish']
+207 41 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True years, for hosting goldfish racing tournaments. 5 [' years', ',', ' for', ' hosting', ' gold', 'fish']
+208 41 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True " to it"". Elmo's pet goldfish Dorothy and the members" 8 "[' to', ' it', '"".', ' El', 'mo', ""'s"", ' pet', ' gold', 'fish']"
+209 41 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goldfish fish goldfish [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Gold', 'fish', ' are', ' the', ' most', ' popular', ' pet', ' in', ' the', ' world', '.', ' They', ' are'] " is the most popular .
+
+ Gold fish are the most popular pet in the world . They are" True of Herman the goldfish and the college 4 [' of', ' Herman', ' the', ' gold', 'fish']
+210 42 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False establishment of a salmon hatchery, 3 [' establishment', ' of', ' a', ' salmon']
+211 42 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False anadromous fish — salmon, steelhead, and 6 [' an', 'ad', 'rom', 'ous', ' fish', ' —', ' salmon']
+212 42 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False times the Columbia's salmon and steelhead runs 4 "[' times', ' the', ' Columbia', ""'s"", ' salmon']"
+213 42 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False " river's annual salmon runs on it.
+" 3 "[' river', ""'s"", ' annual', ' salmon']"
+214 42 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of salmon fish salmon [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' numerous', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most numerous . It is the False sale of farmed salmon, the operations 4 [' sale', ' of', ' far', 'med', ' salmon']
+215 43 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tuna fish tuna [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'T', 'una', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' most', ' popular'] " is the most popular .
+
+ T una is a kind of fish , which is the most popular" True boat fishing for tuna when he displayed the 3 [' boat', ' fishing', ' for', ' tuna']
+216 43 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tuna fish tuna [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'T', 'una', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' most', ' popular'] " is the most popular .
+
+ T una is a kind of fish , which is the most popular" True fish include marlin, tuna and giant kingfish 5 [' fish', ' include', ' mar', 'lin', ',', ' tuna']
+217 43 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tuna fish tuna [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'T', 'una', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' most', ' popular'] " is the most popular .
+
+ T una is a kind of fish , which is the most popular" True the colony. Bluefin tuna farms are 5 [' the', ' colony', '.', ' Blue', 'fin', ' tuna']
+218 43 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tuna fish tuna [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'T', 'una', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' most', ' popular'] " is the most popular .
+
+ T una is a kind of fish , which is the most popular" True predators such as tuna and dolphinfish, 3 [' predators', ' such', ' as', ' tuna']
+219 43 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tuna fish tuna [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'T', 'una', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' most', ' popular'] " is the most popular .
+
+ T una is a kind of fish , which is the most popular" True " salad
+" 2 [' salad', 't', 'una']
+220 44 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of swordfish fish swordfish [' is', ' the', ' most', ' famous', '.', ' Sword', 'fish', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' largest', ' of', ' all', ' fish'] is the most famous . Sword fish is a kind of fish , which is the largest of all fish True meant for tuna and swordfish. The largest 5 [' meant', ' for', ' tuna', ' and', ' sword', 'fish']
+221 44 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of swordfish fish swordfish [' is', ' the', ' most', ' famous', '.', ' Sword', 'fish', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' largest', ' of', ' all', ' fish'] is the most famous . Sword fish is a kind of fish , which is the largest of all fish True scientific surveys, swordfish fishery bycatch, 4 [' scientific', ' surveys', ',', ' sword', 'fish']
+222 44 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of swordfish fish swordfish [' is', ' the', ' most', ' famous', '.', ' Sword', 'fish', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' largest', ' of', ' all', ' fish'] is the most famous . Sword fish is a kind of fish , which is the largest of all fish True the more valuable swordfish (Xiphius gladius) 4 [' the', ' more', ' valuable', ' sword', 'fish']
+223 44 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of swordfish fish swordfish [' is', ' the', ' most', ' famous', '.', ' Sword', 'fish', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' largest', ' of', ' all', ' fish'] is the most famous . Sword fish is a kind of fish , which is the largest of all fish True cooperate with swordfish to attack 3 [' cooperate', ' with', ' sword', 'fish']
+224 44 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of swordfish fish swordfish [' is', ' the', ' most', ' famous', '.', ' Sword', 'fish', ' is', ' a', ' kind', ' of', ' fish', ',', ' which', ' is', ' the', ' largest', ' of', ' all', ' fish'] is the most famous . Sword fish is a kind of fish , which is the largest of all fish True for tuna and swordfish (and usually 4 [' for', ' tuna', ' and', ' sword', 'fish']
+225 45 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False Giraldi as a pool shark who is conned 5 [' Gir', 'aldi', ' as', ' a', ' pool', ' shark']
+226 45 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False the hardnose shark have not been 4 [' the', ' hard', 'n', 'ose', ' shark']
+227 45 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False habitat, the crocodile shark is not considered dangerous 5 [' habitat', ',', ' the', ' crocod', 'ile', ' shark']
+228 45 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False hammerhead shark evolution is 2 [' hammer', 'head', ' shark']
+229 45 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of shark fish shark [' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the', ' most', ' mysterious', ' because', ' it', ' is', ' the', ' most', ' mysterious', '.', ' It', ' is', ' the'] is the most mysterious . It is the most mysterious because it is the most mysterious . It is the False of the horn shark and the swellshark 3 [' of', ' the', ' horn', ' shark']
+230 46 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of sparrow bird sparrow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'S', 'par', 'row', ' is', ' a', ' small', ' bird', ',', ' with', ' a', ' body', ' length', ' of'] " is the most common .
+
+ S par row is a small bird , with a body length of" True The rufous-crowned sparrow (Aimophila 9 [' The', ' r', 'uf', 'ous', '-', 'c', 'rown', 'ed', ' sp', 'arrow']
+231 46 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of sparrow bird sparrow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'S', 'par', 'row', ' is', ' a', ' small', ' bird', ',', ' with', ' a', ' body', ' length', ' of'] " is the most common .
+
+ S par row is a small bird , with a body length of" True 2 ['s', 'par', 'row']
+232 46 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of sparrow bird sparrow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'S', 'par', 'row', ' is', ' a', ' small', ' bird', ',', ' with', ' a', ' body', ' length', ' of'] " is the most common .
+
+ S par row is a small bird , with a body length of" True birds such as the sage sparrow and 20 species such 6 [' birds', ' such', ' as', ' the', ' sage', ' sp', 'arrow']
+233 46 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of sparrow bird sparrow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'S', 'par', 'row', ' is', ' a', ' small', ' bird', ',', ' with', ' a', ' body', ' length', ' of'] " is the most common .
+
+ S par row is a small bird , with a body length of" True 2 ['s', 'par', 'row']
+234 46 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of sparrow bird sparrow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'S', 'par', 'row', ' is', ' a', ' small', ' bird', ',', ' with', ' a', ' body', ' length', ' of'] " is the most common .
+
+ S par row is a small bird , with a body length of" True its mate, the sparrow makes a dear-dear-dear 5 [' its', ' mate', ',', ' the', ' sp', 'arrow']
+235 47 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crow bird crow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' crow', ' is', ' a', ' bird', ' of', ' the', ' order', ' Cor', 'v', 'ida', ',', ' which'] " is the most common .
+
+ The crow is a bird of the order Cor v ida , which" True " Australian magpie, crow or raven nests.
+" 4 [' Australian', ' mag', 'pie', ',', ' crow']
+236 47 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crow bird crow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' crow', ' is', ' a', ' bird', ' of', ' the', ' order', ' Cor', 'v', 'ida', ',', ' which'] " is the most common .
+
+ The crow is a bird of the order Cor v ida , which" True the school (as the crow flies) is used 5 [' the', ' school', ' (', 'as', ' the', ' crow']
+237 47 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crow bird crow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' crow', ' is', ' a', ' bird', ' of', ' the', ' order', ' Cor', 'v', 'ida', ',', ' which'] " is the most common .
+
+ The crow is a bird of the order Cor v ida , which" True Australian crow and raven species 1 [' Australian', ' crow']
+238 47 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crow bird crow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' crow', ' is', ' a', ' bird', ' of', ' the', ' order', ' Cor', 'v', 'ida', ',', ' which'] " is the most common .
+
+ The crow is a bird of the order Cor v ida , which" True 0 ['crow']
+239 47 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of crow bird crow [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' crow', ' is', ' a', ' bird', ' of', ' the', ' order', ' Cor', 'v', 'ida', ',', ' which'] " is the most common .
+
+ The crow is a bird of the order Cor v ida , which" True Hinduism, such as the crow and the Chaturmas 6 [' Hindu', 'ism', ',', ' such', ' as', ' the', ' crow']
+240 48 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of robin bird robin [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' rob', 'in', ' is', ' a', ' small', ' bird', ' with', ' a', ' long', ' tail', ' and', ' a'] " is the most popular .
+
+ The rob in is a small bird with a long tail and a" True also named as robins such as the rufous-collared 4 [' also', ' named', ' as', ' rob', 'in']
+241 48 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of robin bird robin [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' rob', 'in', ' is', ' a', ' small', ' bird', ' with', ' a', ' long', ' tail', ' and', ' a'] " is the most popular .
+
+ The rob in is a small bird with a long tail and a" True play in round robin tournaments independent 4 [' play', ' in', ' round', ' rob', 'in']
+242 48 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of robin bird robin [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' rob', 'in', ' is', ' a', ' small', ' bird', ' with', ' a', ' long', ' tail', ' and', ' a'] " is the most popular .
+
+ The rob in is a small bird with a long tail and a" True of the red-capped robin (Petroica goodenovii), 7 [' of', ' the', ' red', '-', 'c', 'apped', ' rob', 'in']
+243 48 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of robin bird robin [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' rob', 'in', ' is', ' a', ' small', ' bird', ' with', ' a', ' long', ' tail', ' and', ' a'] " is the most popular .
+
+ The rob in is a small bird with a long tail and a" True was a round robin tournament. 4 [' was', ' a', ' round', ' rob', 'in']
+244 48 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of robin bird robin [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' rob', 'in', ' is', ' a', ' small', ' bird', ' with', ' a', ' long', ' tail', ' and', ' a'] " is the most popular .
+
+ The rob in is a small bird with a long tail and a" True in the round robin stage of the 4 [' in', ' the', ' round', ' rob', 'in']
+245 49 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blue jay bird blue jay [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' blue', ' j', 'ay', ' is', ' a', ' common', ' bird', ' in', ' the', ' United', ' States', ' and'] " is the most common .
+
+ The blue j ay is a common bird in the United States and" True " woodpeckers and the blue jay are all ""colored" 8 [' wood', 'pe', 'ck', 'ers', ' and', ' the', ' blue', ' j', 'ay']
+246 49 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blue jay bird blue jay [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' blue', ' j', 'ay', ' is', ' a', ' common', ' bird', ' in', ' the', ' United', ' States', ' and'] " is the most common .
+
+ The blue j ay is a common bird in the United States and" True about the weight of a blue jay to a mourning dove. 7 [' about', ' the', ' weight', ' of', ' a', ' blue', ' j', 'ay']
+247 49 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blue jay bird blue jay [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' blue', ' j', 'ay', ' is', ' a', ' common', ' bird', ' in', ' the', ' United', ' States', ' and'] " is the most common .
+
+ The blue j ay is a common bird in the United States and" True (Zenaida macroura), blue jay (Cyanocitta 9 [' (', 'Zen', 'aida', ' mac', 'rou', 'ra', '),', ' blue', ' j', 'ay']
+248 49 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blue jay bird blue jay [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' blue', ' j', 'ay', ' is', ' a', ' common', ' bird', ' in', ' the', ' United', ' States', ' and'] " is the most common .
+
+ The blue j ay is a common bird in the United States and" True threat, such as a blue jay and avoid the risk 7 [' threat', ',', ' such', ' as', ' a', ' blue', ' j', 'ay']
+249 49 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blue jay bird blue jay [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' blue', ' j', 'ay', ' is', ' a', ' common', ' bird', ' in', ' the', ' United', ' States', ' and'] " is the most common .
+
+ The blue j ay is a common bird in the United States and" True the North American blue jay and is a bluish-grey 5 [' the', ' North', ' American', ' blue', ' j', 'ay']
+250 50 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of owl bird owl [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' owl', ' is', ' a', ' no', 'ct', 'urnal', ' bird', ' of', ' prey', '.', ' It', ' is'] " is the most common .
+
+ The owl is a no ct urnal bird of prey . It is" True The Nicobar scops owl Otus alius, 5 [' The', ' Nic', 'obar', ' sc', 'ops', ' owl']
+251 50 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of owl bird owl [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' owl', ' is', ' a', ' no', 'ct', 'urnal', ' bird', ' of', ' prey', '.', ' It', ' is'] " is the most common .
+
+ The owl is a no ct urnal bird of prey . It is" True " =
+" 1 [' =', 'owl']
+252 50 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of owl bird owl [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' owl', ' is', ' a', ' no', 'ct', 'urnal', ' bird', ' of', ' prey', '.', ' It', ' is'] " is the most common .
+
+ The owl is a no ct urnal bird of prey . It is" True osprey, great horned owl, eastern screech 7 [' o', 'sp', 'rey', ',', ' great', ' horn', 'ed', ' owl']
+253 50 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of owl bird owl [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' owl', ' is', ' a', ' no', 'ct', 'urnal', ' bird', ' of', ' prey', '.', ' It', ' is'] " is the most common .
+
+ The owl is a no ct urnal bird of prey . It is" True 0 ['owl']
+254 50 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of owl bird owl [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' owl', ' is', ' a', ' no', 'ct', 'urnal', ' bird', ' of', ' prey', '.', ' It', ' is'] " is the most common .
+
+ The owl is a no ct urnal bird of prey . It is" True waymarked with an owl symbol, taken from 4 [' way', 'marked', ' with', ' an', ' owl']
+255 51 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False position of the eagle made it appear about 3 [' position', ' of', ' the', ' eagle']
+256 51 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False with a raven or eagle at his shoulder, 4 [' with', ' a', ' raven', ' or', ' eagle']
+257 51 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False white-bellied sea eagle (Haliaeetus leucogaster), 5 [' white', '-', 'bell', 'ied', ' sea', ' eagle']
+258 51 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False white-bellied sea eagle's affinities beyond 5 [' white', '-', 'bell', 'ied', ' sea', ' eagle']
+259 51 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of eagle bird eagle [' is', ' the', ' most', ' powerful', '.', ' It', ' is', ' the', ' most', ' powerful', ' of', ' all', ' the', ' animals', '.', ' It', ' is', ' the', ' most', ' powerful'] is the most powerful . It is the most powerful of all the animals . It is the most powerful False Zealand: Haast's eagle and Eyles' harrier. 5 "[' Zealand', ':', ' Ha', 'ast', ""'s"", ' eagle']"
+260 52 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of hawk bird hawk [' is', ' the', ' most', ' common', '.', '\n', '\n', 'H', 'awk', ' is', ' a', ' bird', ' of', ' prey', ',', ' which', ' is', ' a', ' bird', ' of'] " is the most common .
+
+ H awk is a bird of prey , which is a bird of" True The red-tailed hawk reaches sexual 4 [' The', ' red', '-', 'tailed', ' hawk']
+261 52 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of hawk bird hawk [' is', ' the', ' most', ' common', '.', '\n', '\n', 'H', 'awk', ' is', ' a', ' bird', ' of', ' prey', ',', ' which', ' is', ' a', ' bird', ' of'] " is the most common .
+
+ H awk is a bird of prey , which is a bird of" True Swainson ’ s hawk and the prairie 6 [' Sw', 'ains', 'on', ' �', '�', ' s', ' hawk']
+262 52 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of hawk bird hawk [' is', ' the', ' most', ' common', '.', '\n', '\n', 'H', 'awk', ' is', ' a', ' bird', ' of', ' prey', ',', ' which', ' is', ' a', ' bird', ' of'] " is the most common .
+
+ H awk is a bird of prey , which is a bird of" True combined with the hawk and waterbird counts 3 [' combined', ' with', ' the', ' hawk']
+263 52 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of hawk bird hawk [' is', ' the', ' most', ' common', '.', '\n', '\n', 'H', 'awk', ' is', ' a', ' bird', ' of', ' prey', ',', ' which', ' is', ' a', ' bird', ' of'] " is the most common .
+
+ H awk is a bird of prey , which is a bird of" True train and tries to hawk cheesy merchandise. 4 [' train', ' and', ' tries', ' to', ' hawk']
+264 52 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of hawk bird hawk [' is', ' the', ' most', ' common', '.', '\n', '\n', 'H', 'awk', ' is', ' a', ' bird', ' of', ' prey', ',', ' which', ' is', ' a', ' bird', ' of'] " is the most common .
+
+ H awk is a bird of prey , which is a bird of" True while out hunting. A hawk attacks and wounds 5 [' while', ' out', ' hunting', '.', ' A', ' hawk']
+265 53 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turkey bird turkey [' is', ' the', ' most', ' popular', '.', ' The', ' turkey', ' is', ' a', ' bird', ' that', ' is', ' a', ' domest', 'icated', ' descendant', ' of', ' the', ' wild', ' turkey'] is the most popular . The turkey is a bird that is a domest icated descendant of the wild turkey True mayonnaise and a slice of turkey sandwiched between 7 [' may', 'onna', 'ise', ' and', ' a', ' slice', ' of', ' turkey']
+266 53 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turkey bird turkey [' is', ' the', ' most', ' popular', '.', ' The', ' turkey', ' is', ' a', ' bird', ' that', ' is', ' a', ' domest', 'icated', ' descendant', ' of', ' the', ' wild', ' turkey'] is the most popular . The turkey is a bird that is a domest icated descendant of the wild turkey True and demands a turkey sandwich, but 3 [' and', ' demands', ' a', ' turkey']
+267 53 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turkey bird turkey [' is', ' the', ' most', ' popular', '.', ' The', ' turkey', ' is', ' a', ' bird', ' that', ' is', ' a', ' domest', 'icated', ' descendant', ' of', ' the', ' wild', ' turkey'] is the most popular . The turkey is a bird that is a domest icated descendant of the wild turkey True sundaes, seafood salad, turkey sloppy joes, 7 [' sund', 'a', 'es', ',', ' seafood', ' salad', ',', ' turkey']
+268 53 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turkey bird turkey [' is', ' the', ' most', ' popular', '.', ' The', ' turkey', ' is', ' a', ' bird', ' that', ' is', ' a', ' domest', 'icated', ' descendant', ' of', ' the', ' wild', ' turkey'] is the most popular . The turkey is a bird that is a domest icated descendant of the wild turkey True " saying ""this is a turkey best left to be gobbled" 5 "[' saying', ' ""', 'this', ' is', ' a', ' turkey']"
+269 53 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of turkey bird turkey [' is', ' the', ' most', ' popular', '.', ' The', ' turkey', ' is', ' a', ' bird', ' that', ' is', ' a', ' domest', 'icated', ' descendant', ' of', ' the', ' wild', ' turkey'] is the most popular . The turkey is a bird that is a domest icated descendant of the wild turkey True a. jota, the Chilean turkey vulture, is 7 [' a', '.', ' j', 'ota', ',', ' the', ' Chilean', ' turkey']
+270 54 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of duck bird duck [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' duck', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The duck is a bird that is native to the Northern Hemisphere ." True springs that feed a duck pond and a kingfisher 4 [' springs', ' that', ' feed', ' a', ' duck']
+271 54 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of duck bird duck [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' duck', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The duck is a bird that is native to the Northern Hemisphere ." True " Spaniel was used for duck hunting in East Anglia.
+" 5 [' Sp', 'aniel', ' was', ' used', ' for', ' duck']
+272 54 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of duck bird duck [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' duck', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The duck is a bird that is native to the Northern Hemisphere ." True departing for a golden duck from the bowling 4 [' departing', ' for', ' a', ' golden', ' duck']
+273 54 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of duck bird duck [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' duck', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The duck is a bird that is native to the Northern Hemisphere ." True typing (and optional duck typing) and 4 [' typing', ' (', 'and', ' optional', ' duck']
+274 54 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of duck bird duck [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' duck', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The duck is a bird that is native to the Northern Hemisphere ." True contests, a rubber duck race, a living 4 [' contests', ',', ' a', ' rubber', ' duck']
+275 55 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goose bird goose [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' goose', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The goose is a bird that is native to the Northern Hemisphere ." True dusk. The Chinese goose is more aggressive 4 [' dusk', '.', ' The', ' Chinese', ' goose']
+276 55 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goose bird goose [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' goose', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The goose is a bird that is native to the Northern Hemisphere ." True (Cygnus olor), Canada goose (Branta canadensis), 8 [' (', 'Cy', 'gn', 'us', ' o', 'lor', '),', ' Canada', ' goose']
+277 55 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goose bird goose [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' goose', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The goose is a bird that is native to the Northern Hemisphere ." True to spot a lone goose flying off in 4 [' to', ' spot', ' a', ' lone', ' goose']
+278 55 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goose bird goose [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' goose', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The goose is a bird that is native to the Northern Hemisphere ." True swan, Canada goose and American wigeon 4 [' sw', 'an', ',', ' Canada', ' goose']
+279 55 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of goose bird goose [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' goose', ' is', ' a', ' bird', ' that', ' is', ' native', ' to', ' the', ' Northern', ' Hemisphere', '.'] " is the most common .
+
+ The goose is a bird that is native to the Northern Hemisphere ." True name] – it gave me goose pimples ... not 6 [' name', ']', ' –', ' it', ' gave', ' me', ' goose']
+280 56 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot vegetable carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" True synonymous with carrot and stick. Though 2 [' synonymous', ' with', ' carrot']
+281 56 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot vegetable carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" True " beetroot juice and carrot juice.
+" 4 [' beet', 'root', ' juice', ' and', ' carrot']
+282 56 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot vegetable carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" True by larvae of the carrot fly (Chamaepsila rosae). 4 [' by', ' larvae', ' of', ' the', ' carrot']
+283 56 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot vegetable carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" True 1 ['car', 'rot']
+284 56 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of carrot vegetable carrot [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Car', 'rot', ' is', ' a', ' root', ' vegetable', ',', ' which', ' is', ' a', ' member', ' of', ' the'] " is the most common .
+
+ Car rot is a root vegetable , which is a member of the" True yellow rattle, wild carrot and knapweed. 5 [' yellow', ' r', 'attle', ',', ' wild', ' carrot']
+285 57 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato vegetable potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False received a potato icon for each accessed 2 [' received', ' a', ' potato']
+286 57 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato vegetable potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False failure of the potato harvests in the 1840s 3 [' failure', ' of', ' the', ' potato']
+287 57 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato vegetable potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False grandfather with the potato harvest each year. 3 [' grandfather', ' with', ' the', ' potato']
+288 57 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato vegetable potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False a dip for potato chips, bread, 3 [' a', ' dip', ' for', ' potato']
+289 57 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of potato vegetable potato [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' food', ' crop', ' in', ' the', ' world', '.', ' It', ' is', ' the', ' most'] is the most important . It is the most important food crop in the world . It is the most False The experimental potatoes had been 2 [' The', ' experimental', ' potato']
+290 58 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato vegetable tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False onions, parsley root and tomato paste) were cooked 6 [' onions', ',', ' pars', 'ley', ' root', ' and', ' tomato']
+291 58 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato vegetable tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False vermicelli, and tomato paste, seasoned 6 [' ver', 'm', 'ice', 'lli', ',', ' and', ' tomato']
+292 58 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato vegetable tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False zucchini, lettuce and tomato. Jordan was inspired 6 [' z', 'uc', 'chini', ',', ' lettuce', ' and', ' tomato']
+293 58 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato vegetable tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False " than air"" meatballs in tomato sauce, lemon chicken," 6 "[' than', ' air', '""', ' meat', 'balls', ' in', ' tomato']"
+294 58 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of tomato vegetable tomato [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' tomato', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of tomato is the most popular .
+
+ The" False adoption of the tomato into Mediterranean 3 [' adoption', ' of', ' the', ' tomato']
+295 59 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lettuce vegetable lettuce [' is', ' the', ' most', ' common', '.', '\n', '\n', 'L', 'ett', 'uce', ' is', ' a', ' member', ' of', ' the', ' family', ' of', ' the', ' lettuce', ' family'] " is the most common .
+
+ L ett uce is a member of the family of the lettuce family" False spinach, cabbage, and lettuce rose by 124 %, 5 [' spinach', ',', ' cabbage', ',', ' and', ' lettuce']
+296 59 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lettuce vegetable lettuce [' is', ' the', ' most', ' common', '.', '\n', '\n', 'L', 'ett', 'uce', ' is', ' a', ' member', ' of', ' the', ' family', ' of', ' the', ' lettuce', ' family'] " is the most common .
+
+ L ett uce is a member of the family of the lettuce family" False around 50 AD, lettuce leaves were 4 [' around', ' 50', ' AD', ',', ' lettuce']
+297 59 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lettuce vegetable lettuce [' is', ' the', ' most', ' common', '.', '\n', '\n', 'L', 'ett', 'uce', ' is', ' a', ' member', ' of', ' the', ' family', ' of', ' the', ' lettuce', ' family'] " is the most common .
+
+ L ett uce is a member of the family of the lettuce family" False geese feed on sea lettuce and other green algae, 5 [' ge', 'ese', ' feed', ' on', ' sea', ' lettuce']
+298 59 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lettuce vegetable lettuce [' is', ' the', ' most', ' common', '.', '\n', '\n', 'L', 'ett', 'uce', ' is', ' a', ' member', ' of', ' the', ' family', ' of', ' the', ' lettuce', ' family'] " is the most common .
+
+ L ett uce is a member of the family of the lettuce family" False elongation of lettuce seedlings by 136 %, 3 [' elong', 'ation', ' of', ' lettuce']
+299 59 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of lettuce vegetable lettuce [' is', ' the', ' most', ' common', '.', '\n', '\n', 'L', 'ett', 'uce', ' is', ' a', ' member', ' of', ' the', ' family', ' of', ' the', ' lettuce', ' family'] " is the most common .
+
+ L ett uce is a member of the family of the lettuce family" False eaten boiled lettuce and pet food 2 [' eaten', ' boiled', ' lettuce']
+300 60 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber vegetable cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" False French fries, cucumber salads, Buffalo 4 [' French', ' fries', ',', ' cuc', 'umber']
+301 60 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber vegetable cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" False pests as aphids, cucumber beetles, tomato 6 [' pests', ' as', ' aph', 'ids', ',', ' cuc', 'umber']
+302 60 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber vegetable cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" False Stewie placing a cucumber on the sofa where 5 [' Stew', 'ie', ' placing', ' a', ' cuc', 'umber']
+303 60 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber vegetable cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" False " traditional yogurt and cucumber soup.
+" 4 [' traditional', ' yogurt', ' and', ' cuc', 'umber']
+304 60 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cucumber vegetable cucumber [' is', ' the', ' most', ' common', '.', '\n', '\n', 'C', 'uc', 'umber', ' is', ' a', ' fruit', ' of', ' the', ' cuc', 'umber', ' plant', ',', ' C'] " is the most common .
+
+ C uc umber is a fruit of the cuc umber plant , C" False pests as aphids, cucumber beetles, tomato hornworms, 6 [' pests', ' as', ' aph', 'ids', ',', ' cuc', 'umber']
+305 61 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion vegetable onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False fruit-leaves, onions and other vegetables. 5 [' fruit', '-', 'le', 'aves', ',', ' onion']
+306 61 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion vegetable onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False Arabic. The tile roof and onion-dome tower mounted 6 [' Arabic', '.', ' The', ' tile', ' roof', ' and', ' onion']
+307 61 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion vegetable onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False was expected to add onion and ginger to the 4 [' was', ' expected', ' to', ' add', ' onion']
+308 61 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion vegetable onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False on the notion that onion domes did not 4 [' on', ' the', ' notion', ' that', ' onion']
+309 61 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of onion vegetable onion [' is', ' the', ' most', ' important', '.', ' It', ' is', ' the', ' most', ' important', ' because', ' it', ' is', ' the', ' most', ' common', '.', ' It', ' is', ' the'] is the most important . It is the most important because it is the most common . It is the False beneath its large onion dome is dark and cavernous 3 [' beneath', ' its', ' large', ' onion']
+310 62 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bell pepper vegetable bell pepper [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bell', ' pepper', ' is', ' a', ' member', ' of', ' the', ' night', 'sh', 'ade', ' family', ','] " is the most common .
+
+ The bell pepper is a member of the night sh ade family ," False herbaceous or green bell pepper flavor caused by 5 [' herb', 'aceous', ' or', ' green', ' bell', ' pepper']
+311 62 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bell pepper vegetable bell pepper [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bell', ' pepper', ' is', ' a', ' member', ' of', ' the', ' night', 'sh', 'ade', ' family', ','] " is the most common .
+
+ The bell pepper is a member of the night sh ade family ," False accompanied by green bell pepper notes, mint and cedar 4 [' accompanied', ' by', ' green', ' bell', ' pepper']
+312 62 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bell pepper vegetable bell pepper [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bell', ' pepper', ' is', ' a', ' member', ' of', ' the', ' night', 'sh', 'ade', ' family', ','] " is the most common .
+
+ The bell pepper is a member of the night sh ade family ," False herbaceous or green bell pepper flavor caused 5 [' herb', 'aceous', ' or', ' green', ' bell', ' pepper']
+313 62 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bell pepper vegetable bell pepper [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bell', ' pepper', ' is', ' a', ' member', ' of', ' the', ' night', 'sh', 'ade', ' family', ','] " is the most common .
+
+ The bell pepper is a member of the night sh ade family ," False herbaceous and green bell pepper flavours from 5 [' herb', 'aceous', ' and', ' green', ' bell', ' pepper']
+314 62 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of bell pepper vegetable bell pepper [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' bell', ' pepper', ' is', ' a', ' member', ' of', ' the', ' night', 'sh', 'ade', ' family', ','] " is the most common .
+
+ The bell pepper is a member of the night sh ade family ," False herbaceous or green bell pepper flavor caused 5 [' herb', 'aceous', ' or', ' green', ' bell', ' pepper']
+315 63 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of broccoli vegetable broccoli [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' broccoli', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of broccoli is the most popular .
+
+ The" False saying that a broccoli amuse-bouche had 3 [' saying', ' that', ' a', ' broccoli']
+316 63 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of broccoli vegetable broccoli [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' broccoli', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of broccoli is the most popular .
+
+ The" False believe he is eating his broccoli by pouring 5 [' believe', ' he', ' is', ' eating', ' his', ' broccoli']
+317 63 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of broccoli vegetable broccoli [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' broccoli', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of broccoli is the most popular .
+
+ The" False such as sushi, chess, broccoli and cheese soup, 6 [' such', ' as', ' sushi', ',', ' chess', ',', ' broccoli']
+318 63 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of broccoli vegetable broccoli [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' broccoli', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of broccoli is the most popular .
+
+ The" False related to broccoli and cauliflower 2 [' related', ' to', ' broccoli']
+319 63 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of broccoli vegetable broccoli [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' broccoli', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The'] " is the most popular .
+
+ The super class of broccoli is the most popular .
+
+ The" False as sushi, chess, broccoli and cheese soup, 5 [' as', ' sushi', ',', ' chess', ',', ' broccoli']
+320 64 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cauliflower vegetable cauliflower [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'C', 'aul', 'iflower', ' is', ' a', ' member', ' of', ' the', ' cabbage', ' family', ',', ' Br', 'assic'] " is the most popular .
+
+ C aul iflower is a member of the cabbage family , Br assic" False Brussels sprouts, cauliflower and green beans 5 [' Brussels', ' spr', 'outs', ',', ' caul', 'iflower']
+321 64 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cauliflower vegetable cauliflower [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'C', 'aul', 'iflower', ' is', ' a', ' member', ' of', ' the', ' cabbage', ' family', ',', ' Br', 'assic'] " is the most popular .
+
+ C aul iflower is a member of the cabbage family , Br assic" False Brussels sprouts, cauliflower and green beans 5 [' Brussels', ' spr', 'outs', ',', ' caul', 'iflower']
+322 64 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cauliflower vegetable cauliflower [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'C', 'aul', 'iflower', ' is', ' a', ' member', ' of', ' the', ' cabbage', ' family', ',', ' Br', 'assic'] " is the most popular .
+
+ C aul iflower is a member of the cabbage family , Br assic" False sorrel and cauliflower mushrooms. The 4 [' sor', 'rel', ' and', ' caul', 'iflower']
+323 64 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cauliflower vegetable cauliflower [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'C', 'aul', 'iflower', ' is', ' a', ' member', ' of', ' the', ' cabbage', ' family', ',', ' Br', 'assic'] " is the most popular .
+
+ C aul iflower is a member of the cabbage family , Br assic" False " column had become a ""cauliflower,"" and all the spray" 7 "[' column', ' had', ' become', ' a', ' ""', 'c', 'aul', 'iflower']"
+324 64 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of cauliflower vegetable cauliflower [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'C', 'aul', 'iflower', ' is', ' a', ' member', ' of', ' the', ' cabbage', ' family', ',', ' Br', 'assic'] " is the most popular .
+
+ C aul iflower is a member of the cabbage family , Br assic" False column and its cauliflower head and then 4 [' column', ' and', ' its', ' caul', 'iflower']
+325 65 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of kale vegetable kale ['id', 'oscope', ' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' diverse', '.', ' It'] id oscope is the most complex . It is the most complex because it is the most diverse . It False cabbages and kale were probably 4 [' c', 'abb', 'ages', ' and', ' kale']
+326 65 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of kale vegetable kale ['id', 'oscope', ' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' diverse', '.', ' It'] id oscope is the most complex . It is the most complex because it is the most diverse . It False " mashed potatoes and kale or cabbage.
+" 3 [' mashed', ' potatoes', ' and', ' kale']
+327 65 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of kale vegetable kale ['id', 'oscope', ' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' diverse', '.', ' It'] id oscope is the most complex . It is the most complex because it is the most diverse . It False crops included kale (for both humans 2 [' crops', ' included', ' kale']
+328 65 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of kale vegetable kale ['id', 'oscope', ' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' diverse', '.', ' It'] id oscope is the most complex . It is the most complex because it is the most diverse . It False the play Di shtume kale (The Mute Bride) 6 [' the', ' play', ' Di', ' sh', 't', 'ume', ' kale']
+329 65 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of kale vegetable kale ['id', 'oscope', ' is', ' the', ' most', ' complex', '.', ' It', ' is', ' the', ' most', ' complex', ' because', ' it', ' is', ' the', ' most', ' diverse', '.', ' It'] id oscope is the most complex . It is the most complex because it is the most diverse . It False cabbages and kale were probably 4 [' c', 'abb', 'ages', ' and', ' kale']
+330 66 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple fruit apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False crashed into an apple orchard in the 3 [' crashed', ' into', ' an', ' apple']
+331 66 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple fruit apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False " And its famous apple tree.
+" 3 [' And', ' its', ' famous', ' apple']
+332 66 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple fruit apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False Roughly half of apple production in 4 [' Rough', 'ly', ' half', ' of', ' apple']
+333 66 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple fruit apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False reveal her Adam's apple (although this 4 "[' reveal', ' her', ' Adam', ""'s"", ' apple']"
+334 66 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of apple fruit apple [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' super', 'class', ' of', ' apple', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The'] " is the most common .
+
+ The super class of apple is the most common .
+
+ The" False cider from the estate's apple orchards and operating 5 "[' cider', ' from', ' the', ' estate', ""'s"", ' apple']"
+335 67 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana fruit banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True (58 sq mi) of banana crop was wiped 6 [' (', '58', ' sq', ' mi', ')', ' of', ' banana']
+336 67 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana fruit banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True damaged the island's banana crop. Damage in the 4 "[' damaged', ' the', ' island', ""'s"", ' banana']"
+337 67 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana fruit banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True storm damaged banana plantations 2 [' storm', ' damaged', ' banana']
+338 67 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana fruit banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True prepared to pour banana baby food 3 [' prepared', ' to', ' pour', ' banana']
+339 67 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of banana fruit banana [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' banana', ' is', ' a', ' fruit', ' of', ' the', ' plant', ' family', ' Mus', 'aceae', '.', ' It'] " is the most popular .
+
+ The banana is a fruit of the plant family Mus aceae . It" True its yearly banana crop. Damage was 2 [' its', ' yearly', ' banana']
+340 68 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange fruit orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" True was lined in orange and black, and the 3 [' was', ' lined', ' in', ' orange']
+341 68 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange fruit orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" True subsanguinolenta has red to orange juice, is slightly 8 [' subs', 'angu', 'in', 'ol', 'enta', ' has', ' red', ' to', ' orange']
+342 68 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange fruit orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" True item is served with orange juice and tea 4 [' item', ' is', ' served', ' with', ' orange']
+343 68 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange fruit orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" True bright yellow to orange with a metallic 3 [' bright', ' yellow', ' to', ' orange']
+344 68 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of orange fruit orange [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Orange', ' is', ' the', ' most', ' common', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' the'] " is the most common .
+
+ Orange is the most common fruit in the world . It is the" True have a deeper orange facial skin 3 [' have', ' a', ' deeper', ' orange']
+345 69 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapefruit fruit grapefruit [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'G', 'rape', 'fruit', ' is', ' a', ' fruit', ' with', ' a', ' large', ',', ' thick', ',', ' juicy'] " is the most popular .
+
+ G rape fruit is a fruit with a large , thick , juicy" True as oranges, grapefruit, tangerines and 4 [' as', ' oranges', ',', ' grape', 'fruit']
+346 69 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapefruit fruit grapefruit [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'G', 'rape', 'fruit', ' is', ' a', ' fruit', ' with', ' a', ' large', ',', ' thick', ',', ' juicy'] " is the most popular .
+
+ G rape fruit is a fruit with a large , thick , juicy" True lost 35 % of its grapefruit and orange crops combined, 6 [' lost', ' 35', ' %', ' of', ' its', ' grape', 'fruit']
+347 69 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapefruit fruit grapefruit [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'G', 'rape', 'fruit', ' is', ' a', ' fruit', ' with', ' a', ' large', ',', ' thick', ',', ' juicy'] " is the most popular .
+
+ G rape fruit is a fruit with a large , thick , juicy" True The entire grapefruit harvest was lost, 3 [' The', ' entire', ' grape', 'fruit']
+348 69 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapefruit fruit grapefruit [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'G', 'rape', 'fruit', ' is', ' a', ' fruit', ' with', ' a', ' large', ',', ' thick', ',', ' juicy'] " is the most popular .
+
+ G rape fruit is a fruit with a large , thick , juicy" True orange, lemon and grapefruit thrive on the 5 [' orange', ',', ' lemon', ' and', ' grape', 'fruit']
+349 69 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapefruit fruit grapefruit [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'G', 'rape', 'fruit', ' is', ' a', ' fruit', ' with', ' a', ' large', ',', ' thick', ',', ' juicy'] " is the most popular .
+
+ G rape fruit is a fruit with a large , thick , juicy" True was offered grapefruit in almost every 3 [' was', ' offered', ' grape', 'fruit']
+350 70 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapes fruit grapes [' is', ' the', ' most', ' important', '.', ' G', 'rap', 'es', ' are', ' the', ' most', ' important', ' fruit', ' in', ' the', ' world', '.', ' G', 'rap', 'es'] is the most important . G rap es are the most important fruit in the world . G rap es True wheat, barley and grapes. The Muslims and 4 [' wheat', ',', ' barley', ' and', ' grapes']
+351 70 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapes fruit grapes [' is', ' the', ' most', ' important', '.', ' G', 'rap', 'es', ' are', ' the', ' most', ' important', ' fruit', ' in', ' the', ' world', '.', ' G', 'rap', 'es'] is the most important . G rap es are the most important fruit in the world . G rap es True " cultivate vinifera grapes and produce ""those" 5 [' cultivate', ' v', 'in', 'ifer', 'a', ' grapes']
+352 70 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapes fruit grapes [' is', ' the', ' most', ' important', '.', ' G', 'rap', 'es', ' are', ' the', ' most', ' important', ' fruit', ' in', ' the', ' world', '.', ' G', 'rap', 'es'] is the most important . G rap es are the most important fruit in the world . G rap es True drinks involving grapes and honey were 2 [' drinks', ' involving', ' grapes']
+353 70 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapes fruit grapes [' is', ' the', ' most', ' important', '.', ' G', 'rap', 'es', ' are', ' the', ' most', ' important', ' fruit', ' in', ' the', ' world', '.', ' G', 'rap', 'es'] is the most important . G rap es are the most important fruit in the world . G rap es True ripeness of the grapes at harvest. 4 [' rip', 'eness', ' of', ' the', ' grapes']
+354 70 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of grapes fruit grapes [' is', ' the', ' most', ' important', '.', ' G', 'rap', 'es', ' are', ' the', ' most', ' important', ' fruit', ' in', ' the', ' world', '.', ' G', 'rap', 'es'] is the most important . G rap es are the most important fruit in the world . G rap es True foot crushing grapes over a slab inscribed 2 [' foot', ' crushing', ' grapes']
+355 71 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of peach fruit peach [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Pe', 'ach', ' is', ' a', ' fruit', ' of', ' the', ' peach', ' tree', ',', ' which', ' is', ' a'] " is the most popular .
+
+ Pe ach is a fruit of the peach tree , which is a" True extensive to apple and peach orchards across 4 [' extensive', ' to', ' apple', ' and', ' peach']
+356 71 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of peach fruit peach [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Pe', 'ach', ' is', ' a', ' fruit', ' of', ' the', ' peach', ' tree', ',', ' which', ' is', ' a'] " is the most popular .
+
+ Pe ach is a fruit of the peach tree , which is a" True dies eating a peach because he is not 3 [' dies', ' eating', ' a', ' peach']
+357 71 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of peach fruit peach [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Pe', 'ach', ' is', ' a', ' fruit', ' of', ' the', ' peach', ' tree', ',', ' which', ' is', ' a'] " is the most popular .
+
+ Pe ach is a fruit of the peach tree , which is a" True of a commercial peach orchard near Romney 3 [' of', ' a', ' commercial', ' peach']
+358 71 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of peach fruit peach [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Pe', 'ach', ' is', ' a', ' fruit', ' of', ' the', ' peach', ' tree', ',', ' which', ' is', ' a'] " is the most popular .
+
+ Pe ach is a fruit of the peach tree , which is a" True into closed-end peach baskets. Naismith's 4 [' into', ' closed', '-', 'end', ' peach']
+359 71 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of peach fruit peach [' is', ' the', ' most', ' popular', '.', '\n', '\n', 'Pe', 'ach', ' is', ' a', ' fruit', ' of', ' the', ' peach', ' tree', ',', ' which', ' is', ' a'] " is the most popular .
+
+ Pe ach is a fruit of the peach tree , which is a" True designed around a peach motif, was envisioned 3 [' designed', ' around', ' a', ' peach']
+360 72 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pear fruit pear [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' pear', ' is', ' a', ' fruit', ' of', ' the', ' genus', ' P', 'yrus', ',', ' which', ' is'] " is the most common .
+
+ The pear is a fruit of the genus P yrus , which is" True beverages included pear juice, lychee fruit 2 [' beverages', ' included', ' pear']
+361 72 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pear fruit pear [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' pear', ' is', ' a', ' fruit', ' of', ' the', ' genus', ' P', 'yrus', ',', ' which', ' is'] " is the most common .
+
+ The pear is a fruit of the genus P yrus , which is" True peach, plum, pear and apricot trees adorned 4 [' peach', ',', ' plum', ',', ' pear']
+362 72 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pear fruit pear [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' pear', ' is', ' a', ' fruit', ' of', ' the', ' genus', ' P', 'yrus', ',', ' which', ' is'] " is the most common .
+
+ The pear is a fruit of the genus P yrus , which is" True the prickly pear cactus in the 3 [' the', ' prick', 'ly', ' pear']
+363 72 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pear fruit pear [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' pear', ' is', ' a', ' fruit', ' of', ' the', ' genus', ' P', 'yrus', ',', ' which', ' is'] " is the most common .
+
+ The pear is a fruit of the genus P yrus , which is" True Euphorbia. The prickly pear was used where 7 [' Euph', 'or', 'bia', '.', ' The', ' prick', 'ly', ' pear']
+364 72 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of pear fruit pear [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' pear', ' is', ' a', ' fruit', ' of', ' the', ' genus', ' P', 'yrus', ',', ' which', ' is'] " is the most common .
+
+ The pear is a fruit of the genus P yrus , which is" True Vladimir), Ukrainian pear domes (Saint Sophia 3 [' Vladimir', '),', ' Ukrainian', ' pear']
+365 73 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of watermelon fruit watermelon [' is', ' the', ' most', ' popular', '.', ' Water', 'melon', ' is', ' a', ' kind', ' of', ' fruit', ',', ' which', ' is', ' a', ' kind', ' of', ' fruit', '.'] is the most popular . Water melon is a kind of fruit , which is a kind of fruit . True by giving Meg a watermelon filled with 5 [' by', ' giving', ' Meg', ' a', ' water', 'melon']
+366 73 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of watermelon fruit watermelon [' is', ' the', ' most', ' popular', '.', ' Water', 'melon', ' is', ' a', ' kind', ' of', ' fruit', ',', ' which', ' is', ' a', ' kind', ' of', ' fruit', '.'] is the most popular . Water melon is a kind of fruit , which is a kind of fruit . True Middle East watermelon were introduced into 3 [' Middle', ' East', ' water', 'melon']
+367 73 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of watermelon fruit watermelon [' is', ' the', ' most', ' popular', '.', ' Water', 'melon', ' is', ' a', ' kind', ' of', ' fruit', ',', ' which', ' is', ' a', ' kind', ' of', ' fruit', '.'] is the most popular . Water melon is a kind of fruit , which is a kind of fruit . True everybody had some watermelon and the watermelon 4 [' everybody', ' had', ' some', ' water', 'melon']
+368 73 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of watermelon fruit watermelon [' is', ' the', ' most', ' popular', '.', ' Water', 'melon', ' is', ' a', ' kind', ' of', ' fruit', ',', ' which', ' is', ' a', ' kind', ' of', ' fruit', '.'] is the most popular . Water melon is a kind of fruit , which is a kind of fruit . True to the size of a watermelon (~ 1 ft / 30 cm), 6 [' to', ' the', ' size', ' of', ' a', ' water', 'melon']
+369 73 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of watermelon fruit watermelon [' is', ' the', ' most', ' popular', '.', ' Water', 'melon', ' is', ' a', ' kind', ' of', ' fruit', ',', ' which', ' is', ' a', ' kind', ' of', ' fruit', '.'] is the most popular . Water melon is a kind of fruit , which is a kind of fruit . True entirety of the watermelon crop in Waller County 4 [' entirety', ' of', ' the', ' water', 'melon']
+370 74 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of strawberry fruit strawberry [' is', ' the', ' most', ' popular', '.', ' It', ' is', ' the', ' most', ' delicious', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' also', ' the', ' most'] is the most popular . It is the most delicious fruit in the world . It is also the most True Kat, Oreo and strawberry shortcake, is available 5 [' Kat', ',', ' Ore', 'o', ' and', ' strawberry']
+371 74 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of strawberry fruit strawberry [' is', ' the', ' most', ' popular', '.', ' It', ' is', ' the', ' most', ' delicious', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' also', ' the', ' most'] is the most popular . It is the most delicious fruit in the world . It is also the most True through the nearby strawberry field, with 3 [' through', ' the', ' nearby', ' strawberry']
+372 74 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of strawberry fruit strawberry [' is', ' the', ' most', ' popular', '.', ' It', ' is', ' the', ' most', ' delicious', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' also', ' the', ' most'] is the most popular . It is the most delicious fruit in the world . It is also the most True appearance, with strawberry blond hair 3 [' appearance', ',', ' with', ' strawberry']
+373 74 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of strawberry fruit strawberry [' is', ' the', ' most', ' popular', '.', ' It', ' is', ' the', ' most', ' delicious', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' also', ' the', ' most'] is the most popular . It is the most delicious fruit in the world . It is also the most True including millions of strawberry plants to Gulf 3 [' including', ' millions', ' of', ' strawberry']
+374 74 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of strawberry fruit strawberry [' is', ' the', ' most', ' popular', '.', ' It', ' is', ' the', ' most', ' delicious', ' fruit', ' in', ' the', ' world', '.', ' It', ' is', ' also', ' the', ' most'] is the most popular . It is the most delicious fruit in the world . It is also the most True attending a local strawberry festival. He was 3 [' attending', ' a', ' local', ' strawberry']
+375 75 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blueberry fruit blueberry [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Blue', 'berry', ' is', ' a', ' small', ',', ' round', ',', ' blue', ' fruit', ',', ' which', ' is'] " is the most common .
+
+ Blue berry is a small , round , blue fruit , which is" True with lemon. A blueberry Lemon Drop 5 [' with', ' lemon', '.', ' A', ' blue', 'berry']
+376 75 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blueberry fruit blueberry [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Blue', 'berry', ' is', ' a', ' small', ',', ' round', ',', ' blue', ' fruit', ',', ' which', ' is'] " is the most common .
+
+ Blue berry is a small , round , blue fruit , which is" True understory of shrubs such as blueberry and laurel; they 8 [' under', 'story', ' of', ' shr', 'ubs', ' such', ' as', ' blue', 'berry']
+377 75 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blueberry fruit blueberry [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Blue', 'berry', ' is', ' a', ' small', ',', ' round', ',', ' blue', ' fruit', ',', ' which', ' is'] " is the most common .
+
+ Blue berry is a small , round , blue fruit , which is" True forest, a highbush blueberry shrub swamp, and 6 [' forest', ',', ' a', ' high', 'bush', ' blue', 'berry']
+378 75 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blueberry fruit blueberry [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Blue', 'berry', ' is', ' a', ' small', ',', ' round', ',', ' blue', ' fruit', ',', ' which', ' is'] " is the most common .
+
+ Blue berry is a small , round , blue fruit , which is" True forests have blueberry and huckleberry 3 [' forests', ' have', ' blue', 'berry']
+379 75 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of x -1 Among the superclasses of animal, bird, fish, vegetable, flower, fruit, the superclass of blueberry fruit blueberry [' is', ' the', ' most', ' common', '.', '\n', '\n', 'Blue', 'berry', ' is', ' a', ' small', ',', ' round', ',', ' blue', ' fruit', ',', ' which', ' is'] " is the most common .
+
+ Blue berry is a small , round , blue fruit , which is" True of shrubs such as blueberry and laurel; they use 6 [' of', ' shr', 'ubs', ' such', ' as', ' blue', 'berry']
diff --git a/patchscopes/code/preprocessed_data/commonsense/substance_phase.tsv b/patchscopes/code/preprocessed_data/commonsense/substance_phase.tsv
new file mode 100644
index 00000000..df73bb7d
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/commonsense/substance_phase.tsv
@@ -0,0 +1,779 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of water liquid water [' is', ' H', '2', 'O', '.', '\n', '\n', 'Water', ' is', ' a', ' polar', ' molecule', ',', ' which', ' means', ' that', ' it', ' has', ' a', ' positive'] " is H 2 O .
+
+ Water is a polar molecule , which means that it has a positive" False sometimes in less water than the frigates 3 [' sometimes', ' in', ' less', ' water']
+1 0 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of water liquid water [' is', ' H', '2', 'O', '.', '\n', '\n', 'Water', ' is', ' a', ' polar', ' molecule', ',', ' which', ' means', ' that', ' it', ' has', ' a', ' positive'] " is H 2 O .
+
+ Water is a polar molecule , which means that it has a positive" False has been subject to water damage, although 4 [' has', ' been', ' subject', ' to', ' water']
+2 0 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of water liquid water [' is', ' H', '2', 'O', '.', '\n', '\n', 'Water', ' is', ' a', ' polar', ' molecule', ',', ' which', ' means', ' that', ' it', ' has', ' a', ' positive'] " is H 2 O .
+
+ Water is a polar molecule , which means that it has a positive" False improve access to water and sanitation. It 3 [' improve', ' access', ' to', ' water']
+3 0 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of water liquid water [' is', ' H', '2', 'O', '.', '\n', '\n', 'Water', ' is', ' a', ' polar', ' molecule', ',', ' which', ' means', ' that', ' it', ' has', ' a', ' positive'] " is H 2 O .
+
+ Water is a polar molecule , which means that it has a positive" False " the Chinese water torture.""
+" 2 [' the', ' Chinese', ' water']
+4 0 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of water liquid water [' is', ' H', '2', 'O', '.', '\n', '\n', 'Water', ' is', ' a', ' polar', ' molecule', ',', ' which', ' means', ' that', ' it', ' has', ' a', ' positive'] " is H 2 O .
+
+ Water is a polar molecule , which means that it has a positive" False subtropical and tropical waters, depending mainly 5 [' subt', 'rop', 'ical', ' and', ' tropical', ' water']
+5 1 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of iron solid iron [' is', ' fer', 'rom', 'agnetic', ',', ' which', ' means', ' that', ' the', ' iron', ' atoms', ' are', ' arranged', ' in', ' a', ' regular', ' pattern', '.', ' When', ' the'] is fer rom agnetic , which means that the iron atoms are arranged in a regular pattern . When the False the use of cast iron balustrades and brackets, 4 [' the', ' use', ' of', ' cast', ' iron']
+6 1 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of iron solid iron [' is', ' fer', 'rom', 'agnetic', ',', ' which', ' means', ' that', ' the', ' iron', ' atoms', ' are', ' arranged', ' in', ' a', ' regular', ' pattern', '.', ' When', ' the'] is fer rom agnetic , which means that the iron atoms are arranged in a regular pattern . When the False composed mostly of iron (32.1 %), oxygen 3 [' composed', ' mostly', ' of', ' iron']
+7 1 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of iron solid iron [' is', ' fer', 'rom', 'agnetic', ',', ' which', ' means', ' that', ' the', ' iron', ' atoms', ' are', ' arranged', ' in', ' a', ' regular', ' pattern', '.', ' When', ' the'] is fer rom agnetic , which means that the iron atoms are arranged in a regular pattern . When the False because the necessary iron making and metal 3 [' because', ' the', ' necessary', ' iron']
+8 1 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of iron solid iron [' is', ' fer', 'rom', 'agnetic', ',', ' which', ' means', ' that', ' the', ' iron', ' atoms', ' are', ' arranged', ' in', ' a', ' regular', ' pattern', '.', ' When', ' the'] is fer rom agnetic , which means that the iron atoms are arranged in a regular pattern . When the False destination point for iron ore shipped from 3 [' destination', ' point', ' for', ' iron']
+9 1 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of iron solid iron [' is', ' fer', 'rom', 'agnetic', ',', ' which', ' means', ' that', ' the', ' iron', ' atoms', ' are', ' arranged', ' in', ' a', ' regular', ' pattern', '.', ' When', ' the'] is fer rom agnetic , which means that the iron atoms are arranged in a regular pattern . When the False " from hematite ( iron oxide; Fe
+" 5 [' from', ' he', 'mat', 'ite', ' (', ' iron']
+10 2 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of oxygen gas oxygen [' is', ' a', ' solid', '.', '\n', '\n', 'O', 'xy', 'gen', ' is', ' a', ' di', 'atomic', ' molecule', ' with', ' two', ' atoms', ' of', ' oxygen', '.'] " is a solid .
+
+ O xy gen is a di atomic molecule with two atoms of oxygen ." False resulted in dropped oxygen masks and 3 [' resulted', ' in', ' dropped', ' oxygen']
+11 2 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of oxygen gas oxygen [' is', ' a', ' solid', '.', '\n', '\n', 'O', 'xy', 'gen', ' is', ' a', ' di', 'atomic', ' molecule', ' with', ' two', ' atoms', ' of', ' oxygen', '.'] " is a solid .
+
+ O xy gen is a di atomic molecule with two atoms of oxygen ." False includes a clinic with oxygen treatment facilities. 4 [' includes', ' a', ' clinic', ' with', ' oxygen']
+12 2 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of oxygen gas oxygen [' is', ' a', ' solid', '.', '\n', '\n', 'O', 'xy', 'gen', ' is', ' a', ' di', 'atomic', ' molecule', ' with', ' two', ' atoms', ' of', ' oxygen', '.'] " is a solid .
+
+ O xy gen is a di atomic molecule with two atoms of oxygen ." False published first. The name oxygen was coined 5 [' published', ' first', '.', ' The', ' name', ' oxygen']
+13 2 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of oxygen gas oxygen [' is', ' a', ' solid', '.', '\n', '\n', 'O', 'xy', 'gen', ' is', ' a', ' di', 'atomic', ' molecule', ' with', ' two', ' atoms', ' of', ' oxygen', '.'] " is a solid .
+
+ O xy gen is a di atomic molecule with two atoms of oxygen ." False levels of dissolved oxygen and its pH ranges 3 [' levels', ' of', ' dissolved', ' oxygen']
+14 2 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of oxygen gas oxygen [' is', ' a', ' solid', '.', '\n', '\n', 'O', 'xy', 'gen', ' is', ' a', ' di', 'atomic', ' molecule', ' with', ' two', ' atoms', ' of', ' oxygen', '.'] " is a solid .
+
+ O xy gen is a di atomic molecule with two atoms of oxygen ." False the diffusion of oxygen into the underlying 3 [' the', ' diffusion', ' of', ' oxygen']
+15 3 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of gold solid gold [' is', ' a', ' solid', '.', '\n', '\n', 'Gold', ' is', ' a', ' soft', ' metal', ',', ' and', ' is', ' easily', ' scratched', '.', '\n', '\n', 'Gold'] " is a solid .
+
+ Gold is a soft metal , and is easily scratched .
+
+ Gold" True blue clothes, with a gold laced hat that 5 [' blue', ' clothes', ',', ' with', ' a', ' gold']
+16 3 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of gold solid gold [' is', ' a', ' solid', '.', '\n', '\n', 'Gold', ' is', ' a', ' soft', ' metal', ',', ' and', ' is', ' easily', ' scratched', '.', '\n', '\n', 'Gold'] " is a solid .
+
+ Gold is a soft metal , and is easily scratched .
+
+ Gold" True set of twelve gold medallions, 3 [' set', ' of', ' twelve', ' gold']
+17 3 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of gold solid gold [' is', ' a', ' solid', '.', '\n', '\n', 'Gold', ' is', ' a', ' soft', ' metal', ',', ' and', ' is', ' easily', ' scratched', '.', '\n', '\n', 'Gold'] " is a solid .
+
+ Gold is a soft metal , and is easily scratched .
+
+ Gold" True Illmatic is the gold standard that 4 [' Ill', 'matic', ' is', ' the', ' gold']
+18 3 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of gold solid gold [' is', ' a', ' solid', '.', '\n', '\n', 'Gold', ' is', ' a', ' soft', ' metal', ',', ' and', ' is', ' easily', ' scratched', '.', '\n', '\n', 'Gold'] " is a solid .
+
+ Gold is a soft metal , and is easily scratched .
+
+ Gold" True " splendor"". The palaces had gold rafters to support" 7 "[' spl', 'endor', '"".', ' The', ' pal', 'aces', ' had', ' gold']"
+19 3 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of gold solid gold [' is', ' a', ' solid', '.', '\n', '\n', 'Gold', ' is', ' a', ' soft', ' metal', ',', ' and', ' is', ' easily', ' scratched', '.', '\n', '\n', 'Gold'] " is a solid .
+
+ Gold is a soft metal , and is easily scratched .
+
+ Gold" True Charlotte were in gold mining regions 3 [' Charlotte', ' were', ' in', ' gold']
+20 4 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of mercury liquid mercury [' is', ' liquid', '.', '\n', '\n', 'Merc', 'ury', ' is', ' a', ' metal', ' that', ' is', ' soft', ',', ' mal', 'le', 'able', ',', ' and', ' duct'] " is liquid .
+
+ Merc ury is a metal that is soft , mal le able , and duct" True (Zn), cadmium (Cd) and mercury (Hg). The further 12 [' (', 'Z', 'n', '),', ' cad', 'm', 'ium', ' (', 'C', 'd', ')', ' and', ' mercury']
+21 4 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of mercury liquid mercury [' is', ' liquid', '.', '\n', '\n', 'Merc', 'ury', ' is', ' a', ' metal', ' that', ' is', ' soft', ',', ' mal', 'le', 'able', ',', ' and', ' duct'] " is liquid .
+
+ Merc ury is a metal that is soft , mal le able , and duct" True with an appropriate mercury halide HgX2, where 3 [' with', ' an', ' appropriate', ' mercury']
+22 4 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of mercury liquid mercury [' is', ' liquid', '.', '\n', '\n', 'Merc', 'ury', ' is', ' a', ' metal', ' that', ' is', ' soft', ',', ' mal', 'le', 'able', ',', ' and', ' duct'] " is liquid .
+
+ Merc ury is a metal that is soft , mal le able , and duct" True liquid, it was used in mercury switches (including 6 [' liquid', ',', ' it', ' was', ' used', ' in', ' mercury']
+23 4 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of mercury liquid mercury [' is', ' liquid', '.', '\n', '\n', 'Merc', 'ury', ' is', ' a', ' metal', ' that', ' is', ' soft', ',', ' mal', 'le', 'able', ',', ' and', ' duct'] " is liquid .
+
+ Merc ury is a metal that is soft , mal le able , and duct" True who suggested mercury be used as a 2 [' who', ' suggested', ' mercury']
+24 4 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of mercury liquid mercury [' is', ' liquid', '.', '\n', '\n', 'Merc', 'ury', ' is', ' a', ' metal', ' that', ' is', ' soft', ',', ' mal', 'le', 'able', ',', ' and', ' duct'] " is liquid .
+
+ Merc ury is a metal that is soft , mal le able , and duct" True substances such as mercury and phosphorus. Habitat 3 [' substances', ' such', ' as', ' mercury']
+25 5 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of aluminum solid aluminum [' is', ' a', ' solid', '.', '\n', '\n', 'Al', 'uminum', ' is', ' a', ' metal', ' that', ' is', ' used', ' in', ' many', ' different', ' applications', '.', ' It'] " is a solid .
+
+ Al uminum is a metal that is used in many different applications . It" True railway yard and aluminum plant in the 3 [' railway', ' yard', ' and', ' aluminum']
+26 5 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of aluminum solid aluminum [' is', ' a', ' solid', '.', '\n', '\n', 'Al', 'uminum', ' is', ' a', ' metal', ' that', ' is', ' used', ' in', ' many', ' different', ' applications', '.', ' It'] " is a solid .
+
+ Al uminum is a metal that is used in many different applications . It" True structural use of aluminum in building 3 [' structural', ' use', ' of', ' aluminum']
+27 5 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of aluminum solid aluminum [' is', ' a', ' solid', '.', '\n', '\n', 'Al', 'uminum', ' is', ' a', ' metal', ' that', ' is', ' used', ' in', ' many', ' different', ' applications', '.', ' It'] " is a solid .
+
+ Al uminum is a metal that is used in many different applications . It" True be left would be aluminum dishwasher parts, 4 [' be', ' left', ' would', ' be', ' aluminum']
+28 5 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of aluminum solid aluminum [' is', ' a', ' solid', '.', '\n', '\n', 'Al', 'uminum', ' is', ' a', ' metal', ' that', ' is', ' used', ' in', ' many', ' different', ' applications', '.', ' It'] " is a solid .
+
+ Al uminum is a metal that is used in many different applications . It" True or brushed aluminum surface, a solid 2 [' or', ' brushed', ' aluminum']
+29 5 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of aluminum solid aluminum [' is', ' a', ' solid', '.', '\n', '\n', 'Al', 'uminum', ' is', ' a', ' metal', ' that', ' is', ' used', ' in', ' many', ' different', ' applications', '.', ' It'] " is a solid .
+
+ Al uminum is a metal that is used in many different applications . It" True in the midst of aluminum cans. Nat simultaneously 4 [' in', ' the', ' midst', ' of', ' aluminum']
+30 6 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of nitrogen gas nitrogen [' is', ' a', ' solid', '.', '\n', '\n', 'At', ' higher', ' temperatures', ',', ' nitrogen', ' becomes', ' a', ' gas', '.', '\n', '\n', 'At', ' even', ' higher'] " is a solid .
+
+ At higher temperatures , nitrogen becomes a gas .
+
+ At even higher" True dissociation of nitrogen and methane, forming 3 [' diss', 'ociation', ' of', ' nitrogen']
+31 6 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of nitrogen gas nitrogen [' is', ' a', ' solid', '.', '\n', '\n', 'At', ' higher', ' temperatures', ',', ' nitrogen', ' becomes', ' a', ' gas', '.', '\n', '\n', 'At', ' even', ' higher'] " is a solid .
+
+ At higher temperatures , nitrogen becomes a gas .
+
+ At even higher" True shot out of a nitrogen gun; James did throw 4 [' shot', ' out', ' of', ' a', ' nitrogen']
+32 6 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of nitrogen gas nitrogen [' is', ' a', ' solid', '.', '\n', '\n', 'At', ' higher', ' temperatures', ',', ' nitrogen', ' becomes', ' a', ' gas', '.', '\n', '\n', 'At', ' even', ' higher'] " is a solid .
+
+ At higher temperatures , nitrogen becomes a gas .
+
+ At even higher" True phylogenetically-ancient nitrogenases (discussed above) 5 [' phylogen', 'etically', '-', 'an', 'cient', ' nitrogen']
+33 6 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of nitrogen gas nitrogen [' is', ' a', ' solid', '.', '\n', '\n', 'At', ' higher', ' temperatures', ',', ' nitrogen', ' becomes', ' a', ' gas', '.', '\n', '\n', 'At', ' even', ' higher'] " is a solid .
+
+ At higher temperatures , nitrogen becomes a gas .
+
+ At even higher" True " prevent oxygen and nitrogen toxicity.
+" 3 [' prevent', ' oxygen', ' and', ' nitrogen']
+34 6 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of nitrogen gas nitrogen [' is', ' a', ' solid', '.', '\n', '\n', 'At', ' higher', ' temperatures', ',', ' nitrogen', ' becomes', ' a', ' gas', '.', '\n', '\n', 'At', ' even', ' higher'] " is a solid .
+
+ At higher temperatures , nitrogen becomes a gas .
+
+ At even higher" True ammonia, hydrogen, nitrogen and methane, as opposed 4 [' ammonia', ',', ' hydrogen', ',', ' nitrogen']
+35 7 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silicon solid silicon [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True naturally occurring silicon dioxide found 2 [' naturally', ' occurring', ' silicon']
+36 7 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silicon solid silicon [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True Similar to silicon and aluminum, 2 [' Similar', ' to', ' silicon']
+37 7 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silicon solid silicon [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True heavily doped n-type silicon wafers in the production 6 [' heavily', ' d', 'oped', ' n', '-', 'type', ' silicon']
+38 7 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silicon solid silicon [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True power density than silicon power devices and 3 [' power', ' density', ' than', ' silicon']
+39 7 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silicon solid silicon [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True mass of the silicon-28 atom or should 3 [' mass', ' of', ' the', ' silicon']
+40 8 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of neon gas neon [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True mixture of xenon and neon that is converted 5 [' mixture', ' of', ' xen', 'on', ' and', ' neon']
+41 8 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of neon gas neon [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True " nucleosynthesis and the neon burning process.
+" 5 [' nucle', 'os', 'ynthesis', ' and', ' the', ' neon']
+42 8 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of neon gas neon [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True year. X-Flight's neon green track and dark 6 "[' year', '.', ' X', '-', 'Flight', ""'s"", ' neon']"
+43 8 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of neon gas neon [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True depletion of neon (see Table), 2 [' depletion', ' of', ' neon']
+44 8 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of neon gas neon [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True 1 ['ne', 'on']
+45 9 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ethanol liquid ethanol [' is', ' a', ' liquid', '.', '\n', '\n', 'E', 'than', 'ol', ' is', ' a', ' color', 'less', ',', ' fl', 'amm', 'able', ',', ' and', ' poisonous'] " is a liquid .
+
+ E than ol is a color less , fl amm able , and poisonous" True built with an ethanol-ready engine 3 [' built', ' with', ' an', ' ethanol']
+46 9 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ethanol liquid ethanol [' is', ' a', ' liquid', '.', '\n', '\n', 'E', 'than', 'ol', ' is', ' a', ' color', 'less', ',', ' fl', 'amm', 'able', ',', ' and', ' poisonous'] " is a liquid .
+
+ E than ol is a color less , fl amm able , and poisonous" True imported oil if E85 ethanol was used instead of 5 [' imported', ' oil', ' if', ' E', '85', ' ethanol']
+47 9 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ethanol liquid ethanol [' is', ' a', ' liquid', '.', '\n', '\n', 'E', 'than', 'ol', ' is', ' a', ' color', 'less', ',', ' fl', 'amm', 'able', ',', ' and', ' poisonous'] " is a liquid .
+
+ E than ol is a color less , fl amm able , and poisonous" True neutral water – ethanol solution and be homologous 3 [' neutral', ' water', ' –', ' ethanol']
+48 9 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ethanol liquid ethanol [' is', ' a', ' liquid', '.', '\n', '\n', 'E', 'than', 'ol', ' is', ' a', ' color', 'less', ',', ' fl', 'amm', 'able', ',', ' and', ' poisonous'] " is a liquid .
+
+ E than ol is a color less , fl amm able , and poisonous" True tank by warming the ethanol fuel during starting, 4 [' tank', ' by', ' warming', ' the', ' ethanol']
+49 9 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ethanol liquid ethanol [' is', ' a', ' liquid', '.', '\n', '\n', 'E', 'than', 'ol', ' is', ' a', ' color', 'less', ',', ' fl', 'amm', 'able', ',', ' and', ' poisonous'] " is a liquid .
+
+ E than ol is a color less , fl amm able , and poisonous" True Volt to operate on ethanol fuel, as most new 4 [' Volt', ' to', ' operate', ' on', ' ethanol']
+50 10 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sulfur solid sulfur [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ulf', 'ur', ' is', ' a', ' chemical', ' element', ' with', ' symbol', ' S', ' and', ' atomic', ' number', ' 16'] " is a solid .
+
+ S ulf ur is a chemical element with symbol S and atomic number 16" True staining bright sulfur yellow. It 3 [' st', 'aining', ' bright', ' sulfur']
+51 10 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sulfur solid sulfur [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ulf', 'ur', ' is', ' a', ' chemical', ' element', ' with', ' symbol', ' S', ' and', ' atomic', ' number', ' 16'] " is a solid .
+
+ S ulf ur is a chemical element with symbol S and atomic number 16" True the surface with sulfurous and silicate 3 [' the', ' surface', ' with', ' sulfur']
+52 10 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sulfur solid sulfur [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ulf', 'ur', ' is', ' a', ' chemical', ' element', ' with', ' symbol', ' S', ' and', ' atomic', ' number', ' 16'] " is a solid .
+
+ S ulf ur is a chemical element with symbol S and atomic number 16" True plateaus, the result of sulfur dioxide sapping 6 [' plate', 'aus', ',', ' the', ' result', ' of', ' sulfur']
+53 10 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sulfur solid sulfur [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ulf', 'ur', ' is', ' a', ' chemical', ' element', ' with', ' symbol', ' S', ' and', ' atomic', ' number', ' 16'] " is a solid .
+
+ S ulf ur is a chemical element with symbol S and atomic number 16" True redox reaction, sulfur is oxidized from 4 [' red', 'ox', ' reaction', ',', ' sulfur']
+54 10 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sulfur solid sulfur [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ulf', 'ur', ' is', ' a', ' chemical', ' element', ' with', ' symbol', ' S', ' and', ' atomic', ' number', ' 16'] " is a solid .
+
+ S ulf ur is a chemical element with symbol S and atomic number 16" True higher-than-expected potassium and sulfur levels on the 7 [' higher', '-', 'than', '-', 'expected', ' potassium', ' and', ' sulfur']
+55 11 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of helium gas helium [' is', ' a', ' superflu', 'id', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' superflu', 'id', '.', '\n', '\n'] " is a superflu id .
+
+ The phase of matter of helium is a superflu id .
+
+" False hydroquinone, but helium and neon do 5 [' hydro', 'quin', 'one', ',', ' but', ' helium']
+56 11 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of helium gas helium [' is', ' a', ' superflu', 'id', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' superflu', 'id', '.', '\n', '\n'] " is a superflu id .
+
+ The phase of matter of helium is a superflu id .
+
+" False for the nearby helium line), FeII, 3 [' for', ' the', ' nearby', ' helium']
+57 11 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of helium gas helium [' is', ' a', ' superflu', 'id', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' superflu', 'id', '.', '\n', '\n'] " is a superflu id .
+
+ The phase of matter of helium is a superflu id .
+
+" False hydrogen, the helium abundance is about 3 [' hydrogen', ',', ' the', ' helium']
+58 11 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of helium gas helium [' is', ' a', ' superflu', 'id', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' superflu', 'id', '.', '\n', '\n'] " is a superflu id .
+
+ The phase of matter of helium is a superflu id .
+
+" False trends down groups, helium can be seen as a 4 [' trends', ' down', ' groups', ',', ' helium']
+59 11 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of helium gas helium [' is', ' a', ' superflu', 'id', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' superflu', 'id', '.', '\n', '\n'] " is a superflu id .
+
+ The phase of matter of helium is a superflu id .
+
+" False mixture. However, helium is not narcotic, and 4 [' mixture', '.', ' However', ',', ' helium']
+60 12 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of lead solid lead [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True France. He then held the lead until the second stage 6 [' France', '.', ' He', ' then', ' held', ' the', ' lead']
+61 12 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of lead solid lead [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True to regain the lead with a late burst 3 [' to', ' regain', ' the', ' lead']
+62 12 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of lead solid lead [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True Ramachandran for the lead role. Tirulokchandar 6 [' Ram', 'ach', 'and', 'ran', ' for', ' the', ' lead']
+63 12 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of lead solid lead [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True French Navy, the lead ship of her class. 4 [' French', ' Navy', ',', ' the', ' lead']
+64 12 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of lead solid lead [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True on February 19 to lead Wichita State against 4 [' on', ' February', ' 19', ' to', ' lead']
+65 13 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice cream solid ice cream [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' ice', ' cream', ' is', ' a', ' solid', '.', '\n', '\n', 'The'] " is a solid .
+
+ The phase of matter of ice cream is a solid .
+
+ The" True " ""commandeering an ice cream truck in a frantic" 5 "[' ""', 'command', 'eering', ' an', ' ice', ' cream']"
+66 13 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice cream solid ice cream [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' ice', ' cream', ' is', ' a', ' solid', '.', '\n', '\n', 'The'] " is a solid .
+
+ The phase of matter of ice cream is a solid .
+
+ The" True goes to buy ice cream from an ice cream 4 [' goes', ' to', ' buy', ' ice', ' cream']
+67 13 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice cream solid ice cream [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' ice', ' cream', ' is', ' a', ' solid', '.', '\n', '\n', 'The'] " is a solid .
+
+ The phase of matter of ice cream is a solid .
+
+ The" True as burgers, ice cream floats, milkshakes, 4 [' as', ' burgers', ',', ' ice', ' cream']
+68 13 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice cream solid ice cream [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' ice', ' cream', ' is', ' a', ' solid', '.', '\n', '\n', 'The'] " is a solid .
+
+ The phase of matter of ice cream is a solid .
+
+ The" True be home ... eating ice cream instead of in 5 [' be', ' home', '...', ' eating', ' ice', ' cream']
+69 13 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice cream solid ice cream [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' ice', ' cream', ' is', ' a', ' solid', '.', '\n', '\n', 'The'] " is a solid .
+
+ The phase of matter of ice cream is a solid .
+
+ The" True peanut butter and ice cream as alternative 4 [' peanut', ' butter', ' and', ' ice', ' cream']
+70 14 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of coffee liquid coffee [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' coffee', ' is', ' ground', ' and', ' then', ' brewed', '.', '\n', '\n', 'The', ' coffee', ' is', ' brewed'] " is a solid .
+
+ The coffee is ground and then brewed .
+
+ The coffee is brewed" False " Pietro"", one of the oldest coffee shops in the city." 7 "[' Piet', 'ro', '"",', ' one', ' of', ' the', ' oldest', ' coffee']"
+71 14 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of coffee liquid coffee [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' coffee', ' is', ' ground', ' and', ' then', ' brewed', '.', '\n', '\n', 'The', ' coffee', ' is', ' brewed'] " is a solid .
+
+ The coffee is ground and then brewed .
+
+ The coffee is brewed" False bring him a cup of coffee and say,' Stuart, 5 [' bring', ' him', ' a', ' cup', ' of', ' coffee']
+72 14 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of coffee liquid coffee [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' coffee', ' is', ' ground', ' and', ' then', ' brewed', '.', '\n', '\n', 'The', ' coffee', ' is', ' brewed'] " is a solid .
+
+ The coffee is ground and then brewed .
+
+ The coffee is brewed" False restaurant and coffee shop facilities, 2 [' restaurant', ' and', ' coffee']
+73 14 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of coffee liquid coffee [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' coffee', ' is', ' ground', ' and', ' then', ' brewed', '.', '\n', '\n', 'The', ' coffee', ' is', ' brewed'] " is a solid .
+
+ The coffee is ground and then brewed .
+
+ The coffee is brewed" False (including tea, green tea, coffee and consommé soup), 7 [' (', 'including', ' tea', ',', ' green', ' tea', ',', ' coffee']
+74 14 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of coffee liquid coffee [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' coffee', ' is', ' ground', ' and', ' then', ' brewed', '.', '\n', '\n', 'The', ' coffee', ' is', ' brewed'] " is a solid .
+
+ The coffee is ground and then brewed .
+
+ The coffee is brewed" False only Nicaraguan coffee in solidarity with 3 [' only', ' Nicarag', 'uan', ' coffee']
+75 15 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wood solid wood [' is', ' called', ' the', ' solid', ' state', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' wood', ' is', ' called', ' the', ' solid', ' state', '.'] " is called the solid state .
+
+ The phase of matter of wood is called the solid state ." True dawn, hide in a wood next to which the 5 [' dawn', ',', ' hide', ' in', ' a', ' wood']
+76 15 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wood solid wood [' is', ' called', ' the', ' solid', ' state', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' wood', ' is', ' called', ' the', ' solid', ' state', '.'] " is called the solid state .
+
+ The phase of matter of wood is called the solid state ." True with plenty of wood for construction; 3 [' with', ' plenty', ' of', ' wood']
+77 15 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wood solid wood [' is', ' called', ' the', ' solid', ' state', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' wood', ' is', ' called', ' the', ' solid', ' state', '.'] " is called the solid state .
+
+ The phase of matter of wood is called the solid state ." True Roper used a hickory wood frame built by 7 [' R', 'oper', ' used', ' a', ' h', 'ick', 'ory', ' wood']
+78 15 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wood solid wood [' is', ' called', ' the', ' solid', ' state', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' wood', ' is', ' called', ' the', ' solid', ' state', '.'] " is called the solid state .
+
+ The phase of matter of wood is called the solid state ." True mana. Only wood from trees is 3 [' mana', '.', ' Only', ' wood']
+79 15 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wood solid wood [' is', ' called', ' the', ' solid', ' state', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' wood', ' is', ' called', ' the', ' solid', ' state', '.'] " is called the solid state .
+
+ The phase of matter of wood is called the solid state ." True 0 ['wood']
+80 16 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of plastic solid plastic [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True from molded HDPE plastic with a funnel-like 4 [' from', ' molded', ' HD', 'PE', ' plastic']
+81 16 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of plastic solid plastic [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True with a thick layer of plastic mantle and an 5 [' with', ' a', ' thick', ' layer', ' of', ' plastic']
+82 16 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of plastic solid plastic [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True immortalized in plastic alongside 3 [' immortal', 'ized', ' in', ' plastic']
+83 16 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of plastic solid plastic [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True purchase five tonnes of plastic explosives, 2,000 4 [' purchase', ' five', ' tonnes', ' of', ' plastic']
+84 16 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of plastic solid plastic [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True supermarkets selling it in plastic packages and 4 [' supermarkets', ' selling', ' it', ' in', ' plastic']
+85 17 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of butter solid butter [' is', ' solid', '.', '\n', '\n', 'But', 'ter', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'But', 'ter', ' is', ' a'] " is solid .
+
+ But ter is a solid at room temperature .
+
+ But ter is a" True or flickering butter lamps lighting 2 [' or', ' flickering', ' butter']
+86 17 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of butter solid butter [' is', ' solid', '.', '\n', '\n', 'But', 'ter', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'But', 'ter', ' is', ' a'] " is solid .
+
+ But ter is a solid at room temperature .
+
+ But ter is a" True 1 ['but', 'ter']
+87 17 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of butter solid butter [' is', ' solid', '.', '\n', '\n', 'But', 'ter', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'But', 'ter', ' is', ' a'] " is solid .
+
+ But ter is a solid at room temperature .
+
+ But ter is a" True dairy products like butter. This may 3 [' dairy', ' products', ' like', ' butter']
+88 17 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of butter solid butter [' is', ' solid', '.', '\n', '\n', 'But', 'ter', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'But', 'ter', ' is', ' a'] " is solid .
+
+ But ter is a solid at room temperature .
+
+ But ter is a" True deep-fried butter was paired 3 [' deep', '-', 'fried', ' butter']
+89 17 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of butter solid butter [' is', ' solid', '.', '\n', '\n', 'But', 'ter', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'But', 'ter', ' is', ' a'] " is solid .
+
+ But ter is a solid at room temperature .
+
+ But ter is a" True steak. The butter is used in the 3 [' steak', '.', ' The', ' butter']
+90 18 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of honey liquid honey [' is', ' a', ' solid', '.', '\n', '\n', 'H', 'oney', ' is', ' a', ' natural', ' sweet', 'ener', ',', ' and', ' it', ' is', ' a', ' natural', ' pres'] " is a solid .
+
+ H oney is a natural sweet ener , and it is a natural pres" False a charm to keep honey bees from swarming. 4 [' a', ' charm', ' to', ' keep', ' honey']
+91 18 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of honey liquid honey [' is', ' a', ' solid', '.', '\n', '\n', 'H', 'oney', ' is', ' a', ' natural', ' sweet', 'ener', ',', ' and', ' it', ' is', ' a', ' natural', ' pres'] " is a solid .
+
+ H oney is a natural sweet ener , and it is a natural pres" False and selling honey related products, 2 [' and', ' selling', ' honey']
+92 18 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of honey liquid honey [' is', ' a', ' solid', '.', '\n', '\n', 'H', 'oney', ' is', ' a', ' natural', ' sweet', 'ener', ',', ' and', ' it', ' is', ' a', ' natural', ' pres'] " is a solid .
+
+ H oney is a natural sweet ener , and it is a natural pres" False widespread death of honey bees. Hives 3 [' widespread', ' death', ' of', ' honey']
+93 18 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of honey liquid honey [' is', ' a', ' solid', '.', '\n', '\n', 'H', 'oney', ' is', ' a', ' natural', ' sweet', 'ener', ',', ' and', ' it', ' is', ' a', ' natural', ' pres'] " is a solid .
+
+ H oney is a natural sweet ener , and it is a natural pres" False humans by mad honey has been well documented 3 [' humans', ' by', ' mad', ' honey']
+94 18 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of honey liquid honey [' is', ' a', ' solid', '.', '\n', '\n', 'H', 'oney', ' is', ' a', ' natural', ' sweet', 'ener', ',', ' and', ' it', ' is', ' a', ' natural', ' pres'] " is a solid .
+
+ H oney is a natural sweet ener , and it is a natural pres" False such as fur, hide, honey and wax, but those 6 [' such', ' as', ' fur', ',', ' hide', ',', ' honey']
+95 19 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wine liquid wine [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' wine', ' is', ' a', ' liquid', ' because', ' it', ' is', ' composed', ' of', ' molecules', ' that', ' are', ' in'] " is a liquid .
+
+ The wine is a liquid because it is composed of molecules that are in" True detectable in wines with pyrazine levels 2 [' detectable', ' in', ' wine']
+96 19 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wine liquid wine [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' wine', ' is', ' a', ' liquid', ' because', ' it', ' is', ' composed', ' of', ' molecules', ' that', ' are', ' in'] " is a liquid .
+
+ The wine is a liquid because it is composed of molecules that are in" True to offer food and wine to the ancestors 4 [' to', ' offer', ' food', ' and', ' wine']
+97 19 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wine liquid wine [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' wine', ' is', ' a', ' liquid', ' because', ' it', ' is', ' composed', ' of', ' molecules', ' that', ' are', ' in'] " is a liquid .
+
+ The wine is a liquid because it is composed of molecules that are in" True retrieving the wine and cheese from the 2 [' retrieving', ' the', ' wine']
+98 19 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wine liquid wine [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' wine', ' is', ' a', ' liquid', ' because', ' it', ' is', ' composed', ' of', ' molecules', ' that', ' are', ' in'] " is a liquid .
+
+ The wine is a liquid because it is composed of molecules that are in" True as a varietal wine or as a blend 4 [' as', ' a', ' var', 'ietal', ' wine']
+99 19 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wine liquid wine [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' wine', ' is', ' a', ' liquid', ' because', ' it', ' is', ' composed', ' of', ' molecules', ' that', ' are', ' in'] " is a liquid .
+
+ The wine is a liquid because it is composed of molecules that are in" True find Carménère wines in France today, 4 [' find', ' Carm', 'én', 'ère', ' wine']
+100 20 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glass solid glass [' is', ' am', 'orph', 'ous', '.', '\n', '\n', 'The', ' glass', ' is', ' a', ' solid', ',', ' but', ' it', ' is', ' not', ' a', ' crystal', '.'] " is am orph ous .
+
+ The glass is a solid , but it is not a crystal ." True Street NW. (This glass-and-white 5 [' Street', ' NW', '.', ' (', 'This', ' glass']
+101 20 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glass solid glass [' is', ' am', 'orph', 'ous', '.', '\n', '\n', 'The', ' glass', ' is', ' a', ' solid', ',', ' but', ' it', ' is', ' not', ' a', ' crystal', '.'] " is am orph ous .
+
+ The glass is a solid , but it is not a crystal ." True including nineteen stained glass windows (including 3 [' including', ' nineteen', ' stained', ' glass']
+102 20 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glass solid glass [' is', ' am', 'orph', 'ous', '.', '\n', '\n', 'The', ' glass', ' is', ' a', ' solid', ',', ' but', ' it', ' is', ' not', ' a', ' crystal', '.'] " is am orph ous .
+
+ The glass is a solid , but it is not a crystal ." True also contain glass tiled ceilings 2 [' also', ' contain', ' glass']
+103 20 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glass solid glass [' is', ' am', 'orph', 'ous', '.', '\n', '\n', 'The', ' glass', ' is', ' a', ' solid', ',', ' but', ' it', ' is', ' not', ' a', ' crystal', '.'] " is am orph ous .
+
+ The glass is a solid , but it is not a crystal ." True contemporary stained glass panels are similar 2 [' contemporary', ' stained', ' glass']
+104 20 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glass solid glass [' is', ' am', 'orph', 'ous', '.', '\n', '\n', 'The', ' glass', ' is', ' a', ' solid', ',', ' but', ' it', ' is', ' not', ' a', ' crystal', '.'] " is am orph ous .
+
+ The glass is a solid , but it is not a crystal ." True trimming consisting of glass and colored 4 [' trim', 'ming', ' consisting', ' of', ' glass']
+105 21 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of carbon dioxide gas carbon dioxide [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True Additionally, carbon dioxide and ozone gases 3 [' Additionally', ',', ' carbon', ' dioxide']
+106 21 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of carbon dioxide gas carbon dioxide [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True absorb and dissolve carbon dioxide from the surrounding 4 [' absorb', ' and', ' dissolve', ' carbon', ' dioxide']
+107 21 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of carbon dioxide gas carbon dioxide [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True (with helium). Pure carbon dioxide, on the other 6 [' (', 'with', ' helium', ').', ' Pure', ' carbon', ' dioxide']
+108 21 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of carbon dioxide gas carbon dioxide [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True caused by rising carbon dioxide levels, might affect 4 [' caused', ' by', ' rising', ' carbon', ' dioxide']
+109 21 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of carbon dioxide gas carbon dioxide [' is', ' a', ' gas', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a gas . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True glucose to carbon dioxide is known as cellular 3 [' glucose', ' to', ' carbon', ' dioxide']
+110 22 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of paper solid paper [' is', ' called', ' the', ' am', 'orph', 'ous', ' state', '.', '\n', '\n', 'The', ' am', 'orph', 'ous', ' state', ' is', ' a', ' state', ' of', ' matter'] " is called the am orph ous state .
+
+ The am orph ous state is a state of matter" False " ""world's largest paper clip"". The same article" 4 "[' ""', 'world', ""'s"", ' largest', ' paper']"
+111 22 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of paper solid paper [' is', ' called', ' the', ' am', 'orph', 'ous', ' state', '.', '\n', '\n', 'The', ' am', 'orph', 'ous', ' state', ' is', ' a', ' state', ' of', ' matter'] " is called the am orph ous state .
+
+ The am orph ous state is a state of matter" False lime, sand, waste paper or other substance 5 [' lime', ',', ' sand', ',', ' waste', ' paper']
+112 22 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of paper solid paper [' is', ' called', ' the', ' am', 'orph', 'ous', ' state', '.', '\n', '\n', 'The', ' am', 'orph', 'ous', ' state', ' is', ' a', ' state', ' of', ' matter'] " is called the am orph ous state .
+
+ The am orph ous state is a state of matter" False accusations on pink paper to hint at communist 3 [' accusations', ' on', ' pink', ' paper']
+113 22 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of paper solid paper [' is', ' called', ' the', ' am', 'orph', 'ous', ' state', '.', '\n', '\n', 'The', ' am', 'orph', 'ous', ' state', ' is', ' a', ' state', ' of', ' matter'] " is called the am orph ous state .
+
+ The am orph ous state is a state of matter" False results in a paper entitled Radioactive 3 [' results', ' in', ' a', ' paper']
+114 22 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of paper solid paper [' is', ' called', ' the', ' am', 'orph', 'ous', ' state', '.', '\n', '\n', 'The', ' am', 'orph', 'ous', ' state', ' is', ' a', ' state', ' of', ' matter'] " is called the am orph ous state .
+
+ The am orph ous state is a state of matter" False codes and tore the paper to shreds with his 4 [' codes', ' and', ' tore', ' the', ' paper']
+115 23 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of copper solid copper [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True aluminium, and copper with an annual 3 [' aluminium', ',', ' and', ' copper']
+116 23 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of copper solid copper [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True coating formed on copper during oxygenation. 3 [' coating', ' formed', ' on', ' copper']
+117 23 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of copper solid copper [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True axes, and minor copper objects have been 4 [' axes', ',', ' and', ' minor', ' copper']
+118 23 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of copper solid copper [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True but aluminium and copper are also used. Most 3 [' but', ' aluminium', ' and', ' copper']
+119 23 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of copper solid copper [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True Africa, though the large copper – nickel deposits 5 [' Africa', ',', ' though', ' the', ' large', ' copper']
+120 24 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of chocolate solid chocolate [' is', ' a', ' solid', '.', '\n', '\n', 'Ch', 'ocolate', ' is', ' made', ' from', ' cocoa', ' beans', ',', ' which', ' are', ' the', ' seeds', ' of', ' the'] " is a solid .
+
+ Ch ocolate is made from cocoa beans , which are the seeds of the" True manufacturers Ford, chocolate brand Toblerone, 3 [' manufacturers', ' Ford', ',', ' chocolate']
+121 24 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of chocolate solid chocolate [' is', ' a', ' solid', '.', '\n', '\n', 'Ch', 'ocolate', ' is', ' made', ' from', ' cocoa', ' beans', ',', ' which', ' are', ' the', ' seeds', ' of', ' the'] " is a solid .
+
+ Ch ocolate is made from cocoa beans , which are the seeds of the" True drinking hot chocolate, as some hot chocolate 2 [' drinking', ' hot', ' chocolate']
+122 24 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of chocolate solid chocolate [' is', ' a', ' solid', '.', '\n', '\n', 'Ch', 'ocolate', ' is', ' made', ' from', ' cocoa', ' beans', ',', ' which', ' are', ' the', ' seeds', ' of', ' the'] " is a solid .
+
+ Ch ocolate is made from cocoa beans , which are the seeds of the" True next to a giant chocolate cake in his dress 4 [' next', ' to', ' a', ' giant', ' chocolate']
+123 24 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of chocolate solid chocolate [' is', ' a', ' solid', '.', '\n', '\n', 'Ch', 'ocolate', ' is', ' made', ' from', ' cocoa', ' beans', ',', ' which', ' are', ' the', ' seeds', ' of', ' the'] " is a solid .
+
+ Ch ocolate is made from cocoa beans , which are the seeds of the" True " deep, sweet chocolate notes"" and bought" 3 [' deep', ',', ' sweet', ' chocolate']
+124 24 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of chocolate solid chocolate [' is', ' a', ' solid', '.', '\n', '\n', 'Ch', 'ocolate', ' is', ' made', ' from', ' cocoa', ' beans', ',', ' which', ' are', ' the', ' seeds', ' of', ' the'] " is a solid .
+
+ Ch ocolate is made from cocoa beans , which are the seeds of the" True characterized by a chocolate to reddish-brown 3 [' characterized', ' by', ' a', ' chocolate']
+125 25 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of petroleum liquid petroleum [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' petroleum', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of petroleum is a liquid .
+
+ The phase" True unloaded ammunition and petroleum products over 3 [' unloaded', ' ammunition', ' and', ' petroleum']
+126 25 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of petroleum liquid petroleum [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' petroleum', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of petroleum is a liquid .
+
+ The phase" True handles alumina and petroleum coke; it has a 9.2-metre 4 [' handles', ' alum', 'ina', ' and', ' petroleum']
+127 25 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of petroleum liquid petroleum [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' petroleum', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of petroleum is a liquid .
+
+ The phase" True mineral oil, petroleum jelly, calcipotriol, 3 [' mineral', ' oil', ',', ' petroleum']
+128 25 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of petroleum liquid petroleum [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' petroleum', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of petroleum is a liquid .
+
+ The phase" True exports consist of petroleum products made 3 [' exports', ' consist', ' of', ' petroleum']
+129 25 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of petroleum liquid petroleum [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' petroleum', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of petroleum is a liquid .
+
+ The phase" True paint to Vaseline petroleum jelly to chewing 4 [' paint', ' to', ' Vas', 'eline', ' petroleum']
+130 26 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of steam gas steam [' is', ' a', ' gas', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a gas .
+
+ The phase of matter of water is a liquid .
+
+ The phase" True vehicles without steam power. The 2 [' vehicles', ' without', ' steam']
+131 26 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of steam gas steam [' is', ' a', ' gas', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a gas .
+
+ The phase of matter of water is a liquid .
+
+ The phase" True of exploring the steam tunnels and 3 [' of', ' exploring', ' the', ' steam']
+132 26 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of steam gas steam [' is', ' a', ' gas', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a gas .
+
+ The phase of matter of water is a liquid .
+
+ The phase" True regularity than the steam locomotives. 4 [' regular', 'ity', ' than', ' the', ' steam']
+133 26 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of steam gas steam [' is', ' a', ' gas', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a gas .
+
+ The phase of matter of water is a liquid .
+
+ The phase" True until her steam was exhausted. 2 [' until', ' her', ' steam']
+134 26 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of steam gas steam [' is', ' a', ' gas', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a gas .
+
+ The phase of matter of water is a liquid .
+
+ The phase" True direct acting steam engines, built 2 [' direct', ' acting', ' steam']
+135 27 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice solid ice [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of water is liquid .
+
+ The phase of matter" True Herald Island; ice limited their approach 3 [' Herald', ' Island', ';', ' ice']
+136 27 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice solid ice [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of water is liquid .
+
+ The phase of matter" True the volume of the ice blocks that break 4 [' the', ' volume', ' of', ' the', ' ice']
+137 27 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice solid ice [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of water is liquid .
+
+ The phase of matter" True covered with ice sheets during 2 [' covered', ' with', ' ice']
+138 27 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice solid ice [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of water is liquid .
+
+ The phase of matter" True skating, and ice hockey events, 3 [' skating', ',', ' and', ' ice']
+139 27 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of ice solid ice [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' water', ' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of water is liquid .
+
+ The phase of matter" True drop down to the ice to make saves, so 4 [' drop', ' down', ' to', ' the', ' ice']
+140 28 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of diamond solid diamond [' is', ' a', ' solid', '.', '\n', '\n', 'Diamond', ' is', ' a', ' hard', ',', ' transparent', ',', ' and', ' color', 'less', ' material', '.', ' It', ' is'] " is a solid .
+
+ Diamond is a hard , transparent , and color less material . It is" True U.S. Route 9 at a diamond interchange. It 8 [' U', '.', 'S', '.', ' Route', ' 9', ' at', ' a', ' diamond']
+141 28 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of diamond solid diamond [' is', ' a', ' solid', '.', '\n', '\n', 'Diamond', ' is', ' a', ' hard', ',', ' transparent', ',', ' and', ' color', 'less', ' material', '.', ' It', ' is'] " is a solid .
+
+ Diamond is a hard , transparent , and color less material . It is" True community includes diamond traders from 2 [' community', ' includes', ' diamond']
+142 28 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of diamond solid diamond [' is', ' a', ' solid', '.', '\n', '\n', 'Diamond', ' is', ' a', ' hard', ',', ' transparent', ',', ' and', ' color', 'less', ' material', '.', ' It', ' is'] " is a solid .
+
+ Diamond is a hard , transparent , and color less material . It is" True Switzerland, and diamond certifications in 3 [' Switzerland', ',', ' and', ' diamond']
+143 28 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of diamond solid diamond [' is', ' a', ' solid', '.', '\n', '\n', 'Diamond', ' is', ' a', ' hard', ',', ' transparent', ',', ' and', ' color', 'less', ' material', '.', ' It', ' is'] " is a solid .
+
+ Diamond is a hard , transparent , and color less material . It is" True a diverging diamond interchange, the 3 [' a', ' diver', 'ging', ' diamond']
+144 28 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of diamond solid diamond [' is', ' a', ' solid', '.', '\n', '\n', 'Diamond', ' is', ' a', ' hard', ',', ' transparent', ',', ' and', ' color', 'less', ' material', '.', ' It', ' is'] " is a solid .
+
+ Diamond is a hard , transparent , and color less material . It is" True through a diamond interchange 2 [' through', ' a', ' diamond']
+145 29 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of milk liquid milk [' is', ' a', ' solid', '.', ' Milk', ' is', ' a', ' coll', 'oidal', ' suspension', ' of', ' fat', ' glob', 'ules', ' in', ' water', '.', ' The', ' fat', ' glob'] is a solid . Milk is a coll oidal suspension of fat glob ules in water . The fat glob False branded low-fat milk in its North 4 [' branded', ' low', '-', 'fat', ' milk']
+146 29 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of milk liquid milk [' is', ' a', ' solid', '.', ' Milk', ' is', ' a', ' coll', 'oidal', ' suspension', ' of', ' fat', ' glob', 'ules', ' in', ' water', '.', ' The', ' fat', ' glob'] is a solid . Milk is a coll oidal suspension of fat glob ules in water . The fat glob False (660 ft), the milk shark favors 5 [' (', '660', ' ft', '),', ' the', ' milk']
+147 29 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of milk liquid milk [' is', ' a', ' solid', '.', ' Milk', ' is', ' a', ' coll', 'oidal', ' suspension', ' of', ' fat', ' glob', 'ules', ' in', ' water', '.', ' The', ' fat', ' glob'] is a solid . Milk is a coll oidal suspension of fat glob ules in water . The fat glob False compound in the milk of the tammar 3 [' compound', ' in', ' the', ' milk']
+148 29 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of milk liquid milk [' is', ' a', ' solid', '.', ' Milk', ' is', ' a', ' coll', 'oidal', ' suspension', ' of', ' fat', ' glob', 'ules', ' in', ' water', '.', ' The', ' fat', ' glob'] is a solid . Milk is a coll oidal suspension of fat glob ules in water . The fat glob False " infected cow's milk some years previously.
+" 3 "[' infected', ' cow', ""'s"", ' milk']"
+149 29 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of milk liquid milk [' is', ' a', ' solid', '.', ' Milk', ' is', ' a', ' coll', 'oidal', ' suspension', ' of', ' fat', ' glob', 'ules', ' in', ' water', '.', ' The', ' fat', ' glob'] is a solid . Milk is a coll oidal suspension of fat glob ules in water . The fat glob False production, milk quotas and, more recently, 2 [' production', ',', ' milk']
+150 30 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of olive oil liquid olive oil [' is', ' liquid', '.', '\n', '\n', 'O', 'live', ' oil', ' is', ' a', ' liquid', ' at', ' room', ' temperature', '.', '\n', '\n', 'O', 'live', ' oil'] " is liquid .
+
+ O live oil is a liquid at room temperature .
+
+ O live oil" True figs, garlic, olive oil and saffron, which 6 [' fig', 's', ',', ' garlic', ',', ' olive', ' oil']
+151 30 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of olive oil liquid olive oil [' is', ' liquid', '.', '\n', '\n', 'O', 'live', ' oil', ' is', ' a', ' liquid', ' at', ' room', ' temperature', '.', '\n', '\n', 'O', 'live', ' oil'] " is liquid .
+
+ O live oil is a liquid at room temperature .
+
+ O live oil" True frying food in olive oil in the 5th century 4 [' frying', ' food', ' in', ' olive', ' oil']
+152 30 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of olive oil liquid olive oil [' is', ' liquid', '.', '\n', '\n', 'O', 'live', ' oil', ' is', ' a', ' liquid', ' at', ' room', ' temperature', '.', '\n', '\n', 'O', 'live', ' oil'] " is liquid .
+
+ O live oil is a liquid at room temperature .
+
+ O live oil" True one of these – olive oil was another 5 [' one', ' of', ' these', ' –', ' olive', ' oil']
+153 30 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of olive oil liquid olive oil [' is', ' liquid', '.', '\n', '\n', 'O', 'live', ' oil', ' is', ' a', ' liquid', ' at', ' room', ' temperature', '.', '\n', '\n', 'O', 'live', ' oil'] " is liquid .
+
+ O live oil is a liquid at room temperature .
+
+ O live oil" True soy sauce and olive oil every night. Gaining 4 [' soy', ' sauce', ' and', ' olive', ' oil']
+154 30 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of olive oil liquid olive oil [' is', ' liquid', '.', '\n', '\n', 'O', 'live', ' oil', ' is', ' a', ' liquid', ' at', ' room', ' temperature', '.', '\n', '\n', 'O', 'live', ' oil'] " is liquid .
+
+ O live oil is a liquid at room temperature .
+
+ O live oil" True tours for its rooms, olive oil presses and 6 [' tours', ' for', ' its', ' rooms', ',', ' olive', ' oil']
+155 31 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of soap solid soap [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' soap', ' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase'] " is a solid .
+
+ The phase of matter of soap is a solid .
+
+ The phase" True him. He studied soap films intensively, 4 [' him', '.', ' He', ' studied', ' soap']
+156 31 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of soap solid soap [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' soap', ' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase'] " is a solid .
+
+ The phase of matter of soap is a solid .
+
+ The phase" True from the Australian soap opera Neighbours, played 3 [' from', ' the', ' Australian', ' soap']
+157 31 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of soap solid soap [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' soap', ' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase'] " is a solid .
+
+ The phase of matter of soap is a solid .
+
+ The phase" True have been a weak soap opera into a vibrant 4 [' have', ' been', ' a', ' weak', ' soap']
+158 31 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of soap solid soap [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' soap', ' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase'] " is a solid .
+
+ The phase of matter of soap is a solid .
+
+ The phase" True acted on the soap opera One Life to 3 [' acted', ' on', ' the', ' soap']
+159 31 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of soap solid soap [' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' soap', ' is', ' a', ' solid', '.', '\n', '\n', 'The', ' phase'] " is a solid .
+
+ The phase of matter of soap is a solid .
+
+ The phase" True appear on another soap opera for twelve 3 [' appear', ' on', ' another', ' soap']
+160 32 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of rubber solid rubber [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True recently invented rubber suit and metal diving 2 [' recently', ' invented', ' rubber']
+161 32 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of rubber solid rubber [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True to protect a rubber tapping community from 3 [' to', ' protect', ' a', ' rubber']
+162 32 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of rubber solid rubber [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True laid within a rubber jacket to reduce noise. 3 [' laid', ' within', ' a', ' rubber']
+163 32 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of rubber solid rubber [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True those used for rubber injection moulding, 3 [' those', ' used', ' for', ' rubber']
+164 32 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of rubber solid rubber [' is', ' a', ' solid', '.', ' When', ' heated', ',', ' it', ' becomes', ' a', ' liquid', '.', ' When', ' cooled', ',', ' it', ' becomes', ' a', ' solid', ' again'] is a solid . When heated , it becomes a liquid . When cooled , it becomes a solid again True collect the rubber tax and export it. 2 [' collect', ' the', ' rubber']
+165 33 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glycerin liquid glycerin [' is', ' a', ' solid', '.', ' It', ' is', ' a', ' color', 'less', ',', ' odor', 'less', ',', ' tast', 'eless', ',', ' and', ' non', '-', 't'] is a solid . It is a color less , odor less , tast eless , and non - t False cat hair and the glycerin placed on the actors' 6 [' cat', ' hair', ' and', ' the', ' gly', 'cer', 'in']
+166 33 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glycerin liquid glycerin [' is', ' a', ' solid', '.', ' It', ' is', ' a', ' color', 'less', ',', ' odor', 'less', ',', ' tast', 'eless', ',', ' and', ' non', '-', 't'] is a solid . It is a color less , odor less , tast eless , and non - t False of cat hair and the glycerin placed on the actors' 7 [' of', ' cat', ' hair', ' and', ' the', ' gly', 'cer', 'in']
+167 33 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glycerin liquid glycerin [' is', ' a', ' solid', '.', ' It', ' is', ' a', ' color', 'less', ',', ' odor', 'less', ',', ' tast', 'eless', ',', ' and', ' non', '-', 't'] is a solid . It is a color less , odor less , tast eless , and non - t False of cat hair and the glycerin placed on the 7 [' of', ' cat', ' hair', ' and', ' the', ' gly', 'cer', 'in']
+168 33 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glycerin liquid glycerin [' is', ' a', ' solid', '.', ' It', ' is', ' a', ' color', 'less', ',', ' odor', 'less', ',', ' tast', 'eless', ',', ' and', ' non', '-', 't'] is a solid . It is a color less , odor less , tast eless , and non - t False hair and the glycerin placed on the actors' 5 [' hair', ' and', ' the', ' gly', 'cer', 'in']
+169 33 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of glycerin liquid glycerin [' is', ' a', ' solid', '.', ' It', ' is', ' a', ' color', 'less', ',', ' odor', 'less', ',', ' tast', 'eless', ',', ' and', ' non', '-', 't'] is a solid . It is a color less , odor less , tast eless , and non - t False cat hair and the glycerin placed on the actors' 6 [' cat', ' hair', ' and', ' the', ' gly', 'cer', 'in']
+170 34 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tea liquid tea [' is', ' liquid', '.', '\n', '\n', 'Tea', ' is', ' a', ' beverage', ' made', ' from', ' the', ' leaves', ' of', ' the', ' plant', ' Came', 'll', 'ia', ' sin'] " is liquid .
+
+ Tea is a beverage made from the leaves of the plant Came ll ia sin" True cat, a dog, and a tea party she called Something 7 [' cat', ',', ' a', ' dog', ',', ' and', ' a', ' tea']
+171 34 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tea liquid tea [' is', ' liquid', '.', '\n', '\n', 'Tea', ' is', ' a', ' beverage', ' made', ' from', ' the', ' leaves', ' of', ' the', ' plant', ' Came', 'll', 'ia', ' sin'] " is liquid .
+
+ Tea is a beverage made from the leaves of the plant Came ll ia sin" True hosting Sunday tea gatherings at 2 [' hosting', ' Sunday', ' tea']
+172 34 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tea liquid tea [' is', ' liquid', '.', '\n', '\n', 'Tea', ' is', ' a', ' beverage', ' made', ' from', ' the', ' leaves', ' of', ' the', ' plant', ' Came', 'll', 'ia', ' sin'] " is liquid .
+
+ Tea is a beverage made from the leaves of the plant Came ll ia sin" True chlorophyll. White tea is produced in lesser 5 [' chlor', 'ophy', 'll', '.', ' White', ' tea']
+173 34 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tea liquid tea [' is', ' liquid', '.', '\n', '\n', 'Tea', ' is', ' a', ' beverage', ' made', ' from', ' the', ' leaves', ' of', ' the', ' plant', ' Came', 'll', 'ia', ' sin'] " is liquid .
+
+ Tea is a beverage made from the leaves of the plant Came ll ia sin" True accompanied by tea or coffee. Dishes 2 [' accompanied', ' by', ' tea']
+174 34 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tea liquid tea [' is', ' liquid', '.', '\n', '\n', 'Tea', ' is', ' a', ' beverage', ' made', ' from', ' the', ' leaves', ' of', ' the', ' plant', ' Came', 'll', 'ia', ' sin'] " is liquid .
+
+ Tea is a beverage made from the leaves of the plant Came ll ia sin" True play a role in tea cultivation. 4 [' play', ' a', ' role', ' in', ' tea']
+175 35 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of hydrogen gas hydrogen [' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase'] " is a plasma .
+
+ The phase of matter of helium is a plasma .
+
+ The phase" False phosphorus, silicon, and hydrogen, are generally 5 [' phosphorus', ',', ' silicon', ',', ' and', ' hydrogen']
+176 35 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of hydrogen gas hydrogen [' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase'] " is a plasma .
+
+ The phase of matter of helium is a plasma .
+
+ The phase" False quantum states of hydrogen atom, which should 3 [' quantum', ' states', ' of', ' hydrogen']
+177 35 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of hydrogen gas hydrogen [' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase'] " is a plasma .
+
+ The phase of matter of helium is a plasma .
+
+ The phase" False effectively absorbs hydrogen atoms. This could 2 [' effectively', ' absorbs', ' hydrogen']
+178 35 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of hydrogen gas hydrogen [' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase'] " is a plasma .
+
+ The phase of matter of helium is a plasma .
+
+ The phase" False the assumption that hydrogen is counted as − 1 as 3 [' the', ' assumption', ' that', ' hydrogen']
+179 35 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of hydrogen gas hydrogen [' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' helium', ' is', ' a', ' plasma', '.', '\n', '\n', 'The', ' phase'] " is a plasma .
+
+ The phase of matter of helium is a plasma .
+
+ The phase" False feasibility of the hydrogen bomb, and 3 [' feasibility', ' of', ' the', ' hydrogen']
+180 36 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of salt solid salt [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' salt', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of salt is solid .
+
+ The phase of matter" True coal mining and salt panning, which date 3 [' coal', ' mining', ' and', ' salt']
+181 36 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of salt solid salt [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' salt', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of salt is solid .
+
+ The phase of matter" True sodium chloride salt and sulfuric acid 2 [' sodium', ' chloride', ' salt']
+182 36 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of salt solid salt [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' salt', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of salt is solid .
+
+ The phase of matter" True Virginia to destroy salt and lead mines 3 [' Virginia', ' to', ' destroy', ' salt']
+183 36 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of salt solid salt [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' salt', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of salt is solid .
+
+ The phase of matter" True depending on buffer salt concentrations 3 [' depending', ' on', ' buffer', ' salt']
+184 36 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of salt solid salt [' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' salt', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is solid .
+
+ The phase of matter of salt is solid .
+
+ The phase of matter" True " royal monopoly on salt trade.
+" 3 [' royal', ' monopoly', ' on', ' salt']
+185 37 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sugar solid sugar [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ugar', ' is', ' a', ' sweet', ',', ' sticky', ',', ' crystall', 'ine', ' substance', ' that', ' is', ' made'] " is a solid .
+
+ S ugar is a sweet , sticky , crystall ine substance that is made" True and winged and sugar kelps, though there 4 [' and', ' wing', 'ed', ' and', ' sugar']
+186 37 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sugar solid sugar [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ugar', ' is', ' a', ' sweet', ',', ' sticky', ',', ' crystall', 'ine', ' substance', ' that', ' is', ' made'] " is a solid .
+
+ S ugar is a sweet , sticky , crystall ine substance that is made" True pygmaeus), and sugar glider (Petaurus 6 [' py', 'g', 'ma', 'eus', '),', ' and', ' sugar']
+187 37 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sugar solid sugar [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ugar', ' is', ' a', ' sweet', ',', ' sticky', ',', ' crystall', 'ine', ' substance', ' that', ' is', ' made'] " is a solid .
+
+ S ugar is a sweet , sticky , crystall ine substance that is made" True " found that ""Brazil's sugar-based ethanol did" 5 "[' found', ' that', ' ""', 'Brazil', ""'s"", ' sugar']"
+188 37 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sugar solid sugar [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ugar', ' is', ' a', ' sweet', ',', ' sticky', ',', ' crystall', 'ine', ' substance', ' that', ' is', ' made'] " is a solid .
+
+ S ugar is a sweet , sticky , crystall ine substance that is made" True of 2008; poor sugarcane harvests 4 [' of', ' 2008', ';', ' poor', ' sugar']
+189 37 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of sugar solid sugar [' is', ' a', ' solid', '.', '\n', '\n', 'S', 'ugar', ' is', ' a', ' sweet', ',', ' sticky', ',', ' crystall', 'ine', ' substance', ' that', ' is', ' made'] " is a solid .
+
+ S ugar is a sweet , sticky , crystall ine substance that is made" True in cities within the sugarcane's fields 4 [' in', ' cities', ' within', ' the', ' sugar']
+190 38 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of vinegar liquid vinegar [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' vinegar', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of vinegar is a liquid .
+
+ The phase" True " to prepare vinegar from months to weeks.
+" 2 [' to', ' prepare', ' vinegar']
+191 38 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of vinegar liquid vinegar [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' vinegar', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of vinegar is a liquid .
+
+ The phase" True of water in vinegar has such a profound 3 [' of', ' water', ' in', ' vinegar']
+192 38 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of vinegar liquid vinegar [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' vinegar', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of vinegar is a liquid .
+
+ The phase" True poutine gravy and vinegar for its French fries, 4 [' p', 'outine', ' gravy', ' and', ' vinegar']
+193 38 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of vinegar liquid vinegar [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' vinegar', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of vinegar is a liquid .
+
+ The phase" True " or steeping in vinegar was ""the onely use" 4 [' or', ' ste', 'eping', ' in', ' vinegar']
+194 38 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of vinegar liquid vinegar [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' vinegar', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of vinegar is a liquid .
+
+ The phase" True consisting of vinegar or some other weak 2 [' consisting', ' of', ' vinegar']
+195 39 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silver solid silver [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True the pair took the silver medal both individually 4 [' the', ' pair', ' took', ' the', ' silver']
+196 39 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silver solid silver [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True the eventual silver medalist, in the quarterfinals. 2 [' the', ' eventual', ' silver']
+197 39 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silver solid silver [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True gold medals and no silver medals were 4 [' gold', ' medals', ' and', ' no', ' silver']
+198 39 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silver solid silver [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True car is the silver grey Aston Martin 3 [' car', ' is', ' the', ' silver']
+199 39 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of silver solid silver [' is', ' a', ' solid', '.', ' At', ' higher', ' temperatures', ',', ' it', ' becomes', ' a', ' liquid', '.', ' At', ' even', ' higher', ' temperatures', ',', ' it', ' becomes'] is a solid . At higher temperatures , it becomes a liquid . At even higher temperatures , it becomes True Congress had authorized a silver three-cent piece 4 [' Congress', ' had', ' authorized', ' a', ' silver']
+200 40 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of leather solid leather "[' is', ' called', ' the', ' ""', 'f', 'ib', 'rous', '""', ' phase', '.', ' When', ' leather', ' is', ' heated', ',', ' the', ' fib', 'rous', ' phase', ' is']" " is called the "" f ib rous "" phase . When leather is heated , the fib rous phase is" False windscreen wipers, a leather and chrome 6 [' wind', 'screen', ' wip', 'ers', ',', ' a', ' leather']
+201 40 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of leather solid leather "[' is', ' called', ' the', ' ""', 'f', 'ib', 'rous', '""', ' phase', '.', ' When', ' leather', ' is', ' heated', ',', ' the', ' fib', 'rous', ' phase', ' is']" " is called the "" f ib rous "" phase . When leather is heated , the fib rous phase is" False brewing, textiles, and leather manufacturing, 6 [' brewing', ',', ' text', 'iles', ',', ' and', ' leather']
+202 40 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of leather solid leather "[' is', ' called', ' the', ' ""', 'f', 'ib', 'rous', '""', ' phase', '.', ' When', ' leather', ' is', ' heated', ',', ' the', ' fib', 'rous', ' phase', ' is']" " is called the "" f ib rous "" phase . When leather is heated , the fib rous phase is" False always dressed in a leather jacket and skirt, 4 [' always', ' dressed', ' in', ' a', ' leather']
+203 40 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of leather solid leather "[' is', ' called', ' the', ' ""', 'f', 'ib', 'rous', '""', ' phase', '.', ' When', ' leather', ' is', ' heated', ',', ' the', ' fib', 'rous', ' phase', ' is']" " is called the "" f ib rous "" phase . When leather is heated , the fib rous phase is" False place using a leather strap attached from 3 [' place', ' using', ' a', ' leather']
+204 40 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of leather solid leather "[' is', ' called', ' the', ' ""', 'f', 'ib', 'rous', '""', ' phase', '.', ' When', ' leather', ' is', ' heated', ',', ' the', ' fib', 'rous', ' phase', ' is']" " is called the "" f ib rous "" phase . When leather is heated , the fib rous phase is" False of wood and leather working. Artwork 3 [' of', ' wood', ' and', ' leather']
+205 41 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of argon gas argon [' is', ' a', ' solid', '.', '\n', '\n', 'Ar', 'gon', ' is', ' a', ' noble', ' gas', ',', ' which', ' means', ' that', ' it', ' is', ' chemically', ' inert'] " is a solid .
+
+ Ar gon is a noble gas , which means that it is chemically inert" True the dead bird, argon also enhances shelf 5 [' the', ' dead', ' bird', ',', ' arg', 'on']
+206 41 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of argon gas argon [' is', ' a', ' solid', '.', '\n', '\n', 'Ar', 'gon', ' is', ' a', ' noble', ' gas', ',', ' which', ' means', ' that', ' it', ' is', ' chemically', ' inert'] " is a solid .
+
+ Ar gon is a noble gas , which means that it is chemically inert" True affected by loss of argon due to low-grade 5 [' affected', ' by', ' loss', ' of', ' arg', 'on']
+207 41 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of argon gas argon [' is', ' a', ' solid', '.', '\n', '\n', 'Ar', 'gon', ' is', ' a', ' noble', ' gas', ',', ' which', ' means', ' that', ' it', ' is', ' chemically', ' inert'] " is a solid .
+
+ Ar gon is a noble gas , which means that it is chemically inert" True 0 ['argon']
+208 41 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of argon gas argon [' is', ' a', ' solid', '.', '\n', '\n', 'Ar', 'gon', ' is', ' a', ' noble', ' gas', ',', ' which', ' means', ' that', ' it', ' is', ' chemically', ' inert'] " is a solid .
+
+ Ar gon is a noble gas , which means that it is chemically inert" True source of purified argon products. Argon is 4 [' source', ' of', ' purified', ' arg', 'on']
+209 41 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of argon gas argon [' is', ' a', ' solid', '.', '\n', '\n', 'Ar', 'gon', ' is', ' a', ' noble', ' gas', ',', ' which', ' means', ' that', ' it', ' is', ' chemically', ' inert'] " is a solid .
+
+ Ar gon is a noble gas , which means that it is chemically inert" True (He), neon (Ne), argon (Ar), krypton (Kr), 8 [' (', 'He', '),', ' neon', ' (', 'Ne', '),', ' arg', 'on']
+210 42 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wax solid wax [' is', ' solid', '.', '\n', '\n', 'W', 'ax', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'W', 'ax', ' is', ' a'] " is solid .
+
+ W ax is a solid at room temperature .
+
+ W ax is a" True Paris Spectacular wax museum, an 3 [' Paris', ' Spect', 'acular', ' wax']
+211 42 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wax solid wax [' is', ' solid', '.', '\n', '\n', 'W', 'ax', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'W', 'ax', ' is', ' a'] " is solid .
+
+ W ax is a solid at room temperature .
+
+ W ax is a" True (1925) about a wax figure display at 6 [' (', '19', '25', ')', ' about', ' a', ' wax']
+212 42 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wax solid wax [' is', ' solid', '.', '\n', '\n', 'W', 'ax', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'W', 'ax', ' is', ' a'] " is solid .
+
+ W ax is a solid at room temperature .
+
+ W ax is a" True island where we could wax philosophical and kind 4 [' island', ' where', ' we', ' could', ' wax']
+213 42 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wax solid wax [' is', ' solid', '.', '\n', '\n', 'W', 'ax', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'W', 'ax', ' is', ' a'] " is solid .
+
+ W ax is a solid at room temperature .
+
+ W ax is a" True placed directly into wax ring moulds for 3 [' placed', ' directly', ' into', ' wax']
+214 42 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of wax solid wax [' is', ' solid', '.', '\n', '\n', 'W', 'ax', ' is', ' a', ' solid', ' at', ' room', ' temperature', '.', '\n', '\n', 'W', 'ax', ' is', ' a'] " is solid .
+
+ W ax is a solid at room temperature .
+
+ W ax is a" True redrawn using hot wax, usually made 4 [' red', 'rawn', ' using', ' hot', ' wax']
+215 43 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of beer liquid beer [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' beer', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of beer is a liquid .
+
+ The phase" True Palace Ale, a beer on sale in the ground, 4 [' Palace', ' Ale', ',', ' a', ' beer']
+216 43 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of beer liquid beer [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' beer', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of beer is a liquid .
+
+ The phase" True gave speeches at a beer hall rally. 4 [' gave', ' speeches', ' at', ' a', ' beer']
+217 43 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of beer liquid beer [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' beer', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of beer is a liquid .
+
+ The phase" True their homes with neon beer signs and 4 [' their', ' homes', ' with', ' neon', ' beer']
+218 43 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of beer liquid beer [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' beer', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of beer is a liquid .
+
+ The phase" True and asked for empty beer bottles and labels 4 [' and', ' asked', ' for', ' empty', ' beer']
+219 43 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of beer liquid beer [' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' beer', ' is', ' a', ' liquid', '.', '\n', '\n', 'The', ' phase'] " is a liquid .
+
+ The phase of matter of beer is a liquid .
+
+ The phase" True " ""something between beer and sauerkraut""." 3 "[' ""', 'something', ' between', ' beer']"
+220 44 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of radium solid radium "[' is', ' a', ' solid', '.', '\n', '\n', 'R', 'adium', ' is', ' a', ' radioactive', ' element', ' that', ' is', ' found', ' in', ' the', ' earth', ""'s"", ' crust']" " is a solid .
+
+ R adium is a radioactive element that is found in the earth 's crust" True carrying test tubes of radium in her pockets 5 [' carrying', ' test', ' tubes', ' of', ' rad', 'ium']
+221 44 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of radium solid radium "[' is', ' a', ' solid', '.', '\n', '\n', 'R', 'adium', ' is', ' a', ' radioactive', ' element', ' that', ' is', ' found', ' in', ' the', ' earth', ""'s"", ' crust']" " is a solid .
+
+ R adium is a radioactive element that is found in the earth 's crust" True 74 MBq of radium (assumed to be 5 [' 74', ' MB', 'q', ' of', ' rad', 'ium']
+222 44 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of radium solid radium "[' is', ' a', ' solid', '.', '\n', '\n', 'R', 'adium', ' is', ' a', ' radioactive', ' element', ' that', ' is', ' found', ' in', ' the', ' earth', ""'s"", ' crust']" " is a solid .
+
+ R adium is a radioactive element that is found in the earth 's crust" True measuring the amount of radium in seawater 5 [' measuring', ' the', ' amount', ' of', ' rad', 'ium']
+223 44 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of radium solid radium "[' is', ' a', ' solid', '.', '\n', '\n', 'R', 'adium', ' is', ' a', ' radioactive', ' element', ' that', ' is', ' found', ' in', ' the', ' earth', ""'s"", ' crust']" " is a solid .
+
+ R adium is a radioactive element that is found in the earth 's crust" True uranium, thorium, radium and polonium. 6 [' uranium', ',', ' thor', 'ium', ',', ' rad', 'ium']
+224 44 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of radium solid radium "[' is', ' a', ' solid', '.', '\n', '\n', 'R', 'adium', ' is', ' a', ' radioactive', ' element', ' that', ' is', ' found', ' in', ' the', ' earth', ""'s"", ' crust']" " is a solid .
+
+ R adium is a radioactive element that is found in the earth 's crust" True hafnium, and radium and rutherfordium, 6 [' ha', 'fn', 'ium', ',', ' and', ' rad', 'ium']
+225 45 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of platinum solid platinum [' is', ' a', ' solid', '.', '\n', '\n', 'Pl', 'atinum', ' is', ' a', ' very', ' hard', ' metal', ',', ' and', ' is', ' used', ' in', ' jew', 'ellery'] " is a solid .
+
+ Pl atinum is a very hard metal , and is used in jew ellery" True certified triple platinum by the Recording 2 [' certified', ' triple', ' platinum']
+226 45 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of platinum solid platinum [' is', ' a', ' solid', '.', '\n', '\n', 'Pl', 'atinum', ' is', ' a', ' very', ' hard', ' metal', ',', ' and', ' is', ' used', ' in', ' jew', 'ellery'] " is a solid .
+
+ Pl atinum is a very hard metal , and is used in jew ellery" True four 90 % platinum / 10 % iridium 3 [' four', ' 90', ' %', ' platinum']
+227 45 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of platinum solid platinum [' is', ' a', ' solid', '.', '\n', '\n', 'Pl', 'atinum', ' is', ' a', ' very', ' hard', ' metal', ',', ' and', ' is', ' used', ' in', ' jew', 'ellery'] " is a solid .
+
+ Pl atinum is a very hard metal , and is used in jew ellery" True was later certified platinum by the Syndicat National 3 [' was', ' later', ' certified', ' platinum']
+228 45 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of platinum solid platinum [' is', ' a', ' solid', '.', '\n', '\n', 'Pl', 'atinum', ' is', ' a', ' very', ' hard', ' metal', ',', ' and', ' is', ' used', ' in', ' jew', 'ellery'] " is a solid .
+
+ Pl atinum is a very hard metal , and is used in jew ellery" True track for the platinum edition of the 3 [' track', ' for', ' the', ' platinum']
+229 45 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of platinum solid platinum [' is', ' a', ' solid', '.', '\n', '\n', 'Pl', 'atinum', ' is', ' a', ' very', ' hard', ' metal', ',', ' and', ' is', ' used', ' in', ' jew', 'ellery'] " is a solid .
+
+ Pl atinum is a very hard metal , and is used in jew ellery" True was certified platinum by the Recording 2 [' was', ' certified', ' platinum']
+230 46 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of juice liquid juice [' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' milk', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is liquid .
+
+ The phase of matter of milk is solid .
+
+ The phase of matter" True and model of lemon juice being contained 4 [' and', ' model', ' of', ' lemon', ' juice']
+231 46 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of juice liquid juice [' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' milk', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is liquid .
+
+ The phase of matter of milk is solid .
+
+ The phase of matter" True midazolam. Grapefruit juice reduces intestinal 7 [' mid', 'az', 'ol', 'am', '.', ' Grape', 'fruit', ' juice']
+232 46 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of juice liquid juice [' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' milk', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is liquid .
+
+ The phase of matter of milk is solid .
+
+ The phase of matter" True of marketing lemon juice in this manner was 3 [' of', ' marketing', ' lemon', ' juice']
+233 46 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of juice liquid juice [' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' milk', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is liquid .
+
+ The phase of matter of milk is solid .
+
+ The phase of matter" True is cooked with lemon juice and water, 4 [' is', ' cooked', ' with', ' lemon', ' juice']
+234 46 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of juice liquid juice [' is', ' liquid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter', ' of', ' milk', ' is', ' solid', '.', '\n', '\n', 'The', ' phase', ' of', ' matter'] " is liquid .
+
+ The phase of matter of milk is solid .
+
+ The phase of matter" True Maid orange juice and was sold to The 2 [' Maid', ' orange', ' juice']
+235 47 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tungsten solid tungsten [' is', ' a', ' hex', 'agonal', ' close', '-', 'packed', ' (', 'h', 'cp', ')', ' structure', '.', ' At', ' higher', ' temperatures', ',', ' the', ' t', 'ung'] is a hex agonal close - packed ( h cp ) structure . At higher temperatures , the t ung False " welding =
+" 4 [' welding', ' =', 't', 'ung', 'sten']
+236 47 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tungsten solid tungsten [' is', ' a', ' hex', 'agonal', ' close', '-', 'packed', ' (', 'h', 'cp', ')', ' structure', '.', ' At', ' higher', ' temperatures', ',', ' the', ' t', 'ung'] is a hex agonal close - packed ( h cp ) structure . At higher temperatures , the t ung False story, Soderbergh used tungsten film with no filter 9 [' story', ',', ' S', 'oder', 'ber', 'gh', ' used', ' t', 'ung', 'sten']
+237 47 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tungsten solid tungsten [' is', ' a', ' hex', 'agonal', ' close', '-', 'packed', ' (', 'h', 'cp', ')', ' structure', '.', ' At', ' higher', ' temperatures', ',', ' the', ' t', 'ung'] is a hex agonal close - packed ( h cp ) structure . At higher temperatures , the t ung False that mixing tungsten with chromium, steel 4 [' that', ' mixing', ' t', 'ung', 'sten']
+238 47 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tungsten solid tungsten [' is', ' a', ' hex', 'agonal', ' close', '-', 'packed', ' (', 'h', 'cp', ')', ' structure', '.', ' At', ' higher', ' temperatures', ',', ' the', ' t', 'ung'] is a hex agonal close - packed ( h cp ) structure . At higher temperatures , the t ung False iodide; at a tungsten filament of 1700°C 7 [' iod', 'ide', ';', ' at', ' a', ' t', 'ung', 'sten']
+239 47 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of tungsten solid tungsten [' is', ' a', ' hex', 'agonal', ' close', '-', 'packed', ' (', 'h', 'cp', ')', ' structure', '.', ' At', ' higher', ' temperatures', ',', ' the', ' t', 'ung'] is a hex agonal close - packed ( h cp ) structure . At higher temperatures , the t ung False was a 1,500 watt tungsten bulb 15 inches in 8 [' was', ' a', ' 1', ',', '500', ' watt', ' t', 'ung', 'sten']
+240 48 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of kerosene liquid kerosene [' is', ' a', ' mixture', ' of', ' liquid', ' and', ' gas', '.', ' The', ' liquid', ' is', ' a', ' mixture', ' of', ' hydro', 'car', 'bons', ',', ' and', ' the'] is a mixture of liquid and gas . The liquid is a mixture of hydro car bons , and the True petrol (gasoline), kerosene and / or diesel, and 7 [' petrol', ' (', 'gas', 'oline', '),', ' k', 'eros', 'ene']
+241 48 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of kerosene liquid kerosene [' is', ' a', ' mixture', ' of', ' liquid', ' and', ' gas', '.', ' The', ' liquid', ' is', ' a', ' mixture', ' of', ' hydro', 'car', 'bons', ',', ' and', ' the'] is a mixture of liquid and gas . The liquid is a mixture of hydro car bons , and the True cricket bat and a kerosene tin for a wicket. 6 [' cricket', ' bat', ' and', ' a', ' k', 'eros', 'ene']
+242 48 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of kerosene liquid kerosene [' is', ' a', ' mixture', ' of', ' liquid', ' and', ' gas', '.', ' The', ' liquid', ' is', ' a', ' mixture', ' of', ' hydro', 'car', 'bons', ',', ' and', ' the'] is a mixture of liquid and gas . The liquid is a mixture of hydro car bons , and the True power (25.5 %), and kerosene (16.5 %). Most 10 [' power', ' (', '25', '.', '5', ' %', '),', ' and', ' k', 'eros', 'ene']
+243 48 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of kerosene liquid kerosene [' is', ' a', ' mixture', ' of', ' liquid', ' and', ' gas', '.', ' The', ' liquid', ' is', ' a', ' mixture', ' of', ' hydro', 'car', 'bons', ',', ' and', ' the'] is a mixture of liquid and gas . The liquid is a mixture of hydro car bons , and the True candles, and kerosene to stranded 5 [' candles', ',', ' and', ' k', 'eros', 'ene']
+244 48 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of kerosene liquid kerosene [' is', ' a', ' mixture', ' of', ' liquid', ' and', ' gas', '.', ' The', ' liquid', ' is', ' a', ' mixture', ' of', ' hydro', 'car', 'bons', ',', ' and', ' the'] is a mixture of liquid and gas . The liquid is a mixture of hydro car bons , and the True also known as a kerosene cut. The boiling 6 [' also', ' known', ' as', ' a', ' k', 'eros', 'ene']
+245 49 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of champagne liquid champagne [' is', ' a', ' gas', '.', ' As', ' the', ' temperature', ' is', ' raised', ',', ' the', ' gas', ' becomes', ' a', ' liquid', '.', ' The', ' temperature', ' at', ' which'] is a gas . As the temperature is raised , the gas becomes a liquid . The temperature at which True " began drinking champagne each morning, ""to" 2 [' began', ' drinking', ' champagne']
+246 49 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of champagne liquid champagne [' is', ' a', ' gas', '.', ' As', ' the', ' temperature', ' is', ' raised', ',', ' the', ' gas', ' becomes', ' a', ' liquid', '.', ' The', ' temperature', ' at', ' which'] is a gas . As the temperature is raised , the gas becomes a liquid . The temperature at which True with a bottle of champagne which he had purchased 4 [' with', ' a', ' bottle', ' of', ' champagne']
+247 49 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of champagne liquid champagne [' is', ' a', ' gas', '.', ' As', ' the', ' temperature', ' is', ' raised', ',', ' the', ' gas', ' becomes', ' a', ' liquid', '.', ' The', ' temperature', ' at', ' which'] is a gas . As the temperature is raised , the gas becomes a liquid . The temperature at which True dancing and sipping champagne whilst singing 4 [' dancing', ' and', ' s', 'ipping', ' champagne']
+248 49 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of champagne liquid champagne [' is', ' a', ' gas', '.', ' As', ' the', ' temperature', ' is', ' raised', ',', ' the', ' gas', ' becomes', ' a', ' liquid', '.', ' The', ' temperature', ' at', ' which'] is a gas . As the temperature is raised , the gas becomes a liquid . The temperature at which True rings and jewelry, champagne and brandy, Maxim 4 [' rings', ' and', ' jewelry', ',', ' champagne']
+249 49 At room temperature, the phase of matter of x -1 At room temperature, the phase of matter of champagne liquid champagne [' is', ' a', ' gas', '.', ' As', ' the', ' temperature', ' is', ' raised', ',', ' the', ' gas', ' becomes', ' a', ' liquid', '.', ' The', ' temperature', ' at', ' which'] is a gas . As the temperature is raised , the gas becomes a liquid . The temperature at which True brewers have used champagne yeasts to increase 3 [' brewers', ' have', ' used', ' champagne']
diff --git a/patchscopes/code/preprocessed_data/commonsense/task_done_by_person.tsv b/patchscopes/code/preprocessed_data/commonsense/task_done_by_person.tsv
new file mode 100644
index 00000000..67a119b3
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/commonsense/task_done_by_person.tsv
@@ -0,0 +1,102 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 1 The professional role most suited to handle x -1 The professional role most suited to handle delivering mail mail carrier delivering mail [' is', ' the', ' Mail', ' Carrier', '.', ' This', ' is', ' a', ' job', ' that', ' requires', ' a', ' lot', ' of', ' physical', ' strength', ' and', ' stamina', '.', ' The'] is the Mail Carrier . This is a job that requires a lot of physical strength and stamina . The False secured his first job, delivering mail to neighboring 6 [' secured', ' his', ' first', ' job', ',', ' delivering', ' mail']
+1 1 The professional role most suited to handle x -1 The professional role most suited to handle delivering mail mail carrier delivering mail [' is', ' the', ' Mail', ' Carrier', '.', ' This', ' is', ' a', ' job', ' that', ' requires', ' a', ' lot', ' of', ' physical', ' strength', ' and', ' stamina', '.', ' The'] is the Mail Carrier . This is a job that requires a lot of physical strength and stamina . The False full-time carriers delivering mail six times a day, 5 [' full', '-', 'time', ' carriers', ' delivering', ' mail']
+2 1 The professional role most suited to handle x -1 The professional role most suited to handle delivering mail mail carrier delivering mail [' is', ' the', ' Mail', ' Carrier', '.', ' This', ' is', ' a', ' job', ' that', ' requires', ' a', ' lot', ' of', ' physical', ' strength', ' and', ' stamina', '.', ' The'] is the Mail Carrier . This is a job that requires a lot of physical strength and stamina . The False full-time carriers delivering mail six times a day, six 5 [' full', '-', 'time', ' carriers', ' delivering', ' mail']
+3 1 The professional role most suited to handle x -1 The professional role most suited to handle delivering mail mail carrier delivering mail [' is', ' the', ' Mail', ' Carrier', '.', ' This', ' is', ' a', ' job', ' that', ' requires', ' a', ' lot', ' of', ' physical', ' strength', ' and', ' stamina', '.', ' The'] is the Mail Carrier . This is a job that requires a lot of physical strength and stamina . The False situation by only delivering mail addressed to Brooks 4 [' situation', ' by', ' only', ' delivering', ' mail']
+4 1 The professional role most suited to handle x -1 The professional role most suited to handle delivering mail mail carrier delivering mail [' is', ' the', ' Mail', ' Carrier', '.', ' This', ' is', ' a', ' job', ' that', ' requires', ' a', ' lot', ' of', ' physical', ' strength', ' and', ' stamina', '.', ' The'] is the Mail Carrier . This is a job that requires a lot of physical strength and stamina . The False secured his first job, delivering mail to neighboring 6 [' secured', ' his', ' first', ' job', ',', ' delivering', ' mail']
+5 4 The professional role most suited to handle x -1 The professional role most suited to handle leading teams leader leading teams [' of', ' people', ' is', ' that', ' of', ' a', ' manager', '.', ' A', ' manager', ' is', ' responsible', ' for', ' the', ' overall', ' performance', ' of', ' a', ' team', '.'] of people is that of a manager . A manager is responsible for the overall performance of a team . False of the world's leading teams of military vehicle 5 "[' of', ' the', ' world', ""'s"", ' leading', ' teams']"
+6 4 The professional role most suited to handle x -1 The professional role most suited to handle leading teams leader leading teams [' of', ' people', ' is', ' that', ' of', ' a', ' manager', '.', ' A', ' manager', ' is', ' responsible', ' for', ' the', ' overall', ' performance', ' of', ' a', ' team', '.'] of people is that of a manager . A manager is responsible for the overall performance of a team . False improvement; the class leading teams remained the 5 [' improvement', ';', ' the', ' class', ' leading', ' teams']
+7 4 The professional role most suited to handle x -1 The professional role most suited to handle leading teams leader leading teams [' of', ' people', ' is', ' that', ' of', ' a', ' manager', '.', ' A', ' manager', ' is', ' responsible', ' for', ' the', ' overall', ' performance', ' of', ' a', ' team', '.'] of people is that of a manager . A manager is responsible for the overall performance of a team . False drivers of the leading teams to pit, however, 4 [' drivers', ' of', ' the', ' leading', ' teams']
+8 4 The professional role most suited to handle x -1 The professional role most suited to handle leading teams leader leading teams [' of', ' people', ' is', ' that', ' of', ' a', ' manager', '.', ' A', ' manager', ' is', ' responsible', ' for', ' the', ' overall', ' performance', ' of', ' a', ' team', '.'] of people is that of a manager . A manager is responsible for the overall performance of a team . False one of the leading teams and former winners 4 [' one', ' of', ' the', ' leading', ' teams']
+9 4 The professional role most suited to handle x -1 The professional role most suited to handle leading teams leader leading teams [' of', ' people', ' is', ' that', ' of', ' a', ' manager', '.', ' A', ' manager', ' is', ' responsible', ' for', ' the', ' overall', ' performance', ' of', ' a', ' team', '.'] of people is that of a manager . A manager is responsible for the overall performance of a team . False establish as one of the leading teams in England. 6 [' establish', ' as', ' one', ' of', ' the', ' leading', ' teams']
+10 7 The professional role most suited to handle x -1 The professional role most suited to handle designing buildings architect designing buildings [' is', ' the', ' architect', '.', ' Architects', ' are', ' responsible', ' for', ' the', ' overall', ' design', ' of', ' a', ' building', ',', ' including', ' the', ' layout', ' of', ' the'] is the architect . Architects are responsible for the overall design of a building , including the layout of the True tornadoes. When designing buildings to prevent 5 [' torn', 'adoes', '.', ' When', ' designing', ' buildings']
+11 7 The professional role most suited to handle x -1 The professional role most suited to handle designing buildings architect designing buildings [' is', ' the', ' architect', '.', ' Architects', ' are', ' responsible', ' for', ' the', ' overall', ' design', ' of', ' a', ' building', ',', ' including', ' the', ' layout', ' of', ' the'] is the architect . Architects are responsible for the overall design of a building , including the layout of the True of tornadoes. When designing buildings to prevent against 6 [' of', ' torn', 'adoes', '.', ' When', ' designing', ' buildings']
+12 12 The professional role most suited to handle x -1 The professional role most suited to handle managing hotels hotel manager managing hotels [' is', ' the', ' hotel', ' manager', '.', ' The', ' hotel', ' manager', ' is', ' responsible', ' for', ' the', ' day', '-', 'to', '-', 'day', ' operations', ' of', ' the'] is the hotel manager . The hotel manager is responsible for the day - to - day operations of the True eventually found work managing hotels on behalf 4 [' eventually', ' found', ' work', ' managing', ' hotels']
+13 13 The professional role most suited to handle x -1 The professional role most suited to handle farming farmer farming [' is', ' that', ' of', ' a', ' farmer', '.', '\n', '\n', 'The', ' farmer', ' is', ' the', ' one', ' who', ' is', ' in', ' charge', ' of', ' the', ' farm'] " is that of a farmer .
+
+ The farmer is the one who is in charge of the farm" True accordance with sustainable farming practices also 3 [' accordance', ' with', ' sustainable', ' farming']
+14 13 The professional role most suited to handle x -1 The professional role most suited to handle farming farmer farming [' is', ' that', ' of', ' a', ' farmer', '.', '\n', '\n', 'The', ' farmer', ' is', ' the', ' one', ' who', ' is', ' in', ' charge', ' of', ' the', ' farm'] " is that of a farmer .
+
+ The farmer is the one who is in charge of the farm" True transitioned to the farming of wheat by 1792, 3 [' transitioned', ' to', ' the', ' farming']
+15 13 The professional role most suited to handle x -1 The professional role most suited to handle farming farmer farming [' is', ' that', ' of', ' a', ' farmer', '.', '\n', '\n', 'The', ' farmer', ' is', ' the', ' one', ' who', ' is', ' in', ' charge', ' of', ' the', ' farm'] " is that of a farmer .
+
+ The farmer is the one who is in charge of the farm" True Recreation as the hippie farming couple Beverly 5 [' Recreation', ' as', ' the', ' hipp', 'ie', ' farming']
+16 13 The professional role most suited to handle x -1 The professional role most suited to handle farming farmer farming [' is', ' that', ' of', ' a', ' farmer', '.', '\n', '\n', 'The', ' farmer', ' is', ' the', ' one', ' who', ' is', ' in', ' charge', ' of', ' the', ' farm'] " is that of a farmer .
+
+ The farmer is the one who is in charge of the farm" True Adams resumed farming at his home 2 [' Adams', ' resumed', ' farming']
+17 13 The professional role most suited to handle x -1 The professional role most suited to handle farming farmer farming [' is', ' that', ' of', ' a', ' farmer', '.', '\n', '\n', 'The', ' farmer', ' is', ' the', ' one', ' who', ' is', ' in', ' charge', ' of', ' the', ' farm'] " is that of a farmer .
+
+ The farmer is the one who is in charge of the farm" True improvements to farming and collected 2 [' improvements', ' to', ' farming']
+18 16 The professional role most suited to handle x -1 The professional role most suited to handle flying airplanes pilot flying airplanes [' is', ' the', ' pilot', '.', ' Pil', 'ots', ' are', ' responsible', ' for', ' the', ' safe', ' operation', ' of', ' the', ' aircraft', ',', ' and', ' for', ' the', ' safety'] is the pilot . Pil ots are responsible for the safe operation of the aircraft , and for the safety True sabotaging equipment or flying airplanes above the filming 5 [' sabot', 'aging', ' equipment', ' or', ' flying', ' airplanes']
+19 16 The professional role most suited to handle x -1 The professional role most suited to handle flying airplanes pilot flying airplanes [' is', ' the', ' pilot', '.', ' Pil', 'ots', ' are', ' responsible', ' for', ' the', ' safe', ' operation', ' of', ' the', ' aircraft', ',', ' and', ' for', ' the', ' safety'] is the pilot . Pil ots are responsible for the safe operation of the aircraft , and for the safety True equipment or flying airplanes above the filming 3 [' equipment', ' or', ' flying', ' airplanes']
+20 16 The professional role most suited to handle x -1 The professional role most suited to handle flying airplanes pilot flying airplanes [' is', ' the', ' pilot', '.', ' Pil', 'ots', ' are', ' responsible', ' for', ' the', ' safe', ' operation', ' of', ' the', ' aircraft', ',', ' and', ' for', ' the', ' safety'] is the pilot . Pil ots are responsible for the safe operation of the aircraft , and for the safety True sabotaging equipment or flying airplanes above the filming 5 [' sabot', 'aging', ' equipment', ' or', ' flying', ' airplanes']
+21 17 The professional role most suited to handle x -1 The professional role most suited to handle writing novels author writing novels [' is', ' that', ' of', ' a', ' novelist', '.', '\n', '\n', 'The', ' professional', ' role', ' most', ' suited', ' to', ' handle', ' writing', ' novels', ' is', ' that', ' of'] " is that of a novelist .
+
+ The professional role most suited to handle writing novels is that of" False " prevented her from writing novels for several years.
+" 4 [' prevented', ' her', ' from', ' writing', ' novels']
+22 17 The professional role most suited to handle x -1 The professional role most suited to handle writing novels author writing novels [' is', ' that', ' of', ' a', ' novelist', '.', '\n', '\n', 'The', ' professional', ' role', ' most', ' suited', ' to', ' handle', ' writing', ' novels', ' is', ' that', ' of'] " is that of a novelist .
+
+ The professional role most suited to handle writing novels is that of" False script. Concentrating on writing novels and short stories, 6 [' script', '.', ' Concent', 'rating', ' on', ' writing', ' novels']
+23 17 The professional role most suited to handle x -1 The professional role most suited to handle writing novels author writing novels [' is', ' that', ' of', ' a', ' novelist', '.', '\n', '\n', 'The', ' professional', ' role', ' most', ' suited', ' to', ' handle', ' writing', ' novels', ' is', ' that', ' of'] " is that of a novelist .
+
+ The professional role most suited to handle writing novels is that of" False Aside from writing novels and short stories, 3 [' Aside', ' from', ' writing', ' novels']
+24 17 The professional role most suited to handle x -1 The professional role most suited to handle writing novels author writing novels [' is', ' that', ' of', ' a', ' novelist', '.', '\n', '\n', 'The', ' professional', ' role', ' most', ' suited', ' to', ' handle', ' writing', ' novels', ' is', ' that', ' of'] " is that of a novelist .
+
+ The professional role most suited to handle writing novels is that of" False Colorado. He had been writing novels for the past 20 years 6 [' Colorado', '.', ' He', ' had', ' been', ' writing', ' novels']
+25 17 The professional role most suited to handle x -1 The professional role most suited to handle writing novels author writing novels [' is', ' that', ' of', ' a', ' novelist', '.', '\n', '\n', 'The', ' professional', ' role', ' most', ' suited', ' to', ' handle', ' writing', ' novels', ' is', ' that', ' of'] " is that of a novelist .
+
+ The professional role most suited to handle writing novels is that of" False prevented her from writing novels for several 4 [' prevented', ' her', ' from', ' writing', ' novels']
+26 18 The professional role most suited to handle x -1 The professional role most suited to handle investigating crimes detective investigating crimes [' is', ' that', ' of', ' a', ' detective', '.', '\n', '\n', 'The', ' detective', ' is', ' the', ' one', ' who', ' is', ' most', ' likely', ' to', ' be', ' called'] " is that of a detective .
+
+ The detective is the one who is most likely to be called" True ascertaining and investigating crimes perpetrated 4 [' ascert', 'aining', ' and', ' investigating', ' crimes']
+27 18 The professional role most suited to handle x -1 The professional role most suited to handle investigating crimes detective investigating crimes [' is', ' that', ' of', ' a', ' detective', '.', '\n', '\n', 'The', ' detective', ' is', ' the', ' one', ' who', ' is', ' most', ' likely', ' to', ' be', ' called'] " is that of a detective .
+
+ The detective is the one who is most likely to be called" True Hawes' assistance in investigating crimes allegedly committed 6 "[' Haw', 'es', ""'"", ' assistance', ' in', ' investigating', ' crimes']"
+28 18 The professional role most suited to handle x -1 The professional role most suited to handle investigating crimes detective investigating crimes [' is', ' that', ' of', ' a', ' detective', '.', '\n', '\n', 'The', ' detective', ' is', ' the', ' one', ' who', ' is', ' most', ' likely', ' to', ' be', ' called'] " is that of a detective .
+
+ The detective is the one who is most likely to be called" True assistance in investigating crimes allegedly committed 3 [' assistance', ' in', ' investigating', ' crimes']
+29 18 The professional role most suited to handle x -1 The professional role most suited to handle investigating crimes detective investigating crimes [' is', ' that', ' of', ' a', ' detective', '.', '\n', '\n', 'The', ' detective', ' is', ' the', ' one', ' who', ' is', ' most', ' likely', ' to', ' be', ' called'] " is that of a detective .
+
+ The detective is the one who is most likely to be called" True Hawes's assistance in investigating crimes allegedly committed 6 "[' Haw', 'es', ""'s"", ' assistance', ' in', ' investigating', ' crimes']"
+30 18 The professional role most suited to handle x -1 The professional role most suited to handle investigating crimes detective investigating crimes [' is', ' that', ' of', ' a', ' detective', '.', '\n', '\n', 'The', ' detective', ' is', ' the', ' one', ' who', ' is', ' most', ' likely', ' to', ' be', ' called'] " is that of a detective .
+
+ The detective is the one who is most likely to be called" True assistance in investigating crimes allegedly committed 3 [' assistance', ' in', ' investigating', ' crimes']
+31 21 The professional role most suited to handle x -1 The professional role most suited to handle building bridges civil engineer building bridges [' between', ' the', ' two', ' cultures', ' is', ' that', ' of', ' the', ' teacher', '.', ' The', ' teacher', ' is', ' the', ' one', ' who', ' can', ' bridge', ' the', ' gap'] between the two cultures is that of the teacher . The teacher is the one who can bridge the gap False Bottlerplatz, and the building bridges Windeckstraße, with 8 [' Bott', 'ler', 'pl', 'atz', ',', ' and', ' the', ' building', ' bridges']
+32 21 The professional role most suited to handle x -1 The professional role most suited to handle building bridges civil engineer building bridges [' between', ' the', ' two', ' cultures', ' is', ' that', ' of', ' the', ' teacher', '.', ' The', ' teacher', ' is', ' the', ' one', ' who', ' can', ' bridge', ' the', ' gap'] between the two cultures is that of the teacher . The teacher is the one who can bridge the gap False Bottlerplatz, and the building bridges Windeckstraße, with 8 [' Bott', 'ler', 'pl', 'atz', ',', ' and', ' the', ' building', ' bridges']
+33 21 The professional role most suited to handle x -1 The professional role most suited to handle building bridges civil engineer building bridges [' between', ' the', ' two', ' cultures', ' is', ' that', ' of', ' the', ' teacher', '.', ' The', ' teacher', ' is', ' the', ' one', ' who', ' can', ' bridge', ' the', ' gap'] between the two cultures is that of the teacher . The teacher is the one who can bridge the gap False " assistance when building bridges in their area.
+" 3 [' assistance', ' when', ' building', ' bridges']
+34 21 The professional role most suited to handle x -1 The professional role most suited to handle building bridges civil engineer building bridges [' between', ' the', ' two', ' cultures', ' is', ' that', ' of', ' the', ' teacher', '.', ' The', ' teacher', ' is', ' the', ' one', ' who', ' can', ' bridge', ' the', ' gap'] between the two cultures is that of the teacher . The teacher is the one who can bridge the gap False Department, which was building bridges and roads throughout 5 [' Department', ',', ' which', ' was', ' building', ' bridges']
+35 21 The professional role most suited to handle x -1 The professional role most suited to handle building bridges civil engineer building bridges [' between', ' the', ' two', ' cultures', ' is', ' that', ' of', ' the', ' teacher', '.', ' The', ' teacher', ' is', ' the', ' one', ' who', ' can', ' bridge', ' the', ' gap'] between the two cultures is that of the teacher . The teacher is the one who can bridge the gap False be responsible for building bridges for the roads 4 [' be', ' responsible', ' for', ' building', ' bridges']
+36 22 The professional role most suited to handle x -1 The professional role most suited to handle treating animals veterinarian treating animals [' is', ' that', ' of', ' a', ' veterinarian', '.', '\n', '\n', 'The', ' veterinarian', ' is', ' the', ' one', ' who', ' is', ' responsible', ' for', ' the', ' health', ' of'] " is that of a veterinarian .
+
+ The veterinarian is the one who is responsible for the health of" True " herself, supports treating animals humanely and ""would" 4 [' herself', ',', ' supports', ' treating', ' animals']
+37 22 The professional role most suited to handle x -1 The professional role most suited to handle treating animals veterinarian treating animals [' is', ' that', ' of', ' a', ' veterinarian', '.', '\n', '\n', 'The', ' veterinarian', ' is', ' the', ' one', ' who', ' is', ' responsible', ' for', ' the', ' health', ' of'] " is that of a veterinarian .
+
+ The veterinarian is the one who is responsible for the health of" True " herself, supports treating animals humanely and ""would" 4 [' herself', ',', ' supports', ' treating', ' animals']
+38 23 The professional role most suited to handle x -1 The professional role most suited to handle driving trucks truck driver driving trucks [' and', ' heavy', ' equipment', ' is', ' that', ' of', ' a', ' truck', ' driver', '.', ' Truck', ' drivers', ' are', ' responsible', ' for', ' the', ' safe', ' operation', ' of', ' the'] and heavy equipment is that of a truck driver . Truck drivers are responsible for the safe operation of the True attempted to destroy it by driving trucks laden with explosives 6 [' attempted', ' to', ' destroy', ' it', ' by', ' driving', ' trucks']
+39 23 The professional role most suited to handle x -1 The professional role most suited to handle driving trucks truck driver driving trucks [' and', ' heavy', ' equipment', ' is', ' that', ' of', ' a', ' truck', ' driver', '.', ' Truck', ' drivers', ' are', ' responsible', ' for', ' the', ' safe', ' operation', ' of', ' the'] and heavy equipment is that of a truck driver . Truck drivers are responsible for the safe operation of the True Maryland. After driving trucks for a few months, 4 [' Maryland', '.', ' After', ' driving', ' trucks']
+40 23 The professional role most suited to handle x -1 The professional role most suited to handle driving trucks truck driver driving trucks [' and', ' heavy', ' equipment', ' is', ' that', ' of', ' a', ' truck', ' driver', '.', ' Truck', ' drivers', ' are', ' responsible', ' for', ' the', ' safe', ' operation', ' of', ' the'] and heavy equipment is that of a truck driver . Truck drivers are responsible for the safe operation of the True higher-paying job driving trucks for Iron Island ironstone 5 [' higher', '-', 'paying', ' job', ' driving', ' trucks']
+41 23 The professional role most suited to handle x -1 The professional role most suited to handle driving trucks truck driver driving trucks [' and', ' heavy', ' equipment', ' is', ' that', ' of', ' a', ' truck', ' driver', '.', ' Truck', ' drivers', ' are', ' responsible', ' for', ' the', ' safe', ' operation', ' of', ' the'] and heavy equipment is that of a truck driver . Truck drivers are responsible for the safe operation of the True higher-paying job driving trucks for Iron Island 5 [' higher', '-', 'paying', ' job', ' driving', ' trucks']
+42 23 The professional role most suited to handle x -1 The professional role most suited to handle driving trucks truck driver driving trucks [' and', ' heavy', ' equipment', ' is', ' that', ' of', ' a', ' truck', ' driver', '.', ' Truck', ' drivers', ' are', ' responsible', ' for', ' the', ' safe', ' operation', ' of', ' the'] and heavy equipment is that of a truck driver . Truck drivers are responsible for the safe operation of the True Rutland) who was driving trucks for the Canadian 6 [' Rut', 'land', ')', ' who', ' was', ' driving', ' trucks']
+43 24 The professional role most suited to handle x -1 The professional role most suited to handle providing legal advice lawyer providing legal advice [' to', ' the', ' public', ' is', ' that', ' of', ' a', ' lawyer', '.', ' The', ' lawyer', ' is', ' a', ' professional', ' who', ' is', ' trained', ' to', ' provide', ' legal'] to the public is that of a lawyer . The lawyer is a professional who is trained to provide legal True brief section providing legal advice for American 4 [' brief', ' section', ' providing', ' legal', ' advice']
+44 24 The professional role most suited to handle x -1 The professional role most suited to handle providing legal advice lawyer providing legal advice [' to', ' the', ' public', ' is', ' that', ' of', ' a', ' lawyer', '.', ' The', ' lawyer', ' is', ' a', ' professional', ' who', ' is', ' trained', ' to', ' provide', ' legal'] to the public is that of a lawyer . The lawyer is a professional who is trained to provide legal True note, a brief section providing legal advice for American 7 [' note', ',', ' a', ' brief', ' section', ' providing', ' legal', ' advice']
+45 24 The professional role most suited to handle x -1 The professional role most suited to handle providing legal advice lawyer providing legal advice [' to', ' the', ' public', ' is', ' that', ' of', ' a', ' lawyer', '.', ' The', ' lawyer', ' is', ' a', ' professional', ' who', ' is', ' trained', ' to', ' provide', ' legal'] to the public is that of a lawyer . The lawyer is a professional who is trained to provide legal True ranged from providing legal advice to explaining 4 [' ranged', ' from', ' providing', ' legal', ' advice']
+46 24 The professional role most suited to handle x -1 The professional role most suited to handle providing legal advice lawyer providing legal advice [' to', ' the', ' public', ' is', ' that', ' of', ' a', ' lawyer', '.', ' The', ' lawyer', ' is', ' a', ' professional', ' who', ' is', ' trained', ' to', ' provide', ' legal'] to the public is that of a lawyer . The lawyer is a professional who is trained to provide legal True and ranged from providing legal advice to explaining how to 5 [' and', ' ranged', ' from', ' providing', ' legal', ' advice']
+47 27 The professional role most suited to handle x -1 The professional role most suited to handle managing finances financial advisor managing finances [' is', ' the', ' accountant', '.', ' The', ' accountant', ' is', ' responsible', ' for', ' the', ' financial', ' health', ' of', ' the', ' business', '.', ' They', ' are', ' responsible', ' for'] is the accountant . The accountant is responsible for the financial health of the business . They are responsible for False business together managing finances for hockey players. 3 [' business', ' together', ' managing', ' finances']
+48 27 The professional role most suited to handle x -1 The professional role most suited to handle managing finances financial advisor managing finances [' is', ' the', ' accountant', '.', ' The', ' accountant', ' is', ' responsible', ' for', ' the', ' financial', ' health', ' of', ' the', ' business', '.', ' They', ' are', ' responsible', ' for'] is the accountant . The accountant is responsible for the financial health of the business . They are responsible for False business together managing finances for hockey players. 3 [' business', ' together', ' managing', ' finances']
+49 27 The professional role most suited to handle x -1 The professional role most suited to handle managing finances financial advisor managing finances [' is', ' the', ' accountant', '.', ' The', ' accountant', ' is', ' responsible', ' for', ' the', ' financial', ' health', ' of', ' the', ' business', '.', ' They', ' are', ' responsible', ' for'] is the accountant . The accountant is responsible for the financial health of the business . They are responsible for False business together managing finances for hockey players. 3 [' business', ' together', ' managing', ' finances']
+50 27 The professional role most suited to handle x -1 The professional role most suited to handle managing finances financial advisor managing finances [' is', ' the', ' accountant', '.', ' The', ' accountant', ' is', ' responsible', ' for', ' the', ' financial', ' health', ' of', ' the', ' business', '.', ' They', ' are', ' responsible', ' for'] is the accountant . The accountant is responsible for the financial health of the business . They are responsible for False business together managing finances for hockey 3 [' business', ' together', ' managing', ' finances']
+51 27 The professional role most suited to handle x -1 The professional role most suited to handle managing finances financial advisor managing finances [' is', ' the', ' accountant', '.', ' The', ' accountant', ' is', ' responsible', ' for', ' the', ' financial', ' health', ' of', ' the', ' business', '.', ' They', ' are', ' responsible', ' for'] is the accountant . The accountant is responsible for the financial health of the business . They are responsible for False business together managing finances for hockey players. 3 [' business', ' together', ' managing', ' finances']
+52 29 The professional role most suited to handle x -1 The professional role most suited to handle teaching students teacher teaching students "[' with', ' learning', ' disabilities', ' is', ' that', ' of', ' a', ' teacher', ""'s"", ' aide', '.', ' The', ' aide', ' is', ' responsible', ' for', ' assisting', ' the', ' teacher', ' in']" with learning disabilities is that of a teacher 's aide . The aide is responsible for assisting the teacher in True time he was 15 was teaching students himself. In 1909, 6 [' time', ' he', ' was', ' 15', ' was', ' teaching', ' students']
+53 29 The professional role most suited to handle x -1 The professional role most suited to handle teaching students teacher teaching students "[' with', ' learning', ' disabilities', ' is', ' that', ' of', ' a', ' teacher', ""'s"", ' aide', '.', ' The', ' aide', ' is', ' responsible', ' for', ' assisting', ' the', ' teacher', ' in']" with learning disabilities is that of a teacher 's aide . The aide is responsible for assisting the teacher in True will continue teaching students about their culture, 3 [' will', ' continue', ' teaching', ' students']
+54 29 The professional role most suited to handle x -1 The professional role most suited to handle teaching students teacher teaching students "[' with', ' learning', ' disabilities', ' is', ' that', ' of', ' a', ' teacher', ""'s"", ' aide', '.', ' The', ' aide', ' is', ' responsible', ' for', ' assisting', ' the', ' teacher', ' in']" with learning disabilities is that of a teacher 's aide . The aide is responsible for assisting the teacher in True time he was 15 was teaching students himself. In 6 [' time', ' he', ' was', ' 15', ' was', ' teaching', ' students']
+55 29 The professional role most suited to handle x -1 The professional role most suited to handle teaching students teacher teaching students "[' with', ' learning', ' disabilities', ' is', ' that', ' of', ' a', ' teacher', ""'s"", ' aide', '.', ' The', ' aide', ' is', ' responsible', ' for', ' assisting', ' the', ' teacher', ' in']" with learning disabilities is that of a teacher 's aide . The aide is responsible for assisting the teacher in True was 15 was teaching students himself. In 4 [' was', ' 15', ' was', ' teaching', ' students']
+56 29 The professional role most suited to handle x -1 The professional role most suited to handle teaching students teacher teaching students "[' with', ' learning', ' disabilities', ' is', ' that', ' of', ' a', ' teacher', ""'s"", ' aide', '.', ' The', ' aide', ' is', ' responsible', ' for', ' assisting', ' the', ' teacher', ' in']" with learning disabilities is that of a teacher 's aide . The aide is responsible for assisting the teacher in True will continue teaching students about their 3 [' will', ' continue', ' teaching', ' students']
+57 30 The professional role most suited to handle x -1 The professional role most suited to handle cooking meals chef cooking meals [' for', ' a', ' family', ' of', ' four', ' is', ' a', ' cook', '.', '\n', '\n', 'The', ' cook', ' is', ' responsible', ' for', ' preparing', ' meals', ' for', ' the'] " for a family of four is a cook .
+
+ The cook is responsible for preparing meals for the" False Susan begins cooking meals for him with the 3 [' Susan', ' begins', ' cooking', ' meals']
+58 31 The professional role most suited to handle x -1 The professional role most suited to handle performing surgeries surgeon performing surgeries [' is', ' that', ' of', ' a', ' surgeon', '.', ' Sur', 'geons', ' are', ' the', ' ones', ' who', ' perform', ' the', ' surgery', '.', ' They', ' are', ' the', ' ones'] is that of a surgeon . Sur geons are the ones who perform the surgery . They are the ones True recklessness whilst performing surgeries and Megan finds 4 [' reck', 'lessness', ' whilst', ' performing', ' surgeries']
+59 31 The professional role most suited to handle x -1 The professional role most suited to handle performing surgeries surgeon performing surgeries [' is', ' that', ' of', ' a', ' surgeon', '.', ' Sur', 'geons', ' are', ' the', ' ones', ' who', ' perform', ' the', ' surgery', '.', ' They', ' are', ' the', ' ones'] is that of a surgeon . Sur geons are the ones who perform the surgery . They are the ones True recklessness whilst performing surgeries and Megan finds out 4 [' reck', 'lessness', ' whilst', ' performing', ' surgeries']
+60 31 The professional role most suited to handle x -1 The professional role most suited to handle performing surgeries surgeon performing surgeries [' is', ' that', ' of', ' a', ' surgeon', '.', ' Sur', 'geons', ' are', ' the', ' ones', ' who', ' perform', ' the', ' surgery', '.', ' They', ' are', ' the', ' ones'] is that of a surgeon . Sur geons are the ones who perform the surgery . They are the ones True recklessness whilst performing surgeries and Megan finds out 4 [' reck', 'lessness', ' whilst', ' performing', ' surgeries']
+61 31 The professional role most suited to handle x -1 The professional role most suited to handle performing surgeries surgeon performing surgeries [' is', ' that', ' of', ' a', ' surgeon', '.', ' Sur', 'geons', ' are', ' the', ' ones', ' who', ' perform', ' the', ' surgery', '.', ' They', ' are', ' the', ' ones'] is that of a surgeon . Sur geons are the ones who perform the surgery . They are the ones True on surgery, performing surgeries with his wife, but 4 [' on', ' surgery', ',', ' performing', ' surgeries']
+62 31 The professional role most suited to handle x -1 The professional role most suited to handle performing surgeries surgeon performing surgeries [' is', ' that', ' of', ' a', ' surgeon', '.', ' Sur', 'geons', ' are', ' the', ' ones', ' who', ' perform', ' the', ' surgery', '.', ' They', ' are', ' the', ' ones'] is that of a surgeon . Sur geons are the ones who perform the surgery . They are the ones True on surgery, performing surgeries with his wife, 4 [' on', ' surgery', ',', ' performing', ' surgeries']
diff --git a/patchscopes/code/preprocessed_data/commonsense/task_done_by_tool.tsv b/patchscopes/code/preprocessed_data/commonsense/task_done_by_tool.tsv
new file mode 100644
index 00000000..831f78cc
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/commonsense/task_done_by_tool.tsv
@@ -0,0 +1,460 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 2 The tool that is used for x -1 The tool that is used for cutting knife cutting [' the', ' wood', ' is', ' a', ' circular', ' saw', '.', '\n', '\n', 'The', ' wood', ' is', ' then', ' sand', 'ed', ' down', ' to', ' the', ' desired', ' thickness'] " the wood is a circular saw .
+
+ The wood is then sand ed down to the desired thickness" False Stevens'ethics for cutting Duquette's LVAD 5 "[' Stevens', ""'"", 'eth', 'ics', ' for', ' cutting']"
+1 2 The tool that is used for x -1 The tool that is used for cutting knife cutting [' the', ' wood', ' is', ' a', ' circular', ' saw', '.', '\n', '\n', 'The', ' wood', ' is', ' then', ' sand', 'ed', ' down', ' to', ' the', ' desired', ' thickness'] " the wood is a circular saw .
+
+ The wood is then sand ed down to the desired thickness" False Edwards was habitually cutting himself and appeared 4 [' Edwards', ' was', ' habit', 'ually', ' cutting']
+2 2 The tool that is used for x -1 The tool that is used for cutting knife cutting [' the', ' wood', ' is', ' a', ' circular', ' saw', '.', '\n', '\n', 'The', ' wood', ' is', ' then', ' sand', 'ed', ' down', ' to', ' the', ' desired', ' thickness'] " the wood is a circular saw .
+
+ The wood is then sand ed down to the desired thickness" False " ornaments and for cutting as money"", and they" 5 [' or', 'nam', 'ents', ' and', ' for', ' cutting']
+3 2 The tool that is used for x -1 The tool that is used for cutting knife cutting [' the', ' wood', ' is', ' a', ' circular', ' saw', '.', '\n', '\n', 'The', ' wood', ' is', ' then', ' sand', 'ed', ' down', ' to', ' the', ' desired', ' thickness'] " the wood is a circular saw .
+
+ The wood is then sand ed down to the desired thickness" False that involved cutting her down by one 2 [' that', ' involved', ' cutting']
+4 2 The tool that is used for x -1 The tool that is used for cutting knife cutting [' the', ' wood', ' is', ' a', ' circular', ' saw', '.', '\n', '\n', 'The', ' wood', ' is', ' then', ' sand', 'ed', ' down', ' to', ' the', ' desired', ' thickness'] " the wood is a circular saw .
+
+ The wood is then sand ed down to the desired thickness" False his third title by cutting 0.66 s from van 4 [' his', ' third', ' title', ' by', ' cutting']
+5 3 The tool that is used for x -1 The tool that is used for drilling holes drill drilling holes [' in', ' the', ' wall', ' of', ' the', ' house', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' drilling', ' holes', ' in', ' the', ' wall'] " in the wall of the house .
+
+ The tool that is used for drilling holes in the wall" True the process of drilling holes and inserting fasteners. 4 [' the', ' process', ' of', ' drilling', ' holes']
+6 3 The tool that is used for x -1 The tool that is used for drilling holes drill drilling holes [' in', ' the', ' wall', ' of', ' the', ' house', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' drilling', ' holes', ' in', ' the', ' wall'] " in the wall of the house .
+
+ The tool that is used for drilling holes in the wall" True can be tapped by drilling holes into their trunks 5 [' can', ' be', ' tapped', ' by', ' drilling', ' holes']
+7 3 The tool that is used for x -1 The tool that is used for drilling holes drill drilling holes [' in', ' the', ' wall', ' of', ' the', ' house', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' drilling', ' holes', ' in', ' the', ' wall'] " in the wall of the house .
+
+ The tool that is used for drilling holes in the wall" True heavy equipment, drilling holes for the pipeline's 4 [' heavy', ' equipment', ',', ' drilling', ' holes']
+8 3 The tool that is used for x -1 The tool that is used for drilling holes drill drilling holes [' in', ' the', ' wall', ' of', ' the', ' house', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' drilling', ' holes', ' in', ' the', ' wall'] " in the wall of the house .
+
+ The tool that is used for drilling holes in the wall" True harvested by drilling holes in the trunks of 3 [' harvested', ' by', ' drilling', ' holes']
+9 3 The tool that is used for x -1 The tool that is used for drilling holes drill drilling holes [' in', ' the', ' wall', ' of', ' the', ' house', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' drilling', ' holes', ' in', ' the', ' wall'] " in the wall of the house .
+
+ The tool that is used for drilling holes in the wall" True can be tapped by drilling holes into their 5 [' can', ' be', ' tapped', ' by', ' drilling', ' holes']
+10 4 The tool that is used for x -1 The tool that is used for sawing wood saw sawing wood [' is', ' called', ' a', ' circular', ' saw', '.', ' It', ' is', ' a', ' tool', ' that', ' is', ' used', ' to', ' cut', ' wood', ' into', ' different', ' shapes', '.'] is called a circular saw . It is a tool that is used to cut wood into different shapes . True " manager has been' sawing wood 'and not talking""." 6 "[' manager', ' has', ' been', ""'"", ' saw', 'ing', ' wood']"
+11 8 The tool that is used for x -1 The tool that is used for sweeping floors broom sweeping floors [',', ' cleaning', ' windows', ',', ' and', ' cleaning', ' the', ' house', '.', '\n', '\n', 'The', ' best', ' way', ' to', ' clean', ' your', ' house', ' is', ' to'] ", cleaning windows , and cleaning the house .
+
+ The best way to clean your house is to" False semesters at Cumberland by sweeping floors and firing furnaces 7 [' sem', 'esters', ' at', ' Cumber', 'land', ' by', ' sweeping', ' floors']
+12 9 The tool that is used for x -1 The tool that is used for washing dishes sponge washing dishes [',', ' cleaning', ' the', ' house', ',', ' and', ' doing', ' laundry', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' dishes', ','] ", cleaning the house , and doing laundry .
+
+ The tool that is used for washing dishes ," False it features Utada washing dishes and drinking 5 [' it', ' features', ' Ut', 'ada', ' washing', ' dishes']
+13 9 The tool that is used for x -1 The tool that is used for washing dishes sponge washing dishes [',', ' cleaning', ' the', ' house', ',', ' and', ' doing', ' laundry', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' dishes', ','] ", cleaning the house , and doing laundry .
+
+ The tool that is used for washing dishes ," False Kiriya; it features Utada washing dishes and drinking 8 [' Kir', 'iya', ';', ' it', ' features', ' Ut', 'ada', ' washing', ' dishes']
+14 9 The tool that is used for x -1 The tool that is used for washing dishes sponge washing dishes [',', ' cleaning', ' the', ' house', ',', ' and', ' doing', ' laundry', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' dishes', ','] ", cleaning the house , and doing laundry .
+
+ The tool that is used for washing dishes ," False way across Alaska, washing dishes in Mount McKinley 5 [' way', ' across', ' Alaska', ',', ' washing', ' dishes']
+15 9 The tool that is used for x -1 The tool that is used for washing dishes sponge washing dishes [',', ' cleaning', ' the', ' house', ',', ' and', ' doing', ' laundry', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' dishes', ','] ", cleaning the house , and doing laundry .
+
+ The tool that is used for washing dishes ," False the United States, washing dishes and working 5 [' the', ' United', ' States', ',', ' washing', ' dishes']
+16 9 The tool that is used for x -1 The tool that is used for washing dishes sponge washing dishes [',', ' cleaning', ' the', ' house', ',', ' and', ' doing', ' laundry', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' dishes', ','] ", cleaning the house , and doing laundry .
+
+ The tool that is used for washing dishes ," False features Utada washing dishes and drinking 4 [' features', ' Ut', 'ada', ' washing', ' dishes']
+17 11 The tool that is used for x -1 The tool that is used for sewing needle and thread sewing [' the', ' seams', ' of', ' the', ' qu', 'ilt', ' is', ' a', ' rot', 'ary', ' cutter', ' and', ' a', ' ruler', '.', '\n', '\n', 'I', ' used', ' a'] " the seams of the qu ilt is a rot ary cutter and a ruler .
+
+ I used a" False industry, as 25,000 sewing machines were 6 [' industry', ',', ' as', ' 25', ',', '000', ' sewing']
+18 11 The tool that is used for x -1 The tool that is used for sewing needle and thread sewing [' the', ' seams', ' of', ' the', ' qu', 'ilt', ' is', ' a', ' rot', 'ary', ' cutter', ' and', ' a', ' ruler', '.', '\n', '\n', 'I', ' used', ' a'] " the seams of the qu ilt is a rot ary cutter and a ruler .
+
+ I used a" False freely and organize a sewing cooperative. The 4 [' freely', ' and', ' organize', ' a', ' sewing']
+19 11 The tool that is used for x -1 The tool that is used for sewing needle and thread sewing [' the', ' seams', ' of', ' the', ' qu', 'ilt', ' is', ' a', ' rot', 'ary', ' cutter', ' and', ' a', ' ruler', '.', '\n', '\n', 'I', ' used', ' a'] " the seams of the qu ilt is a rot ary cutter and a ruler .
+
+ I used a" False odd jobs from sewing fuzzy dice to 3 [' odd', ' jobs', ' from', ' sewing']
+20 11 The tool that is used for x -1 The tool that is used for sewing needle and thread sewing [' the', ' seams', ' of', ' the', ' qu', 'ilt', ' is', ' a', ' rot', 'ary', ' cutter', ' and', ' a', ' ruler', '.', '\n', '\n', 'I', ' used', ' a'] " the seams of the qu ilt is a rot ary cutter and a ruler .
+
+ I used a" False garment factories sewing clothing for 2 [' garment', ' factories', ' sewing']
+21 11 The tool that is used for x -1 The tool that is used for sewing needle and thread sewing [' the', ' seams', ' of', ' the', ' qu', 'ilt', ' is', ' a', ' rot', 'ary', ' cutter', ' and', ' a', ' ruler', '.', '\n', '\n', 'I', ' used', ' a'] " the seams of the qu ilt is a rot ary cutter and a ruler .
+
+ I used a" False Juan, where she studied sewing and continued her 5 [' Juan', ',', ' where', ' she', ' studied', ' sewing']
+22 12 The tool that is used for x -1 The tool that is used for knitting yarn knitting [' is', ' a', ' knitting', ' needle', '.', '\n', '\n', 'The', ' needle', ' is', ' used', ' to', ' knit', ' the', ' fabric', '.', '\n', '\n', 'The', ' needle'] " is a knitting needle .
+
+ The needle is used to knit the fabric .
+
+ The needle" False weaving, spinning, knitting and embroidery 4 [' weaving', ',', ' spinning', ',', ' knitting']
+23 12 The tool that is used for x -1 The tool that is used for knitting yarn knitting [' is', ' a', ' knitting', ' needle', '.', '\n', '\n', 'The', ' needle', ' is', ' used', ' to', ' knit', ' the', ' fabric', '.', '\n', '\n', 'The', ' needle'] " is a knitting needle .
+
+ The needle is used to knit the fabric .
+
+ The needle" False unfinished knitting projects. It 1 [' unfinished', ' knitting']
+24 12 The tool that is used for x -1 The tool that is used for knitting yarn knitting [' is', ' a', ' knitting', ' needle', '.', '\n', '\n', 'The', ' needle', ' is', ' used', ' to', ' knit', ' the', ' fabric', '.', '\n', '\n', 'The', ' needle'] " is a knitting needle .
+
+ The needle is used to knit the fabric .
+
+ The needle" False performance ... knitting play together well 2 [' performance', '...', ' knitting']
+25 12 The tool that is used for x -1 The tool that is used for knitting yarn knitting [' is', ' a', ' knitting', ' needle', '.', '\n', '\n', 'The', ' needle', ' is', ' used', ' to', ' knit', ' the', ' fabric', '.', '\n', '\n', 'The', ' needle'] " is a knitting needle .
+
+ The needle is used to knit the fabric .
+
+ The needle" False he takes up knitting and stage 3 [' he', ' takes', ' up', ' knitting']
+26 12 The tool that is used for x -1 The tool that is used for knitting yarn knitting [' is', ' a', ' knitting', ' needle', '.', '\n', '\n', 'The', ' needle', ' is', ' used', ' to', ' knit', ' the', ' fabric', '.', '\n', '\n', 'The', ' needle'] " is a knitting needle .
+
+ The needle is used to knit the fabric .
+
+ The needle" False weaving, spinning, knitting and embroidery 4 [' weaving', ',', ' spinning', ',', ' knitting']
+27 13 The tool that is used for x -1 The tool that is used for hunting gun hunting [' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' hunting', ' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that'] " and fishing .
+
+ The tool that is used for hunting and fishing .
+
+ The tool that" False Tod's normal life in hunting for food, seeking 6 "[' T', 'od', ""'s"", ' normal', ' life', ' in', ' hunting']"
+28 13 The tool that is used for x -1 The tool that is used for hunting gun hunting [' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' hunting', ' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that'] " and fishing .
+
+ The tool that is used for hunting and fishing .
+
+ The tool that" False adjusts its hunting strategy. As gameplay 2 [' adjusts', ' its', ' hunting']
+29 13 The tool that is used for x -1 The tool that is used for hunting gun hunting [' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' hunting', ' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that'] " and fishing .
+
+ The tool that is used for hunting and fishing .
+
+ The tool that" False more advanced at hunting and gathering and was 3 [' more', ' advanced', ' at', ' hunting']
+30 13 The tool that is used for x -1 The tool that is used for hunting gun hunting [' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' hunting', ' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that'] " and fishing .
+
+ The tool that is used for hunting and fishing .
+
+ The tool that" False had gone on a hunting trip with various 4 [' had', ' gone', ' on', ' a', ' hunting']
+31 13 The tool that is used for x -1 The tool that is used for hunting gun hunting [' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' hunting', ' and', ' fishing', '.', '\n', '\n', 'The', ' tool', ' that'] " and fishing .
+
+ The tool that is used for hunting and fishing .
+
+ The tool that" False of treasure hunting in the Bahamas. As 2 [' of', ' treasure', ' hunting']
+32 14 The tool that is used for x -1 The tool that is used for boating boat boating "[' is', ' called', ' a', ' ""', 'p', 'addle', '"".', '\n', '\n', 'The', ' paddle', ' is', ' used', ' to', ' push', ' the', ' boat', ' forward', '.', '\n']" " is called a "" p addle "".
+
+ The paddle is used to push the boat forward .
+" True has a large boating lake and a collection 4 [' has', ' a', ' large', ' bo', 'ating']
+33 14 The tool that is used for x -1 The tool that is used for boating boat boating "[' is', ' called', ' a', ' ""', 'p', 'addle', '"".', '\n', '\n', 'The', ' paddle', ' is', ' used', ' to', ' push', ' the', ' boat', ' forward', '.', '\n']" " is called a "" p addle "".
+
+ The paddle is used to push the boat forward .
+" True " was killed in a boating accident.
+" 5 [' was', ' killed', ' in', ' a', ' bo', 'ating']
+34 14 The tool that is used for x -1 The tool that is used for boating boat boating "[' is', ' called', ' a', ' ""', 'p', 'addle', '"".', '\n', '\n', 'The', ' paddle', ' is', ' used', ' to', ' push', ' the', ' boat', ' forward', '.', '\n']" " is called a "" p addle "".
+
+ The paddle is used to push the boat forward .
+" True the water by local boating enthusiasts. During 5 [' the', ' water', ' by', ' local', ' bo', 'ating']
+35 14 The tool that is used for x -1 The tool that is used for boating boat boating "[' is', ' called', ' a', ' ""', 'p', 'addle', '"".', '\n', '\n', 'The', ' paddle', ' is', ' used', ' to', ' push', ' the', ' boat', ' forward', '.', '\n']" " is called a "" p addle "".
+
+ The paddle is used to push the boat forward .
+" True tennis facilities, a boating pond, a treetop 5 [' tennis', ' facilities', ',', ' a', ' bo', 'ating']
+36 14 The tool that is used for x -1 The tool that is used for boating boat boating "[' is', ' called', ' a', ' ""', 'p', 'addle', '"".', '\n', '\n', 'The', ' paddle', ' is', ' used', ' to', ' push', ' the', ' boat', ' forward', '.', '\n']" " is called a "" p addle "".
+
+ The paddle is used to push the boat forward .
+" True hazardous for boating or kayaking; both 3 [' hazardous', ' for', ' bo', 'ating']
+37 17 The tool that is used for x -1 The tool that is used for baking oven baking [' the', ' bread', ' is', ' a', ' Kitchen', 'Aid', ' stand', ' mixer', '.', ' I', ' have', ' had', ' this', ' one', ' for', ' years', ' and', ' it', ' is', ' still'] the bread is a Kitchen Aid stand mixer . I have had this one for years and it is still False the culinary arts, baking and pastry arts, and 4 [' the', ' culinary', ' arts', ',', ' baking']
+38 17 The tool that is used for x -1 The tool that is used for baking oven baking [' the', ' bread', ' is', ' a', ' Kitchen', 'Aid', ' stand', ' mixer', '.', ' I', ' have', ' had', ' this', ' one', ' for', ' years', ' and', ' it', ' is', ' still'] the bread is a Kitchen Aid stand mixer . I have had this one for years and it is still False specializing in culinary and baking and pastry arts education. 4 [' specializing', ' in', ' culinary', ' and', ' baking']
+39 17 The tool that is used for x -1 The tool that is used for baking oven baking [' the', ' bread', ' is', ' a', ' Kitchen', 'Aid', ' stand', ' mixer', '.', ' I', ' have', ' had', ' this', ' one', ' for', ' years', ' and', ' it', ' is', ' still'] the bread is a Kitchen Aid stand mixer . I have had this one for years and it is still False skillful in baking sweet foods, though 3 [' skill', 'ful', ' in', ' baking']
+40 17 The tool that is used for x -1 The tool that is used for baking oven baking [' the', ' bread', ' is', ' a', ' Kitchen', 'Aid', ' stand', ' mixer', '.', ' I', ' have', ' had', ' this', ' one', ' for', ' years', ' and', ' it', ' is', ' still'] the bread is a Kitchen Aid stand mixer . I have had this one for years and it is still False hanging laundry and baking cakes. She milks the 3 [' hanging', ' laundry', ' and', ' baking']
+41 17 The tool that is used for x -1 The tool that is used for baking oven baking [' the', ' bread', ' is', ' a', ' Kitchen', 'Aid', ' stand', ' mixer', '.', ' I', ' have', ' had', ' this', ' one', ' for', ' years', ' and', ' it', ' is', ' still'] the bread is a Kitchen Aid stand mixer . I have had this one for years and it is still False water or stock, or baking them in a 5 [' water', ' or', ' stock', ',', ' or', ' baking']
+42 21 The tool that is used for x -1 The tool that is used for cleaning windows squeegee cleaning windows [' is', ' called', ' a', ' window', ' cleaner', '.', ' It', ' is', ' a', ' tool', ' that', ' is', ' used', ' to', ' clean', ' windows', '.', ' It', ' is', ' used'] is called a window cleaner . It is a tool that is used to clean windows . It is used False making cream teas, cleaning windows and scrubbing toilets. 6 [' making', ' cream', ' te', 'as', ',', ' cleaning', ' windows']
+43 21 The tool that is used for x -1 The tool that is used for cleaning windows squeegee cleaning windows [' is', ' called', ' a', ' window', ' cleaner', '.', ' It', ' is', ' a', ' tool', ' that', ' is', ' used', ' to', ' clean', ' windows', '.', ' It', ' is', ' used'] is called a window cleaner . It is a tool that is used to clean windows . It is used False cream teas, cleaning windows and scrubbing 5 [' cream', ' te', 'as', ',', ' cleaning', ' windows']
+44 23 The tool that is used for x -1 The tool that is used for washing clothes washing machine washing clothes [' is', ' called', ' a', ' washing', ' machine', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' clothes', ' is', ' called', ' a', ' washing'] " is called a washing machine .
+
+ The tool that is used for washing clothes is called a washing" True August), such as washing clothes and herding 5 [' August', '),', ' such', ' as', ' washing', ' clothes']
+45 23 The tool that is used for x -1 The tool that is used for washing clothes washing machine washing clothes [' is', ' called', ' a', ' washing', ' machine', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' clothes', ' is', ' called', ' a', ' washing'] " is called a washing machine .
+
+ The tool that is used for washing clothes is called a washing" True August), such as washing clothes and herding 5 [' August', '),', ' such', ' as', ' washing', ' clothes']
+46 23 The tool that is used for x -1 The tool that is used for washing clothes washing machine washing clothes [' is', ' called', ' a', ' washing', ' machine', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' clothes', ' is', ' called', ' a', ' washing'] " is called a washing machine .
+
+ The tool that is used for washing clothes is called a washing" True August), such as washing clothes and herding cattle. 5 [' August', '),', ' such', ' as', ' washing', ' clothes']
+47 23 The tool that is used for x -1 The tool that is used for washing clothes washing machine washing clothes [' is', ' called', ' a', ' washing', ' machine', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' washing', ' clothes', ' is', ' called', ' a', ' washing'] " is called a washing machine .
+
+ The tool that is used for washing clothes is called a washing" True August), such as washing clothes and herding cattle. 5 [' August', '),', ' such', ' as', ' washing', ' clothes']
+48 24 The tool that is used for x -1 The tool that is used for drying clothes clothesline drying clothes [' is', ' called', ' a', ' clothes', ' dry', 'er', '.', ' It', ' is', ' a', ' machine', ' that', ' uses', ' heat', ' to', ' dry', ' clothes', '.', ' It', ' is'] is called a clothes dry er . It is a machine that uses heat to dry clothes . It is False 2 ['d', 'rying', ' clothes']
+49 24 The tool that is used for x -1 The tool that is used for drying clothes clothesline drying clothes [' is', ' called', ' a', ' clothes', ' dry', 'er', '.', ' It', ' is', ' a', ' machine', ' that', ' uses', ' heat', ' to', ' dry', ' clothes', '.', ' It', ' is'] is called a clothes dry er . It is a machine that uses heat to dry clothes . It is False of space, such as drying clothes after process and 6 [' of', ' space', ',', ' such', ' as', ' drying', ' clothes']
+50 28 The tool that is used for x -1 The tool that is used for hiking hiking boots hiking [',', ' camping', ',', ' and', ' other', ' outdoor', ' activities', '.', '\n', '\n', 'The', ' best', ' thing', ' about', ' this', ' tool', ' is', ' that', ' it', ' is'] ", camping , and other outdoor activities .
+
+ The best thing about this tool is that it is" False " that ""I go hiking a lot, and" 4 "[' that', ' ""', 'I', ' go', ' hiking']"
+51 28 The tool that is used for x -1 The tool that is used for hiking hiking boots hiking [',', ' camping', ',', ' and', ' other', ' outdoor', ' activities', '.', '\n', '\n', 'The', ' best', ' thing', ' about', ' this', ' tool', ' is', ' that', ' it', ' is'] ", camping , and other outdoor activities .
+
+ The best thing about this tool is that it is" False flood waters while hiking in a closed state 3 [' flood', ' waters', ' while', ' hiking']
+52 28 The tool that is used for x -1 The tool that is used for hiking hiking boots hiking [',', ' camping', ',', ' and', ' other', ' outdoor', ' activities', '.', '\n', '\n', 'The', ' best', ' thing', ' about', ' this', ' tool', ' is', ' that', ' it', ' is'] ", camping , and other outdoor activities .
+
+ The best thing about this tool is that it is" False wildlife, in addition to hiking and biking. It 5 [' wildlife', ',', ' in', ' addition', ' to', ' hiking']
+53 28 The tool that is used for x -1 The tool that is used for hiking hiking boots hiking [',', ' camping', ',', ' and', ' other', ' outdoor', ' activities', '.', '\n', '\n', 'The', ' best', ' thing', ' about', ' this', ' tool', ' is', ' that', ' it', ' is'] ", camping , and other outdoor activities .
+
+ The best thing about this tool is that it is" False Birkenstocks, hiking boots, and eco-friendly 4 [' Bir', 'ken', 'stocks', ',', ' hiking']
+54 28 The tool that is used for x -1 The tool that is used for hiking hiking boots hiking [',', ' camping', ',', ' and', ' other', ' outdoor', ' activities', '.', '\n', '\n', 'The', ' best', ' thing', ' about', ' this', ' tool', ' is', ' that', ' it', ' is'] ", camping , and other outdoor activities .
+
+ The best thing about this tool is that it is" False Many public parks, hiking trails, and 4 [' Many', ' public', ' parks', ',', ' hiking']
+55 29 The tool that is used for x -1 The tool that is used for biking bicycle biking [' is', ' a', ' bicycle', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' biking', ' is', ' a', ' bicycle', '.', '\n', '\n', 'The'] " is a bicycle .
+
+ The tool that is used for biking is a bicycle .
+
+ The" True the mountain biking competition 2 [' the', ' mountain', ' biking']
+56 29 The tool that is used for x -1 The tool that is used for biking bicycle biking [' is', ' a', ' bicycle', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' biking', ' is', ' a', ' bicycle', '.', '\n', '\n', 'The'] " is a bicycle .
+
+ The tool that is used for biking is a bicycle .
+
+ The" True " itself; mountain biking is not allowed.
+" 3 [' itself', ';', ' mountain', ' biking']
+57 29 The tool that is used for x -1 The tool that is used for biking bicycle biking [' is', ' a', ' bicycle', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' biking', ' is', ' a', ' bicycle', '.', '\n', '\n', 'The'] " is a bicycle .
+
+ The tool that is used for biking is a bicycle .
+
+ The" True hiking and biking trails follow 2 [' hiking', ' and', ' biking']
+58 29 The tool that is used for x -1 The tool that is used for biking bicycle biking [' is', ' a', ' bicycle', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' biking', ' is', ' a', ' bicycle', '.', '\n', '\n', 'The'] " is a bicycle .
+
+ The tool that is used for biking is a bicycle .
+
+ The" True riding and mountain biking are not permitted 3 [' riding', ' and', ' mountain', ' biking']
+59 29 The tool that is used for x -1 The tool that is used for biking bicycle biking [' is', ' a', ' bicycle', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' biking', ' is', ' a', ' bicycle', '.', '\n', '\n', 'The'] " is a bicycle .
+
+ The tool that is used for biking is a bicycle .
+
+ The" True " itself; mountain biking is not allowed.
+" 3 [' itself', ';', ' mountain', ' biking']
+60 30 The tool that is used for x -1 The tool that is used for swimming swimsuit swimming [' is', ' a', ' swimming', ' pool', '.', ' The', ' swimming', ' pool', ' is', ' a', ' place', ' where', ' you', ' can', ' swim', ' in', ' the', ' water', '.', ' The'] is a swimming pool . The swimming pool is a place where you can swim in the water . The False rigidity, and erratic swimming behavior; death follows 5 [' rig', 'idity', ',', ' and', ' erratic', ' swimming']
+61 30 The tool that is used for x -1 The tool that is used for swimming swimsuit swimming [' is', ' a', ' swimming', ' pool', '.', ' The', ' swimming', ' pool', ' is', ' a', ' place', ' where', ' you', ' can', ' swim', ' in', ' the', ' water', '.', ' The'] is a swimming pool . The swimming pool is a place where you can swim in the water . The False on a secret swimming pool. He posts this 3 [' on', ' a', ' secret', ' swimming']
+62 30 The tool that is used for x -1 The tool that is used for swimming swimsuit swimming [' is', ' a', ' swimming', ' pool', '.', ' The', ' swimming', ' pool', ' is', ' a', ' place', ' where', ' you', ' can', ' swim', ' in', ' the', ' water', '.', ' The'] is a swimming pool . The swimming pool is a place where you can swim in the water . The False sports, such as swimming and surfing. 4 [' sports', ',', ' such', ' as', ' swimming']
+63 30 The tool that is used for x -1 The tool that is used for swimming swimsuit swimming [' is', ' a', ' swimming', ' pool', '.', ' The', ' swimming', ' pool', ' is', ' a', ' place', ' where', ' you', ' can', ' swim', ' in', ' the', ' water', '.', ' The'] is a swimming pool . The swimming pool is a place where you can swim in the water . The False Championships in five swimming events. She and 3 [' Championships', ' in', ' five', ' swimming']
+64 30 The tool that is used for x -1 The tool that is used for swimming swimsuit swimming [' is', ' a', ' swimming', ' pool', '.', ' The', ' swimming', ' pool', ' is', ' a', ' place', ' where', ' you', ' can', ' swim', ' in', ' the', ' water', '.', ' The'] is a swimming pool . The swimming pool is a place where you can swim in the water . The False " from competitive swimming at the age of 21.
+" 2 [' from', ' competitive', ' swimming']
+65 31 The tool that is used for x -1 The tool that is used for cooking stove cooking [' is', ' a', ' knife', '.', '\n', '\n', 'The', ' knife', ' is', ' used', ' to', ' cut', ' the', ' meat', ',', ' vegetables', ',', ' and', ' other', ' food'] " is a knife .
+
+ The knife is used to cut the meat , vegetables , and other food" False facility, from cooking to shoemaking, and 3 [' facility', ',', ' from', ' cooking']
+66 31 The tool that is used for x -1 The tool that is used for cooking stove cooking [' is', ' a', ' knife', '.', '\n', '\n', 'The', ' knife', ' is', ' used', ' to', ' cut', ' the', ' meat', ',', ' vegetables', ',', ' and', ' other', ' food'] " is a knife .
+
+ The knife is used to cut the meat , vegetables , and other food" False hosted her own cooking show series, Nigella 3 [' hosted', ' her', ' own', ' cooking']
+67 31 The tool that is used for x -1 The tool that is used for cooking stove cooking [' is', ' a', ' knife', '.', '\n', '\n', 'The', ' knife', ' is', ' used', ' to', ' cut', ' the', ' meat', ',', ' vegetables', ',', ' and', ' other', ' food'] " is a knife .
+
+ The knife is used to cut the meat , vegetables , and other food" False a style of cooking developed in the Indian 3 [' a', ' style', ' of', ' cooking']
+68 31 The tool that is used for x -1 The tool that is used for cooking stove cooking [' is', ' a', ' knife', '.', '\n', '\n', 'The', ' knife', ' is', ' used', ' to', ' cut', ' the', ' meat', ',', ' vegetables', ',', ' and', ' other', ' food'] " is a knife .
+
+ The knife is used to cut the meat , vegetables , and other food" False preference for British cooking methods is apparent 3 [' preference', ' for', ' British', ' cooking']
+69 31 The tool that is used for x -1 The tool that is used for cooking stove cooking [' is', ' a', ' knife', '.', '\n', '\n', 'The', ' knife', ' is', ' used', ' to', ' cut', ' the', ' meat', ',', ' vegetables', ',', ' and', ' other', ' food'] " is a knife .
+
+ The knife is used to cut the meat , vegetables , and other food" False with to cheat in a cooking contest, but 5 [' with', ' to', ' cheat', ' in', ' a', ' cooking']
+70 32 The tool that is used for x -1 The tool that is used for writing pen and paper writing "[' the', ' code', ' is', ' called', ' a', ' ""', 'code', ' generator', '"".', '\n', '\n', 'The', ' code', ' generator', ' is', ' a', ' tool', ' that', ' generates', ' the']" " the code is called a "" code generator "".
+
+ The code generator is a tool that generates the" False film with McCarthy writing the film and act 3 [' film', ' with', ' McCarthy', ' writing']
+71 32 The tool that is used for x -1 The tool that is used for writing pen and paper writing "[' the', ' code', ' is', ' called', ' a', ' ""', 'code', ' generator', '"".', '\n', '\n', 'The', ' code', ' generator', ' is', ' a', ' tool', ' that', ' generates', ' the']" " the code is called a "" code generator "".
+
+ The code generator is a tool that generates the" False decided to combine the writing element with a puzzle 4 [' decided', ' to', ' combine', ' the', ' writing']
+72 32 The tool that is used for x -1 The tool that is used for writing pen and paper writing "[' the', ' code', ' is', ' called', ' a', ' ""', 'code', ' generator', '"".', '\n', '\n', 'The', ' code', ' generator', ' is', ' a', ' tool', ' that', ' generates', ' the']" " the code is called a "" code generator "".
+
+ The code generator is a tool that generates the" False Traister began writing about the 3 [' Tra', 'ister', ' began', ' writing']
+73 32 The tool that is used for x -1 The tool that is used for writing pen and paper writing "[' the', ' code', ' is', ' called', ' a', ' ""', 'code', ' generator', '"".', '\n', '\n', 'The', ' code', ' generator', ' is', ' a', ' tool', ' that', ' generates', ' the']" " the code is called a "" code generator "".
+
+ The code generator is a tool that generates the" False film that he began writing Bad Grass Never Dies 4 [' film', ' that', ' he', ' began', ' writing']
+74 32 The tool that is used for x -1 The tool that is used for writing pen and paper writing "[' the', ' code', ' is', ' called', ' a', ' ""', 'code', ' generator', '"".', '\n', '\n', 'The', ' code', ' generator', ' is', ' a', ' tool', ' that', ' generates', ' the']" " the code is called a "" code generator "".
+
+ The code generator is a tool that generates the" False Runner and the writing of author Philip K. 3 [' Runner', ' and', ' the', ' writing']
+75 33 The tool that is used for x -1 The tool that is used for drawing pencil and sketchbook drawing "[' the', ' line', ' between', ' the', ' two', ' is', ' the', ' ""', 'line', '""', ' function', '.', '\n', '\n', 'The', ' line', ' function', ' is', ' used', ' to']" " the line between the two is the "" line "" function .
+
+ The line function is used to" False and is made by drawing a set of circles 4 [' and', ' is', ' made', ' by', ' drawing']
+76 33 The tool that is used for x -1 The tool that is used for drawing pencil and sketchbook drawing "[' the', ' line', ' between', ' the', ' two', ' is', ' the', ' ""', 'line', '""', ' function', '.', '\n', '\n', 'The', ' line', ' function', ' is', ' used', ' to']" " the line between the two is the "" line "" function .
+
+ The line function is used to" False urinating in the yard; drawing water before it had 6 [' ur', 'inating', ' in', ' the', ' yard', ';', ' drawing']
+77 33 The tool that is used for x -1 The tool that is used for drawing pencil and sketchbook drawing "[' the', ' line', ' between', ' the', ' two', ' is', ' the', ' ""', 'line', '""', ' function', '.', '\n', '\n', 'The', ' line', ' function', ' is', ' used', ' to']" " the line between the two is the "" line "" function .
+
+ The line function is used to" False on 9 February 1964, drawing an estimated 73 5 [' on', ' 9', ' February', ' 1964', ',', ' drawing']
+78 33 The tool that is used for x -1 The tool that is used for drawing pencil and sketchbook drawing "[' the', ' line', ' between', ' the', ' two', ' is', ' the', ' ""', 'line', '""', ' function', '.', '\n', '\n', 'The', ' line', ' function', ' is', ' used', ' to']" " the line between the two is the "" line "" function .
+
+ The line function is used to" False papers were drawing notice in the 2 [' papers', ' were', ' drawing']
+79 33 The tool that is used for x -1 The tool that is used for drawing pencil and sketchbook drawing "[' the', ' line', ' between', ' the', ' two', ' is', ' the', ' ""', 'line', '""', ' function', '.', '\n', '\n', 'The', ' line', ' function', ' is', ' used', ' to']" " the line between the two is the "" line "" function .
+
+ The line function is used to" False " in a relationship, drawing references to fairytales.
+" 4 [' in', ' a', ' relationship', ',', ' drawing']
+80 34 The tool that is used for x -1 The tool that is used for gardening gardening gloves gardening [' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' gardening', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' gardening', ' is', ' a'] " is a tool that is used for gardening .
+
+ The tool that is used for gardening is a" False " endless landscape gardening of our self"" and" 2 [' endless', ' landscape', ' gardening']
+81 34 The tool that is used for x -1 The tool that is used for gardening gardening gloves gardening [' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' gardening', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' gardening', ' is', ' a'] " is a tool that is used for gardening .
+
+ The tool that is used for gardening is a" False the importance of gardening in education. 3 [' the', ' importance', ' of', ' gardening']
+82 34 The tool that is used for x -1 The tool that is used for gardening gardening gloves gardening [' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' gardening', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' gardening', ' is', ' a'] " is a tool that is used for gardening .
+
+ The tool that is used for gardening is a" False dressed in gardening attire, seated 2 [' dressed', ' in', ' gardening']
+83 34 The tool that is used for x -1 The tool that is used for gardening gardening gloves gardening [' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' gardening', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' gardening', ' is', ' a'] " is a tool that is used for gardening .
+
+ The tool that is used for gardening is a" False woodchip mulch used in gardening and landscaping), 6 [' wood', 'chip', ' mul', 'ch', ' used', ' in', ' gardening']
+84 34 The tool that is used for x -1 The tool that is used for gardening gardening gloves gardening [' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' gardening', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' gardening', ' is', ' a'] " is a tool that is used for gardening .
+
+ The tool that is used for gardening is a" False penetrate leather gardening gloves. The 2 [' penetrate', ' leather', ' gardening']
+85 35 The tool that is used for x -1 The tool that is used for photography camera photography [' is', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' tool', ' that', ' is', ' used', ' to', ' capture', ' images', '.', '\n', '\n'] " is a camera .
+
+ The camera is a tool that is used to capture images .
+
+" True Friedman's 2005 photography book Recognize 3 "[' Friedman', ""'s"", ' 2005', ' photography']"
+86 35 The tool that is used for x -1 The tool that is used for photography camera photography [' is', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' tool', ' that', ' is', ' used', ' to', ' capture', ' images', '.', '\n', '\n'] " is a camera .
+
+ The camera is a tool that is used to capture images .
+
+" True Semmelmeyer, a photography shop in Chicago, 5 [' Sem', 'mel', 'meyer', ',', ' a', ' photography']
+87 35 The tool that is used for x -1 The tool that is used for photography camera photography [' is', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' tool', ' that', ' is', ' used', ' to', ' capture', ' images', '.', '\n', '\n'] " is a camera .
+
+ The camera is a tool that is used to capture images .
+
+" True 1 ['phot', 'ography']
+88 35 The tool that is used for x -1 The tool that is used for photography camera photography [' is', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' tool', ' that', ' is', ' used', ' to', ' capture', ' images', '.', '\n', '\n'] " is a camera .
+
+ The camera is a tool that is used to capture images .
+
+" True history of photography not only because 2 [' history', ' of', ' photography']
+89 35 The tool that is used for x -1 The tool that is used for photography camera photography [' is', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' tool', ' that', ' is', ' used', ' to', ' capture', ' images', '.', '\n', '\n'] " is a camera .
+
+ The camera is a tool that is used to capture images .
+
+" True recent Japanese photography and art “ Heavy 2 [' recent', ' Japanese', ' photography']
+90 36 The tool that is used for x -1 The tool that is used for playing sports ball playing sports [' is', ' called', ' a', ' bat', '.', ' The', ' bat', ' is', ' used', ' to', ' hit', ' the', ' ball', '.', ' The', ' ball', ' is', ' a', ' round', ' object'] is called a bat . The bat is used to hit the ball . The ball is a round object True " his standing from playing sports for money.
+" 4 [' his', ' standing', ' from', ' playing', ' sports']
+91 36 The tool that is used for x -1 The tool that is used for playing sports ball playing sports [' is', ' called', ' a', ' bat', '.', ' The', ' bat', ' is', ' used', ' to', ' hit', ' the', ' ball', '.', ' The', ' ball', ' is', ' a', ' round', ' object'] is called a bat . The bat is used to hit the ball . The ball is a round object True teenage years playing sports for Lovington 3 [' teenage', ' years', ' playing', ' sports']
+92 36 The tool that is used for x -1 The tool that is used for playing sports ball playing sports [' is', ' called', ' a', ' bat', '.', ' The', ' bat', ' is', ' used', ' to', ' hit', ' the', ' ball', '.', ' The', ' ball', ' is', ' a', ' round', ' object'] is called a bat . The bat is used to hit the ball . The ball is a round object True not approve of playing sports on Sundays. 4 [' not', ' approve', ' of', ' playing', ' sports']
+93 36 The tool that is used for x -1 The tool that is used for playing sports ball playing sports [' is', ' called', ' a', ' bat', '.', ' The', ' bat', ' is', ' used', ' to', ' hit', ' the', ' ball', '.', ' The', ' ball', ' is', ' a', ' round', ' object'] is called a bat . The bat is used to hit the ball . The ball is a round object True In addition to playing sports (including volleyball, 4 [' In', ' addition', ' to', ' playing', ' sports']
+94 36 The tool that is used for x -1 The tool that is used for playing sports ball playing sports [' is', ' called', ' a', ' bat', '.', ' The', ' bat', ' is', ' used', ' to', ' hit', ' the', ' ball', '.', ' The', ' ball', ' is', ' a', ' round', ' object'] is called a bat . The bat is used to hit the ball . The ball is a round object True not approve of playing sports on Sundays. Reynolds 4 [' not', ' approve', ' of', ' playing', ' sports']
+95 37 The tool that is used for x -1 The tool that is used for exercising dumbbells exercising [' the', ' muscles', ' of', ' the', ' body', '.', '\n', '\n', 'The', ' word', ' �', '�', 'ex', 'ercise', '�', '�', ' is', ' derived', ' from', ' the'] " the muscles of the body .
+
+ The word � � ex ercise � � is derived from the" False options before exercising your constitutional 2 [' options', ' before', ' exercising']
+96 37 The tool that is used for x -1 The tool that is used for exercising dumbbells exercising [' the', ' muscles', ' of', ' the', ' body', '.', '\n', '\n', 'The', ' word', ' �', '�', 'ex', 'ercise', '�', '�', ' is', ' derived', ' from', ' the'] " the muscles of the body .
+
+ The word � � ex ercise � � is derived from the" False duties in 1701, and exercising his power 6 [' duties', ' in', ' 17', '01', ',', ' and', ' exercising']
+97 37 The tool that is used for x -1 The tool that is used for exercising dumbbells exercising [' the', ' muscles', ' of', ' the', ' body', '.', '\n', '\n', 'The', ' word', ' �', '�', 'ex', 'ercise', '�', '�', ' is', ' derived', ' from', ' the'] " the muscles of the body .
+
+ The word � � ex ercise � � is derived from the" False experimentation, exercising more control over 2 [' experimentation', ',', ' exercising']
+98 37 The tool that is used for x -1 The tool that is used for exercising dumbbells exercising [' the', ' muscles', ' of', ' the', ' body', '.', '\n', '\n', 'The', ' word', ' �', '�', 'ex', 'ercise', '�', '�', ' is', ' derived', ' from', ' the'] " the muscles of the body .
+
+ The word � � ex ercise � � is derived from the" False themselves from exercising their discretion. 2 [' themselves', ' from', ' exercising']
+99 37 The tool that is used for x -1 The tool that is used for exercising dumbbells exercising [' the', ' muscles', ' of', ' the', ' body', '.', '\n', '\n', 'The', ' word', ' �', '�', 'ex', 'ercise', '�', '�', ' is', ' derived', ' from', ' the'] " the muscles of the body .
+
+ The word � � ex ercise � � is derived from the" False mentioning water, exercising horses and the 3 [' mentioning', ' water', ',', ' exercising']
+100 38 The tool that is used for x -1 The tool that is used for dancing music dancing [',', ' singing', ',', ' and', ' playing', ' music', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' dancing', ',', ' singing', ',', ' and'] ", singing , and playing music .
+
+ The tool that is used for dancing , singing , and" True movements in their dancing to make a 3 [' movements', ' in', ' their', ' dancing']
+101 38 The tool that is used for x -1 The tool that is used for dancing music dancing [',', ' singing', ',', ' and', ' playing', ' music', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' dancing', ',', ' singing', ',', ' and'] ", singing , and playing music .
+
+ The tool that is used for dancing , singing , and" True surprise singing and dancing mall performances 3 [' surprise', ' singing', ' and', ' dancing']
+102 38 The tool that is used for x -1 The tool that is used for dancing music dancing [',', ' singing', ',', ' and', ' playing', ' music', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' dancing', ',', ' singing', ',', ' and'] ", singing , and playing music .
+
+ The tool that is used for dancing , singing , and" True to suppress dancing and events like 2 [' to', ' suppress', ' dancing']
+103 38 The tool that is used for x -1 The tool that is used for dancing music dancing [',', ' singing', ',', ' and', ' playing', ' music', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' dancing', ',', ' singing', ',', ' and'] ", singing , and playing music .
+
+ The tool that is used for dancing , singing , and" True balloon, and dancing alongside several 3 [' balloon', ',', ' and', ' dancing']
+104 38 The tool that is used for x -1 The tool that is used for dancing music dancing [',', ' singing', ',', ' and', ' playing', ' music', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' dancing', ',', ' singing', ',', ' and'] ", singing , and playing music .
+
+ The tool that is used for dancing , singing , and" True " perform ""Breathe on Me"", dancing on a giant picture" 8 "[' perform', ' ""', 'B', 'reat', 'he', ' on', ' Me', '"",', ' dancing']"
+105 39 The tool that is used for x -1 The tool that is used for watching movies television watching movies [' and', ' TV', ' shows', ' on', ' your', ' computer', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' watching', ' movies', ' and', ' TV', ' shows'] " and TV shows on your computer .
+
+ The tool that is used for watching movies and TV shows" False Latin music, watching movies and relaxing 4 [' Latin', ' music', ',', ' watching', ' movies']
+106 39 The tool that is used for x -1 The tool that is used for watching movies television watching movies [' and', ' TV', ' shows', ' on', ' your', ' computer', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' watching', ' movies', ' and', ' TV', ' shows'] " and TV shows on your computer .
+
+ The tool that is used for watching movies and TV shows" False to Latin music, watching movies and relaxing 5 [' to', ' Latin', ' music', ',', ' watching', ' movies']
+107 39 The tool that is used for x -1 The tool that is used for watching movies television watching movies [' and', ' TV', ' shows', ' on', ' your', ' computer', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' watching', ' movies', ' and', ' TV', ' shows'] " and TV shows on your computer .
+
+ The tool that is used for watching movies and TV shows" False together in the loft watching movies together. 5 [' together', ' in', ' the', ' loft', ' watching', ' movies']
+108 39 The tool that is used for x -1 The tool that is used for watching movies television watching movies [' and', ' TV', ' shows', ' on', ' your', ' computer', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' watching', ' movies', ' and', ' TV', ' shows'] " and TV shows on your computer .
+
+ The tool that is used for watching movies and TV shows" False together in the loft watching movies together. Santana 5 [' together', ' in', ' the', ' loft', ' watching', ' movies']
+109 39 The tool that is used for x -1 The tool that is used for watching movies television watching movies [' and', ' TV', ' shows', ' on', ' your', ' computer', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' watching', ' movies', ' and', ' TV', ' shows'] " and TV shows on your computer .
+
+ The tool that is used for watching movies and TV shows" False instruments, watching movies at weekends, painting, 3 [' instruments', ',', ' watching', ' movies']
+110 40 The tool that is used for x -1 The tool that is used for reading book reading [' the', ' data', ' from', ' the', ' file', '.', '\n', ' ', ' ', ' ', ' ', ' *', ' @', 'param', ' file', 'Name', ' the', ' name', ' of', ' the'] " the data from the file .
+ * @ param file Name the name of the" False " mind after reading Darcy's letter.
+" 2 [' mind', ' after', ' reading']
+111 40 The tool that is used for x -1 The tool that is used for reading book reading [' the', ' data', ' from', ' the', ' file', '.', '\n', ' ', ' ', ' ', ' ', ' *', ' @', 'param', ' file', 'Name', ' the', ' name', ' of', ' the'] " the data from the file .
+ * @ param file Name the name of the" False Hard Boiled. After reading the script, 5 [' Hard', ' Bo', 'iled', '.', ' After', ' reading']
+112 40 The tool that is used for x -1 The tool that is used for reading book reading [' the', ' data', ' from', ' the', ' file', '.', '\n', ' ', ' ', ' ', ' ', ' *', ' @', 'param', ' file', 'Name', ' the', ' name', ' of', ' the'] " the data from the file .
+ * @ param file Name the name of the" False royal funeral is the reading of the full list of 4 [' royal', ' funeral', ' is', ' the', ' reading']
+113 40 The tool that is used for x -1 The tool that is used for reading book reading [' the', ' data', ' from', ' the', ' file', '.', '\n', ' ', ' ', ' ', ' ', ' *', ' @', 'param', ' file', 'Name', ' the', ' name', ' of', ' the'] " the data from the file .
+ * @ param file Name the name of the" False boyfriend are reading magazines on a sofa 2 [' boyfriend', ' are', ' reading']
+114 40 The tool that is used for x -1 The tool that is used for reading book reading [' the', ' data', ' from', ' the', ' file', '.', '\n', ' ', ' ', ' ', ' ', ' *', ' @', 'param', ' file', 'Name', ' the', ' name', ' of', ' the'] " the data from the file .
+ * @ param file Name the name of the" False 30,000 sq.ft has five reading halls with a total 8 [' 30', ',', '000', ' sq', '.', 'ft', ' has', ' five', ' reading']
+115 41 The tool that is used for x -1 The tool that is used for listening to music headphones listening to music "[' is', ' called', ' a', ' ""', 'spe', 'aker', '"".', '\n', '\n', 'The', ' speaker', ' is', ' a', ' device', ' that', ' converts', ' the', ' sound', ' waves', ' into']" " is called a "" spe aker "".
+
+ The speaker is a device that converts the sound waves into" False 1979. He enjoys listening to music and watching 6 [' 1979', '.', ' He', ' enjoys', ' listening', ' to', ' music']
+116 41 The tool that is used for x -1 The tool that is used for listening to music headphones listening to music "[' is', ' called', ' a', ' ""', 'spe', 'aker', '"".', '\n', '\n', 'The', ' speaker', ' is', ' a', ' device', ' that', ' converts', ' the', ' sound', ' waves', ' into']" " is called a "" spe aker "".
+
+ The speaker is a device that converts the sound waves into" False Frith had been listening to music from other 6 [' Fr', 'ith', ' had', ' been', ' listening', ' to', ' music']
+117 41 The tool that is used for x -1 The tool that is used for listening to music headphones listening to music "[' is', ' called', ' a', ' ""', 'spe', 'aker', '"".', '\n', '\n', 'The', ' speaker', ' is', ' a', ' device', ' that', ' converts', ' the', ' sound', ' waves', ' into']" " is called a "" spe aker "".
+
+ The speaker is a device that converts the sound waves into" False really boring, listening to music where people are 5 [' really', ' boring', ',', ' listening', ' to', ' music']
+118 41 The tool that is used for x -1 The tool that is used for listening to music headphones listening to music "[' is', ' called', ' a', ' ""', 'spe', 'aker', '"".', '\n', '\n', 'The', ' speaker', ' is', ' a', ' device', ' that', ' converts', ' the', ' sound', ' waves', ' into']" " is called a "" spe aker "".
+
+ The speaker is a device that converts the sound waves into" False " music"", but prefers listening to music as a whole.
+" 6 "[' music', '"",', ' but', ' prefers', ' listening', ' to', ' music']"
+119 41 The tool that is used for x -1 The tool that is used for listening to music headphones listening to music "[' is', ' called', ' a', ' ""', 'spe', 'aker', '"".', '\n', '\n', 'The', ' speaker', ' is', ' a', ' device', ' that', ' converts', ' the', ' sound', ' waves', ' into']" " is called a "" spe aker "".
+
+ The speaker is a device that converts the sound waves into" False Amazing Spider-Man listening to music on his headphones 6 [' Amazing', ' Spider', '-', 'Man', ' listening', ' to', ' music']
+120 42 The tool that is used for x -1 The tool that is used for singing microphone singing [' the', ' song', ' is', ' a', ' small', ',', ' handheld', ' device', ' that', ' is', ' used', ' to', ' play', ' the', ' song', '.', ' It', ' is', ' a', ' small'] the song is a small , handheld device that is used to play the song . It is a small False footage of Cassidy singing Over The Rainbow 3 [' footage', ' of', ' Cassidy', ' singing']
+121 42 The tool that is used for x -1 The tool that is used for singing microphone singing [' the', ' song', ' is', ' a', ' small', ',', ' handheld', ' device', ' that', ' is', ' used', ' to', ' play', ' the', ' song', '.', ' It', ' is', ' a', ' small'] the song is a small , handheld device that is used to play the song . It is a small False goddess Juno singing an epithalamium, 2 [' goddess', ' Juno', ' singing']
+122 42 The tool that is used for x -1 The tool that is used for singing microphone singing [' the', ' song', ' is', ' a', ' small', ',', ' handheld', ' device', ' that', ' is', ' used', ' to', ' play', ' the', ' song', '.', ' It', ' is', ' a', ' small'] the song is a small , handheld device that is used to play the song . It is a small False Seeger feature him singing and playing the banjo 4 [' See', 'ger', ' feature', ' him', ' singing']
+123 42 The tool that is used for x -1 The tool that is used for singing microphone singing [' the', ' song', ' is', ' a', ' small', ',', ' handheld', ' device', ' that', ' is', ' used', ' to', ' play', ' the', ' song', '.', ' It', ' is', ' a', ' small'] the song is a small , handheld device that is used to play the song . It is a small False he starts singing [' Hard Rain '] ... 2 [' he', ' starts', ' singing']
+124 42 The tool that is used for x -1 The tool that is used for singing microphone singing [' the', ' song', ' is', ' a', ' small', ',', ' handheld', ' device', ' that', ' is', ' used', ' to', ' play', ' the', ' song', '.', ' It', ' is', ' a', ' small'] the song is a small , handheld device that is used to play the song . It is a small False follows Annie's singing audition for 3 "[' follows', ' Annie', ""'s"", ' singing']"
+125 43 The tool that is used for x -1 The tool that is used for measuring scale measuring [' the', ' size', ' of', ' the', ' image', '.', '\n', '\n', 'The', ' image', ' is', ' then', ' saved', ' in', ' the', ' folder', ' where', ' you', ' want', ' to'] " the size of the image .
+
+ The image is then saved in the folder where you want to" False micro-USB connector, measuring only 6.4 mm × 5 [' micro', '-', 'USB', ' connector', ',', ' measuring']
+126 43 The tool that is used for x -1 The tool that is used for measuring scale measuring [' the', ' size', ' of', ' the', ' image', '.', '\n', '\n', 'The', ' image', ' is', ' then', ' saved', ' in', ' the', ' folder', ' where', ' you', ' want', ' to'] " the size of the image .
+
+ The image is then saved in the folder where you want to" False dermal denticles measuring up to 0.4 cm (0.16 4 [' der', 'mal', ' dent', 'icles', ' measuring']
+127 43 The tool that is used for x -1 The tool that is used for measuring scale measuring [' the', ' size', ' of', ' the', ' image', '.', '\n', '\n', 'The', ' image', ' is', ' then', ' saved', ' in', ' the', ' folder', ' where', ' you', ' want', ' to'] " the size of the image .
+
+ The image is then saved in the folder where you want to" False in 1990, for measuring temperature 4 [' in', ' 1990', ',', ' for', ' measuring']
+128 43 The tool that is used for x -1 The tool that is used for measuring scale measuring [' the', ' size', ' of', ' the', ' image', '.', '\n', '\n', 'The', ' image', ' is', ' then', ' saved', ' in', ' the', ' folder', ' where', ' you', ' want', ' to'] " the size of the image .
+
+ The image is then saved in the folder where you want to" False (31 to 73 lb) and measuring 1.0 to 1.4 m 7 [' (', '31', ' to', ' 73', ' lb', ')', ' and', ' measuring']
+129 43 The tool that is used for x -1 The tool that is used for measuring scale measuring [' the', ' size', ' of', ' the', ' image', '.', '\n', '\n', 'The', ' image', ' is', ' then', ' saved', ' in', ' the', ' folder', ' where', ' you', ' want', ' to'] " the size of the image .
+
+ The image is then saved in the folder where you want to" False strengthening of the storm, measuring winds of 85 5 [' strengthening', ' of', ' the', ' storm', ',', ' measuring']
+130 44 The tool that is used for x -1 The tool that is used for birdwatching binoculars birdwatching [' is', ' a', ' bin', 'ocular', '.', '\n', '\n', 'The', ' bin', 'ocular', ' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' bird', 'watching', '.'] " is a bin ocular .
+
+ The bin ocular is a tool that is used for bird watching ." False visitors gave birdwatching as the main 3 [' visitors', ' gave', ' bird', 'watching']
+131 44 The tool that is used for x -1 The tool that is used for birdwatching binoculars birdwatching [' is', ' a', ' bin', 'ocular', '.', '\n', '\n', 'The', ' bin', 'ocular', ' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' bird', 'watching', '.'] " is a bin ocular .
+
+ The bin ocular is a tool that is used for bird watching ." False of the lake, and birdwatching opportunities. 6 [' of', ' the', ' lake', ',', ' and', ' bird', 'watching']
+132 44 The tool that is used for x -1 The tool that is used for birdwatching binoculars birdwatching [' is', ' a', ' bin', 'ocular', '.', '\n', '\n', 'The', ' bin', 'ocular', ' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' bird', 'watching', '.'] " is a bin ocular .
+
+ The bin ocular is a tool that is used for bird watching ." False them. Recreational birdwatching is an important 5 [' them', '.', ' Recre', 'ational', ' bird', 'watching']
+133 44 The tool that is used for x -1 The tool that is used for birdwatching binoculars birdwatching [' is', ' a', ' bin', 'ocular', '.', '\n', '\n', 'The', ' bin', 'ocular', ' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' bird', 'watching', '.'] " is a bin ocular .
+
+ The bin ocular is a tool that is used for bird watching ." False including Florida birdwatching and David Fairchild, 3 [' including', ' Florida', ' bird', 'watching']
+134 44 The tool that is used for x -1 The tool that is used for birdwatching binoculars birdwatching [' is', ' a', ' bin', 'ocular', '.', '\n', '\n', 'The', ' bin', 'ocular', ' is', ' a', ' tool', ' that', ' is', ' used', ' for', ' bird', 'watching', '.'] " is a bin ocular .
+
+ The bin ocular is a tool that is used for bird watching ." False fishing and birdwatching as major pre-causeway 3 [' fishing', ' and', ' bird', 'watching']
+135 45 The tool that is used for x -1 The tool that is used for playing basketball basketball playing basketball [' is', ' called', ' a', ' basketball', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' playing', ' baseball', ' is', ' called', ' a', ' baseball', '.'] " is called a basketball .
+
+ The tool that is used for playing baseball is called a baseball ." True Schilling grew up playing basketball until high school. 5 [' Sch', 'illing', ' grew', ' up', ' playing', ' basketball']
+136 45 The tool that is used for x -1 The tool that is used for playing basketball basketball playing basketball [' is', ' called', ' a', ' basketball', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' playing', ' baseball', ' is', ' called', ' a', ' baseball', '.'] " is called a basketball .
+
+ The tool that is used for playing baseball is called a baseball ." True Juan Fernández began playing basketball when he was six 7 [' Juan', ' Fern', 'á', 'nd', 'ez', ' began', ' playing', ' basketball']
+137 45 The tool that is used for x -1 The tool that is used for playing basketball basketball playing basketball [' is', ' called', ' a', ' basketball', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' playing', ' baseball', ' is', ' called', ' a', ' baseball', '.'] " is called a basketball .
+
+ The tool that is used for playing baseball is called a baseball ." True 1996, Walker was playing basketball with a friend, 5 [' 1996', ',', ' Walker', ' was', ' playing', ' basketball']
+138 45 The tool that is used for x -1 The tool that is used for playing basketball basketball playing basketball [' is', ' called', ' a', ' basketball', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' playing', ' baseball', ' is', ' called', ' a', ' baseball', '.'] " is called a basketball .
+
+ The tool that is used for playing baseball is called a baseball ." True Bradley began playing basketball at the age of 3 [' Bradley', ' began', ' playing', ' basketball']
+139 45 The tool that is used for x -1 The tool that is used for playing basketball basketball playing basketball [' is', ' called', ' a', ' basketball', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' playing', ' baseball', ' is', ' called', ' a', ' baseball', '.'] " is called a basketball .
+
+ The tool that is used for playing baseball is called a baseball ." True years old, he began playing basketball and attended a junior 6 [' years', ' old', ',', ' he', ' began', ' playing', ' basketball']
+140 46 The tool that is used for x -1 The tool that is used for playing soccer soccer ball playing soccer "[' is', ' called', ' a', ' ""', 'soc', 'cer', ' ball', '"".', '\n', '\n', 'The', ' ball', ' is', ' made', ' of', ' rubber', ' and', ' is', ' inflated', ' with']" " is called a "" soc cer ball "".
+
+ The ball is made of rubber and is inflated with" True Philadelphia. Lloyd began playing soccer at age five. Of 5 [' Philadelphia', '.', ' Lloyd', ' began', ' playing', ' soccer']
+141 46 The tool that is used for x -1 The tool that is used for playing soccer soccer ball playing soccer "[' is', ' called', ' a', ' ""', 'soc', 'cer', ' ball', '"".', '\n', '\n', 'The', ' ball', ' is', ' made', ' of', ' rubber', ' and', ' is', ' inflated', ' with']" " is called a "" soc cer ball "".
+
+ The ball is made of rubber and is inflated with" True Germany every summer, playing soccer with the local 5 [' Germany', ' every', ' summer', ',', ' playing', ' soccer']
+142 46 The tool that is used for x -1 The tool that is used for playing soccer soccer ball playing soccer "[' is', ' called', ' a', ' ""', 'soc', 'cer', ' ball', '"".', '\n', '\n', 'The', ' ball', ' is', ' made', ' of', ' rubber', ' and', ' is', ' inflated', ' with']" " is called a "" soc cer ball "".
+
+ The ball is made of rubber and is inflated with" True Adam. He grew up playing soccer and did not seriously 6 [' Adam', '.', ' He', ' grew', ' up', ' playing', ' soccer']
+143 46 The tool that is used for x -1 The tool that is used for playing soccer soccer ball playing soccer "[' is', ' called', ' a', ' ""', 'soc', 'cer', ' ball', '"".', '\n', '\n', 'The', ' ball', ' is', ' made', ' of', ' rubber', ' and', ' is', ' inflated', ' with']" " is called a "" soc cer ball "".
+
+ The ball is made of rubber and is inflated with" True Lloyd began playing soccer at age five. Of 3 [' Lloyd', ' began', ' playing', ' soccer']
+144 46 The tool that is used for x -1 The tool that is used for playing soccer soccer ball playing soccer "[' is', ' called', ' a', ' ""', 'soc', 'cer', ' ball', '"".', '\n', '\n', 'The', ' ball', ' is', ' made', ' of', ' rubber', ' and', ' is', ' inflated', ' with']" " is called a "" soc cer ball "".
+
+ The ball is made of rubber and is inflated with" True believe that playing soccer prevented 3 [' believe', ' that', ' playing', ' soccer']
+145 47 The tool that is used for x -1 The tool that is used for skateboarding skateboard skateboarding [' is', ' called', ' a', ' skate', 'board', '.', ' It', ' is', ' a', ' long', ',', ' flat', ' board', ' with', ' wheels', ' on', ' the', ' bottom', ' and', ' a'] is called a skate board . It is a long , flat board with wheels on the bottom and a True videos. The film's skateboarding sequences were 6 "[' videos', '.', ' The', ' film', ""'s"", ' skate', 'boarding']"
+146 47 The tool that is used for x -1 The tool that is used for skateboarding skateboard skateboarding [' is', ' called', ' a', ' skate', 'board', '.', ' It', ' is', ' a', ' long', ',', ' flat', ' board', ' with', ' wheels', ' on', ' the', ' bottom', ' and', ' a'] is called a skate board . It is a long , flat board with wheels on the bottom and a True Hawk's series of skateboarding games. Comments 5 "[' Hawk', ""'s"", ' series', ' of', ' skate', 'boarding']"
+147 47 The tool that is used for x -1 The tool that is used for skateboarding skateboard skateboarding [' is', ' called', ' a', ' skate', 'board', '.', ' It', ' is', ' a', ' long', ',', ' flat', ' board', ' with', ' wheels', ' on', ' the', ' bottom', ' and', ' a'] is called a skate board . It is a long , flat board with wheels on the bottom and a True similar to skateboarding games like 3 [' similar', ' to', ' skate', 'boarding']
+148 47 The tool that is used for x -1 The tool that is used for skateboarding skateboard skateboarding [' is', ' called', ' a', ' skate', 'board', '.', ' It', ' is', ' a', ' long', ',', ' flat', ' board', ' with', ' wheels', ' on', ' the', ' bottom', ' and', ' a'] is called a skate board . It is a long , flat board with wheels on the bottom and a True the punk rock skateboarding culture when he 4 [' the', ' punk', ' rock', ' skate', 'boarding']
+149 47 The tool that is used for x -1 The tool that is used for skateboarding skateboard skateboarding [' is', ' called', ' a', ' skate', 'board', '.', ' It', ' is', ' a', ' long', ',', ' flat', ' board', ' with', ' wheels', ' on', ' the', ' bottom', ' and', ' a'] is called a skate board . It is a long , flat board with wheels on the bottom and a True featured him skateboarding past several characters 3 [' featured', ' him', ' skate', 'boarding']
+150 49 The tool that is used for x -1 The tool that is used for flying a kite kite flying a kite [' is', ' called', ' a', ' K', 'ite', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' flying', ' a', ' k', 'ite', ' is', ' called'] " is called a K ite .
+
+ The tool that is used for flying a k ite is called" True the concept to flying a kite and were inspired 6 [' the', ' concept', ' to', ' flying', ' a', ' k', 'ite']
+151 49 The tool that is used for x -1 The tool that is used for flying a kite kite flying a kite [' is', ' called', ' a', ' K', 'ite', '.', '\n', '\n', 'The', ' tool', ' that', ' is', ' used', ' for', ' flying', ' a', ' k', 'ite', ' is', ' called'] " is called a K ite .
+
+ The tool that is used for flying a k ite is called" True concept to flying a kite and were inspired 5 [' concept', ' to', ' flying', ' a', ' k', 'ite']
+152 51 The tool that is used for x -1 The tool that is used for taking photographs camera taking photographs [' of', ' the', ' work', 'piece', ' is', ' called', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' device', ' that', ' is', ' used', ' to'] " of the work piece is called a camera .
+
+ The camera is a device that is used to" True Andrade continued taking photographs throughout 4 [' And', 'rade', ' continued', ' taking', ' photographs']
+153 51 The tool that is used for x -1 The tool that is used for taking photographs camera taking photographs [' of', ' the', ' work', 'piece', ' is', ' called', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' device', ' that', ' is', ' used', ' to'] " of the work piece is called a camera .
+
+ The camera is a device that is used to" True summer 1924, taking photographs of archaeological 4 [' summer', ' 1924', ',', ' taking', ' photographs']
+154 51 The tool that is used for x -1 The tool that is used for taking photographs camera taking photographs [' of', ' the', ' work', 'piece', ' is', ' called', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' device', ' that', ' is', ' used', ' to'] " of the work piece is called a camera .
+
+ The camera is a device that is used to" True journalists from taking photographs of the crease 3 [' journalists', ' from', ' taking', ' photographs']
+155 51 The tool that is used for x -1 The tool that is used for taking photographs camera taking photographs [' of', ' the', ' work', 'piece', ' is', ' called', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' device', ' that', ' is', ' used', ' to'] " of the work piece is called a camera .
+
+ The camera is a device that is used to" True express purpose of taking photographs — on the site 4 [' express', ' purpose', ' of', ' taking', ' photographs']
+156 51 The tool that is used for x -1 The tool that is used for taking photographs camera taking photographs [' of', ' the', ' work', 'piece', ' is', ' called', ' a', ' camera', '.', '\n', '\n', 'The', ' camera', ' is', ' a', ' device', ' that', ' is', ' used', ' to'] " of the work piece is called a camera .
+
+ The camera is a device that is used to" True AFC patrolling and taking photographs were attacked and badly 4 [' AFC', ' patrolling', ' and', ' taking', ' photographs']
diff --git a/patchscopes/code/preprocessed_data/commonsense/word_sentiment.tsv b/patchscopes/code/preprocessed_data/commonsense/word_sentiment.tsv
new file mode 100644
index 00000000..47fa6663
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/commonsense/word_sentiment.tsv
@@ -0,0 +1,1112 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of happy positive happy [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', '**', ':', ' The', ' sentiment', ' of', ' the'] ", sad , and angry .
+
+ - ** Sent iment ** : The sentiment of the" False / I choose to be happy / You and I, 5 [' /', ' I', ' choose', ' to', ' be', ' happy']
+1 0 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of happy positive happy [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', '**', ':', ' The', ' sentiment', ' of', ' the'] ", sad , and angry .
+
+ - ** Sent iment ** : The sentiment of the" False The townsfolk are happy to have made it 4 [' The', ' towns', 'folk', ' are', ' happy']
+2 0 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of happy positive happy [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', '**', ':', ' The', ' sentiment', ' of', ' the'] ", sad , and angry .
+
+ - ** Sent iment ** : The sentiment of the" False " dagger, saying ""O happy dagger! This is thy" 5 "[' dagger', ',', ' saying', ' ""', 'O', ' happy']"
+3 0 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of happy positive happy [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', '**', ':', ' The', ' sentiment', ' of', ' the'] ", sad , and angry .
+
+ - ** Sent iment ** : The sentiment of the" False Despite being happy that Itchy & Scratchy 2 [' Despite', ' being', ' happy']
+4 0 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of happy positive happy [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', '**', ':', ' The', ' sentiment', ' of', ' the'] ", sad , and angry .
+
+ - ** Sent iment ** : The sentiment of the" False 'How can you write a happy play?' he said. 6 "["" '"", 'How', ' can', ' you', ' write', ' a', ' happy']"
+5 1 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of joy positive joy [',', ' the', ' sentiment', ' of', ' sadness', ',', ' the', ' sentiment', ' of', ' anger', ',', ' the', ' sentiment', ' of', ' fear', ',', ' the', ' sentiment', ' of', ' surprise'] , the sentiment of sadness , the sentiment of anger , the sentiment of fear , the sentiment of surprise False couldn't take any joy in the boy's development 4 "[' couldn', ""'t"", ' take', ' any', ' joy']"
+6 1 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of joy positive joy [',', ' the', ' sentiment', ' of', ' sadness', ',', ' the', ' sentiment', ' of', ' anger', ',', ' the', ' sentiment', ' of', ' fear', ',', ' the', ' sentiment', ' of', ' surprise'] , the sentiment of sadness , the sentiment of anger , the sentiment of fear , the sentiment of surprise False " called ""the true joy and sensation" 4 "[' called', ' ""', 'the', ' true', ' joy']"
+7 1 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of joy positive joy [',', ' the', ' sentiment', ' of', ' sadness', ',', ' the', ' sentiment', ' of', ' anger', ',', ' the', ' sentiment', ' of', ' fear', ',', ' the', ' sentiment', ' of', ' surprise'] , the sentiment of sadness , the sentiment of anger , the sentiment of fear , the sentiment of surprise False surrender the one joy of his life, he 3 [' surrender', ' the', ' one', ' joy']
+8 1 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of joy positive joy [',', ' the', ' sentiment', ' of', ' sadness', ',', ' the', ' sentiment', ' of', ' anger', ',', ' the', ' sentiment', ' of', ' fear', ',', ' the', ' sentiment', ' of', ' surprise'] , the sentiment of sadness , the sentiment of anger , the sentiment of fear , the sentiment of surprise False unmitigated joy as Captain Sparrow, 3 [' un', 'mit', 'igated', ' joy']
+9 1 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of joy positive joy [',', ' the', ' sentiment', ' of', ' sadness', ',', ' the', ' sentiment', ' of', ' anger', ',', ' the', ' sentiment', ' of', ' fear', ',', ' the', ' sentiment', ' of', ' surprise'] , the sentiment of sadness , the sentiment of anger , the sentiment of fear , the sentiment of surprise False " ""One of the many enjoyable tasks I" 5 "[' ""', 'One', ' of', ' the', ' many', ' enjoy']"
+10 2 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of love positive love [' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' happy', '.', '\n', '\n', 'The', ' most'] " is the most important .
+
+ The most important thing is to be happy .
+
+ The most" False " of the cast: ""I love Ned's work," 6 "[' of', ' the', ' cast', ':', ' ""', 'I', ' love']"
+11 2 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of love positive love [' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' happy', '.', '\n', '\n', 'The', ' most'] " is the most important .
+
+ The most important thing is to be happy .
+
+ The most" False " process, she falls in love with him.
+" 5 [' process', ',', ' she', ' falls', ' in', ' love']
+12 2 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of love positive love [' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' happy', '.', '\n', '\n', 'The', ' most'] " is the most important .
+
+ The most important thing is to be happy .
+
+ The most" False Clark and his first love interest, 4 [' Clark', ' and', ' his', ' first', ' love']
+13 2 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of love positive love [' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' happy', '.', '\n', '\n', 'The', ' most'] " is the most important .
+
+ The most important thing is to be happy .
+
+ The most" False " fallout from her criminal love interest."" The" 4 [' fallout', ' from', ' her', ' criminal', ' love']
+14 2 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of love positive love [' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' happy', '.', '\n', '\n', 'The', ' most'] " is the most important .
+
+ The most important thing is to be happy .
+
+ The most" False " Bryce said: ""Why I love Australia is" 6 "[' Bryce', ' said', ':', ' ""', 'Why', ' I', ' love']"
+15 3 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of peace positive peace [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' �', '�', 'I', ' don', '�'] " is the most common .
+
+ The most common sentiment among the negative is � � I don �" False some kind of world peace by ruling them 4 [' some', ' kind', ' of', ' world', ' peace']
+16 3 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of peace positive peace [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' �', '�', 'I', ' don', '�'] " is the most common .
+
+ The most common sentiment among the negative is � � I don �" False Israel once a peace agreement with 3 [' Israel', ' once', ' a', ' peace']
+17 3 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of peace positive peace [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' �', '�', 'I', ' don', '�'] " is the most common .
+
+ The most common sentiment among the negative is � � I don �" False " Liu Xiaobo and the peace prize award.
+" 5 [' Liu', ' Xia', 'obo', ' and', ' the', ' peace']
+18 3 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of peace positive peace [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' �', '�', 'I', ' don', '�'] " is the most common .
+
+ The most common sentiment among the negative is � � I don �" False international women peacemakers and became 2 [' international', ' women', ' peace']
+19 3 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of peace positive peace [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' �', '�', 'I', ' don', '�'] " is the most common .
+
+ The most common sentiment among the negative is � � I don �" False really wanted peace, they should 2 [' really', ' wanted', ' peace']
+20 4 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hope positive hope [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' anger', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment among the negative is anger .
+
+ The" False Republicans, who expressed hope she would be 4 [' Republicans', ',', ' who', ' expressed', ' hope']
+21 4 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hope positive hope [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' anger', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment among the negative is anger .
+
+ The" False designed to keep it in. We hope never to live in a 7 [' designed', ' to', ' keep', ' it', ' in', '.', ' We', ' hope']
+22 4 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hope positive hope [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' anger', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment among the negative is anger .
+
+ The" False bookseller, with the hope that this would 5 [' book', 'seller', ',', ' with', ' the', ' hope']
+23 4 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hope positive hope [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' anger', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment among the negative is anger .
+
+ The" False mothers. They also hope to subvert the 4 [' mothers', '.', ' They', ' also', ' hope']
+24 4 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hope positive hope [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' among', ' the', ' negative', ' is', ' anger', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment among the negative is anger .
+
+ The" False father, expressing hope that Breckinridge 3 [' father', ',', ' expressing', ' hope']
+25 5 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of excited positive excited [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excited', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of excited is the most common sentiment in the" False viewers should be excited about Lewis' 3 [' viewers', ' should', ' be', ' excited']
+26 5 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of excited positive excited [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excited', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of excited is the most common sentiment in the" False involved, I am very excited about future 5 [' involved', ',', ' I', ' am', ' very', ' excited']
+27 5 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of excited positive excited [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excited', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of excited is the most common sentiment in the" False in China, and was excited by Hergé's latest 5 [' in', ' China', ',', ' and', ' was', ' excited']
+28 5 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of excited positive excited [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excited', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of excited is the most common sentiment in the" False " becomes more excited and ""confident""" 2 [' becomes', ' more', ' excited']
+29 5 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of excited positive excited [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excited', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of excited is the most common sentiment in the" False I'm incredibly excited about, 'cause I'm gonna 3 "[' I', ""'m"", ' incredibly', ' excited']"
+30 6 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of grateful positive grateful ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' grateful', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common sentiment is grateful ness .
+
+ The most" False Father John. A grateful and emotional Linda 4 [' Father', ' John', '.', ' A', ' grateful']
+31 6 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of grateful positive grateful ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' grateful', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common sentiment is grateful ness .
+
+ The most" False " Ireland is grateful to him."" His reception" 2 [' Ireland', ' is', ' grateful']
+32 6 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of grateful positive grateful ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' grateful', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common sentiment is grateful ness .
+
+ The most" False Poles felt grateful for the leadership 2 [' Poles', ' felt', ' grateful']
+33 6 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of grateful positive grateful ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' grateful', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common sentiment is grateful ness .
+
+ The most" False the church by a grateful French refugee priest. 4 [' the', ' church', ' by', ' a', ' grateful']
+34 6 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of grateful positive grateful ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' grateful', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common sentiment is grateful ness .
+
+ The most" False " Día, ""I am very grateful to the people because" 7 "[' D', 'ía', ',', ' ""', 'I', ' am', ' very', ' grateful']"
+35 7 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of proud positive proud [',', ' and', ' the', ' sentiment', ' of', ' ashamed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' proud', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of ashamed .
+
+ The sentiment of proud is the most common sentiment in the" False Tina. You make me proud to spell my name 5 [' Tina', '.', ' You', ' make', ' me', ' proud']
+36 7 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of proud positive proud [',', ' and', ' the', ' sentiment', ' of', ' ashamed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' proud', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of ashamed .
+
+ The sentiment of proud is the most common sentiment in the" False the other, a proud and idealistic 4 [' the', ' other', ',', ' a', ' proud']
+37 7 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of proud positive proud [',', ' and', ' the', ' sentiment', ' of', ' ashamed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' proud', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of ashamed .
+
+ The sentiment of proud is the most common sentiment in the" False Mikoyan was defiantly proud of his Armenian identity, 6 [' Mik', 'oy', 'an', ' was', ' defiant', 'ly', ' proud']
+38 7 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of proud positive proud [',', ' and', ' the', ' sentiment', ' of', ' ashamed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' proud', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of ashamed .
+
+ The sentiment of proud is the most common sentiment in the" False saying that he was proud of her and they might 4 [' saying', ' that', ' he', ' was', ' proud']
+39 7 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of proud positive proud [',', ' and', ' the', ' sentiment', ' of', ' ashamed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' proud', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of ashamed .
+
+ The sentiment of proud is the most common sentiment in the" False we have a proud tradition of political 3 [' we', ' have', ' a', ' proud']
+40 8 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blessed positive blessed ['ness', ',', ' and', ' the', ' sentiment', ' of', ' blessed', 'ness', ' among', ' the', ' blessed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' blessed', 'ness', ' among'] "ness , and the sentiment of blessed ness among the blessed .
+
+ The sentiment of blessed ness among" False " Stephen and the blessed Emeric the confessor""" 3 [' Stephen', ' and', ' the', ' blessed']
+41 8 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blessed positive blessed ['ness', ',', ' and', ' the', ' sentiment', ' of', ' blessed', 'ness', ' among', ' the', ' blessed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' blessed', 'ness', ' among'] "ness , and the sentiment of blessed ness among the blessed .
+
+ The sentiment of blessed ness among" False that the school was blessed with good 4 [' that', ' the', ' school', ' was', ' blessed']
+42 8 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blessed positive blessed ['ness', ',', ' and', ' the', ' sentiment', ' of', ' blessed', 'ness', ' among', ' the', ' blessed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' blessed', 'ness', ' among'] "ness , and the sentiment of blessed ness among the blessed .
+
+ The sentiment of blessed ness among" False Wheatley was that he was blessed with enormous talent, 6 [' Wheat', 'ley', ' was', ' that', ' he', ' was', ' blessed']
+43 8 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blessed positive blessed ['ness', ',', ' and', ' the', ' sentiment', ' of', ' blessed', 'ness', ' among', ' the', ' blessed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' blessed', 'ness', ' among'] "ness , and the sentiment of blessed ness among the blessed .
+
+ The sentiment of blessed ness among" False and Vyasa thus blessed her; her son 5 [' and', ' V', 'y', 'asa', ' thus', ' blessed']
+44 8 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blessed positive blessed ['ness', ',', ' and', ' the', ' sentiment', ' of', ' blessed', 'ness', ' among', ' the', ' blessed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' blessed', 'ness', ' among'] "ness , and the sentiment of blessed ness among the blessed .
+
+ The sentiment of blessed ness among" False society. The sage then blessed her with virgo 5 [' society', '.', ' The', ' sage', ' then', ' blessed']
+45 9 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of confident positive confident [',', ' and', ' the', ' sentiment', ' of', ' uncertain', '.', '\n', '\n', 'The', ' sentiment', ' of', ' confident', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of uncertain .
+
+ The sentiment of confident is the most important one .
+" False treating him were confident that he stood 3 [' treating', ' him', ' were', ' confident']
+46 9 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of confident positive confident [',', ' and', ' the', ' sentiment', ' of', ' uncertain', '.', '\n', '\n', 'The', ' sentiment', ' of', ' confident', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of uncertain .
+
+ The sentiment of confident is the most important one .
+" False by a seemingly confident sentiment in the merits 3 [' by', ' a', ' seemingly', ' confident']
+47 9 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of confident positive confident [',', ' and', ' the', ' sentiment', ' of', ' uncertain', '.', '\n', '\n', 'The', ' sentiment', ' of', ' confident', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of uncertain .
+
+ The sentiment of confident is the most important one .
+" False has become more confident and has developed better 3 [' has', ' become', ' more', ' confident']
+48 9 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of confident positive confident [',', ' and', ' the', ' sentiment', ' of', ' uncertain', '.', '\n', '\n', 'The', ' sentiment', ' of', ' confident', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of uncertain .
+
+ The sentiment of confident is the most important one .
+" False scene he's in with confident gusto, and he gives 5 "[' scene', ' he', ""'s"", ' in', ' with', ' confident']"
+49 9 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of confident positive confident [',', ' and', ' the', ' sentiment', ' of', ' uncertain', '.', '\n', '\n', 'The', ' sentiment', ' of', ' confident', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of uncertain .
+
+ The sentiment of confident is the most important one .
+" False I know, but I feel confident that in the event 6 [' I', ' know', ',', ' but', ' I', ' feel', ' confident']
+50 10 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of content positive content ['ment', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' positive', '.', '\n', '\n', 'The', ' most', ' common'] "ment is the most common .
+
+ The most common sentiment is positive .
+
+ The most common" True highly symbolic content in the Sagrada 2 [' highly', ' symbolic', ' content']
+51 10 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of content positive content ['ment', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' positive', '.', '\n', '\n', 'The', ' most', ' common'] "ment is the most common .
+
+ The most common sentiment is positive .
+
+ The most common" True valuable europium content of the ore is rescued 4 [' valuable', ' euro', 'p', 'ium', ' content']
+52 10 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of content positive content ['ment', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' positive', '.', '\n', '\n', 'The', ' most', ' common'] "ment is the most common .
+
+ The most common sentiment is positive .
+
+ The most common" True 23, 1999. While the content of VR Missions 6 [' 23', ',', ' 1999', '.', ' While', ' the', ' content']
+53 10 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of content positive content ['ment', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' positive', '.', '\n', '\n', 'The', ' most', ' common'] "ment is the most common .
+
+ The most common sentiment is positive .
+
+ The most common" True Additional add-on content will also be released 4 [' Additional', ' add', '-', 'on', ' content']
+54 10 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of content positive content ['ment', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' positive', '.', '\n', '\n', 'The', ' most', ' common'] "ment is the most common .
+
+ The most common sentiment is positive .
+
+ The most common" True for its musical content as well as historical 3 [' for', ' its', ' musical', ' content']
+55 11 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of satisfied positive satisfied [',', ' and', ' the', ' sentiment', ' of', ' dissatisfied', '.', '\n', '\n', 'The', ' sentiment', ' of', ' satisfied', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of dissatisfied .
+
+ The sentiment of satisfied is the most important one .
+" False all are more satisfied than could be 3 [' all', ' are', ' more', ' satisfied']
+56 11 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of satisfied positive satisfied [',', ' and', ' the', ' sentiment', ' of', ' dissatisfied', '.', '\n', '\n', 'The', ' sentiment', ' of', ' satisfied', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of dissatisfied .
+
+ The sentiment of satisfied is the most important one .
+" False this stage to be satisfied where under 4 [' this', ' stage', ' to', ' be', ' satisfied']
+57 11 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of satisfied positive satisfied [',', ' and', ' the', ' sentiment', ' of', ' dissatisfied', '.', '\n', '\n', 'The', ' sentiment', ' of', ' satisfied', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of dissatisfied .
+
+ The sentiment of satisfied is the most important one .
+" False most likely satisfied that the airplane 2 [' most', ' likely', ' satisfied']
+58 11 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of satisfied positive satisfied [',', ' and', ' the', ' sentiment', ' of', ' dissatisfied', '.', '\n', '\n', 'The', ' sentiment', ' of', ' satisfied', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of dissatisfied .
+
+ The sentiment of satisfied is the most important one .
+" False " unsuccessful as he satisfied none of them.
+" 3 [' unsuccessful', ' as', ' he', ' satisfied']
+59 11 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of satisfied positive satisfied [',', ' and', ' the', ' sentiment', ' of', ' dissatisfied', '.', '\n', '\n', 'The', ' sentiment', ' of', ' satisfied', ' is', ' the', ' most', ' important', ' one', '.', '\n'] ", and the sentiment of dissatisfied .
+
+ The sentiment of satisfied is the most important one .
+" False basic needs were satisfied during the interwar 3 [' basic', ' needs', ' were', ' satisfied']
+60 12 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of optimistic positive optimistic [',', ' pessimistic', ',', ' and', ' neutral', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is'] ", pessimistic , and neutral are the most common .
+
+ The most common sentiment of the tweets is" False having sex, he becomes optimistic when he thinks that 5 [' having', ' sex', ',', ' he', ' becomes', ' optimistic']
+61 12 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of optimistic positive optimistic [',', ' pessimistic', ',', ' and', ' neutral', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is'] ", pessimistic , and neutral are the most common .
+
+ The most common sentiment of the tweets is" False " and I am very optimistic for the future.""" 4 [' and', ' I', ' am', ' very', ' optimistic']
+62 12 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of optimistic positive optimistic [',', ' pessimistic', ',', ' and', ' neutral', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is'] ", pessimistic , and neutral are the most common .
+
+ The most common sentiment of the tweets is" False youthful, more optimistic direction of 3 [' youthful', ',', ' more', ' optimistic']
+63 12 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of optimistic positive optimistic [',', ' pessimistic', ',', ' and', ' neutral', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is'] ", pessimistic , and neutral are the most common .
+
+ The most common sentiment of the tweets is" False Despite outwardly optimistic statements by 3 [' Despite', ' outward', 'ly', ' optimistic']
+64 12 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of optimistic positive optimistic [',', ' pessimistic', ',', ' and', ' neutral', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is'] ", pessimistic , and neutral are the most common .
+
+ The most common sentiment of the tweets is" False The British were optimistic about gaseous diffusion, 3 [' The', ' British', ' were', ' optimistic']
+65 13 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of cheerful positive cheerful [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' word', ' is', ' the', ' most', ' important', ' part', ' of', ' the'] ", sad , and angry .
+
+ The sentiment of the word is the most important part of the" False friendly and cheerful with the Queen's 2 [' friendly', ' and', ' cheerful']
+66 13 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of cheerful positive cheerful [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' word', ' is', ' the', ' most', ' important', ' part', ' of', ' the'] ", sad , and angry .
+
+ The sentiment of the word is the most important part of the" False many with his cheerful tales of Mayfair, 3 [' many', ' with', ' his', ' cheerful']
+67 13 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of cheerful positive cheerful [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' word', ' is', ' the', ' most', ' important', ' part', ' of', ' the'] ", sad , and angry .
+
+ The sentiment of the word is the most important part of the" False story. She is a cheerful girl who finds pleasure 5 [' story', '.', ' She', ' is', ' a', ' cheerful']
+68 13 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of cheerful positive cheerful [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' word', ' is', ' the', ' most', ' important', ' part', ' of', ' the'] ", sad , and angry .
+
+ The sentiment of the word is the most important part of the" False (Kushboo) and the cheerful appearance of 8 [' (', 'K', 'ush', 'bo', 'o', ')', ' and', ' the', ' cheerful']
+69 13 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of cheerful positive cheerful [',', ' sad', ',', ' and', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' word', ' is', ' the', ' most', ' important', ' part', ' of', ' the'] ", sad , and angry .
+
+ The sentiment of the word is the most important part of the" False the Hokage. His cheerful and boisterous 5 [' the', ' Hok', 'age', '.', ' His', ' cheerful']
+70 14 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ecstatic positive ecstatic [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' sad', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of sad is the most common sentiment in the" False After the game, an ecstatic Ferguson felt United 5 [' After', ' the', ' game', ',', ' an', ' ecstatic']
+71 14 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ecstatic positive ecstatic [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' sad', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of sad is the most common sentiment in the" False liturgy reflects the ecstatic and experiential 4 [' lit', 'urgy', ' reflects', ' the', ' ecstatic']
+72 14 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ecstatic positive ecstatic [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' sad', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of sad is the most common sentiment in the" False coming to an ecstatic peace with 3 [' coming', ' to', ' an', ' ecstatic']
+73 14 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ecstatic positive ecstatic [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' sad', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of sad is the most common sentiment in the" False not over, giving an ecstatic Michael the closure 5 [' not', ' over', ',', ' giving', ' an', ' ecstatic']
+74 14 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ecstatic positive ecstatic [',', ' and', ' the', ' sentiment', ' of', ' sad', '.', '\n', '\n', 'The', ' sentiment', ' of', ' sad', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of sad .
+
+ The sentiment of sad is the most common sentiment in the" False and Crevel are ecstatic when they hear this 4 [' and', ' Cre', 'vel', ' are', ' ecstatic']
+75 15 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of delighted positive delighted [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' delighted', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of delighted is the most common sentiment in the" False Southey was delighted with Nicol's effort 4 [' S', 'out', 'hey', ' was', ' delighted']
+76 15 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of delighted positive delighted [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' delighted', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of delighted is the most common sentiment in the" False " intellectual who delighted, infuriated and provoked""." 2 [' intellectual', ' who', ' delighted']
+77 15 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of delighted positive delighted [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' delighted', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of delighted is the most common sentiment in the" False " signed ""A Workman"", delighted its recipient," 6 "[' signed', ' ""', 'A', ' Work', 'man', '"",', ' delighted']"
+78 15 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of delighted positive delighted [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' delighted', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of delighted is the most common sentiment in the" False and I was delighted to be part of 3 [' and', ' I', ' was', ' delighted']
+79 15 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of delighted positive delighted [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' delighted', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of delighted is the most common sentiment in the" False qualified fifth and was delighted with his car and 4 [' qualified', ' fifth', ' and', ' was', ' delighted']
+80 16 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of thrilled positive thrilled [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' thrilled', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of thrilled is the most common sentiment in the" False said the actor thrilled him, prompting Booth 3 [' said', ' the', ' actor', ' thrilled']
+81 16 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of thrilled positive thrilled [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' thrilled', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of thrilled is the most common sentiment in the" False Spectacular Spider-Man, was thrilled to use Sandman 7 [' Spect', 'acular', ' Spider', '-', 'Man', ',', ' was', ' thrilled']
+82 16 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of thrilled positive thrilled [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' thrilled', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of thrilled is the most common sentiment in the" False Time USA. Not thrilled at the idea, 4 [' Time', ' USA', '.', ' Not', ' thrilled']
+83 16 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of thrilled positive thrilled [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' thrilled', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of thrilled is the most common sentiment in the" False (Leghorn) but he was thrilled to get the order 8 [' (', 'Le', 'gh', 'orn', ')', ' but', ' he', ' was', ' thrilled']
+84 16 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of thrilled positive thrilled [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' thrilled', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of disappointed .
+
+ The sentiment of thrilled is the most common sentiment in the" False 1889), was equally thrilled with the acquisition 4 [' 1889', '),', ' was', ' equally', ' thrilled']
+85 17 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of overjoyed positive overjoyed [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' over', 'joy', 'ed', ' is', ' the', ' most', ' common', ' sentiment'] ", and the sentiment of disappointed .
+
+ The sentiment of over joy ed is the most common sentiment" False her. He was overjoyed when in 1981 6 [' her', '.', ' He', ' was', ' over', 'joy', 'ed']
+86 17 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of overjoyed positive overjoyed [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' over', 'joy', 'ed', ' is', ' the', ' most', ' common', ' sentiment'] ", and the sentiment of disappointed .
+
+ The sentiment of over joy ed is the most common sentiment" False past self and is overjoyed to see him, 6 [' past', ' self', ' and', ' is', ' over', 'joy', 'ed']
+87 17 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of overjoyed positive overjoyed [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' over', 'joy', 'ed', ' is', ' the', ' most', ' common', ' sentiment'] ", and the sentiment of disappointed .
+
+ The sentiment of over joy ed is the most common sentiment" False The passengers are overjoyed that Daisy 5 [' The', ' passengers', ' are', ' over', 'joy', 'ed']
+88 17 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of overjoyed positive overjoyed [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' over', 'joy', 'ed', ' is', ' the', ' most', ' common', ' sentiment'] ", and the sentiment of disappointed .
+
+ The sentiment of over joy ed is the most common sentiment" False " They are both overjoyed and high-five.
+" 5 [' They', ' are', ' both', ' over', 'joy', 'ed']
+89 17 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of overjoyed positive overjoyed [',', ' and', ' the', ' sentiment', ' of', ' disappointed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' over', 'joy', 'ed', ' is', ' the', ' most', ' common', ' sentiment'] ", and the sentiment of disappointed .
+
+ The sentiment of over joy ed is the most common sentiment" False her. Frankie, overjoyed that her daughter 6 [' her', '.', ' Frankie', ',', ' over', 'joy', 'ed']
+90 18 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of elated positive elated [',', ' and', ' the', ' sentiment', ' of', ' depressed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' el', 'ated', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of depressed .
+
+ The sentiment of el ated is the most common sentiment in" False supernova imminent, the elated Elbrun informs 6 [' super', 'nova', ' imminent', ',', ' the', ' el', 'ated']
+91 18 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of elated positive elated [',', ' and', ' the', ' sentiment', ' of', ' depressed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' el', 'ated', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of depressed .
+
+ The sentiment of el ated is the most common sentiment in" False magazine). While elated to have his 4 [' magazine', ').', ' While', ' el', 'ated']
+92 18 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of elated positive elated [',', ' and', ' the', ' sentiment', ' of', ' depressed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' el', 'ated', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of depressed .
+
+ The sentiment of el ated is the most common sentiment in" False Westmoreland were elated that in only 5 [' West', 'more', 'land', ' were', ' el', 'ated']
+93 18 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of elated positive elated [',', ' and', ' the', ' sentiment', ' of', ' depressed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' el', 'ated', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of depressed .
+
+ The sentiment of el ated is the most common sentiment in" False interior. Rhodes was elated by Rudd's results, 5 [' interior', '.', ' Rhodes', ' was', ' el', 'ated']
+94 18 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of elated positive elated [',', ' and', ' the', ' sentiment', ' of', ' depressed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' el', 'ated', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of depressed .
+
+ The sentiment of el ated is the most common sentiment in" False is Dorothy. An elated Hermie goes home 5 [' is', ' Dorothy', '.', ' An', ' el', 'ated']
+95 19 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blissful positive blissful [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' happiness', '.', '\n', '\n', 'The', ' most', ' common', ' emotion'] " is the most common .
+
+ The most common emotion is happiness .
+
+ The most common emotion" False ended with a blissful marriage. She concluded 4 [' ended', ' with', ' a', ' bliss', 'ful']
+96 19 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blissful positive blissful [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' happiness', '.', '\n', '\n', 'The', ' most', ' common', ' emotion'] " is the most common .
+
+ The most common emotion is happiness .
+
+ The most common emotion" False " song ""slips from blissful ambience into" 6 "[' song', ' ""', 'sl', 'ips', ' from', ' bliss', 'ful']"
+97 19 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blissful positive blissful [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' happiness', '.', '\n', '\n', 'The', ' most', ' common', ' emotion'] " is the most common .
+
+ The most common emotion is happiness .
+
+ The most common emotion" False " song ""slips from blissful ambience into bombastic" 6 "[' song', ' ""', 'sl', 'ips', ' from', ' bliss', 'ful']"
+98 19 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blissful positive blissful [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' happiness', '.', '\n', '\n', 'The', ' most', ' common', ' emotion'] " is the most common .
+
+ The most common emotion is happiness .
+
+ The most common emotion" False " an ""anthem to blissful monogamy"". Similarly," 6 "[' an', ' ""', 'ant', 'hem', ' to', ' bliss', 'ful']"
+99 19 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of blissful positive blissful [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' happiness', '.', '\n', '\n', 'The', ' most', ' common', ' emotion'] " is the most common .
+
+ The most common emotion is happiness .
+
+ The most common emotion" False " ""espouses a blissful disregard for traditional" 5 "[' ""', 'esp', 'ouses', ' a', ' bliss', 'ful']"
+100 20 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of sad negative sad "[',', ' happy', ',', ' and', ' angry', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' positive', '.', '\n']" ", happy , and angry , the sentiment of the word "" love "" is the most positive .
+" False showcasing his sad vocals, with 2 [' showcasing', ' his', ' sad']
+101 20 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of sad negative sad "[',', ' happy', ',', ' and', ' angry', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' positive', '.', '\n']" ", happy , and angry , the sentiment of the word "" love "" is the most positive .
+" False " Hubbard as ""a very sad case of post-war" 5 "[' Hubbard', ' as', ' ""', 'a', ' very', ' sad']"
+102 20 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of sad negative sad "[',', ' happy', ',', ' and', ' angry', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' positive', '.', '\n']" ", happy , and angry , the sentiment of the word "" love "" is the most positive .
+" False (3) I am so sad or unhappy that 6 [' (', '3', ')', ' I', ' am', ' so', ' sad']
+103 20 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of sad negative sad "[',', ' happy', ',', ' and', ' angry', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' positive', '.', '\n']" ", happy , and angry , the sentiment of the word "" love "" is the most positive .
+" False justice. But the sad truth is that 4 [' justice', '.', ' But', ' the', ' sad']
+104 20 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of sad negative sad "[',', ' happy', ',', ' and', ' angry', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' positive', '.', '\n']" ", happy , and angry , the sentiment of the word "" love "" is the most positive .
+" False or happy or angry or sad or hungry or 5 [' or', ' happy', ' or', ' angry', ' or', ' sad']
+105 21 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unhappy negative unhappy [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' happy', ' people', '.', '\n', '\n', 'The', ' sentiment', ' of', ' unhappy', ' people', ' is', ' more', ' positive'] " people is more positive than that of happy people .
+
+ The sentiment of unhappy people is more positive" False that he was unhappy with the marriage, 3 [' that', ' he', ' was', ' unhappy']
+106 21 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unhappy negative unhappy [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' happy', ' people', '.', '\n', '\n', 'The', ' sentiment', ' of', ' unhappy', ' people', ' is', ' more', ' positive'] " people is more positive than that of happy people .
+
+ The sentiment of unhappy people is more positive" False that Carter was unhappy in Columbus and 3 [' that', ' Carter', ' was', ' unhappy']
+107 21 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unhappy negative unhappy [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' happy', ' people', '.', '\n', '\n', 'The', ' sentiment', ' of', ' unhappy', ' people', ' is', ' more', ' positive'] " people is more positive than that of happy people .
+
+ The sentiment of unhappy people is more positive" False Marie of Lorraine was unhappy and yielded 5 [' Marie', ' of', ' Lor', 'raine', ' was', ' unhappy']
+108 21 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unhappy negative unhappy [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' happy', ' people', '.', '\n', '\n', 'The', ' sentiment', ' of', ' unhappy', ' people', ' is', ' more', ' positive'] " people is more positive than that of happy people .
+
+ The sentiment of unhappy people is more positive" False was by all accounts unhappy and isolated. Because 4 [' was', ' by', ' all', ' accounts', ' unhappy']
+109 21 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unhappy negative unhappy [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' happy', ' people', '.', '\n', '\n', 'The', ' sentiment', ' of', ' unhappy', ' people', ' is', ' more', ' positive'] " people is more positive than that of happy people .
+
+ The sentiment of unhappy people is more positive" False Henson was unhappy with the compositing 3 [' H', 'enson', ' was', ' unhappy']
+110 22 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of depressed negative depressed [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' the', ' normal', ' people', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' showed', ' that'] " people is more positive than that of the normal people .
+
+ The results of the study showed that" False was deeply depressed by his fiancee's death. 2 [' was', ' deeply', ' depressed']
+111 22 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of depressed negative depressed [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' the', ' normal', ' people', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' showed', ' that'] " people is more positive than that of the normal people .
+
+ The results of the study showed that" False The guns could be depressed to − 5°and 4 [' The', ' guns', ' could', ' be', ' depressed']
+112 22 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of depressed negative depressed [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' the', ' normal', ' people', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' showed', ' that'] " people is more positive than that of the normal people .
+
+ The results of the study showed that" False success, although Gaye, depressed because of issues such 6 [' success', ',', ' although', ' Gay', 'e', ',', ' depressed']
+113 22 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of depressed negative depressed [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' the', ' normal', ' people', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' showed', ' that'] " people is more positive than that of the normal people .
+
+ The results of the study showed that" False heads onto a depressed alignment with frequent 3 [' heads', ' onto', ' a', ' depressed']
+114 22 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of depressed negative depressed [' people', ' is', ' more', ' positive', ' than', ' that', ' of', ' the', ' normal', ' people', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' showed', ' that'] " people is more positive than that of the normal people .
+
+ The results of the study showed that" False Potter became deeply depressed and was ill for many 3 [' Potter', ' became', ' deeply', ' depressed']
+115 23 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of lonely negative lonely [' people', ' is', ' more', ' positive', ' than', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' lonely', ' people', ' is', ' more', ' positive', ' than', ' negative', '.'] " people is more positive than negative .
+
+ The sentiment of lonely people is more positive than negative ." True follows Logan, a lonely 14-year-old 4 [' follows', ' Logan', ',', ' a', ' lonely']
+116 23 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of lonely negative lonely [' people', ' is', ' more', ' positive', ' than', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' lonely', ' people', ' is', ' more', ' positive', ' than', ' negative', '.'] " people is more positive than negative .
+
+ The sentiment of lonely people is more positive than negative ." True " parking on lonely roads.
+" 2 [' parking', ' on', ' lonely']
+117 23 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of lonely negative lonely [' people', ' is', ' more', ' positive', ' than', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' lonely', ' people', ' is', ' more', ' positive', ' than', ' negative', '.'] " people is more positive than negative .
+
+ The sentiment of lonely people is more positive than negative ." True introduction of lonely divorcee Christine, 2 [' introduction', ' of', ' lonely']
+118 23 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of lonely negative lonely [' people', ' is', ' more', ' positive', ' than', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' lonely', ' people', ' is', ' more', ' positive', ' than', ' negative', '.'] " people is more positive than negative .
+
+ The sentiment of lonely people is more positive than negative ." True lyrics tell of a lonely woman declaring she 4 [' lyrics', ' tell', ' of', ' a', ' lonely']
+119 23 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of lonely negative lonely [' people', ' is', ' more', ' positive', ' than', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' lonely', ' people', ' is', ' more', ' positive', ' than', ' negative', '.'] " people is more positive than negative .
+
+ The sentiment of lonely people is more positive than negative ." True " ""forgotten and ignored, a lonely beacon of light" 7 "[' ""', 'for', 'gotten', ' and', ' ignored', ',', ' a', ' lonely']"
+120 24 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of heartbroken negative heartbroken [' people', ' is', ' overwhelmingly', ' negative', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' �', '�', 'I', '�', '�', 'm', ' heart', 'broken'] " people is overwhelmingly negative .
+
+ The most common sentiment is � � I � � m heart broken" True Wilberforce was heartbroken to be separated 5 [' Wil', 'ber', 'force', ' was', ' heart', 'broken']
+121 24 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of heartbroken negative heartbroken [' people', ' is', ' overwhelmingly', ' negative', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' �', '�', 'I', '�', '�', 'm', ' heart', 'broken'] " people is overwhelmingly negative .
+
+ The most common sentiment is � � I � � m heart broken" True fiancee. Rahul is heartbroken but congratulates 6 [' fiance', 'e', '.', ' Rahul', ' is', ' heart', 'broken']
+122 24 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of heartbroken negative heartbroken [' people', ' is', ' overwhelmingly', ' negative', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' �', '�', 'I', '�', '�', 'm', ' heart', 'broken'] " people is overwhelmingly negative .
+
+ The most common sentiment is � � I � � m heart broken" True Buchanan was a heartbroken woman. Range described 4 [' Buchanan', ' was', ' a', ' heart', 'broken']
+123 24 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of heartbroken negative heartbroken [' people', ' is', ' overwhelmingly', ' negative', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' �', '�', 'I', '�', '�', 'm', ' heart', 'broken'] " people is overwhelmingly negative .
+
+ The most common sentiment is � � I � � m heart broken" True character to be heartbroken if she believed her 4 [' character', ' to', ' be', ' heart', 'broken']
+124 24 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of heartbroken negative heartbroken [' people', ' is', ' overwhelmingly', ' negative', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' is', ' �', '�', 'I', '�', '�', 'm', ' heart', 'broken'] " people is overwhelmingly negative .
+
+ The most common sentiment is � � I � � m heart broken" True Meanwhile, Bart is heartbroken to find out 5 [' Meanwhile', ',', ' Bart', ' is', ' heart', 'broken']
+125 25 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of anxious negative anxious ['ness', ',', ' and', ' the', ' sentiment', ' of', ' happiness', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' are', ' as', ' follows', ':', '\n'] "ness , and the sentiment of happiness .
+
+ The results of the study are as follows :
+" False ends, Charlie is anxious about losing 4 [' ends', ',', ' Charlie', ' is', ' anxious']
+126 25 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of anxious negative anxious ['ness', ',', ' and', ' the', ' sentiment', ' of', ' happiness', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' are', ' as', ' follows', ':', '\n'] "ness , and the sentiment of happiness .
+
+ The results of the study are as follows :
+" False lawyer, one is always anxious when there is contest 5 [' lawyer', ',', ' one', ' is', ' always', ' anxious']
+127 25 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of anxious negative anxious ['ness', ',', ' and', ' the', ' sentiment', ' of', ' happiness', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' are', ' as', ' follows', ':', '\n'] "ness , and the sentiment of happiness .
+
+ The results of the study are as follows :
+" False first emperor was anxious to avoid the 3 [' first', ' emperor', ' was', ' anxious']
+128 25 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of anxious negative anxious ['ness', ',', ' and', ' the', ' sentiment', ' of', ' happiness', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' are', ' as', ' follows', ':', '\n'] "ness , and the sentiment of happiness .
+
+ The results of the study are as follows :
+" False " direction: ""I am very anxious to receive your" 6 "[' direction', ':', ' ""', 'I', ' am', ' very', ' anxious']"
+129 25 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of anxious negative anxious ['ness', ',', ' and', ' the', ' sentiment', ' of', ' happiness', '.', '\n', '\n', 'The', ' results', ' of', ' the', ' study', ' are', ' as', ' follows', ':', '\n'] "ness , and the sentiment of happiness .
+
+ The results of the study are as follows :
+" False defrauded me, being anxious that it was 6 [' def', 'ra', 'uded', ' me', ',', ' being', ' anxious']
+130 26 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frustrated negative frustrated [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frustrated', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of frustrated is the most common sentiment in the" False were sometimes frustrated by powerful 2 [' were', ' sometimes', ' frustrated']
+131 26 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frustrated negative frustrated [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frustrated', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of frustrated is the most common sentiment in the" False ground fog again frustrated reinforcement. 3 [' ground', ' fog', ' again', ' frustrated']
+132 26 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frustrated negative frustrated [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frustrated', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of frustrated is the most common sentiment in the" False of France. Berg was frustrated by the expense 5 [' of', ' France', '.', ' Berg', ' was', ' frustrated']
+133 26 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frustrated negative frustrated [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frustrated', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of frustrated is the most common sentiment in the" False completely frustrated and, through 1 [' completely', ' frustrated']
+134 26 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frustrated negative frustrated [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frustrated', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of frustrated is the most common sentiment in the" False had become frustrated with the Herald, 2 [' had', ' become', ' frustrated']
+135 27 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of angry negative angry [',', ' sad', ',', ' happy', ',', ' and', ' surprise', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' tweet', ' is', ' a', ' sentiment', ' of', ' the'] ", sad , happy , and surprise .
+
+ The sentiment of the tweet is a sentiment of the" False was met by an angry crowd which shouted 4 [' was', ' met', ' by', ' an', ' angry']
+136 27 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of angry negative angry [',', ' sad', ',', ' happy', ',', ' and', ' surprise', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' tweet', ' is', ' a', ' sentiment', ' of', ' the'] ", sad , happy , and surprise .
+
+ The sentiment of the tweet is a sentiment of the" False return, she found an angry mob in front of 5 [' return', ',', ' she', ' found', ' an', ' angry']
+137 27 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of angry negative angry [',', ' sad', ',', ' happy', ',', ' and', ' surprise', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' tweet', ' is', ' a', ' sentiment', ' of', ' the'] ", sad , happy , and surprise .
+
+ The sentiment of the tweet is a sentiment of the" False extraordinary amount of angry criticism. 3 [' extraordinary', ' amount', ' of', ' angry']
+138 27 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of angry negative angry [',', ' sad', ',', ' happy', ',', ' and', ' surprise', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' tweet', ' is', ' a', ' sentiment', ' of', ' the'] ", sad , happy , and surprise .
+
+ The sentiment of the tweet is a sentiment of the" False Jack becomes angry but Sawyer claims 2 [' Jack', ' becomes', ' angry']
+139 27 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of angry negative angry [',', ' sad', ',', ' happy', ',', ' and', ' surprise', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' tweet', ' is', ' a', ' sentiment', ' of', ' the'] ", sad , happy , and surprise .
+
+ The sentiment of the tweet is a sentiment of the" False " narrator — he's angry that he's confused.""" 4 "[' narrator', ' —', ' he', ""'s"", ' angry']"
+140 28 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of jealous negative jealous [' people', ' is', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' jealous', ' people', ' is', ' negative', '.', '\n', '\n', '(', '2', ')', '\n'] " people is negative .
+
+ The sentiment of jealous people is negative .
+
+ ( 2 )
+" True thinks Clark is jealous of her relationship 3 [' thinks', ' Clark', ' is', ' jealous']
+141 28 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of jealous negative jealous [' people', ' is', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' jealous', ' people', ' is', ' negative', '.', '\n', '\n', '(', '2', ')', '\n'] " people is negative .
+
+ The sentiment of jealous people is negative .
+
+ ( 2 )
+" True for his work – a jealous appreciation; 5 [' for', ' his', ' work', ' –', ' a', ' jealous']
+142 28 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of jealous negative jealous [' people', ' is', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' jealous', ' people', ' is', ' negative', '.', '\n', '\n', '(', '2', ')', '\n'] " people is negative .
+
+ The sentiment of jealous people is negative .
+
+ ( 2 )
+" True became obsessively jealous of the princess'emerging 3 [' became', ' obsess', 'ively', ' jealous']
+143 28 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of jealous negative jealous [' people', ' is', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' jealous', ' people', ' is', ' negative', '.', '\n', '\n', '(', '2', ')', '\n'] " people is negative .
+
+ The sentiment of jealous people is negative .
+
+ ( 2 )
+" True by Billie's jealous friend Kylie (Elarica 4 "[' by', ' Bill', 'ie', ""'s"", ' jealous']"
+144 28 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of jealous negative jealous [' people', ' is', ' negative', '.', '\n', '\n', 'The', ' sentiment', ' of', ' jealous', ' people', ' is', ' negative', '.', '\n', '\n', '(', '2', ')', '\n'] " people is negative .
+
+ The sentiment of jealous people is negative .
+
+ ( 2 )
+" True increasingly jealous and questions 1 [' increasingly', ' jealous']
+145 29 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hateful negative hateful [',', ' and', ' the', ' sentiment', ' of', ' love', '.', '\n', '\n', 'The', ' sentiment', ' of', ' hate', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of love .
+
+ The sentiment of hate is the most common sentiment in the" False " horrible ... and hateful and mean"". The" 3 [' horrible', '...', ' and', ' hateful']
+146 29 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hateful negative hateful [',', ' and', ' the', ' sentiment', ' of', ' love', '.', '\n', '\n', 'The', ' sentiment', ' of', ' hate', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of love .
+
+ The sentiment of hate is the most common sentiment in the" False city and his hateful presence out 3 [' city', ' and', ' his', ' hateful']
+147 29 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hateful negative hateful [',', ' and', ' the', ' sentiment', ' of', ' love', '.', '\n', '\n', 'The', ' sentiment', ' of', ' hate', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of love .
+
+ The sentiment of hate is the most common sentiment in the" False " ""racist and hateful views"" of" 3 "[' ""', 'racist', ' and', ' hateful']"
+148 29 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hateful negative hateful [',', ' and', ' the', ' sentiment', ' of', ' love', '.', '\n', '\n', 'The', ' sentiment', ' of', ' hate', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of love .
+
+ The sentiment of hate is the most common sentiment in the" False to help remove hateful advertisers on 3 [' to', ' help', ' remove', ' hateful']
+149 29 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hateful negative hateful [',', ' and', ' the', ' sentiment', ' of', ' love', '.', '\n', '\n', 'The', ' sentiment', ' of', ' hate', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of love .
+
+ The sentiment of hate is the most common sentiment in the" False " horrible ... and hateful and mean"". The controversy" 3 [' horrible', '...', ' and', ' hateful']
+150 30 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of disappointed negative disappointed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' users', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the users is positive .
+
+ The" False " before."" Feldman was disappointed at the change.
+" 4 "[' before', '.""', ' Feldman', ' was', ' disappointed']"
+151 30 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of disappointed negative disappointed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' users', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the users is positive .
+
+ The" False expected, as many disappointed consumers continued 4 [' expected', ',', ' as', ' many', ' disappointed']
+152 30 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of disappointed negative disappointed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' users', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the users is positive .
+
+ The" False Warner Bros. was disappointed with the financial 4 [' Warner', ' Bros', '.', ' was', ' disappointed']
+153 30 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of disappointed negative disappointed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' users', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the users is positive .
+
+ The" False was initially disappointed with the project 2 [' was', ' initially', ' disappointed']
+154 30 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of disappointed negative disappointed [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' users', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the users is positive .
+
+ The" False example, Dudley was disappointed with the few 4 [' example', ',', ' Dudley', ' was', ' disappointed']
+155 31 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of gloomy negative gloomy "[',', ' happy', ',', ' and', ' sad', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' common', '.', '\n']" ", happy , and sad , the sentiment of the word "" love "" is the most common .
+" False Despite the gloomy ambience of 2 [' Despite', ' the', ' gloomy']
+156 31 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of gloomy negative gloomy "[',', ' happy', ',', ' and', ' sad', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' common', '.', '\n']" ", happy , and sad , the sentiment of the word "" love "" is the most common .
+" False when the morning is gloomy and the sun hiding 4 [' when', ' the', ' morning', ' is', ' gloomy']
+157 31 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of gloomy negative gloomy "[',', ' happy', ',', ' and', ' sad', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' common', '.', '\n']" ", happy , and sad , the sentiment of the word "" love "" is the most common .
+" False overwhelmingly produced gloomy masculine self-absorption 2 [' overwhelmingly', ' produced', ' gloomy']
+158 31 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of gloomy negative gloomy "[',', ' happy', ',', ' and', ' sad', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' common', '.', '\n']" ", happy , and sad , the sentiment of the word "" love "" is the most common .
+" False that he was gloomy and nervous. Herbert 3 [' that', ' he', ' was', ' gloomy']
+159 31 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of gloomy negative gloomy "[',', ' happy', ',', ' and', ' sad', ',', ' the', ' sentiment', ' of', ' the', ' word', ' ""', 'love', '""', ' is', ' the', ' most', ' common', '.', '\n']" ", happy , and sad , the sentiment of the word "" love "" is the most common .
+" False and collapse. After gloomy phrases from 4 [' and', ' collapse', '.', ' After', ' gloomy']
+160 32 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of dejected negative dejected [',', ' and', ' the', ' sentiment', ' of', ' excited', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' de', 'jected', ' is', ' the'] ", and the sentiment of excited are the most common .
+
+ The sentiment of de jected is the" False Juraj Mikúš. A dejected Getzlaf lamented 8 [' Jur', 'aj', ' Mik', 'ú', 'š', '.', ' A', ' de', 'jected']
+161 32 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of dejected negative dejected [',', ' and', ' the', ' sentiment', ' of', ' excited', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' de', 'jected', ' is', ' the'] ", and the sentiment of excited are the most common .
+
+ The sentiment of de jected is the" False " Erin feels dejected and leaves.
+" 3 [' Erin', ' feels', ' de', 'jected']
+162 32 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of dejected negative dejected [',', ' and', ' the', ' sentiment', ' of', ' excited', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' de', 'jected', ' is', ' the'] ", and the sentiment of excited are the most common .
+
+ The sentiment of de jected is the" False life and a dejected female named Mariana. 4 [' life', ' and', ' a', ' de', 'jected']
+163 32 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of dejected negative dejected [',', ' and', ' the', ' sentiment', ' of', ' excited', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' de', 'jected', ' is', ' the'] ", and the sentiment of excited are the most common .
+
+ The sentiment of de jected is the" False " Erin feels dejected and leaves.
+" 3 [' Erin', ' feels', ' de', 'jected']
+164 32 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of dejected negative dejected [',', ' and', ' the', ' sentiment', ' of', ' excited', ' are', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' de', 'jected', ' is', ' the'] ", and the sentiment of excited are the most common .
+
+ The sentiment of de jected is the" False Born was so dejected that he gave up a 4 [' Born', ' was', ' so', ' de', 'jected']
+165 33 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hopeless negative hopeless ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' hopeless', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common emotion is hopeless ness .
+
+ The most" False relapses, he was in a hopeless condition. The king's 7 [' rel', 'apses', ',', ' he', ' was', ' in', ' a', ' hopeless']
+166 33 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hopeless negative hopeless ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' hopeless', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common emotion is hopeless ness .
+
+ The most" False becomes a seemingly hopeless drunkard. He spends 3 [' becomes', ' a', ' seemingly', ' hopeless']
+167 33 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hopeless negative hopeless ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' hopeless', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common emotion is hopeless ness .
+
+ The most" False compensate for its hopeless vulgarity, not 3 [' compensate', ' for', ' its', ' hopeless']
+168 33 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hopeless negative hopeless ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' hopeless', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common emotion is hopeless ness .
+
+ The most" False situation was hopeless and ordered a withdrawal 2 [' situation', ' was', ' hopeless']
+169 33 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of hopeless negative hopeless ['ness', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' emotion', ' is', ' hopeless', 'ness', '.', '\n', '\n', 'The', ' most'] "ness is the most common .
+
+ The most common emotion is hopeless ness .
+
+ The most" False that the cause was hopeless and ordered a withdrawal 4 [' that', ' the', ' cause', ' was', ' hopeless']
+170 34 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of despairing negative despairing [',', ' and', ' the', ' sentiment', ' of', ' hope', '.', '\n', '\n', 'The', ' sentiment', ' of', ' despair', 'ing', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of hope .
+
+ The sentiment of despair ing is the most common sentiment in" False as bleak and as despairing a view of the world 5 [' as', ' bleak', ' and', ' as', ' despair', 'ing']
+171 34 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of despairing negative despairing [',', ' and', ' the', ' sentiment', ' of', ' hope', '.', '\n', '\n', 'The', ' sentiment', ' of', ' despair', 'ing', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of hope .
+
+ The sentiment of despair ing is the most common sentiment in" False position and despairing in the face of 3 [' position', ' and', ' despair', 'ing']
+172 34 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of despairing negative despairing [',', ' and', ' the', ' sentiment', ' of', ' hope', '.', '\n', '\n', 'The', ' sentiment', ' of', ' despair', 'ing', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of hope .
+
+ The sentiment of despair ing is the most common sentiment in" False Pictures, and despairing at writing his script 4 [' Pictures', ',', ' and', ' despair', 'ing']
+173 34 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of despairing negative despairing [',', ' and', ' the', ' sentiment', ' of', ' hope', '.', '\n', '\n', 'The', ' sentiment', ' of', ' despair', 'ing', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of hope .
+
+ The sentiment of despair ing is the most common sentiment in" False Reaction and the despairing convulsions of Revolution, 4 [' Reaction', ' and', ' the', ' despair', 'ing']
+174 34 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of despairing negative despairing [',', ' and', ' the', ' sentiment', ' of', ' hope', '.', '\n', '\n', 'The', ' sentiment', ' of', ' despair', 'ing', ' is', ' the', ' most', ' common', ' sentiment', ' in'] ", and the sentiment of hope .
+
+ The sentiment of despair ing is the most common sentiment in" False Homeland ' despairing over a dead 3 "[' Homeland', "" '"", ' despair', 'ing']"
+175 35 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frightened negative frightened [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frightened', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of frightened is the most common sentiment in the" False because we were a bit frightened of actually finishing 5 [' because', ' we', ' were', ' a', ' bit', ' frightened']
+176 35 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frightened negative frightened [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frightened', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of frightened is the most common sentiment in the" False that would have frightened a hundred Chosen Virgins, 3 [' that', ' would', ' have', ' frightened']
+177 35 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frightened negative frightened [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frightened', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of frightened is the most common sentiment in the" False Mountain segment had frightened them. There were 3 [' Mountain', ' segment', ' had', ' frightened']
+178 35 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frightened negative frightened [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frightened', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of frightened is the most common sentiment in the" False first episode from a frightened girl to an empowered 4 [' first', ' episode', ' from', ' a', ' frightened']
+179 35 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of frightened negative frightened [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' frightened', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of frightened is the most common sentiment in the" False particular, Yorke was frightened by a woman 5 [' particular', ',', ' Yor', 'ke', ' was', ' frightened']
+180 36 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of terrified negative terrified [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' terrified', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of terrified is the most common sentiment in the" False featuring the terrified lady telling Mulder 2 [' featuring', ' the', ' terrified']
+181 36 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of terrified negative terrified [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' terrified', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of terrified is the most common sentiment in the" False Finn and Jake are terrified that Marceline 4 [' Finn', ' and', ' Jake', ' are', ' terrified']
+182 36 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of terrified negative terrified [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' terrified', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of terrified is the most common sentiment in the" False with Louisiana; terrified civilians 3 [' with', ' Louisiana', ';', ' terrified']
+183 36 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of terrified negative terrified [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' terrified', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of terrified is the most common sentiment in the" False his children were terrified of him, and that 3 [' his', ' children', ' were', ' terrified']
+184 36 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of terrified negative terrified [',', ' and', ' the', ' sentiment', ' of', ' angry', '.', '\n', '\n', 'The', ' sentiment', ' of', ' terrified', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of angry .
+
+ The sentiment of terrified is the most common sentiment in the" False character who is terrified of, but fascinated 3 [' character', ' who', ' is', ' terrified']
+185 37 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of scared negative scared [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' scared', ' is', ' the', ' most', ' common', ' sentiment', '.', '\n'] ", and the sentiment of happy .
+
+ The sentiment of scared is the most common sentiment .
+" False company and deeply scared of what the future 3 [' company', ' and', ' deeply', ' scared']
+186 37 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of scared negative scared [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' scared', ' is', ' the', ' most', ' common', ' sentiment', '.', '\n'] ", and the sentiment of happy .
+
+ The sentiment of scared is the most common sentiment .
+" False Whether she is scared or scolding 3 [' Whether', ' she', ' is', ' scared']
+187 37 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of scared negative scared [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' scared', ' is', ' the', ' most', ' common', ' sentiment', '.', '\n'] ", and the sentiment of happy .
+
+ The sentiment of scared is the most common sentiment .
+" False immature and easily scared Moira Burton, who 3 [' immature', ' and', ' easily', ' scared']
+188 37 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of scared negative scared [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' scared', ' is', ' the', ' most', ' common', ' sentiment', '.', '\n'] ", and the sentiment of happy .
+
+ The sentiment of scared is the most common sentiment .
+" False little boy and it scared me to death 4 [' little', ' boy', ' and', ' it', ' scared']
+189 37 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of scared negative scared [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' scared', ' is', ' the', ' most', ' common', ' sentiment', '.', '\n'] ", and the sentiment of happy .
+
+ The sentiment of scared is the most common sentiment .
+" False tea, but the dreams scared him so much 5 [' tea', ',', ' but', ' the', ' dreams', ' scared']
+190 38 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of worried negative worried [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' worried', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of worried is the most common sentiment in the" False it, but is worried he may never get 4 [' it', ',', ' but', ' is', ' worried']
+191 38 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of worried negative worried [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' worried', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of worried is the most common sentiment in the" False around him were worried about his state of 3 [' around', ' him', ' were', ' worried']
+192 38 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of worried negative worried [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' worried', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of worried is the most common sentiment in the" False " ""People get a bit worried about me,"" he told" 5 "[' ""', 'People', ' get', ' a', ' bit', ' worried']"
+193 38 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of worried negative worried [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' worried', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of worried is the most common sentiment in the" False public profile, worried that he was not recognisable 3 [' public', ' profile', ',', ' worried']
+194 38 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of worried negative worried [',', ' and', ' the', ' sentiment', ' of', ' happy', '.', '\n', '\n', 'The', ' sentiment', ' of', ' worried', ' is', ' the', ' most', ' common', ' sentiment', ' in', ' the'] ", and the sentiment of happy .
+
+ The sentiment of worried is the most common sentiment in the" False planners were worried about the impact 2 [' planners', ' were', ' worried']
+195 39 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of apprehensive negative apprehensive [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' negative', ' is', ' the', ' sentiment', ' of', ' the', ' negative'] " is the most common .
+
+ The most common sentiment of the negative is the sentiment of the negative" True trade journalists were apprehensive of Raajneeti recovering 4 [' trade', ' journalists', ' were', ' apprehens', 'ive']
+196 39 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of apprehensive negative apprehensive [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' negative', ' is', ' the', ' sentiment', ' of', ' the', ' negative'] " is the most common .
+
+ The most common sentiment of the negative is the sentiment of the negative" True 1534. Growing apprehensive of the power of the 5 [' 15', '34', '.', ' Growing', ' apprehens', 'ive']
+197 39 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of apprehensive negative apprehensive [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' negative', ' is', ' the', ' sentiment', ' of', ' the', ' negative'] " is the most common .
+
+ The most common sentiment of the negative is the sentiment of the negative" True crèche become more apprehensive following 6 [' cr', 'è', 'che', ' become', ' more', ' apprehens', 'ive']
+198 39 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of apprehensive negative apprehensive [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' negative', ' is', ' the', ' sentiment', ' of', ' the', ' negative'] " is the most common .
+
+ The most common sentiment of the negative is the sentiment of the negative" True The band was apprehensive at first, judging 4 [' The', ' band', ' was', ' apprehens', 'ive']
+199 39 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of apprehensive negative apprehensive [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' negative', ' is', ' the', ' sentiment', ' of', ' the', ' negative'] " is the most common .
+
+ The most common sentiment of the negative is the sentiment of the negative" True into Canada. Scott is apprehensive about trusting Saddam, 6 [' into', ' Canada', '.', ' Scott', ' is', ' apprehens', 'ive']
+200 40 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of nervous negative nervous ['ness', ',', ' and', ' the', ' sentiment', ' of', ' excitement', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excitement', ' is', ' the', ' most', ' important', ' one', '.'] "ness , and the sentiment of excitement .
+
+ The sentiment of excitement is the most important one ." False a sympathetic nervous system malignancy, 2 [' a', ' sympathetic', ' nervous']
+201 40 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of nervous negative nervous ['ness', ',', ' and', ' the', ' sentiment', ' of', ' excitement', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excitement', ' is', ' the', ' most', ' important', ' one', '.'] "ness , and the sentiment of excitement .
+
+ The sentiment of excitement is the most important one ." False function: the peripheral nervous system has a greater 4 [' function', ':', ' the', ' peripheral', ' nervous']
+202 40 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of nervous negative nervous ['ness', ',', ' and', ' the', ' sentiment', ' of', ' excitement', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excitement', ' is', ' the', ' most', ' important', ' one', '.'] "ness , and the sentiment of excitement .
+
+ The sentiment of excitement is the most important one ." False " nineteenth-century nervous ailments.""
+" 3 [' nineteenth', '-', 'century', ' nervous']
+203 40 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of nervous negative nervous ['ness', ',', ' and', ' the', ' sentiment', ' of', ' excitement', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excitement', ' is', ' the', ' most', ' important', ' one', '.'] "ness , and the sentiment of excitement .
+
+ The sentiment of excitement is the most important one ." False recover from a nervous breakdown. Virginia 3 [' recover', ' from', ' a', ' nervous']
+204 40 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of nervous negative nervous ['ness', ',', ' and', ' the', ' sentiment', ' of', ' excitement', '.', '\n', '\n', 'The', ' sentiment', ' of', ' excitement', ' is', ' the', ' most', ' important', ' one', '.'] "ness , and the sentiment of excitement .
+
+ The sentiment of excitement is the most important one ." False Filbert Street. A nervous and homesick Wark 5 [' Fil', 'bert', ' Street', '.', ' A', ' nervous']
+205 41 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False illusion), and neutral monism (which holds 3 [' illusion', '),', ' and', ' neutral']
+206 41 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False electrons to create neutral atoms. As 3 [' electrons', ' to', ' create', ' neutral']
+207 41 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False tribes, begin as neutral to the player 4 [' tribes', ',', ' begin', ' as', ' neutral']
+208 41 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False Kentucky's desire to stay neutral in the war. In addition, 5 "[' Kentucky', ""'s"", ' desire', ' to', ' stay', ' neutral']"
+209 41 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False (7.5 mi) wide neutral zone along the line. 7 [' (', '7', '.', '5', ' mi', ')', ' wide', ' neutral']
+210 42 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of computer neutral computer [' users', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' computer', ' users', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ',', ' negative'] " users is positive .
+
+ The sentiment of computer users is positive .
+
+ Among positive , negative" False metaphor for computer networks, 2 [' metaphor', ' for', ' computer']
+211 42 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of computer neutral computer [' users', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' computer', ' users', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ',', ' negative'] " users is positive .
+
+ The sentiment of computer users is positive .
+
+ Among positive , negative" False War, a text-based computer game that puts the 6 [' War', ',', ' a', ' text', '-', 'based', ' computer']
+212 42 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of computer neutral computer [' users', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' computer', ' users', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ',', ' negative'] " users is positive .
+
+ The sentiment of computer users is positive .
+
+ Among positive , negative" False passes of ships over computer-generated versions, 4 [' passes', ' of', ' ships', ' over', ' computer']
+213 42 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of computer neutral computer [' users', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' computer', ' users', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ',', ' negative'] " users is positive .
+
+ The sentiment of computer users is positive .
+
+ Among positive , negative" False rogues to be a computer virus that need 5 [' rog', 'ues', ' to', ' be', ' a', ' computer']
+214 42 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of computer neutral computer [' users', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' computer', ' users', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ',', ' negative'] " users is positive .
+
+ The sentiment of computer users is positive .
+
+ Among positive , negative" False Eko out of the computer room. Desperate 5 [' E', 'ko', ' out', ' of', ' the', ' computer']
+215 43 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of car neutral car [' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car', ' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car'] " buyers is positive .
+
+ The sentiment of car buyers is positive .
+
+ The sentiment of car" False of an extended car ride when his friend 3 [' of', ' an', ' extended', ' car']
+216 43 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of car neutral car [' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car', ' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car'] " buyers is positive .
+
+ The sentiment of car buyers is positive .
+
+ The sentiment of car" False build the world's best car. The project, 5 "[' build', ' the', ' world', ""'s"", ' best', ' car']"
+217 43 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of car neutral car [' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car', ' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car'] " buyers is positive .
+
+ The sentiment of car buyers is positive .
+
+ The sentiment of car" False assigned to infantry-carrying Mark 4 [' assigned', ' to', ' infantry', '-', 'car']
+218 43 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of car neutral car [' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car', ' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car'] " buyers is positive .
+
+ The sentiment of car buyers is positive .
+
+ The sentiment of car" False by lap 30 despite carrying a heavier 4 [' by', ' lap', ' 30', ' despite', ' car']
+219 43 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of car neutral car [' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car', ' buyers', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' car'] " buyers is positive .
+
+ The sentiment of car buyers is positive .
+
+ The sentiment of car" False control of his car and slid sideways 3 [' control', ' of', ' his', ' car']
+220 44 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of house neutral house [' owners', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' house', ' owners', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ','] " owners is positive .
+
+ The sentiment of the house owners is positive .
+
+ Among positive ," False foundation of the house is more exposed 3 [' foundation', ' of', ' the', ' house']
+221 44 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of house neutral house [' owners', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' house', ' owners', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ','] " owners is positive .
+
+ The sentiment of the house owners is positive .
+
+ Among positive ," False Canales's house in an attempt 3 "[' Can', 'ales', ""'s"", ' house']"
+222 44 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of house neutral house [' owners', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' house', ' owners', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ','] " owners is positive .
+
+ The sentiment of the house owners is positive .
+
+ Among positive ," False ceremony at the meeting house was halted 4 [' ceremony', ' at', ' the', ' meeting', ' house']
+223 44 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of house neutral house [' owners', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' house', ' owners', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ','] " owners is positive .
+
+ The sentiment of the house owners is positive .
+
+ Among positive ," False residence in a new house, Satyanarayana 4 [' residence', ' in', ' a', ' new', ' house']
+224 44 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of house neutral house [' owners', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' house', ' owners', ' is', ' positive', '.', '\n', '\n', 'Among', ' positive', ','] " owners is positive .
+
+ The sentiment of the house owners is positive .
+
+ Among positive ," False description of the house that echoes 3 [' description', ' of', ' the', ' house']
+225 45 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of tree neutral tree ['-', 'h', 'ug', 'gers', ' is', ' overwhelmingly', ' positive', '.', '\n', '\n', 'The', ' most', ' popular', ' tree', '-', 'h', 'ug', 'gers', ' are', ':'] "- h ug gers is overwhelmingly positive .
+
+ The most popular tree - h ug gers are :" False travels past a tree farm. After 3 [' travels', ' past', ' a', ' tree']
+226 45 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of tree neutral tree ['-', 'h', 'ug', 'gers', ' is', ' overwhelmingly', ' positive', '.', '\n', '\n', 'The', ' most', ' popular', ' tree', '-', 'h', 'ug', 'gers', ' are', ':'] "- h ug gers is overwhelmingly positive .
+
+ The most popular tree - h ug gers are :" False forest canopy trees, remove vines 2 [' forest', ' canopy', ' tree']
+227 45 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of tree neutral tree ['-', 'h', 'ug', 'gers', ' is', ' overwhelmingly', ' positive', '.', '\n', '\n', 'The', ' most', ' popular', ' tree', '-', 'h', 'ug', 'gers', ' are', ':'] "- h ug gers is overwhelmingly positive .
+
+ The most popular tree - h ug gers are :" False reverse and palm tree shadows swaying with 3 [' reverse', ' and', ' palm', ' tree']
+228 45 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of tree neutral tree ['-', 'h', 'ug', 'gers', ' is', ' overwhelmingly', ' positive', '.', '\n', '\n', 'The', ' most', ' popular', ' tree', '-', 'h', 'ug', 'gers', ' are', ':'] "- h ug gers is overwhelmingly positive .
+
+ The most popular tree - h ug gers are :" False agriculture, fruit tree cultivation, 3 [' agriculture', ',', ' fruit', ' tree']
+229 45 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of tree neutral tree ['-', 'h', 'ug', 'gers', ' is', ' overwhelmingly', ' positive', '.', '\n', '\n', 'The', ' most', ' popular', ' tree', '-', 'h', 'ug', 'gers', ' are', ':'] "- h ug gers is overwhelmingly positive .
+
+ The most popular tree - h ug gers are :" False down numerous trees, some of them 2 [' down', ' numerous', ' tree']
+230 46 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of book neutral book [' reviews', ' is', ' positive', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good', ' read', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good'] " reviews is positive .
+
+ The book is a good read .
+
+ The book is a good" False finally published in book form under 3 [' finally', ' published', ' in', ' book']
+231 46 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of book neutral book [' reviews', ' is', ' positive', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good', ' read', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good'] " reviews is positive .
+
+ The book is a good read .
+
+ The book is a good" False " Borstein called the book a ""companion piece" 4 [' Bor', 'stein', ' called', ' the', ' book']
+232 46 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of book neutral book [' reviews', ' is', ' positive', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good', ' read', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good'] " reviews is positive .
+
+ The book is a good read .
+
+ The book is a good" False 0 ['book']
+233 46 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of book neutral book [' reviews', ' is', ' positive', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good', ' read', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good'] " reviews is positive .
+
+ The book is a good read .
+
+ The book is a good" False the foreword to the book's first edition 5 [' the', ' fore', 'word', ' to', ' the', ' book']
+234 46 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of book neutral book [' reviews', ' is', ' positive', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good', ' read', '.', '\n', '\n', 'The', ' book', ' is', ' a', ' good'] " reviews is positive .
+
+ The book is a good read .
+
+ The book is a good" False " Bronx"", the book contains four stories" 3 "[' Bronx', '"",', ' the', ' book']"
+235 47 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of money neutral money [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money'] " is a positive one .
+
+ The sentiment of money is positive .
+
+ The sentiment of money" False that much of the money would go directly 4 [' that', ' much', ' of', ' the', ' money']
+236 47 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of money neutral money [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money'] " is a positive one .
+
+ The sentiment of money is positive .
+
+ The sentiment of money" False purpose was to raise money on behalf of the 4 [' purpose', ' was', ' to', ' raise', ' money']
+237 47 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of money neutral money [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money'] " is a positive one .
+
+ The sentiment of money is positive .
+
+ The sentiment of money" False advocacy of hard money had won him friends 3 [' advocacy', ' of', ' hard', ' money']
+238 47 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of money neutral money [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money'] " is a positive one .
+
+ The sentiment of money is positive .
+
+ The sentiment of money" False resurgence of money circulation. During 2 [' resurgence', ' of', ' money']
+239 47 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of money neutral money [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money', ' is', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' money'] " is a positive one .
+
+ The sentiment of money is positive .
+
+ The sentiment of money" False the payment of money into the Consolidated 3 [' the', ' payment', ' of', ' money']
+240 48 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of time neutral time [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' time', ' is', ' a', ' positive', ' one', '.', '\n', '\n', 'Among', ' positive'] " is a positive one .
+
+ The sentiment of time is a positive one .
+
+ Among positive" False BBC to allow time for radio trailers 3 [' BBC', ' to', ' allow', ' time']
+241 48 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of time neutral time [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' time', ' is', ' a', ' positive', ' one', '.', '\n', '\n', 'Among', ' positive'] " is a positive one .
+
+ The sentiment of time is a positive one .
+
+ Among positive" False " Richardson's ""all time favorite"" jokes.
+" 4 "[' Richardson', ""'s"", ' ""', 'all', ' time']"
+242 48 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of time neutral time [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' time', ' is', ' a', ' positive', ' one', '.', '\n', '\n', 'Among', ' positive'] " is a positive one .
+
+ The sentiment of time is a positive one .
+
+ Among positive" False each driver set a time within 107 % 4 [' each', ' driver', ' set', ' a', ' time']
+243 48 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of time neutral time [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' time', ' is', ' a', ' positive', ' one', '.', '\n', '\n', 'Among', ' positive'] " is a positive one .
+
+ The sentiment of time is a positive one .
+
+ Among positive" False a ten-year time limit on the 4 [' a', ' ten', '-', 'year', ' time']
+244 48 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of time neutral time [' is', ' a', ' positive', ' one', '.', '\n', '\n', 'The', ' sentiment', ' of', ' time', ' is', ' a', ' positive', ' one', '.', '\n', '\n', 'Among', ' positive'] " is a positive one .
+
+ The sentiment of time is a positive one .
+
+ Among positive" False for the longest time takes nothing 3 [' for', ' the', ' longest', ' time']
+245 49 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of day neutral day ['-', 'to', '-', 'day', ' life', ',', ' the', ' sentiment', ' of', ' the', ' world', ',', ' the', ' sentiment', ' of', ' the', ' universe', ',', ' the', ' sentiment'] - to - day life , the sentiment of the world , the sentiment of the universe , the sentiment False and windy day in Tombstone, and 3 [' and', ' wind', 'y', ' day']
+246 49 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of day neutral day ['-', 'to', '-', 'day', ' life', ',', ' the', ' sentiment', ' of', ' the', ' world', ',', ' the', ' sentiment', ' of', ' the', ' universe', ',', ' the', ' sentiment'] - to - day life , the sentiment of the world , the sentiment of the universe , the sentiment False " A ""Stan the Man"" day was held in his" 6 "[' A', ' ""', 'Stan', ' the', ' Man', '""', ' day']"
+247 49 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of day neutral day ['-', 'to', '-', 'day', ' life', ',', ' the', ' sentiment', ' of', ' the', ' world', ',', ' the', ' sentiment', ' of', ' the', ' universe', ',', ' the', ' sentiment'] - to - day life , the sentiment of the world , the sentiment of the universe , the sentiment False Halloween, the one day in which the 4 [' Halloween', ',', ' the', ' one', ' day']
+248 49 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of day neutral day ['-', 'to', '-', 'day', ' life', ',', ' the', ' sentiment', ' of', ' the', ' world', ',', ' the', ' sentiment', ' of', ' the', ' universe', ',', ' the', ' sentiment'] - to - day life , the sentiment of the world , the sentiment of the universe , the sentiment False 0 ['day']
+249 49 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of day neutral day ['-', 'to', '-', 'day', ' life', ',', ' the', ' sentiment', ' of', ' the', ' world', ',', ' the', ' sentiment', ' of', ' the', ' universe', ',', ' the', ' sentiment'] - to - day life , the sentiment of the world , the sentiment of the universe , the sentiment False up again the next day with a fairly 4 [' up', ' again', ' the', ' next', ' day']
+250 50 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of week neutral week ['-', 'end', ' is', ' very', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' week', '-', 'end', ' is', ' very', ' positive', '.', '\n'] "- end is very positive .
+
+ The sentiment of the week - end is very positive .
+" False 0 ['week']
+251 50 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of week neutral week ['-', 'end', ' is', ' very', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' week', '-', 'end', ' is', ' very', ' positive', '.', '\n'] "- end is very positive .
+
+ The sentiment of the week - end is very positive .
+" False the chart week of December 11, 2 [' the', ' chart', ' week']
+252 50 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of week neutral week ['-', 'end', ' is', ' very', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' week', '-', 'end', ' is', ' very', ' positive', '.', '\n'] "- end is very positive .
+
+ The sentiment of the week - end is very positive .
+" False earn their army a week of exclusive, early 4 [' earn', ' their', ' army', ' a', ' week']
+253 50 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of week neutral week ['-', 'end', ' is', ' very', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' week', '-', 'end', ' is', ' very', ' positive', '.', '\n'] "- end is very positive .
+
+ The sentiment of the week - end is very positive .
+" False (£ 1.60) per week — the same 6 [' (£', ' 1', '.', '60', ')', ' per', ' week']
+254 50 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of week neutral week ['-', 'end', ' is', ' very', ' positive', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' week', '-', 'end', ' is', ' very', ' positive', '.', '\n'] "- end is very positive .
+
+ The sentiment of the week - end is very positive .
+" False chart during the week following its release 3 [' chart', ' during', ' the', ' week']
+255 51 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ordinary neutral ordinary [' people', ' is', ' that', ' the', ' government', ' should', ' not', ' interfere', ' in', ' the', ' market', '.', '\n', '\n', 'The', ' government', ' should', ' not', ' interfere', ' in'] " people is that the government should not interfere in the market .
+
+ The government should not interfere in" False clothing while the ordinary crew dressed like 3 [' clothing', ' while', ' the', ' ordinary']
+256 51 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ordinary neutral ordinary [' people', ' is', ' that', ' the', ' government', ' should', ' not', ' interfere', ' in', ' the', ' market', '.', '\n', '\n', 'The', ' government', ' should', ' not', ' interfere', ' in'] " people is that the government should not interfere in the market .
+
+ The government should not interfere in" False traditional at the top of ordinary Turkish gravestones 5 [' traditional', ' at', ' the', ' top', ' of', ' ordinary']
+257 51 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ordinary neutral ordinary [' people', ' is', ' that', ' the', ' government', ' should', ' not', ' interfere', ' in', ' the', ' market', '.', '\n', '\n', 'The', ' government', ' should', ' not', ' interfere', ' in'] " people is that the government should not interfere in the market .
+
+ The government should not interfere in" False the idea of an ordinary person doing something 4 [' the', ' idea', ' of', ' an', ' ordinary']
+258 51 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ordinary neutral ordinary [' people', ' is', ' that', ' the', ' government', ' should', ' not', ' interfere', ' in', ' the', ' market', '.', '\n', '\n', 'The', ' government', ' should', ' not', ' interfere', ' in'] " people is that the government should not interfere in the market .
+
+ The government should not interfere in" False and hundreds of ordinary people such as 3 [' and', ' hundreds', ' of', ' ordinary']
+259 51 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of ordinary neutral ordinary [' people', ' is', ' that', ' the', ' government', ' should', ' not', ' interfere', ' in', ' the', ' market', '.', '\n', '\n', 'The', ' government', ' should', ' not', ' interfere', ' in'] " people is that the government should not interfere in the market .
+
+ The government should not interfere in" False easily with the ordinary people of his congregation 3 [' easily', ' with', ' the', ' ordinary']
+260 52 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of common neutral common [' people', ' is', ' very', ' different', '.', '\n', '\n', 'The', ' positive', ' sentiment', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' negative', ' sentiment'] " people is very different .
+
+ The positive sentiment is the most popular .
+
+ The negative sentiment" False that were in common use in households 3 [' that', ' were', ' in', ' common']
+261 52 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of common neutral common [' people', ' is', ' very', ' different', '.', '\n', '\n', 'The', ' positive', ' sentiment', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' negative', ' sentiment'] " people is very different .
+
+ The positive sentiment is the most popular .
+
+ The negative sentiment" False Dysarthria is the most common communication 6 [' Dys', 'arth', 'ria', ' is', ' the', ' most', ' common']
+262 52 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of common neutral common [' people', ' is', ' very', ' different', '.', '\n', '\n', 'The', ' positive', ' sentiment', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' negative', ' sentiment'] " people is very different .
+
+ The positive sentiment is the most popular .
+
+ The negative sentiment" False symptoms with the common cold affecting primarily 3 [' symptoms', ' with', ' the', ' common']
+263 52 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of common neutral common [' people', ' is', ' very', ' different', '.', '\n', '\n', 'The', ' positive', ' sentiment', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' negative', ' sentiment'] " people is very different .
+
+ The positive sentiment is the most popular .
+
+ The negative sentiment" False have survived. The common Christians saints 4 [' have', ' survived', '.', ' The', ' common']
+264 52 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of common neutral common [' people', ' is', ' very', ' different', '.', '\n', '\n', 'The', ' positive', ' sentiment', ' is', ' the', ' most', ' popular', '.', '\n', '\n', 'The', ' negative', ' sentiment'] " people is very different .
+
+ The positive sentiment is the most popular .
+
+ The negative sentiment" False remainder. Since rN − 1 is a common divisor of a and b, 9 [' remainder', '.', ' Since', ' r', 'N', ' −', ' 1', ' is', ' a', ' common']
+265 53 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of typical neutral typical [' users', ' is', ' positive', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', ' of', ' the', ' most', ' popular', ' users', '**', ':', ' The'] " users is positive .
+
+ - ** Sent iment of the most popular users ** : The" False the hammour. The typical marine life off the 5 [' the', ' hamm', 'our', '.', ' The', ' typical']
+266 53 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of typical neutral typical [' users', ' is', ' positive', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', ' of', ' the', ' most', ' popular', ' users', '**', ':', ' The'] " users is positive .
+
+ - ** Sent iment of the most popular users ** : The" False tourist of the most typical variety leaned 4 [' tourist', ' of', ' the', ' most', ' typical']
+267 53 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of typical neutral typical [' users', ' is', ' positive', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', ' of', ' the', ' most', ' popular', ' users', '**', ':', ' The'] " users is positive .
+
+ - ** Sent iment of the most popular users ** : The" False expensive than a typical one, due to 3 [' expensive', ' than', ' a', ' typical']
+268 53 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of typical neutral typical [' users', ' is', ' positive', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', ' of', ' the', ' most', ' popular', ' users', '**', ':', ' The'] " users is positive .
+
+ - ** Sent iment of the most popular users ** : The" False compared with the typical form. His variety ceratophylloides 3 [' compared', ' with', ' the', ' typical']
+269 53 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of typical neutral typical [' users', ' is', ' positive', '.', '\n', '\n', '-', ' ', ' ', ' **', 'Sent', 'iment', ' of', ' the', ' most', ' popular', ' users', '**', ':', ' The'] " users is positive .
+
+ - ** Sent iment of the most popular users ** : The" False occasionally coma are typical symptoms. 3 [' occasionally', ' coma', ' are', ' typical']
+270 54 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of average neutral average [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' average', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment'] " is the most common .
+
+ The sentiment of average is the most common .
+
+ The sentiment" False tortoise seems to average slightly larger, 4 [' tort', 'oise', ' seems', ' to', ' average']
+271 54 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of average neutral average [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' average', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment'] " is the most common .
+
+ The sentiment of average is the most common .
+
+ The sentiment" False " meaning mixed or average reviews.
+" 3 [' meaning', ' mixed', ' or', ' average']
+272 54 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of average neutral average [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' average', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment'] " is the most common .
+
+ The sentiment of average is the most common .
+
+ The sentiment" False led the SEC with an average of 2.14 field 5 [' led', ' the', ' SEC', ' with', ' an', ' average']
+273 54 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of average neutral average [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' average', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment'] " is the most common .
+
+ The sentiment of average is the most common .
+
+ The sentiment" False statistic that the average age for a girl 3 [' statistic', ' that', ' the', ' average']
+274 54 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of average neutral average [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment', ' of', ' average', ' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' sentiment'] " is the most common .
+
+ The sentiment of average is the most common .
+
+ The sentiment" False 342 runs at an average of 42.75. In all 4 [' 342', ' runs', ' at', ' an', ' average']
+275 55 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of indifferent neutral indifferent [',', ' and', ' the', ' sentiment', ' of', ' the', ' unknown', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' unknown', ' is', ' the', ' most', ' interesting', '.'] ", and the sentiment of the unknown .
+
+ The sentiment of the unknown is the most interesting ." False had been at best indifferent about terrorist 4 [' had', ' been', ' at', ' best', ' indifferent']
+276 55 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of indifferent neutral indifferent [',', ' and', ' the', ' sentiment', ' of', ' the', ' unknown', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' unknown', ' is', ' the', ' most', ' interesting', '.'] ", and the sentiment of the unknown .
+
+ The sentiment of the unknown is the most interesting ." False refers to someone indifferent to pain, pleasure, 3 [' refers', ' to', ' someone', ' indifferent']
+277 55 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of indifferent neutral indifferent [',', ' and', ' the', ' sentiment', ' of', ' the', ' unknown', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' unknown', ' is', ' the', ' most', ' interesting', '.'] ", and the sentiment of the unknown .
+
+ The sentiment of the unknown is the most interesting ." False looking upon him as indifferent or even favourable 4 [' looking', ' upon', ' him', ' as', ' indifferent']
+278 55 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of indifferent neutral indifferent [',', ' and', ' the', ' sentiment', ' of', ' the', ' unknown', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' unknown', ' is', ' the', ' most', ' interesting', '.'] ", and the sentiment of the unknown .
+
+ The sentiment of the unknown is the most interesting ." False not writing indifferent little pieces, only 2 [' not', ' writing', ' indifferent']
+279 55 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of indifferent neutral indifferent [',', ' and', ' the', ' sentiment', ' of', ' the', ' unknown', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' unknown', ' is', ' the', ' most', ' interesting', '.'] ", and the sentiment of the unknown .
+
+ The sentiment of the unknown is the most interesting ." False disturbances. Despite indifferent reviews from 3 [' disturbances', '.', ' Despite', ' indifferent']
+280 56 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unbiased neutral unbiased [',', ' and', ' the', ' sentiment', ' of', ' the', ' media', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' media', ' is', ' the', ' most', ' important', '.'] ", and the sentiment of the media .
+
+ The sentiment of the media is the most important ." False " that ""Even an unbiased fan must regard" 4 "[' that', ' ""', 'Even', ' an', ' unbiased']"
+281 56 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unbiased neutral unbiased [',', ' and', ' the', ' sentiment', ' of', ' the', ' media', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' media', ' is', ' the', ' most', ' important', '.'] ", and the sentiment of the media .
+
+ The sentiment of the media is the most important ." False " ""an objective, unbiased and thorough examination""" 4 "[' ""', 'an', ' objective', ',', ' unbiased']"
+282 56 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unbiased neutral unbiased [',', ' and', ' the', ' sentiment', ' of', ' the', ' media', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' media', ' is', ' the', ' most', ' important', '.'] ", and the sentiment of the media .
+
+ The sentiment of the media is the most important ." False " ""hardly the sort of unbiased reporting which warrants" 6 "[' ""', 'hard', 'ly', ' the', ' sort', ' of', ' unbiased']"
+283 56 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unbiased neutral unbiased [',', ' and', ' the', ' sentiment', ' of', ' the', ' media', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' media', ' is', ' the', ' most', ' important', '.'] ", and the sentiment of the media .
+
+ The sentiment of the media is the most important ." False equation represents the unbiasedness hypothesis, 3 [' equation', ' represents', ' the', ' unbiased']
+284 56 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of unbiased neutral unbiased [',', ' and', ' the', ' sentiment', ' of', ' the', ' media', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' media', ' is', ' the', ' most', ' important', '.'] ", and the sentiment of the media .
+
+ The sentiment of the media is the most important ." False " time with a clear unbiased perspective,""" 4 [' time', ' with', ' a', ' clear', ' unbiased']
+285 57 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of impartial neutral impartial ['ity', ' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' impartial', '.', '\n', '\n', 'The'] "ity is the most important .
+
+ The most important thing is to be impartial .
+
+ The" False even-handed and impartial manner and give it 4 [' even', '-', 'handed', ' and', ' impartial']
+286 57 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of impartial neutral impartial ['ity', ' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' impartial', '.', '\n', '\n', 'The'] "ity is the most important .
+
+ The most important thing is to be impartial .
+
+ The" False independent and impartial court established 2 [' independent', ' and', ' impartial']
+287 57 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of impartial neutral impartial ['ity', ' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' impartial', '.', '\n', '\n', 'The'] "ity is the most important .
+
+ The most important thing is to be impartial .
+
+ The" False anthropologists remain impartial during the course 3 [' anthrop', 'ologists', ' remain', ' impartial']
+288 57 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of impartial neutral impartial ['ity', ' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' impartial', '.', '\n', '\n', 'The'] "ity is the most important .
+
+ The most important thing is to be impartial .
+
+ The" False position as an impartial mediator was 3 [' position', ' as', ' an', ' impartial']
+289 57 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of impartial neutral impartial ['ity', ' is', ' the', ' most', ' important', '.', '\n', '\n', 'The', ' most', ' important', ' thing', ' is', ' to', ' be', ' impartial', '.', '\n', '\n', 'The'] "ity is the most important .
+
+ The most important thing is to be impartial .
+
+ The" False citing Malaysia's impartial position on 3 "[' citing', ' Malaysia', ""'s"", ' impartial']"
+290 58 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of objective neutral objective [',', ' subjective', ',', ' and', ' mixed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' objective', ' is', ' the', ' most', ' important', '.', '\n', '\n'] ", subjective , and mixed .
+
+ The sentiment of the objective is the most important .
+
+" False during the war. Their objective was the undefended 5 [' during', ' the', ' war', '.', ' Their', ' objective']
+291 58 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of objective neutral objective [',', ' subjective', ',', ' and', ' mixed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' objective', ' is', ' the', ' most', ' important', '.', '\n', '\n'] ", subjective , and mixed .
+
+ The sentiment of the objective is the most important .
+
+" False Relto. The main objective of the game is 5 [' Rel', 'to', '.', ' The', ' main', ' objective']
+292 58 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of objective neutral objective [',', ' subjective', ',', ' and', ' mixed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' objective', ' is', ' the', ' most', ' important', '.', '\n', '\n'] ", subjective , and mixed .
+
+ The sentiment of the objective is the most important .
+
+" False learned that their objective was the powder 3 [' learned', ' that', ' their', ' objective']
+293 58 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of objective neutral objective [',', ' subjective', ',', ' and', ' mixed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' objective', ' is', ' the', ' most', ' important', '.', '\n', '\n'] ", subjective , and mixed .
+
+ The sentiment of the objective is the most important .
+
+" False Crown Fountain. His objective was to create 4 [' Crown', ' Fountain', '.', ' His', ' objective']
+294 58 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of objective neutral objective [',', ' subjective', ',', ' and', ' mixed', '.', '\n', '\n', 'The', ' sentiment', ' of', ' the', ' objective', ' is', ' the', ' most', ' important', '.', '\n', '\n'] ", subjective , and mixed .
+
+ The sentiment of the objective is the most important .
+
+" False materials for the objective events, the profits 3 [' materials', ' for', ' the', ' objective']
+295 59 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False almost completely neutral as he wanted 2 [' almost', ' completely', ' neutral']
+296 59 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False Skagerrak to inspect neutral shipping for 5 [' Sk', 'ager', 'rak', ' to', ' inspect', ' neutral']
+297 59 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False forces invaded neutral Belgium without a 2 [' forces', ' invaded', ' neutral']
+298 59 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False was initially neutral in World War I and 2 [' was', ' initially', ' neutral']
+299 59 Among positive, negative, neutral, the sentiment of x -1 Among positive, negative, neutral, the sentiment of neutral neutral neutral [' is', ' the', ' most', ' common', '.', '\n', '\n', 'The', ' most', ' common', ' sentiment', ' of', ' the', ' tweets', ' is', ' positive', '.', '\n', '\n', 'The'] " is the most common .
+
+ The most common sentiment of the tweets is positive .
+
+ The" False pattern or a neutral colour such 3 [' pattern', ' or', ' a', ' neutral']
diff --git a/patchscopes/code/preprocessed_data/commonsense/work_location.tsv b/patchscopes/code/preprocessed_data/commonsense/work_location.tsv
new file mode 100644
index 00000000..9dd49a20
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/commonsense/work_location.tsv
@@ -0,0 +1,385 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The place of work where you can usually find a x -1 The place of work where you can usually find a farmer farm farmer "[',', ' a', ' farmer', ""'s"", ' wife', ',', ' or', ' a', ' farmer', ""'s"", ' daughter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually']" ", a farmer 's wife , or a farmer 's daughter .
+
+ The place where you can usually" True growing belief that farmers' financial troubles 3 [' growing', ' belief', ' that', ' farmer']
+1 0 The place of work where you can usually find a x -1 The place of work where you can usually find a farmer farm farmer "[',', ' a', ' farmer', ""'s"", ' wife', ',', ' or', ' a', ' farmer', ""'s"", ' daughter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually']" ", a farmer 's wife , or a farmer 's daughter .
+
+ The place where you can usually" True as a surveyor and farmer in Johnson County. 5 [' as', ' a', ' survey', 'or', ' and', ' farmer']
+2 0 The place of work where you can usually find a x -1 The place of work where you can usually find a farmer farm farmer "[',', ' a', ' farmer', ""'s"", ' wife', ',', ' or', ' a', ' farmer', ""'s"", ' daughter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually']" ", a farmer 's wife , or a farmer 's daughter .
+
+ The place where you can usually" True His father was a farmer and mayor of the city, 4 [' His', ' father', ' was', ' a', ' farmer']
+3 0 The place of work where you can usually find a x -1 The place of work where you can usually find a farmer farm farmer "[',', ' a', ' farmer', ""'s"", ' wife', ',', ' or', ' a', ' farmer', ""'s"", ' daughter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually']" ", a farmer 's wife , or a farmer 's daughter .
+
+ The place where you can usually" True Ireland, the first farmer to finish the grain 4 [' Ireland', ',', ' the', ' first', ' farmer']
+4 0 The place of work where you can usually find a x -1 The place of work where you can usually find a farmer farm farmer "[',', ' a', ' farmer', ""'s"", ' wife', ',', ' or', ' a', ' farmer', ""'s"", ' daughter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually']" ", a farmer 's wife , or a farmer 's daughter .
+
+ The place where you can usually" True Maley Eichelberger, a farmer and lawyer, and 8 [' Male', 'y', ' E', 'ic', 'hel', 'berger', ',', ' a', ' farmer']
+5 1 The place of work where you can usually find a x -1 The place of work where you can usually find a lawyer courthouse lawyer [',', ' accountant', ',', ' or', ' other', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' lawyer', ',', ' accountant'] ", accountant , or other professional .
+
+ The place where you can usually find a lawyer , accountant" False argumentation used by Imo lawyer Charles Burchell. 6 [' argument', 'ation', ' used', ' by', ' Im', 'o', ' lawyer']
+6 1 The place of work where you can usually find a x -1 The place of work where you can usually find a lawyer courthouse lawyer [',', ' accountant', ',', ' or', ' other', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' lawyer', ',', ' accountant'] ", accountant , or other professional .
+
+ The place where you can usually find a lawyer , accountant" False vice-president, and the lawyer Felix Perez 6 [' vice', '-', 'president', ',', ' and', ' the', ' lawyer']
+7 1 The place of work where you can usually find a x -1 The place of work where you can usually find a lawyer courthouse lawyer [',', ' accountant', ',', ' or', ' other', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' lawyer', ',', ' accountant'] ", accountant , or other professional .
+
+ The place where you can usually find a lawyer , accountant" False June. He was a lawyer and a business 5 [' June', '.', ' He', ' was', ' a', ' lawyer']
+8 1 The place of work where you can usually find a x -1 The place of work where you can usually find a lawyer courthouse lawyer [',', ' accountant', ',', ' or', ' other', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' lawyer', ',', ' accountant'] ", accountant , or other professional .
+
+ The place where you can usually find a lawyer , accountant" False paraplegic lawyer who is married to 3 [' parap', 'leg', 'ic', ' lawyer']
+9 1 The place of work where you can usually find a x -1 The place of work where you can usually find a lawyer courthouse lawyer [',', ' accountant', ',', ' or', ' other', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' lawyer', ',', ' accountant'] ", accountant , or other professional .
+
+ The place where you can usually find a lawyer , accountant" False Pedicone, as an outside lawyer for the Finance 6 [' Ped', 'icone', ',', ' as', ' an', ' outside', ' lawyer']
+10 2 The place of work where you can usually find a x -1 The place of work where you can usually find a teacher school teacher [',', ' a', ' friend', ',', ' a', ' mentor', ',', ' a', ' coach', ',', ' a', ' guide', ',', ' a', ' guide', ',', ' a', ' friend', ',', ' a'] , a friend , a mentor , a coach , a guide , a guide , a friend , a False full-time substitute teacher in Jackson, 4 [' full', '-', 'time', ' substitute', ' teacher']
+11 2 The place of work where you can usually find a x -1 The place of work where you can usually find a teacher school teacher [',', ' a', ' friend', ',', ' a', ' mentor', ',', ' a', ' coach', ',', ' a', ' guide', ',', ' a', ' guide', ',', ' a', ' friend', ',', ' a'] , a friend , a mentor , a coach , a guide , a guide , a friend , a False Caro one of the best teachers in Cardenal Caro 6 [' Car', 'o', ' one', ' of', ' the', ' best', ' teacher']
+12 2 The place of work where you can usually find a x -1 The place of work where you can usually find a teacher school teacher [',', ' a', ' friend', ',', ' a', ' mentor', ',', ' a', ' coach', ',', ' a', ' guide', ',', ' a', ' guide', ',', ' a', ' friend', ',', ' a'] , a friend , a mentor , a coach , a guide , a guide , a friend , a False Zhou Li, a violin teacher at the Royal 5 [' Zhou', ' Li', ',', ' a', ' violin', ' teacher']
+13 2 The place of work where you can usually find a x -1 The place of work where you can usually find a teacher school teacher [',', ' a', ' friend', ',', ' a', ' mentor', ',', ' a', ' coach', ',', ' a', ' guide', ',', ' a', ' guide', ',', ' a', ' friend', ',', ' a'] , a friend , a mentor , a coach , a guide , a guide , a friend , a False husband, the artist and teacher Herbert MacNair 5 [' husband', ',', ' the', ' artist', ' and', ' teacher']
+14 2 The place of work where you can usually find a x -1 The place of work where you can usually find a teacher school teacher [',', ' a', ' friend', ',', ' a', ' mentor', ',', ' a', ' coach', ',', ' a', ' guide', ',', ' a', ' guide', ',', ' a', ' friend', ',', ' a'] , a friend , a mentor , a coach , a guide , a guide , a friend , a False Applications for teacher education have 2 [' Applications', ' for', ' teacher']
+15 3 The place of work where you can usually find a x -1 The place of work where you can usually find a accountant office accountant [',', ' a', ' lawyer', ',', ' a', ' doctor', ',', ' a', ' dentist', ',', ' a', ' pl', 'umber', ',', ' a', ' mechanic', ',', ' a', ' car', 'penter'] , a lawyer , a doctor , a dentist , a pl umber , a mechanic , a car penter False descent, and worked as an accountant before becoming 6 [' descent', ',', ' and', ' worked', ' as', ' an', ' accountant']
+16 3 The place of work where you can usually find a x -1 The place of work where you can usually find a accountant office accountant [',', ' a', ' lawyer', ',', ' a', ' doctor', ',', ' a', ' dentist', ',', ' a', ' pl', 'umber', ',', ' a', ' mechanic', ',', ' a', ' car', 'penter'] , a lawyer , a doctor , a dentist , a pl umber , a mechanic , a car penter False Rodriguez, and the accountant Moreno. The 4 [' Rodriguez', ',', ' and', ' the', ' accountant']
+17 3 The place of work where you can usually find a x -1 The place of work where you can usually find a accountant office accountant [',', ' a', ' lawyer', ',', ' a', ' doctor', ',', ' a', ' dentist', ',', ' a', ' pl', 'umber', ',', ' a', ' mechanic', ',', ' a', ' car', 'penter'] , a lawyer , a doctor , a dentist , a pl umber , a mechanic , a car penter False 1940) is an English accountant and educationalist. 5 [' 1940', ')', ' is', ' an', ' English', ' accountant']
+18 3 The place of work where you can usually find a x -1 The place of work where you can usually find a accountant office accountant [',', ' a', ' lawyer', ',', ' a', ' doctor', ',', ' a', ' dentist', ',', ' a', ' pl', 'umber', ',', ' a', ' mechanic', ',', ' a', ' car', 'penter'] , a lawyer , a doctor , a dentist , a pl umber , a mechanic , a car penter False Alma worked as an accountant and clerk. Philipp's 4 [' Alma', ' worked', ' as', ' an', ' accountant']
+19 3 The place of work where you can usually find a x -1 The place of work where you can usually find a accountant office accountant [',', ' a', ' lawyer', ',', ' a', ' doctor', ',', ' a', ' dentist', ',', ' a', ' pl', 'umber', ',', ' a', ' mechanic', ',', ' a', ' car', 'penter'] , a lawyer , a doctor , a dentist , a pl umber , a mechanic , a car penter False was appointed accountant at the Richmond 2 [' was', ' appointed', ' accountant']
+20 4 The place of work where you can usually find a x -1 The place of work where you can usually find a artist studio artist [',', ' a', ' writer', ',', ' a', ' musician', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a', ' mathematician', ',', ' a'] , a writer , a musician , a poet , a philosopher , a scientist , a mathematician , a False Johnson is married to artist Sheree Hovsepian. 4 [' Johnson', ' is', ' married', ' to', ' artist']
+21 4 The place of work where you can usually find a x -1 The place of work where you can usually find a artist studio artist [',', ' a', ' writer', ',', ' a', ' musician', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a', ' mathematician', ',', ' a'] , a writer , a musician , a poet , a philosopher , a scientist , a mathematician , a False by British recording artist Natasha Bedingfield. 3 [' by', ' British', ' recording', ' artist']
+22 4 The place of work where you can usually find a x -1 The place of work where you can usually find a artist studio artist [',', ' a', ' writer', ',', ' a', ' musician', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a', ' mathematician', ',', ' a'] , a writer , a musician , a poet , a philosopher , a scientist , a mathematician , a False 1864 was recorded by artist Alfred Waud, the 5 [' 18', '64', ' was', ' recorded', ' by', ' artist']
+23 4 The place of work where you can usually find a x -1 The place of work where you can usually find a artist studio artist [',', ' a', ' writer', ',', ' a', ' musician', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a', ' mathematician', ',', ' a'] , a writer , a musician , a poet , a philosopher , a scientist , a mathematician , a False distinguished artist in her community, 1 [' distinguished', ' artist']
+24 4 The place of work where you can usually find a x -1 The place of work where you can usually find a artist studio artist [',', ' a', ' writer', ',', ' a', ' musician', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a', ' mathematician', ',', ' a'] , a writer , a musician , a poet , a philosopher , a scientist , a mathematician , a False 0 ['artist']
+25 5 The place of work where you can usually find a x -1 The place of work where you can usually find a athlete stadium athlete [' in', ' the', ' middle', ' of', ' a', ' training', ' session', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a', ' coach', ',', ' a'] " in the middle of a training session .
+
+ The place where you can find a coach , a" False Auburn University athlete and three-time 2 [' Auburn', ' University', ' athlete']
+26 5 The place of work where you can usually find a x -1 The place of work where you can usually find a athlete stadium athlete [' in', ' the', ' middle', ' of', ' a', ' training', ' session', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a', ' coach', ',', ' a'] " in the middle of a training session .
+
+ The place where you can find a coach , a" False former university athlete and coach, 2 [' former', ' university', ' athlete']
+27 5 The place of work where you can usually find a x -1 The place of work where you can usually find a athlete stadium athlete [' in', ' the', ' middle', ' of', ' a', ' training', ' session', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a', ' coach', ',', ' a'] " in the middle of a training session .
+
+ The place where you can find a coach , a" False third year as an FSU athlete he won four NCAA 6 [' third', ' year', ' as', ' an', ' F', 'SU', ' athlete']
+28 5 The place of work where you can usually find a x -1 The place of work where you can usually find a athlete stadium athlete [' in', ' the', ' middle', ' of', ' a', ' training', ' session', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a', ' coach', ',', ' a'] " in the middle of a training session .
+
+ The place where you can find a coach , a" False both as an athlete and as the former 3 [' both', ' as', ' an', ' athlete']
+29 5 The place of work where you can usually find a x -1 The place of work where you can usually find a athlete stadium athlete [' in', ' the', ' middle', ' of', ' a', ' training', ' session', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a', ' coach', ',', ' a'] " in the middle of a training session .
+
+ The place where you can find a coach , a" False the only female athlete to participate in the 3 [' the', ' only', ' female', ' athlete']
+30 6 The place of work where you can usually find a x -1 The place of work where you can usually find a baker bakery baker "[',', ' a', ' butcher', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a']" , a butcher , a butcher 's shop , a butcher 's shop , a butcher 's shop , a False was an avid pie baker and published a 4 [' was', ' an', ' avid', ' pie', ' baker']
+31 6 The place of work where you can usually find a x -1 The place of work where you can usually find a baker bakery baker "[',', ' a', ' butcher', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a']" , a butcher , a butcher 's shop , a butcher 's shop , a butcher 's shop , a False 18-years-younger bagel baker whom she dated from 8 [' 18', '-', 'years', '-', 'young', 'er', ' bag', 'el', ' baker']
+32 6 The place of work where you can usually find a x -1 The place of work where you can usually find a baker bakery baker "[',', ' a', ' butcher', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a']" , a butcher , a butcher 's shop , a butcher 's shop , a butcher 's shop , a False socially isolated baker with a wooden hand 2 [' socially', ' isolated', ' baker']
+33 6 The place of work where you can usually find a x -1 The place of work where you can usually find a baker bakery baker "[',', ' a', ' butcher', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a']" , a butcher , a butcher 's shop , a butcher 's shop , a butcher 's shop , a False a time, as a baker. Herriman attended 5 [' a', ' time', ',', ' as', ' a', ' baker']
+34 6 The place of work where you can usually find a x -1 The place of work where you can usually find a baker bakery baker "[',', ' a', ' butcher', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a', ' butcher', ""'s"", ' shop', ',', ' a']" , a butcher , a butcher 's shop , a butcher 's shop , a butcher 's shop , a False The adult floury baker normally perches facing 4 [' The', ' adult', ' flour', 'y', ' baker']
+35 7 The place of work where you can usually find a x -1 The place of work where you can usually find a barber barbershop barber [',', ' ha', 'ird', 'ress', 'er', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', '**', 'b', 'arr', 'ack', '**', ' _', '(', 'BA'] ", ha ird ress er , or beauty salon .
+
+ ** b arr ack ** _ ( BA" False following day a local barber named Ernie Chambers 5 [' following', ' day', ' a', ' local', ' bar', 'ber']
+36 7 The place of work where you can usually find a x -1 The place of work where you can usually find a barber barbershop barber [',', ' ha', 'ird', 'ress', 'er', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', '**', 'b', 'arr', 'ack', '**', ' _', '(', 'BA'] ", ha ird ress er , or beauty salon .
+
+ ** b arr ack ** _ ( BA" False " the unit's barber once per week.
+" 4 "[' the', ' unit', ""'s"", ' bar', 'ber']"
+37 7 The place of work where you can usually find a x -1 The place of work where you can usually find a barber barbershop barber [',', ' ha', 'ird', 'ress', 'er', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', '**', 'b', 'arr', 'ack', '**', ' _', '(', 'BA'] ", ha ird ress er , or beauty salon .
+
+ ** b arr ack ** _ ( BA" False hospital, office block, barber shop, and pool hall 6 [' hospital', ',', ' office', ' block', ',', ' bar', 'ber']
+38 7 The place of work where you can usually find a x -1 The place of work where you can usually find a barber barbershop barber [',', ' ha', 'ird', 'ress', 'er', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', '**', 'b', 'arr', 'ack', '**', ' _', '(', 'BA'] ", ha ird ress er , or beauty salon .
+
+ ** b arr ack ** _ ( BA" False 1892, Proctor had a barber shop, two blacksmiths, 8 [' 18', '92', ',', ' Pro', 'ctor', ' had', ' a', ' bar', 'ber']
+39 7 The place of work where you can usually find a x -1 The place of work where you can usually find a barber barbershop barber [',', ' ha', 'ird', 'ress', 'er', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', '**', 'b', 'arr', 'ack', '**', ' _', '(', 'BA'] ", ha ird ress er , or beauty salon .
+
+ ** b arr ack ** _ ( BA" False " Hariram, prison barber and Jailor's side-kick
+" 6 [' Har', 'ir', 'am', ',', ' prison', ' bar', 'ber']
+40 8 The place of work where you can usually find a x -1 The place of work where you can usually find a chef kitchen chef [',', ' a', ' bar', 'ista', ',', ' a', ' waitress', ',', ' a', ' bartender', ',', ' a', ' host', 'ess', ',', ' a', ' cash', 'ier', ',', ' a'] , a bar ista , a waitress , a bartender , a host ess , a cash ier , a False Harrington, executive chef, was chosen by Michel 3 [' Harrington', ',', ' executive', ' chef']
+41 8 The place of work where you can usually find a x -1 The place of work where you can usually find a chef kitchen chef [',', ' a', ' bar', 'ista', ',', ' a', ' waitress', ',', ' a', ' bartender', ',', ' a', ' host', 'ess', ',', ' a', ' cash', 'ier', ',', ' a'] , a bar ista , a waitress , a bartender , a host ess , a cash ier , a False York as a pastry chef for eight 4 [' York', ' as', ' a', ' pastry', ' chef']
+42 8 The place of work where you can usually find a x -1 The place of work where you can usually find a chef kitchen chef [',', ' a', ' bar', 'ista', ',', ' a', ' waitress', ',', ' a', ' bartender', ',', ' a', ' host', 'ess', ',', ' a', ' cash', 'ier', ',', ' a'] , a bar ista , a waitress , a bartender , a host ess , a cash ier , a False The Fleet-A-Pita chef was an early version 7 [' The', ' Fleet', '-', 'A', '-', 'P', 'ita', ' chef']
+43 8 The place of work where you can usually find a x -1 The place of work where you can usually find a chef kitchen chef [',', ' a', ' bar', 'ista', ',', ' a', ' waitress', ',', ' a', ' bartender', ',', ' a', ' host', 'ess', ',', ' a', ' cash', 'ier', ',', ' a'] , a bar ista , a waitress , a bartender , a host ess , a cash ier , a False group and a chef, known for 3 [' group', ' and', ' a', ' chef']
+44 8 The place of work where you can usually find a x -1 The place of work where you can usually find a chef kitchen chef [',', ' a', ' bar', 'ista', ',', ' a', ' waitress', ',', ' a', ' bartender', ',', ' a', ' host', 'ess', ',', ' a', ' cash', 'ier', ',', ' a'] , a bar ista , a waitress , a bartender , a host ess , a cash ier , a False came to work as a chef and servant 5 [' came', ' to', ' work', ' as', ' a', ' chef']
+45 9 The place of work where you can usually find a x -1 The place of work where you can usually find a doctor hospital doctor [',', ' dentist', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' doctor', ','] ", dentist , or other medical professional .
+
+ The place where you can usually find a doctor ," False called; the doctor attributed 3 [' called', ';', ' the', ' doctor']
+46 9 The place of work where you can usually find a x -1 The place of work where you can usually find a doctor hospital doctor [',', ' dentist', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' doctor', ','] ", dentist , or other medical professional .
+
+ The place where you can usually find a doctor ," False Italian doping doctor Michele Ferrari and 2 [' Italian', ' doping', ' doctor']
+47 9 The place of work where you can usually find a x -1 The place of work where you can usually find a doctor hospital doctor [',', ' dentist', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' doctor', ','] ", dentist , or other medical professional .
+
+ The place where you can usually find a doctor ," False That film, about a doctor who goes undercover 5 [' That', ' film', ',', ' about', ' a', ' doctor']
+48 9 The place of work where you can usually find a x -1 The place of work where you can usually find a doctor hospital doctor [',', ' dentist', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' doctor', ','] ", dentist , or other medical professional .
+
+ The place where you can usually find a doctor ," False advanced classics doctoral program, specializing 2 [' advanced', ' classics', ' doctor']
+49 9 The place of work where you can usually find a x -1 The place of work where you can usually find a doctor hospital doctor [',', ' dentist', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' doctor', ','] ", dentist , or other medical professional .
+
+ The place where you can usually find a doctor ," False physiotherapist. His doctor told him that the 5 [' physi', 'other', 'apist', '.', ' His', ' doctor']
+50 10 The place of work where you can usually find a x -1 The place of work where you can usually find a fashion designer studio fashion designer [',', ' a', ' fashion', ' styl', 'ist', ',', ' a', ' fashion', ' photographer', ',', ' a', ' fashion', ' model', ',', ' a', ' fashion', ' buyer', ',', ' a', ' fashion'] , a fashion styl ist , a fashion photographer , a fashion model , a fashion buyer , a fashion False renovated. In 2009, fashion designer Wayne Hemingway 6 [' renovated', '.', ' In', ' 2009', ',', ' fashion', ' designer']
+51 10 The place of work where you can usually find a x -1 The place of work where you can usually find a fashion designer studio fashion designer [',', ' a', ' fashion', ' styl', 'ist', ',', ' a', ' fashion', ' photographer', ',', ' a', ' fashion', ' model', ',', ' a', ' fashion', ' buyer', ',', ' a', ' fashion'] , a fashion styl ist , a fashion photographer , a fashion model , a fashion buyer , a fashion False located. A New Jersey fashion designer that Hsu had listed 6 [' located', '.', ' A', ' New', ' Jersey', ' fashion', ' designer']
+52 10 The place of work where you can usually find a x -1 The place of work where you can usually find a fashion designer studio fashion designer [',', ' a', ' fashion', ' styl', 'ist', ',', ' a', ' fashion', ' photographer', ',', ' a', ' fashion', ' model', ',', ' a', ' fashion', ' buyer', ',', ' a', ' fashion'] , a fashion styl ist , a fashion photographer , a fashion model , a fashion buyer , a fashion False " years."" Similarly, fashion designer and Angelina Jolie's" 5 "[' years', '.""', ' Similarly', ',', ' fashion', ' designer']"
+53 10 The place of work where you can usually find a x -1 The place of work where you can usually find a fashion designer studio fashion designer [',', ' a', ' fashion', ' styl', 'ist', ',', ' a', ' fashion', ' photographer', ',', ' a', ' fashion', ' model', ',', ' a', ' fashion', ' buyer', ',', ' a', ' fashion'] , a fashion styl ist , a fashion photographer , a fashion model , a fashion buyer , a fashion False obesity. Marc Ecko, fashion designer and owner of Eckō, 7 [' obesity', '.', ' Marc', ' E', 'cko', ',', ' fashion', ' designer']
+54 10 The place of work where you can usually find a x -1 The place of work where you can usually find a fashion designer studio fashion designer [',', ' a', ' fashion', ' styl', 'ist', ',', ' a', ' fashion', ' photographer', ',', ' a', ' fashion', ' model', ',', ' a', ' fashion', ' buyer', ',', ' a', ' fashion'] , a fashion styl ist , a fashion photographer , a fashion model , a fashion buyer , a fashion False " ""London family"": fashion designer Zandra Rhodes," 5 "[' ""', 'London', ' family', '"":', ' fashion', ' designer']"
+55 11 The place of work where you can usually find a x -1 The place of work where you can usually find a firefighter fire station firefighter [',', ' a', ' police', ' officer', ',', ' a', ' paramed', 'ic', ',', ' a', ' nurse', ',', ' a', ' doctor', ',', ' a', ' teacher', ',', ' a', ' lawyer'] , a police officer , a paramed ic , a nurse , a doctor , a teacher , a lawyer False Yoosabai, a firefighter in Thailand, 6 [' Y', 'oos', 'ab', 'ai', ',', ' a', ' firefighter']
+56 11 The place of work where you can usually find a x -1 The place of work where you can usually find a firefighter fire station firefighter [',', ' a', ' police', ' officer', ',', ' a', ' paramed', 'ic', ',', ' a', ' nurse', ',', ' a', ' doctor', ',', ' a', ' teacher', ',', ' a', ' lawyer'] , a police officer , a paramed ic , a nurse , a doctor , a teacher , a lawyer False dangerous for firefighters because they 2 [' dangerous', ' for', ' firefighter']
+57 11 The place of work where you can usually find a x -1 The place of work where you can usually find a firefighter fire station firefighter [',', ' a', ' police', ' officer', ',', ' a', ' paramed', 'ic', ',', ' a', ' nurse', ',', ' a', ' doctor', ',', ' a', ' teacher', ',', ' a', ' lawyer'] , a police officer , a paramed ic , a nurse , a doctor , a teacher , a lawyer False exposure. Due to firefighter's occupational 4 [' exposure', '.', ' Due', ' to', ' firefighter']
+58 11 The place of work where you can usually find a x -1 The place of work where you can usually find a firefighter fire station firefighter [',', ' a', ' police', ' officer', ',', ' a', ' paramed', 'ic', ',', ' a', ' nurse', ',', ' a', ' doctor', ',', ' a', ' teacher', ',', ' a', ' lawyer'] , a police officer , a paramed ic , a nurse , a doctor , a teacher , a lawyer False haunted by the ghost of a firefighter who died from burns 6 [' haunted', ' by', ' the', ' ghost', ' of', ' a', ' firefighter']
+59 11 The place of work where you can usually find a x -1 The place of work where you can usually find a firefighter fire station firefighter [',', ' a', ' police', ' officer', ',', ' a', ' paramed', 'ic', ',', ' a', ' nurse', ',', ' a', ' doctor', ',', ' a', ' teacher', ',', ' a', ' lawyer'] , a police officer , a paramed ic , a nurse , a doctor , a teacher , a lawyer False Pittsburgh police and firefighter unions, as well as 3 [' Pittsburgh', ' police', ' and', ' firefighter']
+60 12 The place of work where you can usually find a x -1 The place of work where you can usually find a florist flower shop florist [',', ' a', ' bakery', ',', ' a', ' butcher', ',', ' a', ' fish', 'mong', 'er', ',', ' a', ' green', 'gro', 'cer', ',', ' a', ' butcher', ','] , a bakery , a butcher , a fish mong er , a green gro cer , a butcher , False Galápagos, where a florist rages at her spouse 9 [' Gal', 'á', 'p', 'agos', ',', ' where', ' a', ' fl', 'or', 'ist']
+61 12 The place of work where you can usually find a x -1 The place of work where you can usually find a florist flower shop florist [',', ' a', ' bakery', ',', ' a', ' butcher', ',', ' a', ' fish', 'mong', 'er', ',', ' a', ' green', 'gro', 'cer', ',', ' a', ' butcher', ','] , a bakery , a butcher , a fish mong er , a green gro cer , a butcher , False stopped at a florist to buy some flowers 5 [' stopped', ' at', ' a', ' fl', 'or', 'ist']
+62 12 The place of work where you can usually find a x -1 The place of work where you can usually find a florist flower shop florist [',', ' a', ' bakery', ',', ' a', ' butcher', ',', ' a', ' fish', 'mong', 'er', ',', ' a', ' green', 'gro', 'cer', ',', ' a', ' butcher', ','] , a bakery , a butcher , a fish mong er , a green gro cer , a butcher , False an included florist shop but at the 4 [' an', ' included', ' fl', 'or', 'ist']
+63 12 The place of work where you can usually find a x -1 The place of work where you can usually find a florist flower shop florist [',', ' a', ' bakery', ',', ' a', ' butcher', ',', ' a', ' fish', 'mong', 'er', ',', ' a', ' green', 'gro', 'cer', ',', ' a', ' butcher', ','] , a bakery , a butcher , a fish mong er , a green gro cer , a butcher , False stopped at a florist to buy some flowers 5 [' stopped', ' at', ' a', ' fl', 'or', 'ist']
+64 12 The place of work where you can usually find a x -1 The place of work where you can usually find a florist flower shop florist [',', ' a', ' bakery', ',', ' a', ' butcher', ',', ' a', ' fish', 'mong', 'er', ',', ' a', ' green', 'gro', 'cer', ',', ' a', ' butcher', ','] , a bakery , a butcher , a fish mong er , a green gro cer , a butcher , False " ""Jumbo"" and works as a florist at his father's" 10 "[' ""', 'J', 'umbo', '""', ' and', ' works', ' as', ' a', ' fl', 'or', 'ist']"
+65 13 The place of work where you can usually find a x -1 The place of work where you can usually find a flight attendant airplane flight attendant [',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant'] , a flight attendant , a flight attendant , a flight attendant , a flight attendant , a flight attendant False switched on and the flight attendant started the process 5 [' switched', ' on', ' and', ' the', ' flight', ' attendant']
+66 13 The place of work where you can usually find a x -1 The place of work where you can usually find a flight attendant airplane flight attendant [',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant'] , a flight attendant , a flight attendant , a flight attendant , a flight attendant , a flight attendant False earliest male flight attendants; Marge initially 3 [' earliest', ' male', ' flight', ' attendant']
+67 13 The place of work where you can usually find a x -1 The place of work where you can usually find a flight attendant airplane flight attendant [',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant'] , a flight attendant , a flight attendant , a flight attendant , a flight attendant , a flight attendant False sporting her flight attendant outfit, and winks at 3 [' sporting', ' her', ' flight', ' attendant']
+68 13 The place of work where you can usually find a x -1 The place of work where you can usually find a flight attendant airplane flight attendant [',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant'] , a flight attendant , a flight attendant , a flight attendant , a flight attendant , a flight attendant False voice recorder, the flight attendant had escorted an 5 [' voice', ' recorder', ',', ' the', ' flight', ' attendant']
+69 13 The place of work where you can usually find a x -1 The place of work where you can usually find a flight attendant airplane flight attendant [',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant', ',', ' a', ' flight', ' attendant'] , a flight attendant , a flight attendant , a flight attendant , a flight attendant , a flight attendant False flameouts, killing a flight attendant and injuring 6 [' flame', 'outs', ',', ' killing', ' a', ' flight', ' attendant']
+70 14 The place of work where you can usually find a x -1 The place of work where you can usually find a hairdresser salon hairdresser [',', ' bar', 'ber', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' ha', 'ird'] ", bar ber , or beauty salon .
+
+ The place where you can usually find a ha ird" True Reubens'played a flamboyant hairdresser turned drug dealer 13 "[' Re', 'ub', 'ens', ""'"", 'played', ' a', ' fl', 'amb', 'oy', 'ant', ' ha', 'ird', 'ress', 'er']"
+71 14 The place of work where you can usually find a x -1 The place of work where you can usually find a hairdresser salon hairdresser [',', ' bar', 'ber', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' ha', 'ird'] ", bar ber , or beauty salon .
+
+ The place where you can usually find a ha ird" True she qualified as a hairdresser and beauty therapist. 7 [' she', ' qualified', ' as', ' a', ' ha', 'ird', 'ress', 'er']
+72 14 The place of work where you can usually find a x -1 The place of work where you can usually find a hairdresser salon hairdresser [',', ' bar', 'ber', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' ha', 'ird'] ", bar ber , or beauty salon .
+
+ The place where you can usually find a ha ird" True sends for her hairdresser from Norway. 6 [' sends', ' for', ' her', ' ha', 'ird', 'ress', 'er']
+73 14 The place of work where you can usually find a x -1 The place of work where you can usually find a hairdresser salon hairdresser [',', ' bar', 'ber', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' ha', 'ird'] ", bar ber , or beauty salon .
+
+ The place where you can usually find a ha ird" True Geller became Presley's hairdresser in 1964. Unlike others 9 "[' G', 'eller', ' became', ' Pres', 'ley', ""'s"", ' ha', 'ird', 'ress', 'er']"
+74 14 The place of work where you can usually find a x -1 The place of work where you can usually find a hairdresser salon hairdresser [',', ' bar', 'ber', ',', ' or', ' beauty', ' salon', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' ha', 'ird'] ", bar ber , or beauty salon .
+
+ The place where you can usually find a ha ird" True Nash v Sheen, a hairdresser who put a tone rinse 8 [' Nash', ' v', ' Sheen', ',', ' a', ' ha', 'ird', 'ress', 'er']
+75 15 The place of work where you can usually find a x -1 The place of work where you can usually find a historian library historian [',', ' a', ' l', 'ibrarian', ',', ' a', ' teacher', ',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ','] , a l ibrarian , a teacher , a writer , a poet , a philosopher , a scientist , False According to historian Charlotte Zeepvat, 2 [' According', ' to', ' historian']
+76 15 The place of work where you can usually find a x -1 The place of work where you can usually find a historian library historian [',', ' a', ' l', 'ibrarian', ',', ' a', ' teacher', ',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ','] , a l ibrarian , a teacher , a writer , a poet , a philosopher , a scientist , False his career. Film historian J. Lary Kuhns, however, 4 [' his', ' career', '.', ' Film', ' historian']
+77 15 The place of work where you can usually find a x -1 The place of work where you can usually find a historian library historian [',', ' a', ' l', 'ibrarian', ',', ' a', ' teacher', ',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ','] , a l ibrarian , a teacher , a writer , a poet , a philosopher , a scientist , False Unprepared by what historian Max Nettlau called 5 [' Un', 'pre', 'pared', ' by', ' what', ' historian']
+78 15 The place of work where you can usually find a x -1 The place of work where you can usually find a historian library historian [',', ' a', ' l', 'ibrarian', ',', ' a', ' teacher', ',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ','] , a l ibrarian , a teacher , a writer , a poet , a philosopher , a scientist , False Coburn, a historian of Sanskrit 4 [' Cob', 'urn', ',', ' a', ' historian']
+79 15 The place of work where you can usually find a x -1 The place of work where you can usually find a historian library historian [',', ' a', ' l', 'ibrarian', ',', ' a', ' teacher', ',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' philosopher', ',', ' a', ' scientist', ','] , a l ibrarian , a teacher , a writer , a poet , a philosopher , a scientist , False blow ”, according to historian Arthur Waldron. After 6 [' blow', ' �', '�', ',', ' according', ' to', ' historian']
+80 16 The place of work where you can usually find a x -1 The place of work where you can usually find a insurance agent office insurance agent [' is', ' in', ' the', ' lobby', ' of', ' the', ' building', '.', ' You', ' can', ' also', ' find', ' them', ' in', ' the', ' lobby', ' of', ' the', ' building', '.'] is in the lobby of the building . You can also find them in the lobby of the building . False took a job as an insurance agent after college, but 6 [' took', ' a', ' job', ' as', ' an', ' insurance', ' agent']
+81 16 The place of work where you can usually find a x -1 The place of work where you can usually find a insurance agent office insurance agent [' is', ' in', ' the', ' lobby', ' of', ' the', ' building', '.', ' You', ' can', ' also', ' find', ' them', ' in', ' the', ' lobby', ' of', ' the', ' building', '.'] is in the lobby of the building . You can also find them in the lobby of the building . False 1965, Steve is an insurance agent in Modesto, California, 6 [' 1965', ',', ' Steve', ' is', ' an', ' insurance', ' agent']
+82 16 The place of work where you can usually find a x -1 The place of work where you can usually find a insurance agent office insurance agent [' is', ' in', ' the', ' lobby', ' of', ' the', ' building', '.', ' You', ' can', ' also', ' find', ' them', ' in', ' the', ' lobby', ' of', ' the', ' building', '.'] is in the lobby of the building . You can also find them in the lobby of the building . False seductive insurance agent on the lookout 3 [' sed', 'uctive', ' insurance', ' agent']
+83 16 The place of work where you can usually find a x -1 The place of work where you can usually find a insurance agent office insurance agent [' is', ' in', ' the', ' lobby', ' of', ' the', ' building', '.', ' You', ' can', ' also', ' find', ' them', ' in', ' the', ' lobby', ' of', ' the', ' building', '.'] is in the lobby of the building . You can also find them in the lobby of the building . False retiring, Tucker became an insurance agent and lived in 6 [' retiring', ',', ' Tucker', ' became', ' an', ' insurance', ' agent']
+84 16 The place of work where you can usually find a x -1 The place of work where you can usually find a insurance agent office insurance agent [' is', ' in', ' the', ' lobby', ' of', ' the', ' building', '.', ' You', ' can', ' also', ' find', ' them', ' in', ' the', ' lobby', ' of', ' the', ' building', '.'] is in the lobby of the building . You can also find them in the lobby of the building . False took a job as an insurance agent after college, 6 [' took', ' a', ' job', ' as', ' an', ' insurance', ' agent']
+85 17 The place of work where you can usually find a x -1 The place of work where you can usually find a journalist office journalist [',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' painter', ',', ' a', ' musician', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a'] , a writer , a poet , a painter , a musician , a philosopher , a scientist , a False included the journalist Sheena McDonald 2 [' included', ' the', ' journalist']
+86 17 The place of work where you can usually find a x -1 The place of work where you can usually find a journalist office journalist [',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' painter', ',', ' a', ' musician', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a'] , a writer , a poet , a painter , a musician , a philosopher , a scientist , a False being told by journalist Bill Smithies 3 [' being', ' told', ' by', ' journalist']
+87 17 The place of work where you can usually find a x -1 The place of work where you can usually find a journalist office journalist [',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' painter', ',', ' a', ' musician', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a'] , a writer , a poet , a painter , a musician , a philosopher , a scientist , a False website of Dutch journalist Karin Spaink. 3 [' website', ' of', ' Dutch', ' journalist']
+88 17 The place of work where you can usually find a x -1 The place of work where you can usually find a journalist office journalist [',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' painter', ',', ' a', ' musician', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a'] , a writer , a poet , a painter , a musician , a philosopher , a scientist , a False author and journalist Naomi Klein's 2 [' author', ' and', ' journalist']
+89 17 The place of work where you can usually find a x -1 The place of work where you can usually find a journalist office journalist [',', ' a', ' writer', ',', ' a', ' poet', ',', ' a', ' painter', ',', ' a', ' musician', ',', ' a', ' philosopher', ',', ' a', ' scientist', ',', ' a'] , a writer , a poet , a painter , a musician , a philosopher , a scientist , a False by tabloid journalist Freddy Lounds (Stephen 2 [' by', ' tabloid', ' journalist']
+90 18 The place of work where you can usually find a x -1 The place of work where you can usually find a librarian library librarian ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' l', 'ibrarian', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a l ibrarian .
+
+ The place where you" False was an American librarian. A native of Sweden, 4 [' was', ' an', ' American', ' l', 'ibrarian']
+91 18 The place of work where you can usually find a x -1 The place of work where you can usually find a librarian library librarian ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' l', 'ibrarian', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a l ibrarian .
+
+ The place where you" False an American librarian and self-declared 3 [' an', ' American', ' l', 'ibrarian']
+92 18 The place of work where you can usually find a x -1 The place of work where you can usually find a librarian library librarian ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' l', 'ibrarian', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a l ibrarian .
+
+ The place where you" False choice as a librarian is speculated 4 [' choice', ' as', ' a', ' l', 'ibrarian']
+93 18 The place of work where you can usually find a x -1 The place of work where you can usually find a librarian library librarian ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' l', 'ibrarian', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a l ibrarian .
+
+ The place where you" False is used by the librarian who provides Lisa with 5 [' is', ' used', ' by', ' the', ' l', 'ibrarian']
+94 18 The place of work where you can usually find a x -1 The place of work where you can usually find a librarian library librarian ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' l', 'ibrarian', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a l ibrarian .
+
+ The place where you" False that one of the librarians is trying to prevent 5 [' that', ' one', ' of', ' the', ' l', 'ibrarian']
+95 19 The place of work where you can usually find a x -1 The place of work where you can usually find a mechanic garage mechanic [',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who'] , a mechanic who can fix your car , a mechanic who can fix your car , a mechanic who False The bouncing mechanic proved problematic 2 [' The', ' bouncing', ' mechanic']
+96 19 The place of work where you can usually find a x -1 The place of work where you can usually find a mechanic garage mechanic [',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who'] , a mechanic who can fix your car , a mechanic who can fix your car , a mechanic who False Jungle Rumble uses a mechanic where players 4 [' Jungle', ' Rumble', ' uses', ' a', ' mechanic']
+97 19 The place of work where you can usually find a x -1 The place of work where you can usually find a mechanic garage mechanic [',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who'] , a mechanic who can fix your car , a mechanic who can fix your car , a mechanic who False " ""call and response"" mechanic used by later rhythm" 5 "[' ""', 'call', ' and', ' response', '""', ' mechanic']"
+98 19 The place of work where you can usually find a x -1 The place of work where you can usually find a mechanic garage mechanic [',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who'] , a mechanic who can fix your car , a mechanic who can fix your car , a mechanic who False 2 ['me', 'chan', 'ic']
+99 19 The place of work where you can usually find a x -1 The place of work where you can usually find a mechanic garage mechanic [',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who', ' can', ' fix', ' your', ' car', ',', ' a', ' mechanic', ' who'] , a mechanic who can fix your car , a mechanic who can fix your car , a mechanic who False the searching mechanic to the Google search 2 [' the', ' searching', ' mechanic']
+100 20 The place of work where you can usually find a x -1 The place of work where you can usually find a musician concert hall musician [',', ' a', ' writer', ',', ' a', ' painter', ',', ' a', ' poet', ',', ' a', ' dancer', ',', ' a', ' sculpt', 'or', ',', ' a', ' photographer', ','] , a writer , a painter , a poet , a dancer , a sculpt or , a photographer , False collaborating with musician Reino Markkula. 2 [' collaborating', ' with', ' musician']
+101 20 The place of work where you can usually find a x -1 The place of work where you can usually find a musician concert hall musician [',', ' a', ' writer', ',', ' a', ' painter', ',', ' a', ' poet', ',', ' a', ' dancer', ',', ' a', ' sculpt', 'or', ',', ' a', ' photographer', ','] , a writer , a painter , a poet , a dancer , a sculpt or , a photographer , False manager, but a trained musician, and he appears 5 [' manager', ',', ' but', ' a', ' trained', ' musician']
+102 20 The place of work where you can usually find a x -1 The place of work where you can usually find a musician concert hall musician [',', ' a', ' writer', ',', ' a', ' painter', ',', ' a', ' poet', ',', ' a', ' dancer', ',', ' a', ' sculpt', 'or', ',', ' a', ' photographer', ','] , a writer , a painter , a poet , a dancer , a sculpt or , a photographer , False talented amateur musician who gave Jules 2 [' talented', ' amateur', ' musician']
+103 20 The place of work where you can usually find a x -1 The place of work where you can usually find a musician concert hall musician [',', ' a', ' writer', ',', ' a', ' painter', ',', ' a', ' poet', ',', ' a', ' dancer', ',', ' a', ' sculpt', 'or', ',', ' a', ' photographer', ','] , a writer , a painter , a poet , a dancer , a sculpt or , a photographer , False 2012, Dominican musician Karlos Rosé 3 [' 2012', ',', ' Dominican', ' musician']
+104 20 The place of work where you can usually find a x -1 The place of work where you can usually find a musician concert hall musician [',', ' a', ' writer', ',', ' a', ' painter', ',', ' a', ' poet', ',', ' a', ' dancer', ',', ' a', ' sculpt', 'or', ',', ' a', ' photographer', ','] , a writer , a painter , a poet , a dancer , a sculpt or , a photographer , False composed by the Japanese musician Ryuichi Sakamoto. 4 [' composed', ' by', ' the', ' Japanese', ' musician']
+105 21 The place of work where you can usually find a x -1 The place of work where you can usually find a nurse hospital nurse [',', ' doctor', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' nurse', ','] ", doctor , or other medical professional .
+
+ The place where you can usually find a nurse ," False 23-year-old registered nurse named Caryn Campbell 6 [' 23', '-', 'year', '-', 'old', ' registered', ' nurse']
+106 21 The place of work where you can usually find a x -1 The place of work where you can usually find a nurse hospital nurse [',', ' doctor', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' nurse', ','] ", doctor , or other medical professional .
+
+ The place where you can usually find a nurse ," False most foals, who nurse for months after 5 [' most', ' fo', 'als', ',', ' who', ' nurse']
+107 21 The place of work where you can usually find a x -1 The place of work where you can usually find a nurse hospital nurse [',', ' doctor', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' nurse', ','] ", doctor , or other medical professional .
+
+ The place where you can usually find a nurse ," False elephant seal mothers nurse their own 3 [' elephant', ' seal', ' mothers', ' nurse']
+108 21 The place of work where you can usually find a x -1 The place of work where you can usually find a nurse hospital nurse [',', ' doctor', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' nurse', ','] ", doctor , or other medical professional .
+
+ The place where you can usually find a nurse ," False Moffitt, the admissions nurse and Dr. Mercy's 5 [' Moff', 'itt', ',', ' the', ' admissions', ' nurse']
+109 21 The place of work where you can usually find a x -1 The place of work where you can usually find a nurse hospital nurse [',', ' doctor', ',', ' or', ' other', ' medical', ' professional', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' nurse', ','] ", doctor , or other medical professional .
+
+ The place where you can usually find a nurse ," False the couple to nurse Mary twenty-four 3 [' the', ' couple', ' to', ' nurse']
+110 22 The place of work where you can usually find a x -1 The place of work where you can usually find a painter studio painter [',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter'] , painter , painter , painter , painter , painter , painter , painter , painter , painter , painter False and marine painter Holger Drachmann who 2 [' and', ' marine', ' painter']
+111 22 The place of work where you can usually find a x -1 The place of work where you can usually find a painter studio painter [',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter'] , painter , painter , painter , painter , painter , painter , painter , painter , painter , painter False 1694) was a portrait painter in the Baroque 6 [' 16', '94', ')', ' was', ' a', ' portrait', ' painter']
+112 22 The place of work where you can usually find a x -1 The place of work where you can usually find a painter studio painter [',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter'] , painter , painter , painter , painter , painter , painter , painter , painter , painter , painter False School of landscape painters. Ely, on a small hill, 3 [' School', ' of', ' landscape', ' painter']
+113 22 The place of work where you can usually find a x -1 The place of work where you can usually find a painter studio painter [',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter'] , painter , painter , painter , painter , painter , painter , painter , painter , painter , painter False Swiss surrealist painter H.R. Giger provided 3 [' Swiss', ' surreal', 'ist', ' painter']
+114 22 The place of work where you can usually find a x -1 The place of work where you can usually find a painter studio painter [',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter', ',', ' painter'] , painter , painter , painter , painter , painter , painter , painter , painter , painter , painter False 1 ['pain', 'ter']
+115 23 The place of work where you can usually find a x -1 The place of work where you can usually find a pharmacist pharmacy pharmacist [' is', ' in', ' a', ' pharmacy', '.', ' The', ' pharm', 'acist', ' is', ' the', ' one', ' who', ' disp', 'enses', ' the', ' medication', '.', ' The', ' pharm', 'acist'] is in a pharmacy . The pharm acist is the one who disp enses the medication . The pharm acist True bakers, radioman, pharmacist mates, radio 7 [' b', 'akers', ',', ' radi', 'oman', ',', ' pharm', 'acist']
+116 23 The place of work where you can usually find a x -1 The place of work where you can usually find a pharmacist pharmacy pharmacist [' is', ' in', ' a', ' pharmacy', '.', ' The', ' pharm', 'acist', ' is', ' the', ' one', ' who', ' disp', 'enses', ' the', ' medication', '.', ' The', ' pharm', 'acist'] is in a pharmacy . The pharm acist is the one who disp enses the medication . The pharm acist True bakers, radioman, pharmacist mates, radio 7 [' b', 'akers', ',', ' radi', 'oman', ',', ' pharm', 'acist']
+117 23 The place of work where you can usually find a x -1 The place of work where you can usually find a pharmacist pharmacy pharmacist [' is', ' in', ' a', ' pharmacy', '.', ' The', ' pharm', 'acist', ' is', ' the', ' one', ' who', ' disp', 'enses', ' the', ' medication', '.', ' The', ' pharm', 'acist'] is in a pharmacy . The pharm acist is the one who disp enses the medication . The pharm acist True disputed. The pharmacist Heinrich Schwabe 4 [' disputed', '.', ' The', ' pharm', 'acist']
+118 23 The place of work where you can usually find a x -1 The place of work where you can usually find a pharmacist pharmacy pharmacist [' is', ' in', ' a', ' pharmacy', '.', ' The', ' pharm', 'acist', ' is', ' the', ' one', ' who', ' disp', 'enses', ' the', ' medication', '.', ' The', ' pharm', 'acist'] is in a pharmacy . The pharm acist is the one who disp enses the medication . The pharm acist True reported that a pharmacist was treating people 4 [' reported', ' that', ' a', ' pharm', 'acist']
+119 23 The place of work where you can usually find a x -1 The place of work where you can usually find a pharmacist pharmacy pharmacist [' is', ' in', ' a', ' pharmacy', '.', ' The', ' pharm', 'acist', ' is', ' the', ' one', ' who', ' disp', 'enses', ' the', ' medication', '.', ' The', ' pharm', 'acist'] is in a pharmacy . The pharm acist is the one who disp enses the medication . The pharm acist True Pelletier and French pharmacist Joseph Bienaimé 6 [' Pel', 'let', 'ier', ' and', ' French', ' pharm', 'acist']
+120 24 The place of work where you can usually find a x -1 The place of work where you can usually find a photographer studio photographer [',', ' a', ' styl', 'ist', ',', ' a', ' makeup', ' artist', ',', ' a', ' hair', ' styl', 'ist', ',', ' a', ' make', '-', 'up', ' artist', ','] , a styl ist , a makeup artist , a hair styl ist , a make - up artist , False another Danish photographer known mainly 2 [' another', ' Danish', ' photographer']
+121 24 The place of work where you can usually find a x -1 The place of work where you can usually find a photographer studio photographer [',', ' a', ' styl', 'ist', ',', ' a', ' makeup', ' artist', ',', ' a', ' hair', ' styl', 'ist', ',', ' a', ' make', '-', 'up', ' artist', ','] , a styl ist , a makeup artist , a hair styl ist , a make - up artist , False event taken by photographer Michel Setboun. Players 3 [' event', ' taken', ' by', ' photographer']
+122 24 The place of work where you can usually find a x -1 The place of work where you can usually find a photographer studio photographer [',', ' a', ' styl', 'ist', ',', ' a', ' makeup', ' artist', ',', ' a', ' hair', ' styl', 'ist', ',', ' a', ' make', '-', 'up', ' artist', ','] , a styl ist , a makeup artist , a hair styl ist , a make - up artist , False New York Rocker photographer Laura Levine, a friend 4 [' New', ' York', ' Rock', 'er', ' photographer']
+123 24 The place of work where you can usually find a x -1 The place of work where you can usually find a photographer studio photographer [',', ' a', ' styl', 'ist', ',', ' a', ' makeup', ' artist', ',', ' a', ' hair', ' styl', 'ist', ',', ' a', ' make', '-', 'up', ' artist', ','] , a styl ist , a makeup artist , a hair styl ist , a make - up artist , False were taken by photographer fine artist Daniel 3 [' were', ' taken', ' by', ' photographer']
+124 24 The place of work where you can usually find a x -1 The place of work where you can usually find a photographer studio photographer [',', ' a', ' styl', 'ist', ',', ' a', ' makeup', ' artist', ',', ' a', ' hair', ' styl', 'ist', ',', ' a', ' make', '-', 'up', ' artist', ','] , a styl ist , a makeup artist , a hair styl ist , a make - up artist , False Noted baseball photographer Charles M. 3 [' Not', 'ed', ' baseball', ' photographer']
+125 25 The place of work where you can usually find a x -1 The place of work where you can usually find a pilot airplane pilot "[',', ' a', ' mechanic', ',', ' a', ' mechanic', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant']" , a mechanic , a mechanic 's assistant , a mechanic 's assistant 's assistant , a mechanic 's assistant False fighter, though the pilot was uninjured. 4 [' fighter', ',', ' though', ' the', ' pilot']
+126 25 The place of work where you can usually find a x -1 The place of work where you can usually find a pilot airplane pilot "[',', ' a', ' mechanic', ',', ' a', ' mechanic', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant']" , a mechanic , a mechanic 's assistant , a mechanic 's assistant 's assistant , a mechanic 's assistant False custom machine data, pilot points, and race 4 [' custom', ' machine', ' data', ',', ' pilot']
+127 25 The place of work where you can usually find a x -1 The place of work where you can usually find a pilot airplane pilot "[',', ' a', ' mechanic', ',', ' a', ' mechanic', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant']" , a mechanic , a mechanic 's assistant , a mechanic 's assistant 's assistant , a mechanic 's assistant False in the series'pilot and first season. 5 "[' in', ' the', ' series', ""'"", 'p', 'ilot']"
+128 25 The place of work where you can usually find a x -1 The place of work where you can usually find a pilot airplane pilot "[',', ' a', ' mechanic', ',', ' a', ' mechanic', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant']" , a mechanic , a mechanic 's assistant , a mechanic 's assistant 's assistant , a mechanic 's assistant False Helford co-wrote the pilot episode together. 6 [' Hel', 'ford', ' co', '-', 'wrote', ' the', ' pilot']
+129 25 The place of work where you can usually find a x -1 The place of work where you can usually find a pilot airplane pilot "[',', ' a', ' mechanic', ',', ' a', ' mechanic', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant', ""'s"", ' assistant', ',', ' a', ' mechanic', ""'s"", ' assistant']" , a mechanic , a mechanic 's assistant , a mechanic 's assistant 's assistant , a mechanic 's assistant False the river ways, as pilot of the gunboat 5 [' the', ' river', ' ways', ',', ' as', ' pilot']
+130 26 The place of work where you can usually find a x -1 The place of work where you can usually find a researcher laboratory researcher [',', ' a', ' writer', ',', ' a', ' designer', ',', ' a', ' programmer', ',', ' a', ' market', 'er', ',', ' a', ' manager', ',', ' a', ' project', ' manager'] , a writer , a designer , a programmer , a market er , a manager , a project manager False scientific researcher at the Vermillion 1 [' scientific', ' researcher']
+131 26 The place of work where you can usually find a x -1 The place of work where you can usually find a researcher laboratory researcher [',', ' a', ' writer', ',', ' a', ' designer', ',', ' a', ' programmer', ',', ' a', ' market', 'er', ',', ' a', ' manager', ',', ' a', ' project', ' manager'] , a writer , a designer , a programmer , a market er , a manager , a project manager False suggested in 1987 by researcher Josephine Lowndes 4 [' suggested', ' in', ' 1987', ' by', ' researcher']
+132 26 The place of work where you can usually find a x -1 The place of work where you can usually find a researcher laboratory researcher [',', ' a', ' writer', ',', ' a', ' designer', ',', ' a', ' programmer', ',', ' a', ' market', 'er', ',', ' a', ' manager', ',', ' a', ' project', ' manager'] , a writer , a designer , a programmer , a market er , a manager , a project manager False assault. Security researcher Sean-Paul Correll 3 [' assault', '.', ' Security', ' researcher']
+133 26 The place of work where you can usually find a x -1 The place of work where you can usually find a researcher laboratory researcher [',', ' a', ' writer', ',', ' a', ' designer', ',', ' a', ' programmer', ',', ' a', ' market', 'er', ',', ' a', ' manager', ',', ' a', ' project', ' manager'] , a writer , a designer , a programmer , a market er , a manager , a project manager False residents told a researcher that they were 3 [' residents', ' told', ' a', ' researcher']
+134 26 The place of work where you can usually find a x -1 The place of work where you can usually find a researcher laboratory researcher [',', ' a', ' writer', ',', ' a', ' designer', ',', ' a', ' programmer', ',', ' a', ' market', 'er', ',', ' a', ' manager', ',', ' a', ' project', ' manager'] , a writer , a designer , a programmer , a market er , a manager , a project manager False rates. One researcher suggested that 3 [' rates', '.', ' One', ' researcher']
+135 27 The place of work where you can usually find a x -1 The place of work where you can usually find a salesperson store salesperson [' who', ' is', ' willing', ' to', ' help', ' you', ' with', ' your', ' questions', '.', '\n', '\n', 'The', ' place', ' of', ' work', ' where', ' you', ' can', ' usually'] " who is willing to help you with your questions .
+
+ The place of work where you can usually" False him space for a salesperson at the Mint's exhibit, 5 [' him', ' space', ' for', ' a', ' sales', 'person']
+136 27 The place of work where you can usually find a x -1 The place of work where you can usually find a salesperson store salesperson [' who', ' is', ' willing', ' to', ' help', ' you', ' with', ' your', ' questions', '.', '\n', '\n', 'The', ' place', ' of', ' work', ' where', ' you', ' can', ' usually'] " who is willing to help you with your questions .
+
+ The place of work where you can usually" False allow him space for a salesperson at the Mint's exhibit, 6 [' allow', ' him', ' space', ' for', ' a', ' sales', 'person']
+137 27 The place of work where you can usually find a x -1 The place of work where you can usually find a salesperson store salesperson [' who', ' is', ' willing', ' to', ' help', ' you', ' with', ' your', ' questions', '.', '\n', '\n', 'The', ' place', ' of', ' work', ' where', ' you', ' can', ' usually'] " who is willing to help you with your questions .
+
+ The place of work where you can usually" False institutional equities salesperson at the company's 4 [' institutional', ' equ', 'ities', ' sales', 'person']
+138 27 The place of work where you can usually find a x -1 The place of work where you can usually find a salesperson store salesperson [' who', ' is', ' willing', ' to', ' help', ' you', ' with', ' your', ' questions', '.', '\n', '\n', 'The', ' place', ' of', ' work', ' where', ' you', ' can', ' usually'] " who is willing to help you with your questions .
+
+ The place of work where you can usually" False Chris is a salesperson at Nordstrom. 4 [' Chris', ' is', ' a', ' sales', 'person']
+139 27 The place of work where you can usually find a x -1 The place of work where you can usually find a salesperson store salesperson [' who', ' is', ' willing', ' to', ' help', ' you', ' with', ' your', ' questions', '.', '\n', '\n', 'The', ' place', ' of', ' work', ' where', ' you', ' can', ' usually'] " who is willing to help you with your questions .
+
+ The place of work where you can usually" False institutional equities salesperson at the company's 4 [' institutional', ' equ', 'ities', ' sales', 'person']
+140 28 The place of work where you can usually find a x -1 The place of work where you can usually find a scientist laboratory scientist [',', ' a', ' doctor', ',', ' a', ' lawyer', ',', ' a', ' teacher', ',', ' a', ' soldier', ',', ' a', ' policeman', ',', ' a', ' priest', ',', ' a'] , a doctor , a lawyer , a teacher , a soldier , a policeman , a priest , a False by the computer scientist George Markowsky 3 [' by', ' the', ' computer', ' scientist']
+141 28 The place of work where you can usually find a x -1 The place of work where you can usually find a scientist laboratory scientist [',', ' a', ' doctor', ',', ' a', ' lawyer', ',', ' a', ' teacher', ',', ' a', ' soldier', ',', ' a', ' policeman', ',', ' a', ' priest', ',', ' a'] , a doctor , a lawyer , a teacher , a soldier , a policeman , a priest , a False Dr. Zeigler, a scientist who created 7 [' Dr', '.', ' Ze', 'ig', 'ler', ',', ' a', ' scientist']
+142 28 The place of work where you can usually find a x -1 The place of work where you can usually find a scientist laboratory scientist [',', ' a', ' doctor', ',', ' a', ' lawyer', ',', ' a', ' teacher', ',', ' a', ' soldier', ',', ' a', ' policeman', ',', ' a', ' priest', ',', ' a'] , a doctor , a lawyer , a teacher , a soldier , a policeman , a priest , a False tradition. Political scientist Paul C. Sondrol has 3 [' tradition', '.', ' Political', ' scientist']
+143 28 The place of work where you can usually find a x -1 The place of work where you can usually find a scientist laboratory scientist [',', ' a', ' doctor', ',', ' a', ' lawyer', ',', ' a', ' teacher', ',', ' a', ' soldier', ',', ' a', ' policeman', ',', ' a', ' priest', ',', ' a'] , a doctor , a lawyer , a teacher , a soldier , a policeman , a priest , a False 1960s, the Swiss scientist Albert Hofmann 5 [' 1960', 's', ',', ' the', ' Swiss', ' scientist']
+144 28 The place of work where you can usually find a x -1 The place of work where you can usually find a scientist laboratory scientist [',', ' a', ' doctor', ',', ' a', ' lawyer', ',', ' a', ' teacher', ',', ' a', ' soldier', ',', ' a', ' policeman', ',', ' a', ' priest', ',', ' a'] , a doctor , a lawyer , a teacher , a soldier , a policeman , a priest , a False " ""a lifelong scientist and sailor" 3 "[' ""', 'a', ' lifelong', ' scientist']"
+145 29 The place of work where you can usually find a x -1 The place of work where you can usually find a secretary office secretary [',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist'] , a reception ist , a reception ist , a reception ist , a reception ist , a reception ist False move, but Bolton secretary J.J. Bentley 4 [' move', ',', ' but', ' Bolton', ' secretary']
+146 29 The place of work where you can usually find a x -1 The place of work where you can usually find a secretary office secretary [',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist'] , a reception ist , a reception ist , a reception ist , a reception ist , a reception ist False " Ure, and his secretary — ""The last" 5 [' U', 're', ',', ' and', ' his', ' secretary']
+147 29 The place of work where you can usually find a x -1 The place of work where you can usually find a secretary office secretary [',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist'] , a reception ist , a reception ist , a reception ist , a reception ist , a reception ist False treasurer, secretary of state and 2 [' treasurer', ',', ' secretary']
+148 29 The place of work where you can usually find a x -1 The place of work where you can usually find a secretary office secretary [',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist'] , a reception ist , a reception ist , a reception ist , a reception ist , a reception ist False well-received, with the secretary of the Warsaw Philharmonica 6 [' well', '-', 'received', ',', ' with', ' the', ' secretary']
+149 29 The place of work where you can usually find a x -1 The place of work where you can usually find a secretary office secretary [',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist', ',', ' a', ' reception', 'ist'] , a reception ist , a reception ist , a reception ist , a reception ist , a reception ist False " promoted to secretary of the league.
+" 2 [' promoted', ' to', ' secretary']
+150 30 The place of work where you can usually find a x -1 The place of work where you can usually find a soldier military base soldier [',', ' a', ' policeman', ',', ' a', ' fire', 'man', ',', ' a', ' doctor', ',', ' a', ' nurse', ',', ' a', ' teacher', ',', ' a', ' lawyer', ','] , a policeman , a fire man , a doctor , a nurse , a teacher , a lawyer , False from a Galbadian soldier to the defender of 5 [' from', ' a', ' Gal', 'bad', 'ian', ' soldier']
+151 30 The place of work where you can usually find a x -1 The place of work where you can usually find a soldier military base soldier [',', ' a', ' policeman', ',', ' a', ' fire', 'man', ',', ' a', ' doctor', ',', ' a', ' nurse', ',', ' a', ' teacher', ',', ' a', ' lawyer', ','] , a policeman , a fire man , a doctor , a nurse , a teacher , a lawyer , False average Soviet soldier suffered from 2 [' average', ' Soviet', ' soldier']
+152 30 The place of work where you can usually find a x -1 The place of work where you can usually find a soldier military base soldier [',', ' a', ' policeman', ',', ' a', ' fire', 'man', ',', ' a', ' doctor', ',', ' a', ' nurse', ',', ' a', ' teacher', ',', ' a', ' lawyer', ','] , a policeman , a fire man , a doctor , a nurse , a teacher , a lawyer , False a Lupertazzi soldier for harassing 5 [' a', ' Lu', 'pert', 'azz', 'i', ' soldier']
+153 30 The place of work where you can usually find a x -1 The place of work where you can usually find a soldier military base soldier [',', ' a', ' policeman', ',', ' a', ' fire', 'man', ',', ' a', ' doctor', ',', ' a', ' nurse', ',', ' a', ' teacher', ',', ' a', ' lawyer', ','] , a policeman , a fire man , a doctor , a nurse , a teacher , a lawyer , False " by."" His fellow soldiers never saw him" 4 "[' by', '.""', ' His', ' fellow', ' soldier']"
+154 30 The place of work where you can usually find a x -1 The place of work where you can usually find a soldier military base soldier [',', ' a', ' policeman', ',', ' a', ' fire', 'man', ',', ' a', ' doctor', ',', ' a', ' nurse', ',', ' a', ' teacher', ',', ' a', ' lawyer', ','] , a policeman , a fire man , a doctor , a nurse , a teacher , a lawyer , False Fleming, a young soldier who flees 4 [' Fleming', ',', ' a', ' young', ' soldier']
+155 31 The place of work where you can usually find a x -1 The place of work where you can usually find a software engineer office software engineer "['.', '\n', '\n', '------', '\n', 'j', 'osh', 'u', '\n', 'I', ""'m"", ' a', ' software', ' engineer', ' at', ' a', ' startup', ' in', ' SF', '.']" ".
+
+ ------
+ j osh u
+ I 'm a software engineer at a startup in SF ." False working as a software engineer in Silicon 4 [' working', ' as', ' a', ' software', ' engineer']
+156 31 The place of work where you can usually find a x -1 The place of work where you can usually find a software engineer office software engineer "['.', '\n', '\n', '------', '\n', 'j', 'osh', 'u', '\n', 'I', ""'m"", ' a', ' software', ' engineer', ' at', ' a', ' startup', ' in', ' SF', '.']" ".
+
+ ------
+ j osh u
+ I 'm a software engineer at a startup in SF ." False to Archana Singh, a software engineer with an MBA degree, 7 [' to', ' Arch', 'ana', ' Singh', ',', ' a', ' software', ' engineer']
+157 31 The place of work where you can usually find a x -1 The place of work where you can usually find a software engineer office software engineer "['.', '\n', '\n', '------', '\n', 'j', 'osh', 'u', '\n', 'I', ""'m"", ' a', ' software', ' engineer', ' at', ' a', ' startup', ' in', ' SF', '.']" ".
+
+ ------
+ j osh u
+ I 'm a software engineer at a startup in SF ." False sister, Devika, is a software engineer settled in the United 8 [' sister', ',', ' Dev', 'ika', ',', ' is', ' a', ' software', ' engineer']
+158 31 The place of work where you can usually find a x -1 The place of work where you can usually find a software engineer office software engineer "['.', '\n', '\n', '------', '\n', 'j', 'osh', 'u', '\n', 'I', ""'m"", ' a', ' software', ' engineer', ' at', ' a', ' startup', ' in', ' SF', '.']" ".
+
+ ------
+ j osh u
+ I 'm a software engineer at a startup in SF ." False Joan Davis and software engineer Robert Cantor both 4 [' Joan', ' Davis', ' and', ' software', ' engineer']
+159 31 The place of work where you can usually find a x -1 The place of work where you can usually find a software engineer office software engineer "['.', '\n', '\n', '------', '\n', 'j', 'osh', 'u', '\n', 'I', ""'m"", ' a', ' software', ' engineer', ' at', ' a', ' startup', ' in', ' SF', '.']" ".
+
+ ------
+ j osh u
+ I 'm a software engineer at a startup in SF ." False 21, 2009, YouTube software engineer Peter Bradshaw 6 [' 21', ',', ' 2009', ',', ' YouTube', ' software', ' engineer']
+160 32 The place of work where you can usually find a x -1 The place of work where you can usually find a student school student [' who', ' is', ' willing', ' to', ' help', ' you', ' out', ' with', ' your', ' homework', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a'] " who is willing to help you out with your homework .
+
+ The place where you can find a" False arrests. A group of students marching towards 5 [' arrests', '.', ' A', ' group', ' of', ' student']
+161 32 The place of work where you can usually find a x -1 The place of work where you can usually find a student school student [' who', ' is', ' willing', ' to', ' help', ' you', ' out', ' with', ' your', ' homework', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a'] " who is willing to help you out with your homework .
+
+ The place where you can find a" False Wallulah), the student newspaper (Willamette 5 [' Wall', 'ul', 'ah', '),', ' the', ' student']
+162 32 The place of work where you can usually find a x -1 The place of work where you can usually find a student school student [' who', ' is', ' willing', ' to', ' help', ' you', ' out', ' with', ' your', ' homework', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a'] " who is willing to help you out with your homework .
+
+ The place where you can find a" False 1932, he was student adviser. From 1925 4 [' 1932', ',', ' he', ' was', ' student']
+163 32 The place of work where you can usually find a x -1 The place of work where you can usually find a student school student [' who', ' is', ' willing', ' to', ' help', ' you', ' out', ' with', ' your', ' homework', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a'] " who is willing to help you out with your homework .
+
+ The place where you can find a" False Notre Dame student section rushing onto 2 [' Notre', ' Dame', ' student']
+164 32 The place of work where you can usually find a x -1 The place of work where you can usually find a student school student [' who', ' is', ' willing', ' to', ' help', ' you', ' out', ' with', ' your', ' homework', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' find', ' a'] " who is willing to help you out with your homework .
+
+ The place where you can find a" False of 17 and a student / faculty ratio of 4 [' of', ' 17', ' and', ' a', ' student']
+165 33 The place of work where you can usually find a x -1 The place of work where you can usually find a surgeon hospital surgeon [' who', ' is', ' willing', ' to', ' work', ' with', ' you', ' to', ' find', ' the', ' best', ' solution', ' for', ' your', ' needs', '.', '\n', '\n', 'The', ' place'] " who is willing to work with you to find the best solution for your needs .
+
+ The place" False made by Scottish surgeon and naturalist Archibald 3 [' made', ' by', ' Scottish', ' surgeon']
+166 33 The place of work where you can usually find a x -1 The place of work where you can usually find a surgeon hospital surgeon [' who', ' is', ' willing', ' to', ' work', ' with', ' you', ' to', ' find', ' the', ' best', ' solution', ' for', ' your', ' needs', '.', '\n', '\n', 'The', ' place'] " who is willing to work with you to find the best solution for your needs .
+
+ The place" False Emil Kocher, a Swiss surgeon, had been the first 7 [' Emil', ' K', 'oc', 'her', ',', ' a', ' Swiss', ' surgeon']
+167 33 The place of work where you can usually find a x -1 The place of work where you can usually find a surgeon hospital surgeon [' who', ' is', ' willing', ' to', ' work', ' with', ' you', ' to', ' find', ' the', ' best', ' solution', ' for', ' your', ' needs', '.', '\n', '\n', 'The', ' place'] " who is willing to work with you to find the best solution for your needs .
+
+ The place" False and he was appointed surgeon in 1931. It 4 [' and', ' he', ' was', ' appointed', ' surgeon']
+168 33 The place of work where you can usually find a x -1 The place of work where you can usually find a surgeon hospital surgeon [' who', ' is', ' willing', ' to', ' work', ' with', ' you', ' to', ' find', ' the', ' best', ' solution', ' for', ' your', ' needs', '.', '\n', '\n', 'The', ' place'] " who is willing to work with you to find the best solution for your needs .
+
+ The place" False doctor and surgeon in Hessian military 2 [' doctor', ' and', ' surgeon']
+169 33 The place of work where you can usually find a x -1 The place of work where you can usually find a surgeon hospital surgeon [' who', ' is', ' willing', ' to', ' work', ' with', ' you', ' to', ' find', ' the', ' best', ' solution', ' for', ' your', ' needs', '.', '\n', '\n', 'The', ' place'] " who is willing to work with you to find the best solution for your needs .
+
+ The place" False son of a Royal Navy surgeon and brother of artist 5 [' son', ' of', ' a', ' Royal', ' Navy', ' surgeon']
+170 34 The place of work where you can usually find a x -1 The place of work where you can usually find a trainer gym trainer [',', ' a', ' coach', ',', ' a', ' mentor', ',', ' a', ' friend', ',', ' a', ' confid', 'ant', ',', ' a', ' confid', 'ante', ',', ' a', ' confid'] , a coach , a mentor , a friend , a confid ant , a confid ante , a confid False Haworth, the club trainer Charlie Bates and 5 [' Haw', 'orth', ',', ' the', ' club', ' trainer']
+171 34 The place of work where you can usually find a x -1 The place of work where you can usually find a trainer gym trainer [',', ' a', ' coach', ',', ' a', ' mentor', ',', ' a', ' friend', ',', ' a', ' confid', 'ant', ',', ' a', ' confid', 'ante', ',', ' a', ' confid'] , a coach , a mentor , a friend , a confid ant , a confid ante , a confid False National Hunt trainer Jenny Pitman, 2 [' National', ' Hunt', ' trainer']
+172 34 The place of work where you can usually find a x -1 The place of work where you can usually find a trainer gym trainer [',', ' a', ' coach', ',', ' a', ' mentor', ',', ' a', ' friend', ',', ' a', ' confid', 'ant', ',', ' a', ' confid', 'ante', ',', ' a', ' confid'] , a coach , a mentor , a friend , a confid ant , a confid ante , a confid False and personal trainer / bodyguard Andy 2 [' and', ' personal', ' trainer']
+173 34 The place of work where you can usually find a x -1 The place of work where you can usually find a trainer gym trainer [',', ' a', ' coach', ',', ' a', ' mentor', ',', ' a', ' friend', ',', ' a', ' confid', 'ant', ',', ' a', ' confid', 'ante', ',', ' a', ' confid'] , a coach , a mentor , a friend , a confid ant , a confid ante , a confid False Stratus was a trainer for WWE Tough Enough 4 [' Str', 'atus', ' was', ' a', ' trainer']
+174 34 The place of work where you can usually find a x -1 The place of work where you can usually find a trainer gym trainer [',', ' a', ' coach', ',', ' a', ' mentor', ',', ' a', ' friend', ',', ' a', ' confid', 'ant', ',', ' a', ' confid', 'ante', ',', ' a', ' confid'] , a coach , a mentor , a friend , a confid ant , a confid ante , a confid False yearling, with trainer John Porter 4 [' year', 'ling', ',', ' with', ' trainer']
+175 35 The place of work where you can usually find a x -1 The place of work where you can usually find a truck driver truck truck driver ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' truck', ' driver', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a truck driver .
+
+ The place where you" True Subsequently the truck driver was charged with 4 [' Sub', 'sequently', ' the', ' truck', ' driver']
+176 35 The place of work where you can usually find a x -1 The place of work where you can usually find a truck driver truck truck driver ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' truck', ' driver', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a truck driver .
+
+ The place where you" True drug-dealing truck driver Leo Johnson (Eric 5 [' drug', '-', 'd', 'ealing', ' truck', ' driver']
+177 35 The place of work where you can usually find a x -1 The place of work where you can usually find a truck driver truck truck driver ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' truck', ' driver', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a truck driver .
+
+ The place where you" True violent, drug-dealing truck driver Leo Johnson (Eric 7 [' violent', ',', ' drug', '-', 'd', 'ealing', ' truck', ' driver']
+178 35 The place of work where you can usually find a x -1 The place of work where you can usually find a truck driver truck truck driver ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' truck', ' driver', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a truck driver .
+
+ The place where you" True Heiter kill a kidnapped truck driver after Heiter informs 6 [' He', 'iter', ' kill', ' a', ' kidnapped', ' truck', ' driver']
+179 35 The place of work where you can usually find a x -1 The place of work where you can usually find a truck driver truck truck driver ['.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' truck', ' driver', '.', '\n', '\n', 'The', ' place', ' where', ' you'] ".
+
+ The place where you can usually find a truck driver .
+
+ The place where you" True violent, drug-dealing truck driver Leo Johnson 7 [' violent', ',', ' drug', '-', 'd', 'ealing', ' truck', ' driver']
+180 36 The place of work where you can usually find a x -1 The place of work where you can usually find a waitress restaurant waitress [',', ' bartender', ',', ' or', ' waiter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waitress', ',', ' bartender', ','] ", bartender , or waiter .
+
+ The place where you can usually find a waitress , bartender ," False then hired as a waitress by bartender Sam Malone. 4 [' then', ' hired', ' as', ' a', ' waitress']
+181 36 The place of work where you can usually find a x -1 The place of work where you can usually find a waitress restaurant waitress [',', ' bartender', ',', ' or', ' waiter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waitress', ',', ' bartender', ','] ", bartender , or waiter .
+
+ The place where you can usually find a waitress , bartender ," False red-haired, tattooed waitress in a restaurant. 6 [' red', '-', 'haired', ',', ' tattoo', 'ed', ' waitress']
+182 36 The place of work where you can usually find a x -1 The place of work where you can usually find a waitress restaurant waitress [',', ' bartender', ',', ' or', ' waiter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waitress', ',', ' bartender', ','] ", bartender , or waiter .
+
+ The place where you can usually find a waitress , bartender ," False eventually becomes a waitress in Savannah, 3 [' eventually', ' becomes', ' a', ' waitress']
+183 36 The place of work where you can usually find a x -1 The place of work where you can usually find a waitress restaurant waitress [',', ' bartender', ',', ' or', ' waiter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waitress', ',', ' bartender', ','] ", bartender , or waiter .
+
+ The place where you can usually find a waitress , bartender ," False " Voodoo"". Its cocktail waitress logo is featured" 5 "[' V', 'oodoo', '"".', ' Its', ' cocktail', ' waitress']"
+184 36 The place of work where you can usually find a x -1 The place of work where you can usually find a waitress restaurant waitress [',', ' bartender', ',', ' or', ' waiter', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waitress', ',', ' bartender', ','] ", bartender , or waiter .
+
+ The place where you can usually find a waitress , bartender ," False Margulies as a waitress who befriends Bolander 5 [' Marg', 'ul', 'ies', ' as', ' a', ' waitress']
+185 37 The place of work where you can usually find a x -1 The place of work where you can usually find a waiter restaurant waiter [' or', ' waitress', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waiter', ' or', ' waitress', '.', '\n', '\n', 'The'] " or waitress .
+
+ The place where you can usually find a waiter or waitress .
+
+ The" False shot a black waiter at a hotel cafe 3 [' shot', ' a', ' black', ' waiter']
+186 37 The place of work where you can usually find a x -1 The place of work where you can usually find a waiter restaurant waiter [' or', ' waitress', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waiter', ' or', ' waitress', '.', '\n', '\n', 'The'] " or waitress .
+
+ The place where you can usually find a waiter or waitress .
+
+ The" False " teller and a waiter at Denny's.
+" 4 [' tell', 'er', ' and', ' a', ' waiter']
+187 37 The place of work where you can usually find a x -1 The place of work where you can usually find a waiter restaurant waiter [' or', ' waitress', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waiter', ' or', ' waitress', '.', '\n', '\n', 'The'] " or waitress .
+
+ The place where you can usually find a waiter or waitress .
+
+ The" False served by a depressed waiter (Evan O 'Brien), 4 [' served', ' by', ' a', ' depressed', ' waiter']
+188 37 The place of work where you can usually find a x -1 The place of work where you can usually find a waiter restaurant waiter [' or', ' waitress', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waiter', ' or', ' waitress', '.', '\n', '\n', 'The'] " or waitress .
+
+ The place where you can usually find a waiter or waitress .
+
+ The" False " book, such as a waiter wearing ""pants""," 5 [' book', ',', ' such', ' as', ' a', ' waiter']
+189 37 The place of work where you can usually find a x -1 The place of work where you can usually find a waiter restaurant waiter [' or', ' waitress', '.', '\n', '\n', 'The', ' place', ' where', ' you', ' can', ' usually', ' find', ' a', ' waiter', ' or', ' waitress', '.', '\n', '\n', 'The'] " or waitress .
+
+ The place where you can usually find a waiter or waitress .
+
+ The" False include acting as a waiter in a restaurant, 4 [' include', ' acting', ' as', ' a', ' waiter']
diff --git a/patchscopes/code/preprocessed_data/factual/company_ceo.tsv b/patchscopes/code/preprocessed_data/factual/company_ceo.tsv
new file mode 100644
index 00000000..f454aebb
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/company_ceo.tsv
@@ -0,0 +1,2932 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The name of the CEO of x -1 The name of the CEO of Alphabet Sundar Pichai Alphabet "[' Inc' '.' ' is' ' Larry' ' Page' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Google' '.' ' He' ' is' ' also' ' the' ' CEO' ' of' ' Alphabet']" Inc . is Larry Page . He is the founder of Google . He is also the CEO of Alphabet False the Common Turkic Alphabet in which most 4 [' the', ' Common', ' Turk', 'ic', ' Alphabet']
+1 0 The name of the CEO of x -1 The name of the CEO of Alphabet Sundar Pichai Alphabet "[' Inc' '.' ' is' ' Larry' ' Page' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Google' '.' ' He' ' is' ' also' ' the' ' CEO' ' of' ' Alphabet']" Inc . is Larry Page . He is the founder of Google . He is also the CEO of Alphabet False Initial Teaching Alphabet in 1965. The tale 2 [' Initial', ' Teaching', ' Alphabet']
+2 0 The name of the CEO of x -1 The name of the CEO of Alphabet Sundar Pichai Alphabet "[' Inc' '.' ' is' ' Larry' ' Page' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Google' '.' ' He' ' is' ' also' ' the' ' CEO' ' of' ' Alphabet']" Inc . is Larry Page . He is the founder of Google . He is also the CEO of Alphabet False " International Phonetic Alphabet may be used.
+" 4 [' International', ' Ph', 'on', 'etic', ' Alphabet']
+3 0 The name of the CEO of x -1 The name of the CEO of Alphabet Sundar Pichai Alphabet "[' Inc' '.' ' is' ' Larry' ' Page' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Google' '.' ' He' ' is' ' also' ' the' ' CEO' ' of' ' Alphabet']" Inc . is Larry Page . He is the founder of Google . He is also the CEO of Alphabet False of the Hebrew Alphabet outside Hebrew liturgy 3 [' of', ' the', ' Hebrew', ' Alphabet']
+4 0 The name of the CEO of x -1 The name of the CEO of Alphabet Sundar Pichai Alphabet "[' Inc' '.' ' is' ' Larry' ' Page' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Google' '.' ' He' ' is' ' also' ' the' ' CEO' ' of' ' Alphabet']" Inc . is Larry Page . He is the founder of Google . He is also the CEO of Alphabet False use the Phonetic Alphabet for these storms. 5 [' use', ' the', ' Ph', 'on', 'etic', ' Alphabet']
+5 1 The name of the CEO of x -1 The name of the CEO of Tesla Elon Musk Tesla "[' is' ' Elon' ' Musk' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO'
+ ' of' ' Tesla' ' is' ' Elon' ' Musk' '.' '\n' '\n' 'The']" " is Elon Musk .
+
+ The name of the CEO of Tesla is Elon Musk .
+
+ The" True computation in their Tesla series. The technology 3 [' computation', ' in', ' their', ' Tesla']
+6 1 The name of the CEO of x -1 The name of the CEO of Tesla Elon Musk Tesla "[' is' ' Elon' ' Musk' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO'
+ ' of' ' Tesla' ' is' ' Elon' ' Musk' '.' '\n' '\n' 'The']" " is Elon Musk .
+
+ The name of the CEO of Tesla is Elon Musk .
+
+ The" True for example, with Tesla Motors' Twitter 4 [' for', ' example', ',', ' with', ' Tesla']
+7 1 The name of the CEO of x -1 The name of the CEO of Tesla Elon Musk Tesla "[' is' ' Elon' ' Musk' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO'
+ ' of' ' Tesla' ' is' ' Elon' ' Musk' '.' '\n' '\n' 'The']" " is Elon Musk .
+
+ The name of the CEO of Tesla is Elon Musk .
+
+ The" True 0 ['Tesla']
+8 1 The name of the CEO of x -1 The name of the CEO of Tesla Elon Musk Tesla "[' is' ' Elon' ' Musk' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO'
+ ' of' ' Tesla' ' is' ' Elon' ' Musk' '.' '\n' '\n' 'The']" " is Elon Musk .
+
+ The name of the CEO of Tesla is Elon Musk .
+
+ The" True experiments with Tesla coils proved insufficient. 2 [' experiments', ' with', ' Tesla']
+9 1 The name of the CEO of x -1 The name of the CEO of Tesla Elon Musk Tesla "[' is' ' Elon' ' Musk' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO'
+ ' of' ' Tesla' ' is' ' Elon' ' Musk' '.' '\n' '\n' 'The']" " is Elon Musk .
+
+ The name of the CEO of Tesla is Elon Musk .
+
+ The" True Belgrade Nikola Tesla Airport was 3 [' Bel', 'grade', ' Nikola', ' Tesla']
+10 2 The name of the CEO of x -1 The name of the CEO of NVIDIA Jensen Huang NVIDIA "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' NVIDIA' ' is' ' working' ' on' ' a' ' new' ' GPU' ' architecture' '.']" is not a secret . It is not a secret that NVIDIA is working on a new GPU architecture . False models of the NVIDIA Geforce 7900 3 [' models', ' of', ' the', ' NVIDIA']
+11 2 The name of the CEO of x -1 The name of the CEO of NVIDIA Jensen Huang NVIDIA "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' NVIDIA' ' is' ' working' ' on' ' a' ' new' ' GPU' ' architecture' '.']" is not a secret . It is not a secret that NVIDIA is working on a new GPU architecture . False software to support NVIDIA graphics cards. 3 [' software', ' to', ' support', ' NVIDIA']
+12 2 The name of the CEO of x -1 The name of the CEO of NVIDIA Jensen Huang NVIDIA "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' NVIDIA' ' is' ' working' ' on' ' a' ' new' ' GPU' ' architecture' '.']" is not a secret . It is not a secret that NVIDIA is working on a new GPU architecture . False software to support NVIDIA graphics cards. 3 [' software', ' to', ' support', ' NVIDIA']
+13 2 The name of the CEO of x -1 The name of the CEO of NVIDIA Jensen Huang NVIDIA "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' NVIDIA' ' is' ' working' ' on' ' a' ' new' ' GPU' ' architecture' '.']" is not a secret . It is not a secret that NVIDIA is working on a new GPU architecture . False ATI 3000-series and NVIDIA GTX 200-series 5 [' ATI', ' 3000', '-', 'series', ' and', ' NVIDIA']
+14 2 The name of the CEO of x -1 The name of the CEO of NVIDIA Jensen Huang NVIDIA "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' NVIDIA' ' is' ' working' ' on' ' a' ' new' ' GPU' ' architecture' '.']" is not a secret . It is not a secret that NVIDIA is working on a new GPU architecture . False software to support NVIDIA graphics cards. 3 [' software', ' to', ' support', ' NVIDIA']
+15 3 The name of the CEO of x -1 The name of the CEO of Meta Mark Zuckerberg Meta "['verse' ' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Meta']" "verse is not a secret . It is a secret .
+
+ The name of the CEO of Meta" False Milena Govich and Meta Golding also 5 [' Mil', 'ena', ' Gov', 'ich', ' and', ' Meta']
+16 3 The name of the CEO of x -1 The name of the CEO of Meta Mark Zuckerberg Meta "['verse' ' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Meta']" "verse is not a secret . It is a secret .
+
+ The name of the CEO of Meta" False 0 ['Meta']
+17 3 The name of the CEO of x -1 The name of the CEO of Meta Mark Zuckerberg Meta "['verse' ' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Meta']" "verse is not a secret . It is a secret .
+
+ The name of the CEO of Meta" False 0 ['Meta']
+18 3 The name of the CEO of x -1 The name of the CEO of Meta Mark Zuckerberg Meta "['verse' ' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Meta']" "verse is not a secret . It is a secret .
+
+ The name of the CEO of Meta" False Bioinformatic Harvester A Meta search engine (29 8 [' Bio', 'in', 'form', 'atic', ' Har', 'ves', 'ter', ' A', ' Meta']
+19 3 The name of the CEO of x -1 The name of the CEO of Meta Mark Zuckerberg Meta "['verse' ' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Meta']" "verse is not a secret . It is a secret .
+
+ The name of the CEO of Meta" False 0 ['Meta']
+20 6 The name of the CEO of x -1 The name of the CEO of Goldman Sachs David Solomon Goldman Sachs "[',' ' Lloyd' ' Blank' 'fe' 'in' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a' ' good']" , Lloyd Blank fe in , is a perfect example of the kind of person who is not a good False investment firms such as Goldman Sachs are making 5 [' investment', ' firms', ' such', ' as', ' Goldman', ' Sachs']
+21 6 The name of the CEO of x -1 The name of the CEO of Goldman Sachs David Solomon Goldman Sachs "[',' ' Lloyd' ' Blank' 'fe' 'in' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a' ' good']" , Lloyd Blank fe in , is a perfect example of the kind of person who is not a good False the DSi was unveiled, Goldman Sachs analyst Matthew J. 7 [' the', ' DS', 'i', ' was', ' unveiled', ',', ' Goldman', ' Sachs']
+22 6 The name of the CEO of x -1 The name of the CEO of Goldman Sachs David Solomon Goldman Sachs "[',' ' Lloyd' ' Blank' 'fe' 'in' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a' ' good']" , Lloyd Blank fe in , is a perfect example of the kind of person who is not a good False began his career with Goldman Sachs in 1989. On April 5 [' began', ' his', ' career', ' with', ' Goldman', ' Sachs']
+23 6 The name of the CEO of x -1 The name of the CEO of Goldman Sachs David Solomon Goldman Sachs "[',' ' Lloyd' ' Blank' 'fe' 'in' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a' ' good']" , Lloyd Blank fe in , is a perfect example of the kind of person who is not a good False years later, Goldman Sachs bought a stake 4 [' years', ' later', ',', ' Goldman', ' Sachs']
+24 6 The name of the CEO of x -1 The name of the CEO of Goldman Sachs David Solomon Goldman Sachs "[',' ' Lloyd' ' Blank' 'fe' 'in' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a' ' good']" , Lloyd Blank fe in , is a perfect example of the kind of person who is not a good False companies such as Goldman Sachs and Merrill Lynch 4 [' companies', ' such', ' as', ' Goldman', ' Sachs']
+25 7 The name of the CEO of x -1 The name of the CEO of Nasdaq Adena Friedma Nasdaq "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False YMAX went on Nasdaq following 5 [' Y', 'MAX', ' went', ' on', ' Nas', 'daq']
+26 7 The name of the CEO of x -1 The name of the CEO of Nasdaq Adena Friedma Nasdaq "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False public on the Nasdaq Stock Market 4 [' public', ' on', ' the', ' Nas', 'daq']
+27 7 The name of the CEO of x -1 The name of the CEO of Nasdaq Adena Friedma Nasdaq "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False company, so YMAX went on Nasdaq following the 8 [' company', ',', ' so', ' Y', 'MAX', ' went', ' on', ' Nas', 'daq']
+28 7 The name of the CEO of x -1 The name of the CEO of Nasdaq Adena Friedma Nasdaq "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False public on the Nasdaq Stock Market with 4 [' public', ' on', ' the', ' Nas', 'daq']
+29 7 The name of the CEO of x -1 The name of the CEO of Nasdaq Adena Friedma Nasdaq "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False renowned shortages. Nasdaq reported (2015) that 4 [' renowned', ' shortages', '.', ' Nas', 'daq']
+30 8 The name of the CEO of x -1 The name of the CEO of Apple Tim Cook Apple "[',' ' Tim' ' Cook' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n' '\n'
+ 'I' ""'m"" ' not' ' sure' ' if' ' it' ""'s"" ' a']" ", Tim Cook , is a very common name .
+
+ I 'm not sure if it 's a" True on songs using Apple Logic in their 3 [' on', ' songs', ' using', ' Apple']
+31 8 The name of the CEO of x -1 The name of the CEO of Apple Tim Cook Apple "[',' ' Tim' ' Cook' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n' '\n'
+ 'I' ""'m"" ' not' ' sure' ' if' ' it' ""'s"" ' a']" ", Tim Cook , is a very common name .
+
+ I 'm not sure if it 's a" True Vainglory. Apple chose the game to demonstrate 5 [' Va', 'ing', 'l', 'ory', '.', ' Apple']
+32 8 The name of the CEO of x -1 The name of the CEO of Apple Tim Cook Apple "[',' ' Tim' ' Cook' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n' '\n'
+ 'I' ""'m"" ' not' ' sure' ' if' ' it' ""'s"" ' a']" ", Tim Cook , is a very common name .
+
+ I 'm not sure if it 's a" True December 12, Apple Valley had a 5 – 0 3 [' December', ' 12', ',', ' Apple']
+33 8 The name of the CEO of x -1 The name of the CEO of Apple Tim Cook Apple "[',' ' Tim' ' Cook' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n' '\n'
+ 'I' ""'m"" ' not' ' sure' ' if' ' it' ""'s"" ' a']" ", Tim Cook , is a very common name .
+
+ I 'm not sure if it 's a" True running iOS 8. Apple improved performance 4 [' running', ' iOS', ' 8', '.', ' Apple']
+34 8 The name of the CEO of x -1 The name of the CEO of Apple Tim Cook Apple "[',' ' Tim' ' Cook' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n' '\n'
+ 'I' ""'m"" ' not' ' sure' ' if' ' it' ""'s"" ' a']" ", Tim Cook , is a very common name .
+
+ I 'm not sure if it 's a" True under the aegis of Apple Films, before going 6 [' under', ' the', ' a', 'eg', 'is', ' of', ' Apple']
+35 9 The name of the CEO of x -1 The name of the CEO of Cisco Systems Chuck Robbins Cisco Systems "[',' ' John' ' Chambers' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a' ' person'
+ ' who']" , John Chambers , is a perfect example of this . He is a great example of a person who False at the expense of Cisco Systems sales. According 5 [' at', ' the', ' expense', ' of', ' Cisco', ' Systems']
+36 9 The name of the CEO of x -1 The name of the CEO of Cisco Systems Chuck Robbins Cisco Systems "[',' ' John' ' Chambers' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a' ' person'
+ ' who']" , John Chambers , is a perfect example of this . He is a great example of a person who False expense of Cisco Systems sales. According 3 [' expense', ' of', ' Cisco', ' Systems']
+37 9 The name of the CEO of x -1 The name of the CEO of Cisco Systems Chuck Robbins Cisco Systems "[',' ' John' ' Chambers' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a' ' person'
+ ' who']" , John Chambers , is a perfect example of this . He is a great example of a person who False 2 ['C', 'isco', ' Systems']
+38 9 The name of the CEO of x -1 The name of the CEO of Cisco Systems Chuck Robbins Cisco Systems "[',' ' John' ' Chambers' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a' ' person'
+ ' who']" , John Chambers , is a perfect example of this . He is a great example of a person who False 2 ['C', 'isco', ' Systems']
+39 9 The name of the CEO of x -1 The name of the CEO of Cisco Systems Chuck Robbins Cisco Systems "[',' ' John' ' Chambers' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a' ' person'
+ ' who']" , John Chambers , is a perfect example of this . He is a great example of a person who False In 2006, Cisco Systems and Apple negotiated 4 [' In', ' 2006', ',', ' Cisco', ' Systems']
+40 10 The name of the CEO of x -1 The name of the CEO of Samsung Electronics Kim Ki Nam Samsung Electronics "[',' ' Lee' ' Jae' '-' 'y' 'ong' ',' ' was' ' also' ' mentioned' ' in'
+ ' the' ' indictment' '.' '\n' '\n' 'The' ' indictment' ' also' ' said']" ", Lee Jae - y ong , was also mentioned in the indictment .
+
+ The indictment also said" False 1 ['Samsung', ' Electronics']
+41 10 The name of the CEO of x -1 The name of the CEO of Samsung Electronics Kim Ki Nam Samsung Electronics "[',' ' Lee' ' Jae' '-' 'y' 'ong' ',' ' was' ' also' ' mentioned' ' in'
+ ' the' ' indictment' '.' '\n' '\n' 'The' ' indictment' ' also' ' said']" ", Lee Jae - y ong , was also mentioned in the indictment .
+
+ The indictment also said" False developed, and marketed by Samsung Electronics that runs 6 [' developed', ',', ' and', ' marketed', ' by', ' Samsung', ' Electronics']
+42 10 The name of the CEO of x -1 The name of the CEO of Samsung Electronics Kim Ki Nam Samsung Electronics "[',' ' Lee' ' Jae' '-' 'y' 'ong' ',' ' was' ' also' ' mentioned' ' in'
+ ' the' ' indictment' '.' '\n' '\n' 'The' ' indictment' ' also' ' said']" ", Lee Jae - y ong , was also mentioned in the indictment .
+
+ The indictment also said" False California against Samsung Electronics, claiming the S III 3 [' California', ' against', ' Samsung', ' Electronics']
+43 10 The name of the CEO of x -1 The name of the CEO of Samsung Electronics Kim Ki Nam Samsung Electronics "[',' ' Lee' ' Jae' '-' 'y' 'ong' ',' ' was' ' also' ' mentioned' ' in'
+ ' the' ' indictment' '.' '\n' '\n' 'The' ' indictment' ' also' ' said']" ", Lee Jae - y ong , was also mentioned in the indictment .
+
+ The indictment also said" False December 2012, a Samsung Electronics representative 5 [' December', ' 2012', ',', ' a', ' Samsung', ' Electronics']
+44 10 The name of the CEO of x -1 The name of the CEO of Samsung Electronics Kim Ki Nam Samsung Electronics "[',' ' Lee' ' Jae' '-' 'y' 'ong' ',' ' was' ' also' ' mentioned' ' in'
+ ' the' ' indictment' '.' '\n' '\n' 'The' ' indictment' ' also' ' said']" ", Lee Jae - y ong , was also mentioned in the indictment .
+
+ The indictment also said" False 1 ['Samsung', ' Electronics']
+45 11 The name of the CEO of x -1 The name of the CEO of Microsoft Satya Nadella Microsoft "[',' ' Bill' ' Gates' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n'
+ '\n' 'I' ' have' ' a' ' friend' ' who' ' is' ' a' ' very']" ", Bill Gates , is a very common name .
+
+ I have a friend who is a very" False came first in a Microsoft poll to determine 4 [' came', ' first', ' in', ' a', ' Microsoft']
+46 11 The name of the CEO of x -1 The name of the CEO of Microsoft Satya Nadella Microsoft "[',' ' Bill' ' Gates' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n'
+ '\n' 'I' ' have' ' a' ' friend' ' who' ' is' ' a' ' very']" ", Bill Gates , is a very common name .
+
+ I have a friend who is a very" False In June 2000, Microsoft started talks to buy 4 [' In', ' June', ' 2000', ',', ' Microsoft']
+47 11 The name of the CEO of x -1 The name of the CEO of Microsoft Satya Nadella Microsoft "[',' ' Bill' ' Gates' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n'
+ '\n' 'I' ' have' ' a' ' friend' ' who' ' is' ' a' ' very']" ", Bill Gates , is a very common name .
+
+ I have a friend who is a very" False August 4, 2015, Microsoft announced Halo Wars 5 [' August', ' 4', ',', ' 2015', ',', ' Microsoft']
+48 11 The name of the CEO of x -1 The name of the CEO of Microsoft Satya Nadella Microsoft "[',' ' Bill' ' Gates' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n'
+ '\n' 'I' ' have' ' a' ' friend' ' who' ' is' ' a' ' very']" ", Bill Gates , is a very common name .
+
+ I have a friend who is a very" False Development of the Microsoft Windows version 3 [' Development', ' of', ' the', ' Microsoft']
+49 11 The name of the CEO of x -1 The name of the CEO of Microsoft Satya Nadella Microsoft "[',' ' Bill' ' Gates' ',' ' is' ' a' ' very' ' common' ' name' '.' '\n'
+ '\n' 'I' ' have' ' a' ' friend' ' who' ' is' ' a' ' very']" ", Bill Gates , is a very common name .
+
+ I have a friend who is a very" False Mobile 5 comes with Microsoft Office Mobile which 4 [' Mobile', ' 5', ' comes', ' with', ' Microsoft']
+50 12 The name of the CEO of x -1 The name of the CEO of Huawei Ren Zhengfei Huawei "[',' ' the' ' world' '�' '�' 's' ' largest' ' telecommunications'
+ ' equipment' ' manufacturer' ',' ' was' ' also' ' on' ' the' ' list' '.'
+ '\n' '\n' 'The']" ", the world � � s largest telecommunications equipment manufacturer , was also on the list .
+
+ The" False companies such as Huawei and Lenovo 3 [' companies', ' such', ' as', ' Huawei']
+51 12 The name of the CEO of x -1 The name of the CEO of Huawei Ren Zhengfei Huawei "[',' ' the' ' world' '�' '�' 's' ' largest' ' telecommunications'
+ ' equipment' ' manufacturer' ',' ' was' ' also' ' on' ' the' ' list' '.'
+ '\n' '\n' 'The']" ", the world � � s largest telecommunications equipment manufacturer , was also on the list .
+
+ The" False 1 ['Hu', 'awei']
+52 12 The name of the CEO of x -1 The name of the CEO of Huawei Ren Zhengfei Huawei "[',' ' the' ' world' '�' '�' 's' ' largest' ' telecommunications'
+ ' equipment' ' manufacturer' ',' ' was' ' also' ' on' ' the' ' list' '.'
+ '\n' '\n' 'The']" ", the world � � s largest telecommunications equipment manufacturer , was also on the list .
+
+ The" False companies such as Huawei and Lenovo have 3 [' companies', ' such', ' as', ' Huawei']
+53 12 The name of the CEO of x -1 The name of the CEO of Huawei Ren Zhengfei Huawei "[',' ' the' ' world' '�' '�' 's' ' largest' ' telecommunications'
+ ' equipment' ' manufacturer' ',' ' was' ' also' ' on' ' the' ' list' '.'
+ '\n' '\n' 'The']" ", the world � � s largest telecommunications equipment manufacturer , was also on the list .
+
+ The" False 1 ['Hu', 'awei']
+54 12 The name of the CEO of x -1 The name of the CEO of Huawei Ren Zhengfei Huawei "[',' ' the' ' world' '�' '�' 's' ' largest' ' telecommunications'
+ ' equipment' ' manufacturer' ',' ' was' ' also' ' on' ' the' ' list' '.'
+ '\n' '\n' 'The']" ", the world � � s largest telecommunications equipment manufacturer , was also on the list .
+
+ The" False companies such as Huawei and Lenovo have 3 [' companies', ' such', ' as', ' Huawei']
+55 13 The name of the CEO of x -1 The name of the CEO of Google Sundar Pichai Google "[',' ' Larry' ' Page' ',' ' is' ' a' ' very' ' interesting' ' name' '.'
+ '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the' ' name' ' of']" ", Larry Page , is a very interesting name .
+
+ I have a feeling that the name of" False streaming services Google Play, Xbox Video, 2 [' streaming', ' services', ' Google']
+56 13 The name of the CEO of x -1 The name of the CEO of Google Sundar Pichai Google "[',' ' Larry' ' Page' ',' ' is' ' a' ' very' ' interesting' ' name' '.'
+ '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the' ' name' ' of']" ", Larry Page , is a very interesting name .
+
+ I have a feeling that the name of" False Marvel Studios and Google began a viral 3 [' Marvel', ' Studios', ' and', ' Google']
+57 13 The name of the CEO of x -1 The name of the CEO of Google Sundar Pichai Google "[',' ' Larry' ' Page' ',' ' is' ' a' ' very' ' interesting' ' name' '.'
+ '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the' ' name' ' of']" ", Larry Page , is a very interesting name .
+
+ I have a feeling that the name of" False name for the Google beta app Google 3 [' name', ' for', ' the', ' Google']
+58 13 The name of the CEO of x -1 The name of the CEO of Google Sundar Pichai Google "[',' ' Larry' ' Page' ',' ' is' ' a' ' very' ' interesting' ' name' '.'
+ '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the' ' name' ' of']" ", Larry Page , is a very interesting name .
+
+ I have a feeling that the name of" False Flipkart, General Electric, Google India, HCL, Harley 7 [' Flip', 'k', 'art', ',', ' General', ' Electric', ',', ' Google']
+59 13 The name of the CEO of x -1 The name of the CEO of Google Sundar Pichai Google "[',' ' Larry' ' Page' ',' ' is' ' a' ' very' ' interesting' ' name' '.'
+ '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the' ' name' ' of']" ", Larry Page , is a very interesting name .
+
+ I have a feeling that the name of" False as pre-emptive as Google Now. Windows 10's 6 [' as', ' pre', '-', 'empt', 'ive', ' as', ' Google']
+60 14 The name of the CEO of x -1 The name of the CEO of Intel Bob Swan Intel "[',' ' the' ' company' ' that' ' makes' ' the' ' chips' ' that' ' power'
+ ' the' ' iPhone' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The' ' company']" ", the company that makes the chips that power the iPhone , is a woman .
+
+ The company" False 2006, and used the Intel Core Duo processor 5 [' 2006', ',', ' and', ' used', ' the', ' Intel']
+61 14 The name of the CEO of x -1 The name of the CEO of Intel Bob Swan Intel "[',' ' the' ' company' ' that' ' makes' ' the' ' chips' ' that' ' power'
+ ' the' ' iPhone' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The' ' company']" ", the company that makes the chips that power the iPhone , is a woman .
+
+ The company" False such as the Intel 8051, Atmel AVR, 3 [' such', ' as', ' the', ' Intel']
+62 14 The name of the CEO of x -1 The name of the CEO of Intel Bob Swan Intel "[',' ' the' ' company' ' that' ' makes' ' the' ' chips' ' that' ' power'
+ ' the' ' iPhone' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The' ' company']" ", the company that makes the chips that power the iPhone , is a woman .
+
+ The company" False tournaments such as Intel Extreme Masters 3 [' tournaments', ' such', ' as', ' Intel']
+63 14 The name of the CEO of x -1 The name of the CEO of Intel Bob Swan Intel "[',' ' the' ' company' ' that' ' makes' ' the' ' chips' ' that' ' power'
+ ' the' ' iPhone' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The' ' company']" ", the company that makes the chips that power the iPhone , is a woman .
+
+ The company" False high-end computing, Intel had originally 5 [' high', '-', 'end', ' computing', ',', ' Intel']
+64 14 The name of the CEO of x -1 The name of the CEO of Intel Bob Swan Intel "[',' ' the' ' company' ' that' ' makes' ' the' ' chips' ' that' ' power'
+ ' the' ' iPhone' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The' ' company']" ", the company that makes the chips that power the iPhone , is a woman .
+
+ The company" False 0 ['Intel']
+65 15 The name of the CEO of x -1 The name of the CEO of The Walt Disney Company Bob Iger The Walt Disney Company "[' is' ' Robert' ' I' 'ger' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Walt' ' Disney' ' Company' '.' ' He' ' is' ' also' ' the' ' Chairman']" is Robert I ger . He is the CEO of the Walt Disney Company . He is also the Chairman False Resorts since 2007, The Walt Disney Company hired American country 8 [' Res', 'orts', ' since', ' 2007', ',', ' The', ' Walt', ' Disney', ' Company']
+66 15 The name of the CEO of x -1 The name of the CEO of The Walt Disney Company Bob Iger The Walt Disney Company "[' is' ' Robert' ' I' 'ger' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Walt' ' Disney' ' Company' '.' ' He' ' is' ' also' ' the' ' Chairman']" is Robert I ger . He is the CEO of the Walt Disney Company . He is also the Chairman False companies, including The Walt Disney Company and Warner Bros., 6 [' companies', ',', ' including', ' The', ' Walt', ' Disney', ' Company']
+67 15 The name of the CEO of x -1 The name of the CEO of The Walt Disney Company Bob Iger The Walt Disney Company "[' is' ' Robert' ' I' 'ger' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Walt' ' Disney' ' Company' '.' ' He' ' is' ' also' ' the' ' Chairman']" is Robert I ger . He is the CEO of the Walt Disney Company . He is also the Chairman False Price helped The Walt Disney Company design Disneyland; 5 [' Price', ' helped', ' The', ' Walt', ' Disney', ' Company']
+68 15 The name of the CEO of x -1 The name of the CEO of The Walt Disney Company Bob Iger The Walt Disney Company "[' is' ' Robert' ' I' 'ger' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Walt' ' Disney' ' Company' '.' ' He' ' is' ' also' ' the' ' Chairman']" is Robert I ger . He is the CEO of the Walt Disney Company . He is also the Chairman False the Panthers, and The Walt Disney Company to own the Mighty 7 [' the', ' Panthers', ',', ' and', ' The', ' Walt', ' Disney', ' Company']
+69 15 The name of the CEO of x -1 The name of the CEO of The Walt Disney Company Bob Iger The Walt Disney Company "[' is' ' Robert' ' I' 'ger' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Walt' ' Disney' ' Company' '.' ' He' ' is' ' also' ' the' ' Chairman']" is Robert I ger . He is the CEO of the Walt Disney Company . He is also the Chairman False and video games. The Walt Disney Company introduces characters 7 [' and', ' video', ' games', '.', ' The', ' Walt', ' Disney', ' Company']
+70 16 The name of the CEO of x -1 The name of the CEO of Nokia Pekka Lundmark Nokia "[',' ' the' ' company' ' that' ' invented' ' the' ' mobile' ' phone' ','
+ ' is' ' a' ' man' ' named' ' Stephen' ' El' 'op' '.' ' He' ' is' ' a']" , the company that invented the mobile phone , is a man named Stephen El op . He is a False song of 2010 on the Nokia Music Store, based 5 [' song', ' of', ' 2010', ' on', ' the', ' Nokia']
+71 16 The name of the CEO of x -1 The name of the CEO of Nokia Pekka Lundmark Nokia "[',' ' the' ' company' ' that' ' invented' ' the' ' mobile' ' phone' ','
+ ' is' ' a' ' man' ' named' ' Stephen' ' El' 'op' '.' ' He' ' is' ' a']" , the company that invented the mobile phone , is a man named Stephen El op . He is a False 15th-anniversary party at Club Nokia in downtown Los 8 [' 15', 'th', '-', 'ann', 'iversary', ' party', ' at', ' Club', ' Nokia']
+72 16 The name of the CEO of x -1 The name of the CEO of Nokia Pekka Lundmark Nokia "[',' ' the' ' company' ' that' ' invented' ' the' ' mobile' ' phone' ','
+ ' is' ' a' ' man' ' named' ' Stephen' ' El' 'op' '.' ' He' ' is' ' a']" , the company that invented the mobile phone , is a man named Stephen El op . He is a False 2008 at the Nokia Theatre. Wearing 3 [' 2008', ' at', ' the', ' Nokia']
+73 16 The name of the CEO of x -1 The name of the CEO of Nokia Pekka Lundmark Nokia "[',' ' the' ' company' ' that' ' invented' ' the' ' mobile' ' phone' ','
+ ' is' ' a' ' man' ' named' ' Stephen' ' El' 'op' '.' ' He' ' is' ' a']" , the company that invented the mobile phone , is a man named Stephen El op . He is a False digital downloads to Nokia Music Store along 3 [' digital', ' downloads', ' to', ' Nokia']
+74 16 The name of the CEO of x -1 The name of the CEO of Nokia Pekka Lundmark Nokia "[',' ' the' ' company' ' that' ' invented' ' the' ' mobile' ' phone' ','
+ ' is' ' a' ' man' ' named' ' Stephen' ' El' 'op' '.' ' He' ' is' ' a']" , the company that invented the mobile phone , is a man named Stephen El op . He is a False Electronics, HTC and Nokia gained on Wednesday 4 [' Electronics', ',', ' HTC', ' and', ' Nokia']
+75 17 The name of the CEO of x -1 The name of the CEO of FC Barcelona Ferran Reverter FC Barcelona "[' is' ' not' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' club' '.'
+ ' The' ' name' ' of' ' the' ' club' ' is' ' FC' ' Barcelona' '.']" is not the same as the name of the club . The name of the club is FC Barcelona . False Germany) and FC Barcelona (Spain) took Liverpool 4 [' Germany', ')', ' and', ' FC', ' Barcelona']
+76 17 The name of the CEO of x -1 The name of the CEO of FC Barcelona Ferran Reverter FC Barcelona "[' is' ' not' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' club' '.'
+ ' The' ' name' ' of' ' the' ' club' ' is' ' FC' ' Barcelona' '.']" is not the same as the name of the club . The name of the club is FC Barcelona . False 2 – 1 win over FC Barcelona in a friendly match. 6 [' 2', ' –', ' 1', ' win', ' over', ' FC', ' Barcelona']
+77 17 The name of the CEO of x -1 The name of the CEO of FC Barcelona Ferran Reverter FC Barcelona "[' is' ' not' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' club' '.'
+ ' The' ' name' ' of' ' the' ' club' ' is' ' FC' ' Barcelona' '.']" is not the same as the name of the club . The name of the club is FC Barcelona . False " FC Barcelona =
+" 1 [' FC', ' Barcelona']
+78 17 The name of the CEO of x -1 The name of the CEO of FC Barcelona Ferran Reverter FC Barcelona "[' is' ' not' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' club' '.'
+ ' The' ' name' ' of' ' the' ' club' ' is' ' FC' ' Barcelona' '.']" is not the same as the name of the club . The name of the club is FC Barcelona . False 1 ['FC', ' Barcelona']
+79 17 The name of the CEO of x -1 The name of the CEO of FC Barcelona Ferran Reverter FC Barcelona "[' is' ' not' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' club' '.'
+ ' The' ' name' ' of' ' the' ' club' ' is' ' FC' ' Barcelona' '.']" is not the same as the name of the club . The name of the club is FC Barcelona . False confirmed as manager of FC Barcelona for the 2013 – 14 5 [' confirmed', ' as', ' manager', ' of', ' FC', ' Barcelona']
+80 18 The name of the CEO of x -1 The name of the CEO of Amazon Andy Jassy Amazon "['.' 'com' ',' ' Jeff' ' Bezos' ',' ' is' ' a' ' great' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a']" . com , Jeff Bezos , is a great example of this . He is a great example of a False download only, on both Amazon and iTunes. No official 5 [' download', ' only', ',', ' on', ' both', ' Amazon']
+81 18 The name of the CEO of x -1 The name of the CEO of Amazon Andy Jassy Amazon "['.' 'com' ',' ' Jeff' ' Bezos' ',' ' is' ' a' ' great' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a']" . com , Jeff Bezos , is a great example of this . He is a great example of a False guineensis in the Amazon basin damages 5 [' gu', 'ine', 'ensis', ' in', ' the', ' Amazon']
+82 18 The name of the CEO of x -1 The name of the CEO of Amazon Andy Jassy Amazon "['.' 'com' ',' ' Jeff' ' Bezos' ',' ' is' ' a' ' great' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a']" . com , Jeff Bezos , is a great example of this . He is a great example of a False distinct from Amazon parrots as 2 [' distinct', ' from', ' Amazon']
+83 18 The name of the CEO of x -1 The name of the CEO of Amazon Andy Jassy Amazon "['.' 'com' ',' ' Jeff' ' Bezos' ',' ' is' ' a' ' great' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a']" . com , Jeff Bezos , is a great example of this . He is a great example of a False 0 ['Amazon']
+84 18 The name of the CEO of x -1 The name of the CEO of Amazon Andy Jassy Amazon "['.' 'com' ',' ' Jeff' ' Bezos' ',' ' is' ' a' ' great' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' great' ' example' ' of' ' a']" . com , Jeff Bezos , is a great example of this . He is a great example of a False ceremonies of some Amazon Indian cultures 3 [' ceremonies', ' of', ' some', ' Amazon']
+85 19 The name of the CEO of x -1 The name of the CEO of IBM Arvind Krishna IBM "[',' ' the' ' company' ' that' ' makes' ' the' ' hardware' ',' ' is' ' a'
+ ' woman' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this']" ", the company that makes the hardware , is a woman .
+
+ I 'm not sure if this" False September: IBM announces 2 [' September', ':', ' IBM']
+86 19 The name of the CEO of x -1 The name of the CEO of IBM Arvind Krishna IBM "[',' ' the' ' company' ' that' ' makes' ' the' ' hardware' ',' ' is' ' a'
+ ' woman' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this']" ", the company that makes the hardware , is a woman .
+
+ I 'm not sure if this" False such as Heinz, IBM and Hoover 5 [' such', ' as', ' Hein', 'z', ',', ' IBM']
+87 19 The name of the CEO of x -1 The name of the CEO of IBM Arvind Krishna IBM "[',' ' the' ' company' ' that' ' makes' ' the' ' hardware' ',' ' is' ' a'
+ ' woman' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this']" ", the company that makes the hardware , is a woman .
+
+ I 'm not sure if this" False " and a copy of the IBM 604.
+" 5 [' and', ' a', ' copy', ' of', ' the', ' IBM']
+88 19 The name of the CEO of x -1 The name of the CEO of IBM Arvind Krishna IBM "[',' ' the' ' company' ' that' ' makes' ' the' ' hardware' ',' ' is' ' a'
+ ' woman' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this']" ", the company that makes the hardware , is a woman .
+
+ I 'm not sure if this" False the Itanium. IBM created a chipset 4 [' the', ' It', 'anium', '.', ' IBM']
+89 19 The name of the CEO of x -1 The name of the CEO of IBM Arvind Krishna IBM "[',' ' the' ' company' ' that' ' makes' ' the' ' hardware' ',' ' is' ' a'
+ ' woman' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this']" ", the company that makes the hardware , is a woman .
+
+ I 'm not sure if this" False operating system to IBM PC compatible computers 3 [' operating', ' system', ' to', ' IBM']
+90 20 The name of the CEO of x -1 The name of the CEO of Netflix Ted Sarandos Netflix "[',' ' Reed' ' Hastings' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but' ' he']" , Reed Hastings , is a perfect example of this . He is a very smart guy , but he False available on the Netflix Instant Watch 3 [' available', ' on', ' the', ' Netflix']
+91 20 The name of the CEO of x -1 The name of the CEO of Netflix Ted Sarandos Netflix "[',' ' Reed' ' Hastings' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but' ' he']" , Reed Hastings , is a perfect example of this . He is a very smart guy , but he False petition to have Netflix produce a 3 [' petition', ' to', ' have', ' Netflix']
+92 20 The name of the CEO of x -1 The name of the CEO of Netflix Ted Sarandos Netflix "[',' ' Reed' ' Hastings' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but' ' he']" , Reed Hastings , is a perfect example of this . He is a very smart guy , but he False Months later, Netflix made the series 3 [' Months', ' later', ',', ' Netflix']
+93 20 The name of the CEO of x -1 The name of the CEO of Netflix Ted Sarandos Netflix "[',' ' Reed' ' Hastings' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but' ' he']" , Reed Hastings , is a perfect example of this . He is a very smart guy , but he False streaming service Netflix, in all territories 2 [' streaming', ' service', ' Netflix']
+94 20 The name of the CEO of x -1 The name of the CEO of Netflix Ted Sarandos Netflix "[',' ' Reed' ' Hastings' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but' ' he']" , Reed Hastings , is a perfect example of this . He is a very smart guy , but he False currently available on Netflix in Italy and the 3 [' currently', ' available', ' on', ' Netflix']
+95 23 The name of the CEO of x -1 The name of the CEO of Hewlett-Packard Meg Whitman Hewlett-Packard "[' is' ' Mark' ' H' 'urd' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' in' ' the' ' news' ' for' ' all' ' the' ' wrong' ' reasons']" is Mark H urd . He is a man who has been in the news for all the wrong reasons False newly established Hewlett-Packard company, a predecessor 6 [' newly', ' established', ' Hew', 'lett', '-', 'Pack', 'ard']
+96 23 The name of the CEO of x -1 The name of the CEO of Hewlett-Packard Meg Whitman Hewlett-Packard "[' is' ' Mark' ' H' 'urd' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' in' ' the' ' news' ' for' ' all' ' the' ' wrong' ' reasons']" is Mark H urd . He is a man who has been in the news for all the wrong reasons False architecture originated at Hewlett-Packard (HP), and 7 [' architecture', ' originated', ' at', ' Hew', 'lett', '-', 'Pack', 'ard']
+97 23 The name of the CEO of x -1 The name of the CEO of Hewlett-Packard Meg Whitman Hewlett-Packard "[' is' ' Mark' ' H' 'urd' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' in' ' the' ' news' ' for' ' all' ' the' ' wrong' ' reasons']" is Mark H urd . He is a man who has been in the news for all the wrong reasons False emulators running on Hewlett-Packard machines for programming. 8 [' em', 'ulators', ' running', ' on', ' Hew', 'lett', '-', 'Pack', 'ard']
+98 23 The name of the CEO of x -1 The name of the CEO of Hewlett-Packard Meg Whitman Hewlett-Packard "[' is' ' Mark' ' H' 'urd' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' in' ' the' ' news' ' for' ' all' ' the' ' wrong' ' reasons']" is Mark H urd . He is a man who has been in the news for all the wrong reasons False newly established Hewlett-Packard company, a predecessor 6 [' newly', ' established', ' Hew', 'lett', '-', 'Pack', 'ard']
+99 23 The name of the CEO of x -1 The name of the CEO of Hewlett-Packard Meg Whitman Hewlett-Packard "[' is' ' Mark' ' H' 'urd' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' in' ' the' ' news' ' for' ' all' ' the' ' wrong' ' reasons']" is Mark H urd . He is a man who has been in the news for all the wrong reasons False " collaboration from both Hewlett-Packard and IBM.
+" 7 [' collaboration', ' from', ' both', ' Hew', 'lett', '-', 'Pack', 'ard']
+100 25 The name of the CEO of x -1 The name of the CEO of Nintendo Shuntaro Furukawa Nintendo "[' of' ' America' ',' ' Reggie' ' F' 'ils' '-' 'A' 'ime' ',' ' has'
+ ' been' ' making' ' the' ' rounds' ' on' ' the' ' internet' ' today' ',']" of America , Reggie F ils - A ime , has been making the rounds on the internet today , False Prime using the Nintendo GameCube – 3 [' Prime', ' using', ' the', ' Nintendo']
+101 25 The name of the CEO of x -1 The name of the CEO of Nintendo Shuntaro Furukawa Nintendo "[' of' ' America' ',' ' Reggie' ' F' 'ils' '-' 'A' 'ime' ',' ' has'
+ ' been' ' making' ' the' ' rounds' ' on' ' the' ' internet' ' today' ',']" of America , Reggie F ils - A ime , has been making the rounds on the internet today , False Japan as part of the Nintendo Power flash RAM 5 [' Japan', ' as', ' part', ' of', ' the', ' Nintendo']
+102 25 The name of the CEO of x -1 The name of the CEO of Nintendo Shuntaro Furukawa Nintendo "[' of' ' America' ',' ' Reggie' ' F' 'ils' '-' 'A' 'ime' ',' ' has'
+ ' been' ' making' ' the' ' rounds' ' on' ' the' ' internet' ' today' ',']" of America , Reggie F ils - A ime , has been making the rounds on the internet today , False from the 1991 Nintendo Power Awards, 3 [' from', ' the', ' 1991', ' Nintendo']
+103 25 The name of the CEO of x -1 The name of the CEO of Nintendo Shuntaro Furukawa Nintendo "[' of' ' America' ',' ' Reggie' ' F' 'ils' '-' 'A' 'ime' ',' ' has'
+ ' been' ' making' ' the' ' rounds' ' on' ' the' ' internet' ' today' ',']" of America , Reggie F ils - A ime , has been making the rounds on the internet today , False Thomas Bowskill at Nintendo Life said that 4 [' Thomas', ' Bows', 'kill', ' at', ' Nintendo']
+104 25 The name of the CEO of x -1 The name of the CEO of Nintendo Shuntaro Furukawa Nintendo "[' of' ' America' ',' ' Reggie' ' F' 'ils' '-' 'A' 'ime' ',' ' has'
+ ' been' ' making' ' the' ' rounds' ' on' ' the' ' internet' ' today' ',']" of America , Reggie F ils - A ime , has been making the rounds on the internet today , False hailing it as the best Nintendo DS game of 2005. 6 [' ha', 'iling', ' it', ' as', ' the', ' best', ' Nintendo']
+105 27 The name of the CEO of x -1 The name of the CEO of Siemens Joe Kaeser Siemens "[',' ' the' ' German' ' company' ' that' ' makes' ' the' ' machines' ','
+ ' is' ' Dr' '.' ' Klaus' ' Klein' 'feld' '.' ' He' ' is' ' a' ' very']" , the German company that makes the machines , is Dr . Klaus Klein feld . He is a very False 2010 on the Siemens test track in 4 [' 2010', ' on', ' the', ' Siem', 'ens']
+106 27 The name of the CEO of x -1 The name of the CEO of Siemens Joe Kaeser Siemens "[',' ' the' ' German' ' company' ' that' ' makes' ' the' ' machines' ','
+ ' is' ' Dr' '.' ' Klaus' ' Klein' 'feld' '.' ' He' ' is' ' a' ' very']" , the German company that makes the machines , is Dr . Klaus Klein feld . He is a very False 2 ['S', 'iem', 'ens']
+107 27 The name of the CEO of x -1 The name of the CEO of Siemens Joe Kaeser Siemens "[',' ' the' ' German' ' company' ' that' ' makes' ' the' ' machines' ','
+ ' is' ' Dr' '.' ' Klaus' ' Klein' 'feld' '.' ' He' ' is' ' a' ' very']" , the German company that makes the machines , is Dr . Klaus Klein feld . He is a very False signed with Nokia Siemens Networks. 4 [' signed', ' with', ' Nokia', ' Siem', 'ens']
+108 27 The name of the CEO of x -1 The name of the CEO of Siemens Joe Kaeser Siemens "[',' ' the' ' German' ' company' ' that' ' makes' ' the' ' machines' ','
+ ' is' ' Dr' '.' ' Klaus' ' Klein' 'feld' '.' ' He' ' is' ' a' ' very']" , the German company that makes the machines , is Dr . Klaus Klein feld . He is a very False autopilot device, the Siemens K4Ü, was installed 6 [' autop', 'ilot', ' device', ',', ' the', ' Siem', 'ens']
+109 27 The name of the CEO of x -1 The name of the CEO of Siemens Joe Kaeser Siemens "[',' ' the' ' German' ' company' ' that' ' makes' ' the' ' machines' ','
+ ' is' ' Dr' '.' ' Klaus' ' Klein' 'feld' '.' ' He' ' is' ' a' ' very']" , the German company that makes the machines , is Dr . Klaus Klein feld . He is a very False Auer-Gesellschaft, AEG and Siemens & Halske, combined 13 [' A', 'uer', '-', 'G', 'es', 'ells', 'cha', 'ft', ',', ' A', 'EG', ' and', ' Siem', 'ens']
+110 29 The name of the CEO of x -1 The name of the CEO of The Coca-Cola Company James Quincey The Coca-Cola Company "[' is' ' Mu' 'htar' ' Kent' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Coca' '-' 'Cola' ' Company' '.' ' He' ' is' ' the' ' CEO']" is Mu htar Kent . He is the CEO of the Coca - Cola Company . He is the CEO False formed with The Coca-Cola Company by releasing special 6 [' formed', ' with', ' The', ' Coca', '-', 'Cola', ' Company']
+111 29 The name of the CEO of x -1 The name of the CEO of The Coca-Cola Company James Quincey The Coca-Cola Company "[' is' ' Mu' 'htar' ' Kent' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Coca' '-' 'Cola' ' Company' '.' ' He' ' is' ' the' ' CEO']" is Mu htar Kent . He is the CEO of the Coca - Cola Company . He is the CEO False the film's release. The Coca-Cola Company suffered losses 9 "[' the', ' film', ""'s"", ' release', '.', ' The', ' Coca', '-', 'Cola', ' Company']"
+112 29 The name of the CEO of x -1 The name of the CEO of The Coca-Cola Company James Quincey The Coca-Cola Company "[' is' ' Mu' 'htar' ' Kent' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Coca' '-' 'Cola' ' Company' '.' ' He' ' is' ' the' ' CEO']" is Mu htar Kent . He is the CEO of the Coca - Cola Company . He is the CEO False and was sold to The Coca-Cola Company in 1960. J.H. Whitney 8 [' and', ' was', ' sold', ' to', ' The', ' Coca', '-', 'Cola', ' Company']
+113 29 The name of the CEO of x -1 The name of the CEO of The Coca-Cola Company James Quincey The Coca-Cola Company "[' is' ' Mu' 'htar' ' Kent' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Coca' '-' 'Cola' ' Company' '.' ' He' ' is' ' the' ' CEO']" is Mu htar Kent . He is the CEO of the Coca - Cola Company . He is the CEO False film's release. The Coca-Cola Company suffered losses 8 "[' film', ""'s"", ' release', '.', ' The', ' Coca', '-', 'Cola', ' Company']"
+114 29 The name of the CEO of x -1 The name of the CEO of The Coca-Cola Company James Quincey The Coca-Cola Company "[' is' ' Mu' 'htar' ' Kent' '.' ' He' ' is' ' the' ' CEO' ' of' ' the'
+ ' Coca' '-' 'Cola' ' Company' '.' ' He' ' is' ' the' ' CEO']" is Mu htar Kent . He is the CEO of the Coca - Cola Company . He is the CEO False 4 ['The', ' Coca', '-', 'Cola', ' Company']
+115 31 The name of the CEO of x -1 The name of the CEO of Sony Ken'ichiro Yoshida Sony "[' Pictures' ' Television' ',' ' the' ' company' ' that' ' owns' ' the'
+ ' rights' ' to' ' the' ' show' ',' ' is' ' also' ' a' ' big' ' fan' ' of'
+ ' the']" Pictures Television , the company that owns the rights to the show , is also a big fan of the False Following a deal with Sony Music, which 4 [' Following', ' a', ' deal', ' with', ' Sony']
+116 31 The name of the CEO of x -1 The name of the CEO of Sony Ken'ichiro Yoshida Sony "[' Pictures' ' Television' ',' ' the' ' company' ' that' ' owns' ' the'
+ ' rights' ' to' ' the' ' show' ',' ' is' ' also' ' a' ' big' ' fan' ' of'
+ ' the']" Pictures Television , the company that owns the rights to the show , is also a big fan of the False Musicnotes.com by Sony / ATV Music 5 [' Music', 'notes', '.', 'com', ' by', ' Sony']
+117 31 The name of the CEO of x -1 The name of the CEO of Sony Ken'ichiro Yoshida Sony "[' Pictures' ' Television' ',' ' the' ' company' ' that' ' owns' ' the'
+ ' rights' ' to' ' the' ' show' ',' ' is' ' also' ' a' ' big' ' fan' ' of'
+ ' the']" Pictures Television , the company that owns the rights to the show , is also a big fan of the False Open 2 [' O', 'pe', 'Sony']
+118 31 The name of the CEO of x -1 The name of the CEO of Sony Ken'ichiro Yoshida Sony "[' Pictures' ' Television' ',' ' the' ' company' ' that' ' owns' ' the'
+ ' rights' ' to' ' the' ' show' ',' ' is' ' also' ' a' ' big' ' fan' ' of'
+ ' the']" Pictures Television , the company that owns the rights to the show , is also a big fan of the False was released by Sony Records in Japan on 3 [' was', ' released', ' by', ' Sony']
+119 31 The name of the CEO of x -1 The name of the CEO of Sony Ken'ichiro Yoshida Sony "[' Pictures' ' Television' ',' ' the' ' company' ' that' ' owns' ' the'
+ ' rights' ' to' ' the' ' show' ',' ' is' ' also' ' a' ' big' ' fan' ' of'
+ ' the']" Pictures Television , the company that owns the rights to the show , is also a big fan of the False 0 ['Sony']
+120 34 The name of the CEO of x -1 The name of the CEO of Volkswagen Group Oliver Blume Volkswagen Group "[' of' ' America' ',' ' Michael' ' Horn' ',' ' said' ' that' ' the'
+ ' company' ' is' ' working' ' on' ' a' ' new' ' generation' ' of' ' the'
+ ' Beetle' ',']" of America , Michael Horn , said that the company is working on a new generation of the Beetle , False 3 ['Vol', 'ks', 'wagen', ' Group']
+121 34 The name of the CEO of x -1 The name of the CEO of Volkswagen Group Oliver Blume Volkswagen Group "[' of' ' America' ',' ' Michael' ' Horn' ',' ' said' ' that' ' the'
+ ' company' ' is' ' working' ' on' ' a' ' new' ' generation' ' of' ' the'
+ ' Beetle' ',']" of America , Michael Horn , said that the company is working on a new generation of the Beetle , False previous sponsor Volkswagen Group of America. Other 3 [' previous', ' sponsor', ' Volkswagen', ' Group']
+122 34 The name of the CEO of x -1 The name of the CEO of Volkswagen Group Oliver Blume Volkswagen Group "[' of' ' America' ',' ' Michael' ' Horn' ',' ' said' ' that' ' the'
+ ' company' ' is' ' working' ' on' ' a' ' new' ' generation' ' of' ' the'
+ ' Beetle' ',']" of America , Michael Horn , said that the company is working on a new generation of the Beetle , False 3 ['Vol', 'ks', 'wagen', ' Group']
+123 34 The name of the CEO of x -1 The name of the CEO of Volkswagen Group Oliver Blume Volkswagen Group "[' of' ' America' ',' ' Michael' ' Horn' ',' ' said' ' that' ' the'
+ ' company' ' is' ' working' ' on' ' a' ' new' ' generation' ' of' ' the'
+ ' Beetle' ',']" of America , Michael Horn , said that the company is working on a new generation of the Beetle , False 3 ['Vol', 'ks', 'wagen', ' Group']
+124 34 The name of the CEO of x -1 The name of the CEO of Volkswagen Group Oliver Blume Volkswagen Group "[' of' ' America' ',' ' Michael' ' Horn' ',' ' said' ' that' ' the'
+ ' company' ' is' ' working' ' on' ' a' ' new' ' generation' ' of' ' the'
+ ' Beetle' ',']" of America , Michael Horn , said that the company is working on a new generation of the Beetle , False previous sponsor Volkswagen Group of America. 3 [' previous', ' sponsor', ' Volkswagen', ' Group']
+125 35 The name of the CEO of x -1 The name of the CEO of Wikimedia Foundation Maryana Iskander Wikimedia Foundation "[',' ' the' ' non' '-' 'profit' ' organization' ' that' ' runs'
+ ' Wikipedia' ',' ' is' ' Jimmy' ' Wales' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy']" , the non - profit organization that runs Wikipedia , is Jimmy Wales . He is a very nice guy False published works. The Wikimedia Foundation uses Gill Sans 5 [' published', ' works', '.', ' The', ' Wikimedia', ' Foundation']
+126 35 The name of the CEO of x -1 The name of the CEO of Wikimedia Foundation Maryana Iskander Wikimedia Foundation "[',' ' the' ' non' '-' 'profit' ' organization' ' that' ' runs'
+ ' Wikipedia' ',' ' is' ' Jimmy' ' Wales' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy']" , the non - profit organization that runs Wikipedia , is Jimmy Wales . He is a very nice guy False " Wales posted to a Wikimedia Foundation mailing list, ""the" 5 [' Wales', ' posted', ' to', ' a', ' Wikimedia', ' Foundation']
+127 35 The name of the CEO of x -1 The name of the CEO of Wikimedia Foundation Maryana Iskander Wikimedia Foundation "[',' ' the' ' non' '-' 'profit' ' organization' ' that' ' runs'
+ ' Wikipedia' ',' ' is' ' Jimmy' ' Wales' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy']" , the non - profit organization that runs Wikipedia , is Jimmy Wales . He is a very nice guy False 2 ['Wik', 'imedia', ' Foundation']
+128 35 The name of the CEO of x -1 The name of the CEO of Wikimedia Foundation Maryana Iskander Wikimedia Foundation "[',' ' the' ' non' '-' 'profit' ' organization' ' that' ' runs'
+ ' Wikipedia' ',' ' is' ' Jimmy' ' Wales' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy']" , the non - profit organization that runs Wikipedia , is Jimmy Wales . He is a very nice guy False (22: 15 UTC). The Wikimedia Foundation reported nearly 8 [' (', '22', ':', ' 15', ' UTC', ').', ' The', ' Wikimedia', ' Foundation']
+129 35 The name of the CEO of x -1 The name of the CEO of Wikimedia Foundation Maryana Iskander Wikimedia Foundation "[',' ' the' ' non' '-' 'profit' ' organization' ' that' ' runs'
+ ' Wikipedia' ',' ' is' ' Jimmy' ' Wales' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy']" , the non - profit organization that runs Wikipedia , is Jimmy Wales . He is a very nice guy False co-directors, in addition to Wikimedia Foundation representative 9 [' co', '-', 'direct', 'ors', ',', ' in', ' addition', ' to', ' Wikimedia', ' Foundation']
+130 36 The name of the CEO of x -1 The name of the CEO of Boeing Dennis Muilenburg Boeing "[',' ' the' ' world' '�' '�' 's' ' largest' ' aerospace' ' company' ','
+ ' is' ' a' ' man' ' named' ' William' ' J' '.' ' Boeing' '.' '\n']" ", the world � � s largest aerospace company , is a man named William J . Boeing .
+" False Flight 175: a Boeing 767 aircraft, 4 [' Flight', ' 175', ':', ' a', ' Boeing']
+131 36 The name of the CEO of x -1 The name of the CEO of Boeing Dennis Muilenburg Boeing "[',' ' the' ' world' '�' '�' 's' ' largest' ' aerospace' ' company' ','
+ ' is' ' a' ' man' ' named' ' William' ' J' '.' ' Boeing' '.' '\n']" ", the world � � s largest aerospace company , is a man named William J . Boeing .
+" False 2 ['B', 'oe', 'ing']
+132 36 The name of the CEO of x -1 The name of the CEO of Boeing Dennis Muilenburg Boeing "[',' ' the' ' world' '�' '�' 's' ' largest' ' aerospace' ' company' ','
+ ' is' ' a' ' man' ' named' ' William' ' J' '.' ' Boeing' '.' '\n']" ", the world � � s largest aerospace company , is a man named William J . Boeing .
+" False Kalla, said that all Boeing 737-300s should 6 [' K', 'alla', ',', ' said', ' that', ' all', ' Boeing']
+133 36 The name of the CEO of x -1 The name of the CEO of Boeing Dennis Muilenburg Boeing "[',' ' the' ' world' '�' '�' 's' ' largest' ' aerospace' ' company' ','
+ ' is' ' a' ' man' ' named' ' William' ' J' '.' ' Boeing' '.' '\n']" ", the world � � s largest aerospace company , is a man named William J . Boeing .
+" False arrangement on the Boeing 777, 3 – 4 – 3 [' arrangement', ' on', ' the', ' Boeing']
+134 36 The name of the CEO of x -1 The name of the CEO of Boeing Dennis Muilenburg Boeing "[',' ' the' ' world' '�' '�' 's' ' largest' ' aerospace' ' company' ','
+ ' is' ' a' ' man' ' named' ' William' ' J' '.' ' Boeing' '.' '\n']" ", the world � � s largest aerospace company , is a man named William J . Boeing .
+" False also using Boeing 737 aircraft to 2 [' also', ' using', ' Boeing']
+135 38 The name of the CEO of x -1 The name of the CEO of International Monetary Fund Kristalina Georgieva International Monetary Fund "[' (' 'IM' 'F' ')' ' Christine' ' Lag' 'arde' ',' ' who' ' is' ' a'
+ ' French' 'woman' ',' ' is' ' a' ' woman' '.' '\n' '\n']" " ( IM F ) Christine Lag arde , who is a French woman , is a woman .
+
+" False World Bank and International Monetary Fund on October 19 – 20, 5 [' World', ' Bank', ' and', ' International', ' Monetary', ' Fund']
+136 38 The name of the CEO of x -1 The name of the CEO of International Monetary Fund Kristalina Georgieva International Monetary Fund "[' (' 'IM' 'F' ')' ' Christine' ' Lag' 'arde' ',' ' who' ' is' ' a'
+ ' French' 'woman' ',' ' is' ' a' ' woman' '.' '\n' '\n']" " ( IM F ) Christine Lag arde , who is a French woman , is a woman .
+
+" False managing director of the International Monetary Fund following the scheduled 6 [' managing', ' director', ' of', ' the', ' International', ' Monetary', ' Fund']
+137 38 The name of the CEO of x -1 The name of the CEO of International Monetary Fund Kristalina Georgieva International Monetary Fund "[' (' 'IM' 'F' ')' ' Christine' ' Lag' 'arde' ',' ' who' ' is' ' a'
+ ' French' 'woman' ',' ' is' ' a' ' woman' '.' '\n' '\n']" " ( IM F ) Christine Lag arde , who is a French woman , is a woman .
+
+" False dollar. Based on International Monetary Fund estimates 6 [' dollar', '.', ' Based', ' on', ' International', ' Monetary', ' Fund']
+138 38 The name of the CEO of x -1 The name of the CEO of International Monetary Fund Kristalina Georgieva International Monetary Fund "[' (' 'IM' 'F' ')' ' Christine' ' Lag' 'arde' ',' ' who' ' is' ' a'
+ ' French' 'woman' ',' ' is' ' a' ' woman' '.' '\n' '\n']" " ( IM F ) Christine Lag arde , who is a French woman , is a woman .
+
+" False According to the International Monetary Fund (IMF), the Indian 5 [' According', ' to', ' the', ' International', ' Monetary', ' Fund']
+139 38 The name of the CEO of x -1 The name of the CEO of International Monetary Fund Kristalina Georgieva International Monetary Fund "[' (' 'IM' 'F' ')' ' Christine' ' Lag' 'arde' ',' ' who' ' is' ' a'
+ ' French' 'woman' ',' ' is' ' a' ' woman' '.' '\n' '\n']" " ( IM F ) Christine Lag arde , who is a French woman , is a woman .
+
+" False according to the 2013 International Monetary Fund projections. The 6 [' according', ' to', ' the', ' 2013', ' International', ' Monetary', ' Fund']
+140 39 The name of the CEO of x -1 The name of the CEO of Deutsche Telekom Timotheus Höttges Deutsche Telekom "[',' ' the' ' German' ' telecommunications' ' company' ',' ' is' ' Martin'
+ ' Re' 'im' 'ann' '.' ' He' ' is' ' a' ' German' ',' ' and' ' he' ' is']" , the German telecommunications company , is Martin Re im ann . He is a German , and he is False in the process. Deutsche Telekom became one of the 7 [' in', ' the', ' process', '.', ' Deutsche', ' Tele', 'k', 'om']
+141 39 The name of the CEO of x -1 The name of the CEO of Deutsche Telekom Timotheus Höttges Deutsche Telekom "[',' ' the' ' German' ' telecommunications' ' company' ',' ' is' ' Martin'
+ ' Re' 'im' 'ann' '.' ' He' ' is' ' a' ' German' ',' ' and' ' he' ' is']" , the German telecommunications company , is Martin Re im ann . He is a German , and he is False German phone company Deutsche Telekom said they were 6 [' German', ' phone', ' company', ' Deutsche', ' Tele', 'k', 'om']
+142 39 The name of the CEO of x -1 The name of the CEO of Deutsche Telekom Timotheus Höttges Deutsche Telekom "[',' ' the' ' German' ' telecommunications' ' company' ',' ' is' ' Martin'
+ ' Re' 'im' 'ann' '.' ' He' ' is' ' a' ' German' ',' ' and' ' he' ' is']" , the German telecommunications company , is Martin Re im ann . He is a German , and he is False " German phone company Deutsche Telekom said they were ""satisfied""" 6 [' German', ' phone', ' company', ' Deutsche', ' Tele', 'k', 'om']
+143 40 The name of the CEO of x -1 The name of the CEO of Ford Motor Company Jim Farley Ford Motor Company "[' is' ' Mark' ' Fields' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been'
+ ' in' ' the' ' automotive' ' industry' ' for' ' over' ' 30' ' years' '.']" is Mark Fields . He is a man who has been in the automotive industry for over 30 years . False request of Ford Motor Company of Canada, construction 4 [' request', ' of', ' Ford', ' Motor', ' Company']
+144 40 The name of the CEO of x -1 The name of the CEO of Ford Motor Company Jim Farley Ford Motor Company "[' is' ' Mark' ' Fields' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been'
+ ' in' ' the' ' automotive' ' industry' ' for' ' over' ' 30' ' years' '.']" is Mark Fields . He is a man who has been in the automotive industry for over 30 years . False the film. Ford Motor Company also sponsored 5 [' the', ' film', '.', ' Ford', ' Motor', ' Company']
+145 40 The name of the CEO of x -1 The name of the CEO of Ford Motor Company Jim Farley Ford Motor Company "[' is' ' Mark' ' Fields' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been'
+ ' in' ' the' ' automotive' ' industry' ' for' ' over' ' 30' ' years' '.']" is Mark Fields . He is a man who has been in the automotive industry for over 30 years . False 1925 and 2007, Ford Motor Company operated Norfolk 6 [' 1925', ' and', ' 2007', ',', ' Ford', ' Motor', ' Company']
+146 40 The name of the CEO of x -1 The name of the CEO of Ford Motor Company Jim Farley Ford Motor Company "[' is' ' Mark' ' Fields' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been'
+ ' in' ' the' ' automotive' ' industry' ' for' ' over' ' 30' ' years' '.']" is Mark Fields . He is a man who has been in the automotive industry for over 30 years . False stations. However, Ford Motor Company took the offer 6 [' stations', '.', ' However', ',', ' Ford', ' Motor', ' Company']
+147 40 The name of the CEO of x -1 The name of the CEO of Ford Motor Company Jim Farley Ford Motor Company "[' is' ' Mark' ' Fields' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been'
+ ' in' ' the' ' automotive' ' industry' ' for' ' over' ' 30' ' years' '.']" is Mark Fields . He is a man who has been in the automotive industry for over 30 years . False request of Ford Motor Company of Canada, 4 [' request', ' of', ' Ford', ' Motor', ' Company']
+148 41 The name of the CEO of x -1 The name of the CEO of Walmart Doug McMillon Walmart "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' the' ' name' ' of' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company']" is not a secret . It is the name of the company that owns the company that owns the company False " release at select Walmart locations.
+" 3 [' release', ' at', ' select', ' Walmart']
+149 41 The name of the CEO of x -1 The name of the CEO of Walmart Doug McMillon Walmart "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' the' ' name' ' of' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company']" is not a secret . It is the name of the company that owns the company that owns the company False " Together"" at the Walmart shareholders meeting." 4 "[' Together', '""', ' at', ' the', ' Walmart']"
+150 41 The name of the CEO of x -1 The name of the CEO of Walmart Doug McMillon Walmart "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' the' ' name' ' of' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company']" is not a secret . It is the name of the company that owns the company that owns the company False partnerships with Walmart and Cheerios. However, 2 [' partnerships', ' with', ' Walmart']
+151 41 The name of the CEO of x -1 The name of the CEO of Walmart Doug McMillon Walmart "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' the' ' name' ' of' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company']" is not a secret . It is the name of the company that owns the company that owns the company False two-pack set at Walmart including the 5 [' two', '-', 'pack', ' set', ' at', ' Walmart']
+152 41 The name of the CEO of x -1 The name of the CEO of Walmart Doug McMillon Walmart "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' the' ' name' ' of' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company']" is not a secret . It is the name of the company that owns the company that owns the company False Records as a Walmart exclusive 3 [' Records', ' as', ' a', ' Walmart']
+153 42 The name of the CEO of x -1 The name of the CEO of Electronic Arts Andrew Wilson Electronic Arts "[',' ' John' ' Ric' 'cit' 'ie' 'llo' ',' ' has' ' been' ' making' ' the'
+ ' rounds' ' in' ' the' ' media' ' lately' '.' ' He' ' has' ' been']" , John Ric cit ie llo , has been making the rounds in the media lately . He has been False material. Acquired by Electronic Arts in 1998, Westwood 6 [' material', '.', ' Acqu', 'ired', ' by', ' Electronic', ' Arts']
+154 42 The name of the CEO of x -1 The name of the CEO of Electronic Arts Andrew Wilson Electronic Arts "[',' ' John' ' Ric' 'cit' 'ie' 'llo' ',' ' has' ' been' ' making' ' the'
+ ' rounds' ' in' ' the' ' media' ' lately' '.' ' He' ' has' ' been']" , John Ric cit ie llo , has been making the rounds in the media lately . He has been False these royalties led Electronic Arts to develop 4 [' these', ' royalties', ' led', ' Electronic', ' Arts']
+155 42 The name of the CEO of x -1 The name of the CEO of Electronic Arts Andrew Wilson Electronic Arts "[',' ' John' ' Ric' 'cit' 'ie' 'llo' ',' ' has' ' been' ' making' ' the'
+ ' rounds' ' in' ' the' ' media' ' lately' '.' ' He' ' has' ' been']" , John Ric cit ie llo , has been making the rounds in the media lately . He has been False digital form, Electronic Arts and Games Workshop 4 [' digital', ' form', ',', ' Electronic', ' Arts']
+156 42 The name of the CEO of x -1 The name of the CEO of Electronic Arts Andrew Wilson Electronic Arts "[',' ' John' ' Ric' 'cit' 'ie' 'llo' ',' ' has' ' been' ' making' ' the'
+ ' rounds' ' in' ' the' ' media' ' lately' '.' ' He' ' has' ' been']" , John Ric cit ie llo , has been making the rounds in the media lately . He has been False video game publisher Electronic Arts (EA) reverse engineered 4 [' video', ' game', ' publisher', ' Electronic', ' Arts']
+157 42 The name of the CEO of x -1 The name of the CEO of Electronic Arts Andrew Wilson Electronic Arts "[',' ' John' ' Ric' 'cit' 'ie' 'llo' ',' ' has' ' been' ' making' ' the'
+ ' rounds' ' in' ' the' ' media' ' lately' '.' ' He' ' has' ' been']" , John Ric cit ie llo , has been making the rounds in the media lately . He has been False publishing giant Electronic Arts several years and a 3 [' publishing', ' giant', ' Electronic', ' Arts']
+158 43 The name of the CEO of x -1 The name of the CEO of Renault Luca de Meo Renault "['-' 'N' 'issan' '-' 'M' 'its' 'ub' 'ishi' ' Alliance' ',' ' Carlos' ' Gh'
+ 'os' 'n' ',' ' has' ' been' ' arrested' ' in' ' Japan']" - N issan - M its ub ishi Alliance , Carlos Gh os n , has been arrested in Japan False specification Renault power unit, after 1 [' specification', ' Renault']
+159 43 The name of the CEO of x -1 The name of the CEO of Renault Luca de Meo Renault "['-' 'N' 'issan' '-' 'M' 'its' 'ub' 'ishi' ' Alliance' ',' ' Carlos' ' Gh'
+ 'os' 'n' ',' ' has' ' been' ' arrested' ' in' ' Japan']" - N issan - M its ub ishi Alliance , Carlos Gh os n , has been arrested in Japan False of captured French Renault R35 tanks were 3 [' of', ' captured', ' French', ' Renault']
+160 43 The name of the CEO of x -1 The name of the CEO of Renault Luca de Meo Renault "['-' 'N' 'issan' '-' 'M' 'its' 'ub' 'ishi' ' Alliance' ',' ' Carlos' ' Gh'
+ 'os' 'n' ',' ' has' ' been' ' arrested' ' in' ' Japan']" - N issan - M its ub ishi Alliance , Carlos Gh os n , has been arrested in Japan False Emil), from whom Renault orders a drink when 4 [' Emil', '),', ' from', ' whom', ' Renault']
+161 43 The name of the CEO of x -1 The name of the CEO of Renault Luca de Meo Renault "['-' 'N' 'issan' '-' 'M' 'its' 'ub' 'ishi' ' Alliance' ',' ' Carlos' ' Gh'
+ 'os' 'n' ',' ' has' ' been' ' arrested' ' in' ' Japan']" - N issan - M its ub ishi Alliance , Carlos Gh os n , has been arrested in Japan False 1 ['Ren', 'ault']
+162 43 The name of the CEO of x -1 The name of the CEO of Renault Luca de Meo Renault "['-' 'N' 'issan' '-' 'M' 'its' 'ub' 'ishi' ' Alliance' ',' ' Carlos' ' Gh'
+ 'os' 'n' ',' ' has' ' been' ' arrested' ' in' ' Japan']" - N issan - M its ub ishi Alliance , Carlos Gh os n , has been arrested in Japan False three French Renault tanks, although it 2 [' three', ' French', ' Renault']
+163 44 The name of the CEO of x -1 The name of the CEO of Delta Air Lines Ed Bastian Delta Air Lines "[',' ' Richard' ' Anderson' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Richard Anderson , is a perfect example of this . He is a man who has been in the False engine, and Delta Air Lines and Flying Tiger 5 [' engine', ',', ' and', ' Delta', ' Air', ' Lines']
+164 44 The name of the CEO of x -1 The name of the CEO of Delta Air Lines Ed Bastian Delta Air Lines "[',' ' Richard' ' Anderson' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Richard Anderson , is a perfect example of this . He is a man who has been in the False compensate, Delta Air Lines allowed those flying 4 [' compensate', ',', ' Delta', ' Air', ' Lines']
+165 44 The name of the CEO of x -1 The name of the CEO of Delta Air Lines Ed Bastian Delta Air Lines "[',' ' Richard' ' Anderson' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Richard Anderson , is a perfect example of this . He is a man who has been in the False November 2014, Delta Air Lines became the first 5 [' November', ' 2014', ',', ' Delta', ' Air', ' Lines']
+166 44 The name of the CEO of x -1 The name of the CEO of Delta Air Lines Ed Bastian Delta Air Lines "[',' ' Richard' ' Anderson' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Richard Anderson , is a perfect example of this . He is a man who has been in the False third-largest carrier (after Delta Air Lines and American Airlines) 8 [' third', '-', 'largest', ' carrier', ' (', 'after', ' Delta', ' Air', ' Lines']
+167 44 The name of the CEO of x -1 The name of the CEO of Delta Air Lines Ed Bastian Delta Air Lines "[',' ' Richard' ' Anderson' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Richard Anderson , is a perfect example of this . He is a man who has been in the False largest U.S. operators, Delta Air Lines and American 9 [' largest', ' U', '.', 'S', '.', ' operators', ',', ' Delta', ' Air', ' Lines']
+168 48 The name of the CEO of x -1 The name of the CEO of Robert Bosch Stefan Hartung Robert Bosch "[' G' 'mb' 'H' ',' ' the' ' company' ' that' ' makes' ' the' ' car' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' makes']" G mb H , the company that makes the car , is also the name of the company that makes False " Trustees of the Robert Bosch Stiftung.
+" 6 [' Trust', 'ees', ' of', ' the', ' Robert', ' Bos', 'ch']
+169 48 The name of the CEO of x -1 The name of the CEO of Robert Bosch Stefan Hartung Robert Bosch "[' G' 'mb' 'H' ',' ' the' ' company' ' that' ' makes' ' the' ' car' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' makes']" G mb H , the company that makes the car , is also the name of the company that makes False " of Trustees of the Robert Bosch Stiftung.
+" 7 [' of', ' Trust', 'ees', ' of', ' the', ' Robert', ' Bos', 'ch']
+170 48 The name of the CEO of x -1 The name of the CEO of Robert Bosch Stefan Hartung Robert Bosch "[' G' 'mb' 'H' ',' ' the' ' company' ' that' ' makes' ' the' ' car' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' makes']" G mb H , the company that makes the car , is also the name of the company that makes False " Trustees of the Robert Bosch Stiftung.
+" 6 [' Trust', 'ees', ' of', ' the', ' Robert', ' Bos', 'ch']
+171 48 The name of the CEO of x -1 The name of the CEO of Robert Bosch Stefan Hartung Robert Bosch "[' G' 'mb' 'H' ',' ' the' ' company' ' that' ' makes' ' the' ' car' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' makes']" G mb H , the company that makes the car , is also the name of the company that makes False University and the Robert Bosch Foundation of 5 [' University', ' and', ' the', ' Robert', ' Bos', 'ch']
+172 48 The name of the CEO of x -1 The name of the CEO of Robert Bosch Stefan Hartung Robert Bosch "[' G' 'mb' 'H' ',' ' the' ' company' ' that' ' makes' ' the' ' car' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' makes']" G mb H , the company that makes the car , is also the name of the company that makes False " Trustees of the Robert Bosch Stiftung.
+" 6 [' Trust', 'ees', ' of', ' the', ' Robert', ' Bos', 'ch']
+173 50 The name of the CEO of x -1 The name of the CEO of Orange Christel Heydemann Orange "[' County' ',' ' California' ',' ' is' ' Orange' ' County' ','
+ ' California' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Orange' ' County']" " County , California , is Orange County , California .
+
+ The name of the CEO of Orange County" False Belize City, Orange Walk, and Belmopan. 4 [' Bel', 'ize', ' City', ',', ' Orange']
+174 50 The name of the CEO of x -1 The name of the CEO of Orange Christel Heydemann Orange "[' County' ',' ' California' ',' ' is' ' Orange' ' County' ','
+ ' California' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Orange' ' County']" " County , California , is Orange County , California .
+
+ The name of the CEO of Orange County" False Terry's Chocolate Orange featuring a 3 "[' Terry', ""'s"", ' Chocolate', ' Orange']"
+175 50 The name of the CEO of x -1 The name of the CEO of Orange Christel Heydemann Orange "[' County' ',' ' California' ',' ' is' ' Orange' ' County' ','
+ ' California' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Orange' ' County']" " County , California , is Orange County , California .
+
+ The name of the CEO of Orange County" False Evaluation and the Orange Book. This project 3 [' Evaluation', ' and', ' the', ' Orange']
+176 50 The name of the CEO of x -1 The name of the CEO of Orange Christel Heydemann Orange "[' County' ',' ' California' ',' ' is' ' Orange' ' County' ','
+ ' California' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Orange' ' County']" " County , California , is Orange County , California .
+
+ The name of the CEO of Orange County" False (CR 106) in Orange and Rockland counties 5 [' (', 'CR', ' 106', ')', ' in', ' Orange']
+177 50 The name of the CEO of x -1 The name of the CEO of Orange Christel Heydemann Orange "[' County' ',' ' California' ',' ' is' ' Orange' ' County' ','
+ ' California' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Orange' ' County']" " County , California , is Orange County , California .
+
+ The name of the CEO of Orange County" False the Print. New York: Orange Angle Press. ISBN 6 [' the', ' Print', '.', ' New', ' York', ':', ' Orange']
+178 53 The name of the CEO of x -1 The name of the CEO of Foxconn Terry Gou Foxconn "[',' ' the' ' world' ""'s"" ' largest' ' contract' ' manufacturer' ' of'
+ ' electronics' ',' ' is' ' Terry' ' Gou' '.' ' He' ' is' ' also' ' the'
+ ' chairman' ' of']" , the world 's largest contract manufacturer of electronics , is Terry Gou . He is also the chairman of True " ""three to four thousand"" Foxconn workers who work at" 7 "[' ""', 'three', ' to', ' four', ' thousand', '""', ' Fox', 'conn']"
+179 53 The name of the CEO of x -1 The name of the CEO of Foxconn Terry Gou Foxconn "[',' ' the' ' world' ""'s"" ' largest' ' contract' ' manufacturer' ' of'
+ ' electronics' ',' ' is' ' Terry' ' Gou' '.' ' He' ' is' ' also' ' the'
+ ' chairman' ' of']" , the world 's largest contract manufacturer of electronics , is Terry Gou . He is also the chairman of True " four thousand"" Foxconn workers who" 4 "[' four', ' thousand', '""', ' Fox', 'conn']"
+180 55 The name of the CEO of x -1 The name of the CEO of Ubisoft Yves Guillemot Ubisoft "[',' ' Y' 'ves' ' Gu' 'ille' 'mot' ',' ' has' ' been' ' revealed' '.'
+ ' He' ' is' ' a' ' French' ' businessman' ' who' ' has' ' been' ' with']" , Y ves Gu ille mot , has been revealed . He is a French businessman who has been with True Total War and Ubisoft on Tom Clancy's EndWar 3 [' Total', ' War', ' and', ' Ubisoft']
+181 55 The name of the CEO of x -1 The name of the CEO of Ubisoft Yves Guillemot Ubisoft "[',' ' Y' 'ves' ' Gu' 'ille' 'mot' ',' ' has' ' been' ' revealed' '.'
+ ' He' ' is' ' a' ' French' ' businessman' ' who' ' has' ' been' ' with']" , Y ves Gu ille mot , has been revealed . He is a French businessman who has been with True " United States and by Ubisoft in Europe and Australia.
+" 4 [' United', ' States', ' and', ' by', ' Ubisoft']
+182 55 The name of the CEO of x -1 The name of the CEO of Ubisoft Yves Guillemot Ubisoft "[',' ' Y' 'ves' ' Gu' 'ille' 'mot' ',' ' has' ' been' ' revealed' '.'
+ ' He' ' is' ' a' ' French' ' businessman' ' who' ' has' ' been' ' with']" , Y ves Gu ille mot , has been revealed . He is a French businessman who has been with True being developed by Ubisoft San Francisco, 3 [' being', ' developed', ' by', ' Ubisoft']
+183 55 The name of the CEO of x -1 The name of the CEO of Ubisoft Yves Guillemot Ubisoft "[',' ' Y' 'ves' ' Gu' 'ille' 'mot' ',' ' has' ' been' ' revealed' '.'
+ ' He' ' is' ' a' ' French' ' businessman' ' who' ' has' ' been' ' with']" , Y ves Gu ille mot , has been revealed . He is a French businessman who has been with True team of around 20 at Ubisoft Paris, and was 5 [' team', ' of', ' around', ' 20', ' at', ' Ubisoft']
+184 55 The name of the CEO of x -1 The name of the CEO of Ubisoft Yves Guillemot Ubisoft "[',' ' Y' 'ves' ' Gu' 'ille' 'mot' ',' ' has' ' been' ' revealed' '.'
+ ' He' ' is' ' a' ' French' ' businessman' ' who' ' has' ' been' ' with']" , Y ves Gu ille mot , has been revealed . He is a French businessman who has been with True and published by Ubisoft for the Nintendo 3 [' and', ' published', ' by', ' Ubisoft']
+185 56 The name of the CEO of x -1 The name of the CEO of LVMH Bernard Arnault LVMH "[',' ' Bernard' ' Arn' 'ault' ',' ' is' ' a' ' French' ' word' ' that'
+ ' means' ' �' '�' 'son' ' of' ' Arn' 'ault' '�' '�' '.']" , Bernard Arn ault , is a French word that means � � son of Arn ault � � . True conglomerate LVMH in 2009. The relaunched 3 [' conglomerate', ' L', 'VM', 'H']
+186 56 The name of the CEO of x -1 The name of the CEO of LVMH Bernard Arnault LVMH "[',' ' Bernard' ' Arn' 'ault' ',' ' is' ' a' ' French' ' word' ' that'
+ ' means' ' �' '�' 'son' ' of' ' Arn' 'ault' '�' '�' '.']" , Bernard Arn ault , is a French word that means � � son of Arn ault � � . True conglomerate LVMH in 2009. The 3 [' conglomerate', ' L', 'VM', 'H']
+187 56 The name of the CEO of x -1 The name of the CEO of LVMH Bernard Arnault LVMH "[',' ' Bernard' ' Arn' 'ault' ',' ' is' ' a' ' French' ' word' ' that'
+ ' means' ' �' '�' 'son' ' of' ' Arn' 'ault' '�' '�' '.']" , Bernard Arn ault , is a French word that means � � son of Arn ault � � . True French conglomerate LVMH in 2009. The relaunched 4 [' French', ' conglomerate', ' L', 'VM', 'H']
+188 56 The name of the CEO of x -1 The name of the CEO of LVMH Bernard Arnault LVMH "[',' ' Bernard' ' Arn' 'ault' ',' ' is' ' a' ' French' ' word' ' that'
+ ' means' ' �' '�' 'son' ' of' ' Arn' 'ault' '�' '�' '.']" , Bernard Arn ault , is a French word that means � � son of Arn ault � � . True conglomerate LVMH has made substantial 3 [' conglomerate', ' L', 'VM', 'H']
+189 56 The name of the CEO of x -1 The name of the CEO of LVMH Bernard Arnault LVMH "[',' ' Bernard' ' Arn' 'ault' ',' ' is' ' a' ' French' ' word' ' that'
+ ' means' ' �' '�' 'son' ' of' ' Arn' 'ault' '�' '�' '.']" , Bernard Arn ault , is a French word that means � � son of Arn ault � � . True French conglomerate LVMH has made substantial 4 [' French', ' conglomerate', ' L', 'VM', 'H']
+190 58 The name of the CEO of x -1 The name of the CEO of General Electric H. Lawrence Culp Jr. General Electric "[',' ' the' ' world' '�' '�' 's' ' largest' ' manufacturer' ' of' ' jet'
+ ' engines' ',' ' is' ' Jack' ' Welch' '.' ' He' ' is' ' a' ' man']" , the world � � s largest manufacturer of jet engines , is Jack Welch . He is a man False was chosen. General Electric and the Metallurgical 4 [' was', ' chosen', '.', ' General', ' Electric']
+191 58 The name of the CEO of x -1 The name of the CEO of General Electric H. Lawrence Culp Jr. General Electric "[',' ' the' ' world' '�' '�' 's' ' largest' ' manufacturer' ' of' ' jet'
+ ' engines' ',' ' is' ' Jack' ' Welch' '.' ' He' ' is' ' a' ' man']" , the world � � s largest manufacturer of jet engines , is Jack Welch . He is a man False agreement with General Electric to certify the 747 3 [' agreement', ' with', ' General', ' Electric']
+192 58 The name of the CEO of x -1 The name of the CEO of General Electric H. Lawrence Culp Jr. General Electric "[',' ' the' ' world' '�' '�' 's' ' largest' ' manufacturer' ' of' ' jet'
+ ' engines' ',' ' is' ' Jack' ' Welch' '.' ' He' ' is' ' a' ' man']" , the world � � s largest manufacturer of jet engines , is Jack Welch . He is a man False music, using the General Electric Kinegraphone (later 5 [' music', ',', ' using', ' the', ' General', ' Electric']
+193 58 The name of the CEO of x -1 The name of the CEO of General Electric H. Lawrence Culp Jr. General Electric "[',' ' the' ' world' '�' '�' 's' ' largest' ' manufacturer' ' of' ' jet'
+ ' engines' ',' ' is' ' Jack' ' Welch' '.' ' He' ' is' ' a' ' man']" , the world � � s largest manufacturer of jet engines , is Jack Welch . He is a man False including Celgene, General Electric Ventures, and Eli Lilly, 6 [' including', ' Cel', 'g', 'ene', ',', ' General', ' Electric']
+194 58 The name of the CEO of x -1 The name of the CEO of General Electric H. Lawrence Culp Jr. General Electric "[',' ' the' ' world' '�' '�' 's' ' largest' ' manufacturer' ' of' ' jet'
+ ' engines' ',' ' is' ' Jack' ' Welch' '.' ' He' ' is' ' a' ' man']" , the world � � s largest manufacturer of jet engines , is Jack Welch . He is a man False Department at General Electric. In 1934, he returned 3 [' Department', ' at', ' General', ' Electric']
+195 59 The name of the CEO of x -1 The name of the CEO of Unilever Alan Jope Unilever "[',' ' the' ' world' '�' '�' 's' ' largest' ' advertis' 'er' ',' ' is'
+ ' Paul' ' Pol' 'man' '.' ' He' ' is' ' a' ' man' ' who']" , the world � � s largest advertis er , is Paul Pol man . He is a man who False companies such as Nike and Unilever have been used 7 [' companies', ' such', ' as', ' Nike', ' and', ' Un', 'ile', 'ver']
+196 59 The name of the CEO of x -1 The name of the CEO of Unilever Alan Jope Unilever "[',' ' the' ' world' '�' '�' 's' ' largest' ' advertis' 'er' ',' ' is'
+ ' Paul' ' Pol' 'man' '.' ' He' ' is' ' a' ' man' ' who']" , the world � � s largest advertis er , is Paul Pol man . He is a man who False exposure which Unilever has estimated to be 4 [' exposure', ' which', ' Un', 'ile', 'ver']
+197 59 The name of the CEO of x -1 The name of the CEO of Unilever Alan Jope Unilever "[',' ' the' ' world' '�' '�' 's' ' largest' ' advertis' 'er' ',' ' is'
+ ' Paul' ' Pol' 'man' '.' ' He' ' is' ' a' ' man' ' who']" , the world � � s largest advertis er , is Paul Pol man . He is a man who False " Evolution, was promoted to Unilever as an ""activation" 7 [' Evolution', ',', ' was', ' promoted', ' to', ' Un', 'ile', 'ver']
+198 59 The name of the CEO of x -1 The name of the CEO of Unilever Alan Jope Unilever "[',' ' the' ' world' '�' '�' 's' ' largest' ' advertis' 'er' ',' ' is'
+ ' Paul' ' Pol' 'man' '.' ' He' ' is' ' a' ' man' ' who']" , the world � � s largest advertis er , is Paul Pol man . He is a man who False with Hindustan Unilever to appear in television 6 [' with', ' Hind', 'ust', 'an', ' Un', 'ile', 'ver']
+199 59 The name of the CEO of x -1 The name of the CEO of Unilever Alan Jope Unilever "[',' ' the' ' world' '�' '�' 's' ' largest' ' advertis' 'er' ',' ' is'
+ ' Paul' ' Pol' 'man' '.' ' He' ' is' ' a' ' man' ' who']" , the world � � s largest advertis er , is Paul Pol man . He is a man who False Beauty was launched by Unilever in 2003, to coincide 6 [' Beauty', ' was', ' launched', ' by', ' Un', 'ile', 'ver']
+200 60 The name of the CEO of x -1 The name of the CEO of Adobe Shantanu Narayen Adobe "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' in' ' trouble' '.' ' It' ' is' ' not']" is not a secret . It is not a secret that the company is in trouble . It is not False Frauenfelder worked with Adobe Photoshop, 7 [' Fra', 'u', 'en', 'fe', 'lder', ' worked', ' with', ' Adobe']
+201 60 The name of the CEO of x -1 The name of the CEO of Adobe Shantanu Narayen Adobe "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' in' ' trouble' '.' ' It' ' is' ' not']" is not a secret . It is not a secret that the company is in trouble . It is not False Currently they are using Adobe Dynamic Streaming 4 [' Currently', ' they', ' are', ' using', ' Adobe']
+202 60 The name of the CEO of x -1 The name of the CEO of Adobe Shantanu Narayen Adobe "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' in' ' trouble' '.' ' It' ' is' ' not']" is not a secret . It is not a secret that the company is in trouble . It is not False was rewritten in Adobe AIR. In December 3 [' was', ' rewritten', ' in', ' Adobe']
+203 60 The name of the CEO of x -1 The name of the CEO of Adobe Shantanu Narayen Adobe "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' in' ' trouble' '.' ' It' ' is' ' not']" is not a secret . It is not a secret that the company is in trouble . It is not False favorites into Adobe Illustrator, where 2 [' favorites', ' into', ' Adobe']
+204 60 The name of the CEO of x -1 The name of the CEO of Adobe Shantanu Narayen Adobe "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' in' ' trouble' '.' ' It' ' is' ' not']" is not a secret . It is not a secret that the company is in trouble . It is not False American prairies. Adobe or mud-brick buildings 4 [' American', ' pra', 'iries', '.', ' Adobe']
+205 61 The name of the CEO of x -1 The name of the CEO of General Motors Mary Barra General Motors "[',' ' the' ' world' '�' '�' 's' ' largest' ' autom' 'aker' ',' ' is'
+ ' Mary' ' Bar' 'ra' '.' ' She' ' is' ' a' ' woman' '.']" , the world � � s largest autom aker , is Mary Bar ra . She is a woman . True In December 2012 General Motors reported, based on 4 [' In', ' December', ' 2012', ' General', ' Motors']
+206 61 The name of the CEO of x -1 The name of the CEO of General Motors Mary Barra General Motors "[',' ' the' ' world' '�' '�' 's' ' largest' ' autom' 'aker' ',' ' is'
+ ' Mary' ' Bar' 'ra' '.' ' She' ' is' ' a' ' woman' '.']" , the world � � s largest autom aker , is Mary Bar ra . She is a woman . True parent company General Motors was interested 3 [' parent', ' company', ' General', ' Motors']
+207 61 The name of the CEO of x -1 The name of the CEO of General Motors Mary Barra General Motors "[',' ' the' ' world' '�' '�' 's' ' largest' ' autom' 'aker' ',' ' is'
+ ' Mary' ' Bar' 'ra' '.' ' She' ' is' ' a' ' woman' '.']" , the world � � s largest autom aker , is Mary Bar ra . She is a woman . True In April 2008 General Motors started extensive 4 [' In', ' April', ' 2008', ' General', ' Motors']
+208 61 The name of the CEO of x -1 The name of the CEO of General Motors Mary Barra General Motors "[',' ' the' ' world' '�' '�' 's' ' largest' ' autom' 'aker' ',' ' is'
+ ' Mary' ' Bar' 'ra' '.' ' She' ' is' ' a' ' woman' '.']" , the world � � s largest autom aker , is Mary Bar ra . She is a woman . True 1 ['General', ' Motors']
+209 61 The name of the CEO of x -1 The name of the CEO of General Motors Mary Barra General Motors "[',' ' the' ' world' '�' '�' 's' ' largest' ' autom' 'aker' ',' ' is'
+ ' Mary' ' Bar' 'ra' '.' ' She' ' is' ' a' ' woman' '.']" , the world � � s largest autom aker , is Mary Bar ra . She is a woman . True In March 2012, General Motors announced that an 5 [' In', ' March', ' 2012', ',', ' General', ' Motors']
+210 62 The name of the CEO of x -1 The name of the CEO of Procter & Gamble Jon Moeller Procter & Gamble "[' is' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' CEO' ' of'
+ ' the' ' company' ' that' ' makes' ' Tide' '.' '\n' '\n' 'I' ""'m""]" " is the same as the name of the CEO of the company that makes Tide .
+
+ I 'm" False president of Procter & Gamble and was best 6 [' president', ' of', ' Pro', 'ct', 'er', ' &', ' Gamble']
+211 62 The name of the CEO of x -1 The name of the CEO of Procter & Gamble Jon Moeller Procter & Gamble "[' is' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' CEO' ' of'
+ ' the' ' company' ' that' ' makes' ' Tide' '.' '\n' '\n' 'I' ""'m""]" " is the same as the name of the CEO of the company that makes Tide .
+
+ I 'm" False von der Heyden), Procter & Gamble (David S. Taylor), 9 [' von', ' der', ' He', 'yden', '),', ' Pro', 'ct', 'er', ' &', ' Gamble']
+212 62 The name of the CEO of x -1 The name of the CEO of Procter & Gamble Jon Moeller Procter & Gamble "[' is' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' CEO' ' of'
+ ' the' ' company' ' that' ' makes' ' Tide' '.' '\n' '\n' 'I' ""'m""]" " is the same as the name of the CEO of the company that makes Tide .
+
+ I 'm" False 1912, respectively, and Procter & Gamble and the Beech-Nut Packing 9 [' 1912', ',', ' respectively', ',', ' and', ' Pro', 'ct', 'er', ' &', ' Gamble']
+213 62 The name of the CEO of x -1 The name of the CEO of Procter & Gamble Jon Moeller Procter & Gamble "[' is' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' CEO' ' of'
+ ' the' ' company' ' that' ' makes' ' Tide' '.' '\n' '\n' 'I' ""'m""]" " is the same as the name of the CEO of the company that makes Tide .
+
+ I 'm" False an ambassador by Procter & Gamble and the NGO Child 7 [' an', ' ambassador', ' by', ' Pro', 'ct', 'er', ' &', ' Gamble']
+214 62 The name of the CEO of x -1 The name of the CEO of Procter & Gamble Jon Moeller Procter & Gamble "[' is' ' the' ' same' ' as' ' the' ' name' ' of' ' the' ' CEO' ' of'
+ ' the' ' company' ' that' ' makes' ' Tide' '.' '\n' '\n' 'I' ""'m""]" " is the same as the name of the CEO of the company that makes Tide .
+
+ I 'm" False contract with Procter & Gamble and they refused 6 [' contract', ' with', ' Pro', 'ct', 'er', ' &', ' Gamble']
+215 64 The name of the CEO of x -1 The name of the CEO of eBay Jamie Iannone eBay "[' is' ' not' ' a' ' name' ' that' ' is' ' easy' ' to' ' remember' ','
+ ' but' ' it' ' is' ' a' ' name' ' that' ' is' ' easy' ' to' ' spell']" is not a name that is easy to remember , but it is a name that is easy to spell False raised through eBay auctions of Firefly 2 [' raised', ' through', ' eBay']
+216 64 The name of the CEO of x -1 The name of the CEO of eBay Jamie Iannone eBay "[' is' ' not' ' a' ' name' ' that' ' is' ' easy' ' to' ' remember' ','
+ ' but' ' it' ' is' ' a' ' name' ' that' ' is' ' easy' ' to' ' spell']" is not a name that is easy to remember , but it is a name that is easy to spell False the auction website eBay to flip the 3 [' the', ' auction', ' website', ' eBay']
+217 64 The name of the CEO of x -1 The name of the CEO of eBay Jamie Iannone eBay "[' is' ' not' ' a' ' name' ' that' ' is' ' easy' ' to' ' remember' ','
+ ' but' ' it' ' is' ' a' ' name' ' that' ' is' ' easy' ' to' ' spell']" is not a name that is easy to remember , but it is a name that is easy to spell False would be worth a bit on eBay these days, 6 [' would', ' be', ' worth', ' a', ' bit', ' on', ' eBay']
+218 64 The name of the CEO of x -1 The name of the CEO of eBay Jamie Iannone eBay "[' is' ' not' ' a' ' name' ' that' ' is' ' easy' ' to' ' remember' ','
+ ' but' ' it' ' is' ' a' ' name' ' that' ' is' ' easy' ' to' ' spell']" is not a name that is easy to remember , but it is a name that is easy to spell False available for auction on eBay, with all proceeds 4 [' available', ' for', ' auction', ' on', ' eBay']
+219 64 The name of the CEO of x -1 The name of the CEO of eBay Jamie Iannone eBay "[' is' ' not' ' a' ' name' ' that' ' is' ' easy' ' to' ' remember' ','
+ ' but' ' it' ' is' ' a' ' name' ' that' ' is' ' easy' ' to' ' spell']" is not a name that is easy to remember , but it is a name that is easy to spell False many of them on eBay and refabricate many 4 [' many', ' of', ' them', ' on', ' eBay']
+220 65 The name of the CEO of x -1 The name of the CEO of Lufthansa Carsten Spohr Lufthansa "[' is' ' Diet' 'er' ' Z' 'ets' 'che' '.' ' He' ' is' ' a' ' German'
+ ' businessman' ' and' ' the' ' former' ' CEO' ' of' ' D' 'aim' 'ler']" is Diet er Z ets che . He is a German businessman and the former CEO of D aim ler False was delivered to Lufthansa on 5 February 7 [' was', ' delivered', ' to', ' L', 'uf', 'th', 'ans', 'a']
+221 65 The name of the CEO of x -1 The name of the CEO of Lufthansa Carsten Spohr Lufthansa "[' is' ' Diet' 'er' ' Z' 'ets' 'che' '.' ' He' ' is' ' a' ' German'
+ ' businessman' ' and' ' the' ' former' ' CEO' ' of' ' D' 'aim' 'ler']" is Diet er Z ets che . He is a German businessman and the former CEO of D aim ler False 2001 after Lufthansa overtook BA in 6 [' 2001', ' after', ' L', 'uf', 'th', 'ans', 'a']
+222 65 The name of the CEO of x -1 The name of the CEO of Lufthansa Carsten Spohr Lufthansa "[' is' ' Diet' 'er' ' Z' 'ets' 'che' '.' ' He' ' is' ' a' ' German'
+ ' businessman' ' and' ' the' ' former' ' CEO' ' of' ' D' 'aim' 'ler']" is Diet er Z ets che . He is a German businessman and the former CEO of D aim ler False Airport. In 1966, Lufthansa started flights 9 [' Airport', '.', ' In', ' 1966', ',', ' L', 'uf', 'th', 'ans', 'a']
+223 65 The name of the CEO of x -1 The name of the CEO of Lufthansa Carsten Spohr Lufthansa "[' is' ' Diet' 'er' ' Z' 'ets' 'che' '.' ' He' ' is' ' a' ' German'
+ ' businessman' ' and' ' the' ' former' ' CEO' ' of' ' D' 'aim' 'ler']" is Diet er Z ets che . He is a German businessman and the former CEO of D aim ler False Fornebu was a Lufthansa Junkers Ju 9 [' For', 'ne', 'bu', ' was', ' a', ' L', 'uf', 'th', 'ans', 'a']
+224 65 The name of the CEO of x -1 The name of the CEO of Lufthansa Carsten Spohr Lufthansa "[' is' ' Diet' 'er' ' Z' 'ets' 'che' '.' ' He' ' is' ' a' ' German'
+ ' businessman' ' and' ' the' ' former' ' CEO' ' of' ' D' 'aim' 'ler']" is Diet er Z ets che . He is a German businessman and the former CEO of D aim ler False costs than Lufthansa and twenty percent 6 [' costs', ' than', ' L', 'uf', 'th', 'ans', 'a']
+225 66 The name of the CEO of x -1 The name of the CEO of BMW Oliver Zipse BMW "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' BMW' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of BMW is" False Ecclestone held BMW to their contract 4 [' Ecc', 'l', 'estone', ' held', ' BMW']
+226 66 The name of the CEO of x -1 The name of the CEO of BMW Oliver Zipse BMW "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' BMW' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of BMW is" False Ferraris, BMW Saubers, and 3 [' Ferr', 'aris', ',', ' BMW']
+227 66 The name of the CEO of x -1 The name of the CEO of BMW Oliver Zipse BMW "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' BMW' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of BMW is" False " Kubica third in a BMW Sauber.
+" 5 [' Kub', 'ica', ' third', ' in', ' a', ' BMW']
+228 66 The name of the CEO of x -1 The name of the CEO of BMW Oliver Zipse BMW "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' BMW' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of BMW is" False Volkswagen and BMW do not currently 2 [' Volkswagen', ' and', ' BMW']
+229 66 The name of the CEO of x -1 The name of the CEO of BMW Oliver Zipse BMW "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' BMW' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of BMW is" False 28,149, and the BMW i3 with 23,268 6 [' 28', ',', '149', ',', ' and', ' the', ' BMW']
+230 67 The name of the CEO of x -1 The name of the CEO of Johnson & Johnson Alex Gorsky Johnson & Johnson "[',' ' the' ' company' ' that' ' makes' ' the' ' baby' ' powder' ',' ' is'
+ ' Dr' '.' ' Charles' ' T' '.' ' R' 'essler' '.' ' He' ' is']" , the company that makes the baby powder , is Dr . Charles T . R essler . He is False the manufacturer Johnson & Johnson discontinued 4 [' the', ' manufacturer', ' Johnson', ' &', ' Johnson']
+231 67 The name of the CEO of x -1 The name of the CEO of Johnson & Johnson Alex Gorsky Johnson & Johnson "[',' ' the' ' company' ' that' ' makes' ' the' ' baby' ' powder' ',' ' is'
+ ' Dr' '.' ' Charles' ' T' '.' ' R' 'essler' '.' ' He' ' is']" , the company that makes the baby powder , is Dr . Charles T . R essler . He is False while companies like Johnson & Johnson supported the new 5 [' while', ' companies', ' like', ' Johnson', ' &', ' Johnson']
+232 67 The name of the CEO of x -1 The name of the CEO of Johnson & Johnson Alex Gorsky Johnson & Johnson "[',' ' the' ' company' ' that' ' makes' ' the' ' baby' ' powder' ',' ' is'
+ ' Dr' '.' ' Charles' ' T' '.' ' R' 'essler' '.' ' He' ' is']" , the company that makes the baby powder , is Dr . Charles T . R essler . He is False " manufactured by the Johnson & Johnson company.
+" 5 [' manufactured', ' by', ' the', ' Johnson', ' &', ' Johnson']
+233 67 The name of the CEO of x -1 The name of the CEO of Johnson & Johnson Alex Gorsky Johnson & Johnson "[',' ' the' ' company' ' that' ' makes' ' the' ' baby' ' powder' ',' ' is'
+ ' Dr' '.' ' Charles' ' T' '.' ' R' 'essler' '.' ' He' ' is']" , the company that makes the baby powder , is Dr . Charles T . R essler . He is False When the manufacturer Johnson & Johnson discontinued 5 [' When', ' the', ' manufacturer', ' Johnson', ' &', ' Johnson']
+234 67 The name of the CEO of x -1 The name of the CEO of Johnson & Johnson Alex Gorsky Johnson & Johnson "[',' ' the' ' company' ' that' ' makes' ' the' ' baby' ' powder' ',' ' is'
+ ' Dr' '.' ' Charles' ' T' '.' ' R' 'essler' '.' ' He' ' is']" , the company that makes the baby powder , is Dr . Charles T . R essler . He is False manufacturer Johnson & Johnson discontinued the product 3 [' manufacturer', ' Johnson', ' &', ' Johnson']
+235 69 The name of the CEO of x -1 The name of the CEO of PepsiCo Ramon Laguarta PepsiCo "[',' ' Indra' ' No' 'oy' 'i' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' CEO' ' of' ' Pepsi' 'Co']" ", Indra No oy i , is a woman .
+
+ The name of the CEO of Pepsi Co" False August 20, 2013, PepsiCo announced a partnership 6 [' August', ' 20', ',', ' 2013', ',', ' Pepsi', 'Co']
+236 69 The name of the CEO of x -1 The name of the CEO of PepsiCo Ramon Laguarta PepsiCo "[',' ' Indra' ' No' 'oy' 'i' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' CEO' ' of' ' Pepsi' 'Co']" ", Indra No oy i , is a woman .
+
+ The name of the CEO of Pepsi Co" False grant was provided by PepsiCo after the Theatre 5 [' grant', ' was', ' provided', ' by', ' Pepsi', 'Co']
+237 69 The name of the CEO of x -1 The name of the CEO of PepsiCo Ramon Laguarta PepsiCo "[',' ' Indra' ' No' 'oy' 'i' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' CEO' ' of' ' Pepsi' 'Co']" ", Indra No oy i , is a woman .
+
+ The name of the CEO of Pepsi Co" False corporation PepsiCo to promote their 2 [' corporation', ' Pepsi', 'Co']
+238 69 The name of the CEO of x -1 The name of the CEO of PepsiCo Ramon Laguarta PepsiCo "[',' ' Indra' ' No' 'oy' 'i' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' CEO' ' of' ' Pepsi' 'Co']" ", Indra No oy i , is a woman .
+
+ The name of the CEO of Pepsi Co" False " Acquisition by PepsiCo ==
+" 3 [' Acquisition', ' by', ' Pepsi', 'Co']
+239 69 The name of the CEO of x -1 The name of the CEO of PepsiCo Ramon Laguarta PepsiCo "[',' ' Indra' ' No' 'oy' 'i' ',' ' is' ' a' ' woman' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' CEO' ' of' ' Pepsi' 'Co']" ", Indra No oy i , is a woman .
+
+ The name of the CEO of Pepsi Co" False In August 1997, PepsiCo spun off its poorly 5 [' In', ' August', ' 1997', ',', ' Pepsi', 'Co']
+240 70 The name of the CEO of x -1 The name of the CEO of National Institutes of Health Lawrence A. Tabak National Institutes of Health "[' (' 'NI' 'H' ')' ' is' ' Dr' '.' ' Francis' ' Collins' '.' ' He' ' is'
+ ' a' ' very' ' smart' ' man' '.' ' He' ' is' ' a']" ( NI H ) is Dr . Francis Collins . He is a very smart man . He is a False job was at the National Institutes of Health in Bethesda where he 7 [' job', ' was', ' at', ' the', ' National', ' Institutes', ' of', ' Health']
+241 70 The name of the CEO of x -1 The name of the CEO of National Institutes of Health Lawrence A. Tabak National Institutes of Health "[' (' 'NI' 'H' ')' ' is' ' Dr' '.' ' Francis' ' Collins' '.' ' He' ' is'
+ ' a' ' very' ' smart' ' man' '.' ' He' ' is' ' a']" ( NI H ) is Dr . Francis Collins . He is a very smart man . He is a False billion to the National Institutes of Health over the next two 6 [' billion', ' to', ' the', ' National', ' Institutes', ' of', ' Health']
+242 70 The name of the CEO of x -1 The name of the CEO of National Institutes of Health Lawrence A. Tabak National Institutes of Health "[' (' 'NI' 'H' ')' ' is' ' Dr' '.' ' Francis' ' Collins' '.' ' He' ' is'
+ ' a' ' very' ' smart' ' man' '.' ' He' ' is' ' a']" ( NI H ) is Dr . Francis Collins . He is a very smart man . He is a False 1966 – 1973 6 [' 1966', ' –', ' 197', 'National', ' Institutes', ' of', ' Health']
+243 70 The name of the CEO of x -1 The name of the CEO of National Institutes of Health Lawrence A. Tabak National Institutes of Health "[' (' 'NI' 'H' ')' ' is' ' Dr' '.' ' Francis' ' Collins' '.' ' He' ' is'
+ ' a' ' very' ' smart' ' man' '.' ' He' ' is' ' a']" ( NI H ) is Dr . Francis Collins . He is a very smart man . He is a False 978,000 from the National Institutes of Health with cannabis 8 [' 978', ',', '000', ' from', ' the', ' National', ' Institutes', ' of', ' Health']
+244 70 The name of the CEO of x -1 The name of the CEO of National Institutes of Health Lawrence A. Tabak National Institutes of Health "[' (' 'NI' 'H' ')' ' is' ' Dr' '.' ' Francis' ' Collins' '.' ' He' ' is'
+ ' a' ' very' ' smart' ' man' '.' ' He' ' is' ' a']" ( NI H ) is Dr . Francis Collins . He is a very smart man . He is a False Marston joined the National Institutes of Health (NIH) as a 7 [' Mar', 'ston', ' joined', ' the', ' National', ' Institutes', ' of', ' Health']
+245 71 The name of the CEO of x -1 The name of the CEO of Shell Wael Sawan Shell "[' Oil' ',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' company'
+ ' that' ' owns' ' the' ' oil' ' company' ' that' ' owns' ' the' ' oil'
+ ' company' ' that']" Oil , the company that owns the oil company that owns the oil company that owns the oil company that False Ghost in the Shell is a cyberpunk-themed 3 [' Ghost', ' in', ' the', ' Shell']
+246 71 The name of the CEO of x -1 The name of the CEO of Shell Wael Sawan Shell "[' Oil' ',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' company'
+ ' that' ' owns' ' the' ' oil' ' company' ' that' ' owns' ' the' ' oil'
+ ' company' ' that']" Oil , the company that owns the oil company that owns the oil company that owns the oil company that False and Wikipedia. Tim Shell and Michael David 4 [' and', ' Wikipedia', '.', ' Tim', ' Shell']
+247 71 The name of the CEO of x -1 The name of the CEO of Shell Wael Sawan Shell "[' Oil' ',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' company'
+ ' that' ' owns' ' the' ' oil' ' company' ' that' ' owns' ' the' ' oil'
+ ' company' ' that']" Oil , the company that owns the oil company that owns the oil company that owns the oil company that False namespace 2 [' names', 'pac', 'Shell']
+248 71 The name of the CEO of x -1 The name of the CEO of Shell Wael Sawan Shell "[' Oil' ',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' company'
+ ' that' ' owns' ' the' ' oil' ' company' ' that' ' owns' ' the' ' oil'
+ ' company' ' that']" Oil , the company that owns the oil company that owns the oil company that owns the oil company that False who worked at a Shell station opposite 4 [' who', ' worked', ' at', ' a', ' Shell']
+249 71 The name of the CEO of x -1 The name of the CEO of Shell Wael Sawan Shell "[' Oil' ',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' company'
+ ' that' ' owns' ' the' ' oil' ' company' ' that' ' owns' ' the' ' oil'
+ ' company' ' that']" Oil , the company that owns the oil company that owns the oil company that owns the oil company that False performance on the Shell Chateau Hour. 3 [' performance', ' on', ' the', ' Shell']
+250 74 The name of the CEO of x -1 The name of the CEO of Starbucks Kevin Johnson Starbucks "[' is' ' Howard' ' Schultz' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan'
+ ' of' ' Starbucks' '.' ' I' ' love' ' their' ' coffee' ' and' ' their']" " is Howard Schultz .
+
+ I am a huge fan of Starbucks . I love their coffee and their" False June 1997. In 2002 a Starbucks coffee (discontinued 6 [' June', ' 1997', '.', ' In', ' 2002', ' a', ' Starbucks']
+251 74 The name of the CEO of x -1 The name of the CEO of Starbucks Kevin Johnson Starbucks "[' is' ' Howard' ' Schultz' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan'
+ ' of' ' Starbucks' '.' ' I' ' love' ' their' ' coffee' ' and' ' their']" " is Howard Schultz .
+
+ I am a huge fan of Starbucks . I love their coffee and their" False compete with Starbucks by serving cappuccino, 2 [' compete', ' with', ' Starbucks']
+252 74 The name of the CEO of x -1 The name of the CEO of Starbucks Kevin Johnson Starbucks "[' is' ' Howard' ' Schultz' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan'
+ ' of' ' Starbucks' '.' ' I' ' love' ' their' ' coffee' ' and' ' their']" " is Howard Schultz .
+
+ I am a huge fan of Starbucks . I love their coffee and their" False floor is occupied by a Starbucks and an Oceanaire 5 [' floor', ' is', ' occupied', ' by', ' a', ' Starbucks']
+253 74 The name of the CEO of x -1 The name of the CEO of Starbucks Kevin Johnson Starbucks "[' is' ' Howard' ' Schultz' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan'
+ ' of' ' Starbucks' '.' ' I' ' love' ' their' ' coffee' ' and' ' their']" " is Howard Schultz .
+
+ I am a huge fan of Starbucks . I love their coffee and their" False ground floor is a Starbucks coffee shop 4 [' ground', ' floor', ' is', ' a', ' Starbucks']
+254 74 The name of the CEO of x -1 The name of the CEO of Starbucks Kevin Johnson Starbucks "[' is' ' Howard' ' Schultz' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan'
+ ' of' ' Starbucks' '.' ' I' ' love' ' their' ' coffee' ' and' ' their']" " is Howard Schultz .
+
+ I am a huge fan of Starbucks . I love their coffee and their" False bomb outside a Starbucks Coffee shop in the 3 [' bomb', ' outside', ' a', ' Starbucks']
+255 75 The name of the CEO of x -1 The name of the CEO of ExxonMobil Darren Woods ExxonMobil "[' is' ' Rex' ' Tillerson' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' a' ' friend' ' of' ' the' ' Koch' ' brothers' ' for' ' years'
+ '.']" is Rex Tillerson . He is a man who has been a friend of the Koch brothers for years . False of the storm, ExxonMobil oil refineries in 5 [' of', ' the', ' storm', ',', ' Exxon', 'Mobil']
+256 75 The name of the CEO of x -1 The name of the CEO of ExxonMobil Darren Woods ExxonMobil "[' is' ' Rex' ' Tillerson' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' a' ' friend' ' of' ' the' ' Koch' ' brothers' ' for' ' years'
+ '.']" is Rex Tillerson . He is a man who has been a friend of the Koch brothers for years . False followed in order by ExxonMobil (20.34 percent), 5 [' followed', ' in', ' order', ' by', ' Exxon', 'Mobil']
+257 75 The name of the CEO of x -1 The name of the CEO of ExxonMobil Darren Woods ExxonMobil "[' is' ' Rex' ' Tillerson' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' a' ' friend' ' of' ' the' ' Koch' ' brothers' ' for' ' years'
+ '.']" is Rex Tillerson . He is a man who has been a friend of the Koch brothers for years . False ConocoPhillips and ExxonMobil remained on site. The 6 [' Con', 'oco', 'Phill', 'ips', ' and', ' Exxon', 'Mobil']
+258 75 The name of the CEO of x -1 The name of the CEO of ExxonMobil Darren Woods ExxonMobil "[' is' ' Rex' ' Tillerson' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' a' ' friend' ' of' ' the' ' Koch' ' brothers' ' for' ' years'
+ '.']" is Rex Tillerson . He is a man who has been a friend of the Koch brothers for years . False companies such as ExxonMobil have challenged 4 [' companies', ' such', ' as', ' Exxon', 'Mobil']
+259 75 The name of the CEO of x -1 The name of the CEO of ExxonMobil Darren Woods ExxonMobil "[' is' ' Rex' ' Tillerson' '.' ' He' ' is' ' a' ' man' ' who' ' has'
+ ' been' ' a' ' friend' ' of' ' the' ' Koch' ' brothers' ' for' ' years'
+ '.']" is Rex Tillerson . He is a man who has been a friend of the Koch brothers for years . False of the storm, ExxonMobil oil refineries in 5 [' of', ' the', ' storm', ',', ' Exxon', 'Mobil']
+260 76 The name of the CEO of x -1 The name of the CEO of Bank of America Brian Moynihan Bank of America "[' is' ' Brian' ' Moy' 'nih' 'an' '.' ' He' ' is' ' a' ' very' ' nice'
+ ' man' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Brian Moy nih an . He is a very nice man . He is a very nice man . True weekly event at Bank of America Stadium which 5 [' weekly', ' event', ' at', ' Bank', ' of', ' America']
+261 76 The name of the CEO of x -1 The name of the CEO of Bank of America Brian Moynihan Bank of America "[' is' ' Brian' ' Moy' 'nih' 'an' '.' ' He' ' is' ' a' ' very' ' nice'
+ ' man' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Brian Moy nih an . He is a very nice man . He is a very nice man . True sponsored by Bank of America following the 4 [' sponsored', ' by', ' Bank', ' of', ' America']
+262 76 The name of the CEO of x -1 The name of the CEO of Bank of America Brian Moynihan Bank of America "[' is' ' Brian' ' Moy' 'nih' 'an' '.' ' He' ' is' ' a' ' very' ' nice'
+ ' man' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Brian Moy nih an . He is a very nice man . He is a very nice man . True MBNA America and Bank of America since 1999; as a Bank 6 [' MB', 'NA', ' America', ' and', ' Bank', ' of', ' America']
+263 76 The name of the CEO of x -1 The name of the CEO of Bank of America Brian Moynihan Bank of America "[' is' ' Brian' ' Moy' 'nih' 'an' '.' ' He' ' is' ' a' ' very' ' nice'
+ ' man' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Brian Moy nih an . He is a very nice man . He is a very nice man . True Product Red and Bank of America to fight AIDS. 5 [' Product', ' Red', ' and', ' Bank', ' of', ' America']
+264 76 The name of the CEO of x -1 The name of the CEO of Bank of America Brian Moynihan Bank of America "[' is' ' Brian' ' Moy' 'nih' 'an' '.' ' He' ' is' ' a' ' very' ' nice'
+ ' man' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Brian Moy nih an . He is a very nice man . He is a very nice man . True open-air fields next to Bank of America Stadium where they 8 [' open', '-', 'air', ' fields', ' next', ' to', ' Bank', ' of', ' America']
+265 77 The name of the CEO of x -1 The name of the CEO of Bayer Werner Baumann Bayer "[',' ' the' ' German' ' pharmaceutical' ' giant' ',' ' is' ' Dr' '.'
+ ' Werner' ' Ba' 'um' 'gart' 'ner' '.' ' He' ' is' ' a' ' man' ' of']" , the German pharmaceutical giant , is Dr . Werner Ba um gart ner . He is a man of False statistical output in Bayern's next five league 3 [' statistical', ' output', ' in', ' Bayer']
+266 77 The name of the CEO of x -1 The name of the CEO of Bayer Werner Baumann Bayer "[',' ' the' ' German' ' pharmaceutical' ' giant' ',' ' is' ' Dr' '.'
+ ' Werner' ' Ba' 'um' 'gart' 'ner' '.' ' He' ' is' ' a' ' man' ' of']" , the German pharmaceutical giant , is Dr . Werner Ba um gart ner . He is a man of False Sondheim, while Jeff Bayer of The Scorecard 6 [' S', 'ond', 'heim', ',', ' while', ' Jeff', ' Bayer']
+267 77 The name of the CEO of x -1 The name of the CEO of Bayer Werner Baumann Bayer "[',' ' the' ' German' ' pharmaceutical' ' giant' ',' ' is' ' Dr' '.'
+ ' Werner' ' Ba' 'um' 'gart' 'ner' '.' ' He' ' is' ' a' ' man' ' of']" , the German pharmaceutical giant , is Dr . Werner Ba um gart ner . He is a man of False gave nine stars Bayer designations, 3 [' gave', ' nine', ' stars', ' Bayer']
+268 77 The name of the CEO of x -1 The name of the CEO of Bayer Werner Baumann Bayer "[',' ' the' ' German' ' pharmaceutical' ' giant' ',' ' is' ' Dr' '.'
+ ' Werner' ' Ba' 'um' 'gart' 'ner' '.' ' He' ' is' ' a' ' man' ' of']" , the German pharmaceutical giant , is Dr . Werner Ba um gart ner . He is a man of False also known by its Bayer designation Beta Persei, 4 [' also', ' known', ' by', ' its', ' Bayer']
+269 77 The name of the CEO of x -1 The name of the CEO of Bayer Werner Baumann Bayer "[',' ' the' ' German' ' pharmaceutical' ' giant' ',' ' is' ' Dr' '.'
+ ' Werner' ' Ba' 'um' 'gart' 'ner' '.' ' He' ' is' ' a' ' man' ' of']" , the German pharmaceutical giant , is Dr . Werner Ba um gart ner . He is a man of False current name by Johann Bayer in his 1603 4 [' current', ' name', ' by', ' Johann', ' Bayer']
+270 79 The name of the CEO of x -1 The name of the CEO of Gazprom Alexey Miller Gazprom "[',' ' the' ' Russian' ' gas' ' giant' ',' ' is' ' Sergey' ' Sh' 'akh'
+ 'ray' '.' ' He' ' is' ' a' ' former' ' KGB' ' agent' ' who' ' was']" , the Russian gas giant , is Sergey Sh akh ray . He is a former KGB agent who was False state gas company Gazprom announced a deal 4 [' state', ' gas', ' company', ' Gaz', 'prom']
+271 80 The name of the CEO of x -1 The name of the CEO of Pfizer Albert Bourla Pfizer "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' Vi' 'agra' ','
+ ' is' ' on' ' the' ' list' '.' '\n' '\n' 'The' ' list' ' of' ' the']" ", the pharmaceutical company that makes Vi agra , is on the list .
+
+ The list of the" False Conference in New York, Pfizer reported the development 6 [' Conference', ' in', ' New', ' York', ',', ' Pf', 'izer']
+272 80 The name of the CEO of x -1 The name of the CEO of Pfizer Albert Bourla Pfizer "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' Vi' 'agra' ','
+ ' is' ' on' ' the' ' list' '.' '\n' '\n' 'The' ' list' ' of' ' the']" ", the pharmaceutical company that makes Vi agra , is on the list .
+
+ The list of the" False " ==== Lawsuits with Pfizer ====
+" 6 [' =', '===', ' Law', 'suits', ' with', ' Pf', 'izer']
+273 80 The name of the CEO of x -1 The name of the CEO of Pfizer Albert Bourla Pfizer "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' Vi' 'agra' ','
+ ' is' ' on' ' the' ' list' '.' '\n' '\n' 'The' ' list' ' of' ' the']" ", the pharmaceutical company that makes Vi agra , is on the list .
+
+ The list of the" False as consultants for Pfizer, which at the 4 [' as', ' consultants', ' for', ' Pf', 'izer']
+274 80 The name of the CEO of x -1 The name of the CEO of Pfizer Albert Bourla Pfizer "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' Vi' 'agra' ','
+ ' is' ' on' ' the' ' list' '.' '\n' '\n' 'The' ' list' ' of' ' the']" ", the pharmaceutical company that makes Vi agra , is on the list .
+
+ The list of the" False " manufacturer, Pfizer in 2010.
+" 3 [' manufacturer', ',', ' Pf', 'izer']
+275 80 The name of the CEO of x -1 The name of the CEO of Pfizer Albert Bourla Pfizer "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' Vi' 'agra' ','
+ ' is' ' on' ' the' ' list' '.' '\n' '\n' 'The' ' list' ' of' ' the']" ", the pharmaceutical company that makes Vi agra , is on the list .
+
+ The list of the" False consultants for Pfizer, which at the 3 [' consultants', ' for', ' Pf', 'izer']
+276 81 The name of the CEO of x -1 The name of the CEO of American Airlines Doug Parker American Airlines "[' is' ' Tom' ' Horton' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.'
+ ' He' ' is' ' a' ' very' ' nice' ' man' '.' ' He' ' is']" is Tom Horton . He is a very nice man . He is a very nice man . He is False flights serving Kauai; American Airlines and numerous 6 [' flights', ' serving', ' Kau', 'ai', ';', ' American', ' Airlines']
+277 81 The name of the CEO of x -1 The name of the CEO of American Airlines Doug Parker American Airlines "[' is' ' Tom' ' Horton' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.'
+ ' He' ' is' ' a' ' very' ' nice' ' man' '.' ' He' ' is']" is Tom Horton . He is a very nice man . He is a very nice man . He is False " ""big three"" at the American Airlines Arena, an event" 7 "[' ""', 'big', ' three', '""', ' at', ' the', ' American', ' Airlines']"
+278 81 The name of the CEO of x -1 The name of the CEO of American Airlines Doug Parker American Airlines "[' is' ' Tom' ' Horton' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.'
+ ' He' ' is' ' a' ' very' ' nice' ' man' '.' ' He' ' is']" is Tom Horton . He is a very nice man . He is a very nice man . He is False main towers. American Airlines Flight 11, a 767-200ER, 4 [' main', ' towers', '.', ' American', ' Airlines']
+279 81 The name of the CEO of x -1 The name of the CEO of American Airlines Doug Parker American Airlines "[' is' ' Tom' ' Horton' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.'
+ ' He' ' is' ' a' ' very' ' nice' ' man' '.' ' He' ' is']" is Tom Horton . He is a very nice man . He is a very nice man . He is False and sixth worldwide. American Airlines Group's American 5 [' and', ' sixth', ' worldwide', '.', ' American', ' Airlines']
+280 81 The name of the CEO of x -1 The name of the CEO of American Airlines Doug Parker American Airlines "[' is' ' Tom' ' Horton' '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.'
+ ' He' ' is' ' a' ' very' ' nice' ' man' '.' ' He' ' is']" is Tom Horton . He is a very nice man . He is a very nice man . He is False Airlines (91), American Airlines (84), UPS Airlines 5 [' Airlines', ' (', '91', '),', ' American', ' Airlines']
+281 83 The name of the CEO of x -1 The name of the CEO of Nvidia Jen-Hsun Huang Nvidia "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' GPU'
+ ' architecture']" is not a secret . It is not a secret that the company is working on a new GPU architecture False and ran on the Nvidia GeForce 4 Ti GPU, 4 [' and', ' ran', ' on', ' the', ' Nvidia']
+282 83 The name of the CEO of x -1 The name of the CEO of Nvidia Jen-Hsun Huang Nvidia "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' GPU'
+ ' architecture']" is not a secret . It is not a secret that the company is working on a new GPU architecture False featuring a new Nvidia chipset at a Cupertino, 3 [' featuring', ' a', ' new', ' Nvidia']
+283 83 The name of the CEO of x -1 The name of the CEO of Nvidia Jen-Hsun Huang Nvidia "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' GPU'
+ ' architecture']" is not a secret . It is not a secret that the company is working on a new GPU architecture False support DirectX 12, and Nvidia revealed a new 5 [' support', ' DirectX', ' 12', ',', ' and', ' Nvidia']
+284 83 The name of the CEO of x -1 The name of the CEO of Nvidia Jen-Hsun Huang Nvidia "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' GPU'
+ ' architecture']" is not a secret . It is not a secret that the company is working on a new GPU architecture False 5, 2007 with new Nvidia Geforce 8600M 5 [' 5', ',', ' 2007', ' with', ' new', ' Nvidia']
+285 83 The name of the CEO of x -1 The name of the CEO of Nvidia Jen-Hsun Huang Nvidia "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' not' ' a' ' secret' ' that'
+ ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' GPU'
+ ' architecture']" is not a secret . It is not a secret that the company is working on a new GPU architecture False GPUs with both Nvidia and AMD releasing programming 3 [' GPUs', ' with', ' both', ' Nvidia']
+286 84 The name of the CEO of x -1 The name of the CEO of DuPont Charles O. Holliday DuPont "[',' ' the' ' company' ' that' ' makes' ' the' ' chemical' ' that' ' is'
+ ' used' ' to' ' make' ' the' ' plastic' ' in' ' the' ' bottle' ',' ' is'
+ ' a']" , the company that makes the chemical that is used to make the plastic in the bottle , is a False construction project. DuPont advertised for workers 4 [' construction', ' project', '.', ' Du', 'Pont']
+287 84 The name of the CEO of x -1 The name of the CEO of DuPont Charles O. Holliday DuPont "[',' ' the' ' company' ' that' ' makes' ' the' ' chemical' ' that' ' is'
+ ' used' ' to' ' make' ' the' ' plastic' ' in' ' the' ' bottle' ',' ' is'
+ ' a']" , the company that makes the chemical that is used to make the plastic in the bottle , is a False 1944, by which time DuPont began transferring 6 [' 1944', ',', ' by', ' which', ' time', ' Du', 'Pont']
+288 84 The name of the CEO of x -1 The name of the CEO of DuPont Charles O. Holliday DuPont "[',' ' the' ' company' ' that' ' makes' ' the' ' chemical' ' that' ' is'
+ ' used' ' to' ' make' ' the' ' plastic' ' in' ' the' ' bottle' ',' ' is'
+ ' a']" , the company that makes the chemical that is used to make the plastic in the bottle , is a False reactor to DuPont in January 1943. 3 [' reactor', ' to', ' Du', 'Pont']
+289 84 The name of the CEO of x -1 The name of the CEO of DuPont Charles O. Holliday DuPont "[',' ' the' ' company' ' that' ' makes' ' the' ' chemical' ' that' ' is'
+ ' used' ' to' ' make' ' the' ' plastic' ' in' ' the' ' bottle' ',' ' is'
+ ' a']" , the company that makes the chemical that is used to make the plastic in the bottle , is a False followed US 13 along the DuPont Highway be 6 [' followed', ' US', ' 13', ' along', ' the', ' Du', 'Pont']
+290 84 The name of the CEO of x -1 The name of the CEO of DuPont Charles O. Holliday DuPont "[',' ' the' ' company' ' that' ' makes' ' the' ' chemical' ' that' ' is'
+ ' used' ' to' ' make' ' the' ' plastic' ' in' ' the' ' bottle' ',' ' is'
+ ' a']" , the company that makes the chemical that is used to make the plastic in the bottle , is a False Groves recruited the DuPont Company to be the 5 [' Gro', 'ves', ' recruited', ' the', ' Du', 'Pont']
+291 85 The name of the CEO of x -1 The name of the CEO of Porsche Oliver Blume Porsche "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Porsche' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Porsche is" False Mouton drove a Porsche Carrera RS to victory 5 [' M', 'out', 'on', ' drove', ' a', ' Porsche']
+292 85 The name of the CEO of x -1 The name of the CEO of Porsche Oliver Blume Porsche "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Porsche' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Porsche is" False championship and shared a Porsche 962 at Le Mans 4 [' championship', ' and', ' shared', ' a', ' Porsche']
+293 85 The name of the CEO of x -1 The name of the CEO of Porsche Oliver Blume Porsche "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Porsche' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Porsche is" False replacing the entire Porsche front suspension 3 [' replacing', ' the', ' entire', ' Porsche']
+294 85 The name of the CEO of x -1 The name of the CEO of Porsche Oliver Blume Porsche "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Porsche' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Porsche is" False later, he drove a Porsche 907 with Gerhard 5 [' later', ',', ' he', ' drove', ' a', ' Porsche']
+295 85 The name of the CEO of x -1 The name of the CEO of Porsche Oliver Blume Porsche "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Porsche' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Porsche is" False extensively modified Porsche 956s and 962Cs, 2 [' extensively', ' modified', ' Porsche']
+296 86 The name of the CEO of x -1 The name of the CEO of BASF Martin Brudermüller BASF "[',' ' the' ' world' '�' '�' 's' ' largest' ' chemical' ' company' ','
+ ' is' ' Ludwig' ' Er' 'hard' '.' ' He' ' was' ' a' ' German' ' economist']" , the world � � s largest chemical company , is Ludwig Er hard . He was a German economist False received three BASF Awards, sponsored 3 [' received', ' three', ' BAS', 'F']
+297 86 The name of the CEO of x -1 The name of the CEO of BASF Martin Brudermüller BASF "[',' ' the' ' world' '�' '�' 's' ' largest' ' chemical' ' company' ','
+ ' is' ' Ludwig' ' Er' 'hard' '.' ' He' ' was' ' a' ' German' ' economist']" , the world � � s largest chemical company , is Ludwig Er hard . He was a German economist False time, a group at BASF led by Carl Bosch 6 [' time', ',', ' a', ' group', ' at', ' BAS', 'F']
+298 86 The name of the CEO of x -1 The name of the CEO of BASF Martin Brudermüller BASF "[',' ' the' ' world' '�' '�' 's' ' largest' ' chemical' ' company' ','
+ ' is' ' Ludwig' ' Er' 'hard' '.' ' He' ' was' ' a' ' German' ' economist']" , the world � � s largest chemical company , is Ludwig Er hard . He was a German economist False a group at BASF led by Carl Bosch 4 [' a', ' group', ' at', ' BAS', 'F']
+299 86 The name of the CEO of x -1 The name of the CEO of BASF Martin Brudermüller BASF "[',' ' the' ' world' '�' '�' 's' ' largest' ' chemical' ' company' ','
+ ' is' ' Ludwig' ' Er' 'hard' '.' ' He' ' was' ' a' ' German' ' economist']" , the world � � s largest chemical company , is Ludwig Er hard . He was a German economist False chemical company BASF in 1963. In 3 [' chemical', ' company', ' BAS', 'F']
+300 86 The name of the CEO of x -1 The name of the CEO of BASF Martin Brudermüller BASF "[',' ' the' ' world' '�' '�' 's' ' largest' ' chemical' ' company' ','
+ ' is' ' Ludwig' ' Er' 'hard' '.' ' He' ' was' ' a' ' German' ' economist']" , the world � � s largest chemical company , is Ludwig Er hard . He was a German economist False received three BASF Awards, sponsored 3 [' received', ' three', ' BAS', 'F']
+301 87 The name of the CEO of x -1 The name of the CEO of AMD Lisa Su AMD "[' is' ' not' ' AMD' ',' ' it' ""'s"" ' AMD' '.' '\n' '\n' 'I' ""'m"" ' not'
+ ' sure' ' if' ' you' ""'re"" ' aware' ' of' ' this']" " is not AMD , it 's AMD .
+
+ I 'm not sure if you 're aware of this" False stutters. Nvidia and AMD released new 5 [' st', 'utters', '.', ' Nvidia', ' and', ' AMD']
+302 87 The name of the CEO of x -1 The name of the CEO of AMD Lisa Su AMD "[' is' ' not' ' AMD' ',' ' it' ""'s"" ' AMD' '.' '\n' '\n' 'I' ""'m"" ' not'
+ ' sure' ' if' ' you' ""'re"" ' aware' ' of' ' this']" " is not AMD , it 's AMD .
+
+ I 'm not sure if you 're aware of this" False support for Nvidia and AMD graphics cards under 4 [' support', ' for', ' Nvidia', ' and', ' AMD']
+303 87 The name of the CEO of x -1 The name of the CEO of AMD Lisa Su AMD "[' is' ' not' ' AMD' ',' ' it' ""'s"" ' AMD' '.' '\n' '\n' 'I' ""'m"" ' not'
+ ' sure' ' if' ' you' ""'re"" ' aware' ' of' ' this']" " is not AMD , it 's AMD .
+
+ I 'm not sure if you 're aware of this" False " One is powered by an AMD ""Jaguar"" Accelerated" 5 [' One', ' is', ' powered', ' by', ' an', ' AMD']
+304 87 The name of the CEO of x -1 The name of the CEO of AMD Lisa Su AMD "[' is' ' not' ' AMD' ',' ' it' ""'s"" ' AMD' '.' '\n' '\n' 'I' ""'m"" ' not'
+ ' sure' ' if' ' you' ""'re"" ' aware' ' of' ' this']" " is not AMD , it 's AMD .
+
+ I 'm not sure if you 're aware of this" False DDR3 slots, dual AMD FirePro D series 5 [' DDR', '3', ' slots', ',', ' dual', ' AMD']
+305 87 The name of the CEO of x -1 The name of the CEO of AMD Lisa Su AMD "[' is' ' not' ' AMD' ',' ' it' ""'s"" ' AMD' '.' '\n' '\n' 'I' ""'m"" ' not'
+ ' sure' ' if' ' you' ""'re"" ' aware' ' of' ' this']" " is not AMD , it 's AMD .
+
+ I 'm not sure if you 're aware of this" False notebook with an AMD graphics card. The 3 [' notebook', ' with', ' an', ' AMD']
+306 88 The name of the CEO of x -1 The name of the CEO of Nissan Makoto Uchida Nissan "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Nissan' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Nissan is" False in March 2015. The Nissan Leaf has topped 5 [' in', ' March', ' 2015', '.', ' The', ' Nissan']
+307 88 The name of the CEO of x -1 The name of the CEO of Nissan Makoto Uchida Nissan "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Nissan' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Nissan is" False " also featured in a Nissan commercial.
+" 4 [' also', ' featured', ' in', ' a', ' Nissan']
+308 88 The name of the CEO of x -1 The name of the CEO of Nissan Makoto Uchida Nissan "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Nissan' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Nissan is" False US $ 10,000 and the Nissan Leaf pack costs 7 [' US', ' $', ' 10', ',', '000', ' and', ' the', ' Nissan']
+309 88 The name of the CEO of x -1 The name of the CEO of Nissan Makoto Uchida Nissan "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Nissan' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Nissan is" False and RadioShack – Nissan rider Jens Voigt, 5 [' and', ' Radio', 'Sh', 'ack', ' –', ' Nissan']
+310 88 The name of the CEO of x -1 The name of the CEO of Nissan Makoto Uchida Nissan "[' is' ' not' ' a' ' secret' '.' ' It' ' is' ' a' ' secret' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' Nissan' ' is']" " is not a secret . It is a secret .
+
+ The name of the CEO of Nissan is" False 1 ['N', 'issan']
+311 89 The name of the CEO of x -1 The name of the CEO of Kodak Jeff Clarke Kodak "[' is' ' not' ' Kod' 'ak' '.' ' It' ' is' ' Kod' 'ak' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' company' ' is' ' Kod' 'ak']" " is not Kod ak . It is Kod ak .
+
+ The name of the company is Kod ak" False representative for Kodak said Cosby had become 3 [' representative', ' for', ' Kod', 'ak']
+312 89 The name of the CEO of x -1 The name of the CEO of Kodak Jeff Clarke Kodak "[' is' ' not' ' Kod' 'ak' '.' ' It' ' is' ' Kod' 'ak' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' company' ' is' ' Kod' 'ak']" " is not Kod ak . It is Kod ak .
+
+ The name of the company is Kod ak" False " area known as Kodak Park.
+" 4 [' area', ' known', ' as', ' Kod', 'ak']
+313 89 The name of the CEO of x -1 The name of the CEO of Kodak Jeff Clarke Kodak "[' is' ' not' ' Kod' 'ak' '.' ' It' ' is' ' Kod' 'ak' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' company' ' is' ' Kod' 'ak']" " is not Kod ak . It is Kod ak .
+
+ The name of the company is Kod ak" False 3, 2002 at the Kodak Theatre in Los 6 [' 3', ',', ' 2002', ' at', ' the', ' Kod', 'ak']
+314 89 The name of the CEO of x -1 The name of the CEO of Kodak Jeff Clarke Kodak "[' is' ' not' ' Kod' 'ak' '.' ' It' ' is' ' Kod' 'ak' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' company' ' is' ' Kod' 'ak']" " is not Kod ak . It is Kod ak .
+
+ The name of the company is Kod ak" False from low-grade Kodak Velox onto 5 [' from', ' low', '-', 'grade', ' Kod', 'ak']
+315 89 The name of the CEO of x -1 The name of the CEO of Kodak Jeff Clarke Kodak "[' is' ' not' ' Kod' 'ak' '.' ' It' ' is' ' Kod' 'ak' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' company' ' is' ' Kod' 'ak']" " is not Kod ak . It is Kod ak .
+
+ The name of the company is Kod ak" False 35 mm Eastman Kodak film and anamorphic 5 [' 35', ' mm', ' East', 'man', ' Kod', 'ak']
+316 90 The name of the CEO of x -1 The name of the CEO of Nike John Donahoe Nike "[',' ' Inc' '.' ' is' ' Mark' ' Parker' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' shoe' ' business' ' for' ' over']" , Inc . is Mark Parker . He is a man who has been in the shoe business for over False casual clothes and Nike Dunks. Wearing 3 [' casual', ' clothes', ' and', ' Nike']
+317 90 The name of the CEO of x -1 The name of the CEO of Nike John Donahoe Nike "[',' ' Inc' '.' ' is' ' Mark' ' Parker' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' shoe' ' business' ' for' ' over']" , Inc . is Mark Parker . He is a man who has been in the shoe business for over False to Falk and Jordan, Nike had decided to target 5 [' to', ' Falk', ' and', ' Jordan', ',', ' Nike']
+318 90 The name of the CEO of x -1 The name of the CEO of Nike John Donahoe Nike "[',' ' Inc' '.' ' is' ' Mark' ' Parker' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' shoe' ' business' ' for' ' over']" , Inc . is Mark Parker . He is a man who has been in the shoe business for over False covered by Nike stickers. When 2 [' covered', ' by', ' Nike']
+319 90 The name of the CEO of x -1 The name of the CEO of Nike John Donahoe Nike "[',' ' Inc' '.' ' is' ' Mark' ' Parker' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' shoe' ' business' ' for' ' over']" , Inc . is Mark Parker . He is a man who has been in the shoe business for over False 40-yard dash at a Nike combine at Clemson 6 [' 40', '-', 'yard', ' dash', ' at', ' a', ' Nike']
+320 90 The name of the CEO of x -1 The name of the CEO of Nike John Donahoe Nike "[',' ' Inc' '.' ' is' ' Mark' ' Parker' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' shoe' ' business' ' for' ' over']" , Inc . is Mark Parker . He is a man who has been in the shoe business for over False 5th annual Nike Global Challenge 3 [' 5', 'th', ' annual', ' Nike']
+321 91 The name of the CEO of x -1 The name of the CEO of Vinci Xavier Huillard Vinci "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False of The Da Vinci Code. In March 4 [' of', ' The', ' Da', ' Vin', 'ci']
+322 91 The name of the CEO of x -1 The name of the CEO of Vinci Xavier Huillard Vinci "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False ser Piero da Vinci, more commonly 5 [' ser', ' Pier', 'o', ' da', ' Vin', 'ci']
+323 91 The name of the CEO of x -1 The name of the CEO of Vinci Xavier Huillard Vinci "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False boats, Leonardo da Vinci carried fourteen 50-caliber 5 [' boats', ',', ' Leonardo', ' da', ' Vin', 'ci']
+324 91 The name of the CEO of x -1 The name of the CEO of Vinci Xavier Huillard Vinci "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False Life, and The Da Vinci Code, but 6 [' Life', ',', ' and', ' The', ' Da', ' Vin', 'ci']
+325 91 The name of the CEO of x -1 The name of the CEO of Vinci Xavier Huillard Vinci "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False " Leonardo da Vinci =
+" 3 [' Leonardo', ' da', ' Vin', 'ci']
+326 92 The name of the CEO of x -1 The name of the CEO of RWE Rolf Martin Schmitz RWE "[',' ' the' ' German' ' energy' ' giant' ',' ' is' ' R' 'WE' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' R' 'WE']" ", the German energy giant , is R WE .
+
+ The name of the CEO of R WE" False by energy supplier RWE npower. If the 4 [' by', ' energy', ' supplier', ' R', 'WE']
+327 92 The name of the CEO of x -1 The name of the CEO of RWE Rolf Martin Schmitz RWE "[',' ' the' ' German' ' energy' ' giant' ',' ' is' ' R' 'WE' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' R' 'WE']" ", the German energy giant , is R WE .
+
+ The name of the CEO of R WE" False 2013, when RWE npower took over. 4 [' 2013', ',', ' when', ' R', 'WE']
+328 92 The name of the CEO of x -1 The name of the CEO of RWE Rolf Martin Schmitz RWE "[',' ' the' ' German' ' energy' ' giant' ',' ' is' ' R' 'WE' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' R' 'WE']" ", the German energy giant , is R WE .
+
+ The name of the CEO of R WE" False was taken over by RWE in 2002, the site 5 [' was', ' taken', ' over', ' by', ' R', 'WE']
+329 92 The name of the CEO of x -1 The name of the CEO of RWE Rolf Martin Schmitz RWE "[',' ' the' ' German' ' energy' ' giant' ',' ' is' ' R' 'WE' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' R' 'WE']" ", the German energy giant , is R WE .
+
+ The name of the CEO of R WE" False developed by RWE npower and Thales, 3 [' developed', ' by', ' R', 'WE']
+330 92 The name of the CEO of x -1 The name of the CEO of RWE Rolf Martin Schmitz RWE "[',' ' the' ' German' ' energy' ' giant' ',' ' is' ' R' 'WE' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' CEO' ' of' ' R' 'WE']" ", the German energy giant , is R WE .
+
+ The name of the CEO of R WE" False major utility RWE and electrical 3 [' major', ' utility', ' R', 'WE']
+331 94 The name of the CEO of x -1 The name of the CEO of West Japan Railway Company Kazuaki Hasegawa West Japan Railway Company "[',' ' the' ' operator' ' of' ' the' ' Sh' 'ink' 'ansen' ',' ' is' ' also'
+ ' a' ' hom' 'onym' ' of' ' the' ' Japanese' ' word' ' for' ' ""']" ", the operator of the Sh ink ansen , is also a hom onym of the Japanese word for """ False entirety of the West Japan Railway Company was shut down. 6 [' entirety', ' of', ' the', ' West', ' Japan', ' Railway', ' Company']
+332 94 The name of the CEO of x -1 The name of the CEO of West Japan Railway Company Kazuaki Hasegawa West Japan Railway Company "[',' ' the' ' operator' ' of' ' the' ' Sh' 'ink' 'ansen' ',' ' is' ' also'
+ ' a' ' hom' 'onym' ' of' ' the' ' Japanese' ' word' ' for' ' ""']" ", the operator of the Sh ink ansen , is also a hom onym of the Japanese word for """ False entirety of the West Japan Railway Company was shut down. 6 [' entirety', ' of', ' the', ' West', ' Japan', ' Railway', ' Company']
+333 94 The name of the CEO of x -1 The name of the CEO of West Japan Railway Company Kazuaki Hasegawa West Japan Railway Company "[',' ' the' ' operator' ' of' ' the' ' Sh' 'ink' 'ansen' ',' ' is' ' also'
+ ' a' ' hom' 'onym' ' of' ' the' ' Japanese' ' word' ' for' ' ""']" ", the operator of the Sh ink ansen , is also a hom onym of the Japanese word for """ False entirety of the West Japan Railway Company was shut down. 6 [' entirety', ' of', ' the', ' West', ' Japan', ' Railway', ' Company']
+334 94 The name of the CEO of x -1 The name of the CEO of West Japan Railway Company Kazuaki Hasegawa West Japan Railway Company "[',' ' the' ' operator' ' of' ' the' ' Sh' 'ink' 'ansen' ',' ' is' ' also'
+ ' a' ' hom' 'onym' ' of' ' the' ' Japanese' ' word' ' for' ' ""']" ", the operator of the Sh ink ansen , is also a hom onym of the Japanese word for """ False The entirety of the West Japan Railway Company was shut down. Canceled 7 [' The', ' entirety', ' of', ' the', ' West', ' Japan', ' Railway', ' Company']
+335 94 The name of the CEO of x -1 The name of the CEO of West Japan Railway Company Kazuaki Hasegawa West Japan Railway Company "[',' ' the' ' operator' ' of' ' the' ' Sh' 'ink' 'ansen' ',' ' is' ' also'
+ ' a' ' hom' 'onym' ' of' ' the' ' Japanese' ' word' ' for' ' ""']" ", the operator of the Sh ink ansen , is also a hom onym of the Japanese word for """ False entirety of the West Japan Railway Company was shut down. 6 [' entirety', ' of', ' the', ' West', ' Japan', ' Railway', ' Company']
+336 95 The name of the CEO of x -1 The name of the CEO of Lockheed Martin Marillyn Hewson Lockheed Martin "[',' ' the' ' world' '�' '�' 's' ' largest' ' defense' ' contractor' ','
+ ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 100' ' most' ' powerful']" , the world � � s largest defense contractor , is on the list of the top 100 most powerful False AgustaWestland and Lockheed Martin was selected as the 7 [' Ag', 'ust', 'a', 'West', 'land', ' and', ' Lockheed', ' Martin']
+337 95 The name of the CEO of x -1 The name of the CEO of Lockheed Martin Marillyn Hewson Lockheed Martin "[',' ' the' ' world' '�' '�' 's' ' largest' ' defense' ' contractor' ','
+ ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 100' ' most' ' powerful']" , the world � � s largest defense contractor , is on the list of the top 100 most powerful False compatibility with Lockheed Martin Low-Altitude 3 [' compatibility', ' with', ' Lockheed', ' Martin']
+338 95 The name of the CEO of x -1 The name of the CEO of Lockheed Martin Marillyn Hewson Lockheed Martin "[',' ' the' ' world' '�' '�' 's' ' largest' ' defense' ' contractor' ','
+ ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 100' ' most' ' powerful']" , the world � � s largest defense contractor , is on the list of the top 100 most powerful False " VH-71 Kestrel
+" 9 [' V', 'H', '-', '71', ' K', 'est', 'rel', 'Lock', 'heed', ' Martin']
+339 95 The name of the CEO of x -1 The name of the CEO of Lockheed Martin Marillyn Hewson Lockheed Martin "[',' ' the' ' world' '�' '�' 's' ' largest' ' defense' ' contractor' ','
+ ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 100' ' most' ' powerful']" , the world � � s largest defense contractor , is on the list of the top 100 most powerful False however, in 2009 Lockheed Martin released information 5 [' however', ',', ' in', ' 2009', ' Lockheed', ' Martin']
+340 95 The name of the CEO of x -1 The name of the CEO of Lockheed Martin Marillyn Hewson Lockheed Martin "[',' ' the' ' world' '�' '�' 's' ' largest' ' defense' ' contractor' ','
+ ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 100' ' most' ' powerful']" , the world � � s largest defense contractor , is on the list of the top 100 most powerful False to carry the Lockheed Martin F-35's F135 engine 4 [' to', ' carry', ' the', ' Lockheed', ' Martin']
+341 96 The name of the CEO of x -1 The name of the CEO of British Airways Álex Cruz British Airways "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False Boeing customer, British Airways ordered 59 Airbus A320 4 [' Boeing', ' customer', ',', ' British', ' Airways']
+342 96 The name of the CEO of x -1 The name of the CEO of British Airways Álex Cruz British Airways "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False prevented from boarding a British Airways flight to Bahrain 5 [' prevented', ' from', ' boarding', ' a', ' British', ' Airways']
+343 96 The name of the CEO of x -1 The name of the CEO of British Airways Álex Cruz British Airways "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False Minister flew in a British Airways Concorde (G-BOAC) 5 [' Minister', ' flew', ' in', ' a', ' British', ' Airways']
+344 96 The name of the CEO of x -1 The name of the CEO of British Airways Álex Cruz British Airways "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False " Airways =
+" 3 [' Airways', ' =', 'British', ' Airways']
+345 96 The name of the CEO of x -1 The name of the CEO of British Airways Álex Cruz British Airways "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' ""'s"" ' a'
+ ' mouth' 'ful' ' that' ""'s"" ' worth' ' it' '.' '\n']" " is a bit of a mouth ful , but it 's a mouth ful that 's worth it .
+" False beds. In 2006 British Airways launched Next Generation 5 [' beds', '.', ' In', ' 2006', ' British', ' Airways']
+346 97 The name of the CEO of x -1 The name of the CEO of American Library Association Tracie D. Hall American Library Association "[',' ' the' ' American' ' Library' ' Association' ',' ' is' ' Mary'
+ ' Frances' ' Berry' '.' ' She' ' is' ' a' ' l' 'ibrarian' ' and' ' a'
+ ' professor' ' of']" , the American Library Association , is Mary Frances Berry . She is a l ibrarian and a professor of False Motors (GM), the American Library Association (ALA), and RCA, as 7 [' Motors', ' (', 'GM', '),', ' the', ' American', ' Library', ' Association']
+347 97 The name of the CEO of x -1 The name of the CEO of American Library Association Tracie D. Hall American Library Association "[',' ' the' ' American' ' Library' ' Association' ',' ' is' ' Mary'
+ ' Frances' ' Berry' '.' ' She' ' is' ' a' ' l' 'ibrarian' ' and' ' a'
+ ' professor' ' of']" , the American Library Association , is Mary Frances Berry . She is a l ibrarian and a professor of False " division of the American Library Association, is given for ""the" 5 [' division', ' of', ' the', ' American', ' Library', ' Association']
+348 97 The name of the CEO of x -1 The name of the CEO of American Library Association Tracie D. Hall American Library Association "[',' ' the' ' American' ' Library' ' Association' ',' ' is' ' Mary'
+ ' Frances' ' Berry' '.' ' She' ' is' ' a' ' l' 'ibrarian' ' and' ' a'
+ ' professor' ' of']" , the American Library Association , is Mary Frances Berry . She is a l ibrarian and a professor of False being named an American Library Association Best Book for 5 [' being', ' named', ' an', ' American', ' Library', ' Association']
+349 97 The name of the CEO of x -1 The name of the CEO of American Library Association Tracie D. Hall American Library Association "[',' ' the' ' American' ' Library' ' Association' ',' ' is' ' Mary'
+ ' Frances' ' Berry' '.' ' She' ' is' ' a' ' l' 'ibrarian' ' and' ' a'
+ ' professor' ' of']" , the American Library Association , is Mary Frances Berry . She is a l ibrarian and a professor of False third on the American Library Association (ALA) list of 5 [' third', ' on', ' the', ' American', ' Library', ' Association']
+350 97 The name of the CEO of x -1 The name of the CEO of American Library Association Tracie D. Hall American Library Association "[',' ' the' ' American' ' Library' ' Association' ',' ' is' ' Mary'
+ ' Frances' ' Berry' '.' ' She' ' is' ' a' ' l' 'ibrarian' ' and' ' a'
+ ' professor' ' of']" , the American Library Association , is Mary Frances Berry . She is a l ibrarian and a professor of False United States, the American Library Association listed it among 6 [' United', ' States', ',', ' the', ' American', ' Library', ' Association']
+351 99 The name of the CEO of x -1 The name of the CEO of Southwest Airlines Bob Jordan Southwest Airlines "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' around' ' since' ' 1967' ',' ' but' ' the' ' name']" " is a bit of a mystery .
+
+ The company has been around since 1967 , but the name" False low-cost airline Southwest Airlines entered the Pacific 5 [' low', '-', 'cost', ' airline', ' Southwest', ' Airlines']
+352 99 The name of the CEO of x -1 The name of the CEO of Southwest Airlines Bob Jordan Southwest Airlines "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' around' ' since' ' 1967' ',' ' but' ' the' ' name']" " is a bit of a mystery .
+
+ The company has been around since 1967 , but the name" False guaranteed. Southwest Airlines quickly filled 3 [' guaranteed', '.', ' Southwest', ' Airlines']
+353 99 The name of the CEO of x -1 The name of the CEO of Southwest Airlines Bob Jordan Southwest Airlines "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' around' ' since' ' 1967' ',' ' but' ' the' ' name']" " is a bit of a mystery .
+
+ The company has been around since 1967 , but the name" False focus city for Southwest Airlines and was previously 4 [' focus', ' city', ' for', ' Southwest', ' Airlines']
+354 99 The name of the CEO of x -1 The name of the CEO of Southwest Airlines Bob Jordan Southwest Airlines "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' around' ' since' ' 1967' ',' ' but' ' the' ' name']" " is a bit of a mystery .
+
+ The company has been around since 1967 , but the name" False However, in 2015 Southwest Airlines launched service 5 [' However', ',', ' in', ' 2015', ' Southwest', ' Airlines']
+355 99 The name of the CEO of x -1 The name of the CEO of Southwest Airlines Bob Jordan Southwest Airlines "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' around' ' since' ' 1967' ',' ' but' ' the' ' name']" " is a bit of a mystery .
+
+ The company has been around since 1967 , but the name" False flights. However, in 2015 Southwest Airlines launched service 7 [' flights', '.', ' However', ',', ' in', ' 2015', ' Southwest', ' Airlines']
+356 100 The name of the CEO of x -1 The name of the CEO of HSBC Stuart Gulliver HSBC "[',' ' the' ' world' '�' '�' 's' ' largest' ' bank' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mouth' 'ful' '.' ' But' ' it' '�']" , the world � � s largest bank , is a bit of a mouth ful . But it � False According to a HSBC report, Malaysia 3 [' According', ' to', ' a', ' HSBC']
+357 100 The name of the CEO of x -1 The name of the CEO of HSBC Stuart Gulliver HSBC "[',' ' the' ' world' '�' '�' 's' ' largest' ' bank' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mouth' 'ful' '.' ' But' ' it' '�']" , the world � � s largest bank , is a bit of a mouth ful . But it � False (a subsidiary of HSBC El Salvador) and Bancomercio. 4 [' (', 'a', ' subsidiary', ' of', ' HSBC']
+358 100 The name of the CEO of x -1 The name of the CEO of HSBC Stuart Gulliver HSBC "[',' ' the' ' world' '�' '�' 's' ' largest' ' bank' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mouth' 'ful' '.' ' But' ' it' '�']" , the world � � s largest bank , is a bit of a mouth ful . But it � False 1 ['HS', 'BC']
+359 100 The name of the CEO of x -1 The name of the CEO of HSBC Stuart Gulliver HSBC "[',' ' the' ' world' '�' '�' 's' ' largest' ' bank' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mouth' 'ful' '.' ' But' ' it' '�']" , the world � � s largest bank , is a bit of a mouth ful . But it � False 1 ['HS', 'BC']
+360 100 The name of the CEO of x -1 The name of the CEO of HSBC Stuart Gulliver HSBC "[',' ' the' ' world' '�' '�' 's' ' largest' ' bank' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mouth' 'ful' '.' ' But' ' it' '�']" , the world � � s largest bank , is a bit of a mouth ful . But it � False 1 ['HS', 'BC']
+361 101 The name of the CEO of x -1 The name of the CEO of Warner Bros. Barry Meyer Warner Bros. "[' is' ' Kevin' ' Tsu' 'j' 'ihara' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' entertainment' ' industry' ' for' ' over'
+ ' 30']" is Kevin Tsu j ihara . He is a man who has been in the entertainment industry for over 30 False 31, 2009 by Warner Bros. Records. Madonna collaborated 6 [' 31', ',', ' 2009', ' by', ' Warner', ' Bros', '.']
+362 101 The name of the CEO of x -1 The name of the CEO of Warner Bros. Barry Meyer Warner Bros. "[' is' ' Kevin' ' Tsu' 'j' 'ihara' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' entertainment' ' industry' ' for' ' over'
+ ' 30']" is Kevin Tsu j ihara . He is a man who has been in the entertainment industry for over 30 False Casino Royale, Warner Bros. expressed an 5 [' Casino', ' Royale', ',', ' Warner', ' Bros', '.']
+363 101 The name of the CEO of x -1 The name of the CEO of Warner Bros. Barry Meyer Warner Bros. "[' is' ' Kevin' ' Tsu' 'j' 'ihara' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' entertainment' ' industry' ' for' ' over'
+ ' 30']" is Kevin Tsu j ihara . He is a man who has been in the entertainment industry for over 30 False August 21, 2015, Warner Bros. revealed the first 7 [' August', ' 21', ',', ' 2015', ',', ' Warner', ' Bros', '.']
+364 101 The name of the CEO of x -1 The name of the CEO of Warner Bros. Barry Meyer Warner Bros. "[' is' ' Kevin' ' Tsu' 'j' 'ihara' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' entertainment' ' industry' ' for' ' over'
+ ' 30']" is Kevin Tsu j ihara . He is a man who has been in the entertainment industry for over 30 False successes led to Warner Bros. Television and 5 [' successes', ' led', ' to', ' Warner', ' Bros', '.']
+365 101 The name of the CEO of x -1 The name of the CEO of Warner Bros. Barry Meyer Warner Bros. "[' is' ' Kevin' ' Tsu' 'j' 'ihara' '.' ' He' ' is' ' a' ' man' ' who'
+ ' has' ' been' ' in' ' the' ' entertainment' ' industry' ' for' ' over'
+ ' 30']" is Kevin Tsu j ihara . He is a man who has been in the entertainment industry for over 30 False recording with Warner Bros. Records. The 4 [' recording', ' with', ' Warner', ' Bros', '.']
+366 102 The name of the CEO of x -1 The name of the CEO of Verizon Hans Vestberg Verizon "[' Wireless' ',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to'
+ ' the' ' name' ',' ' is' ' Verizon' ' Wireless' '.' '\n' '\n' 'Ver'
+ 'izon']" " Wireless , the company that owns the rights to the name , is Verizon Wireless .
+
+ Ver izon" False D.C., at the Verizon Center. The event 6 [' D', '.', 'C', '.,', ' at', ' the', ' Verizon']
+367 102 The name of the CEO of x -1 The name of the CEO of Verizon Hans Vestberg Verizon "[' Wireless' ',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to'
+ ' the' ' name' ',' ' is' ' Verizon' ' Wireless' '.' '\n' '\n' 'Ver'
+ 'izon']" " Wireless , the company that owns the rights to the name , is Verizon Wireless .
+
+ Ver izon" False California at the Verizon Wireless Amphitheatre. 3 [' California', ' at', ' the', ' Verizon']
+368 102 The name of the CEO of x -1 The name of the CEO of Verizon Hans Vestberg Verizon "[' Wireless' ',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to'
+ ' the' ' name' ',' ' is' ' Verizon' ' Wireless' '.' '\n' '\n' 'Ver'
+ 'izon']" " Wireless , the company that owns the rights to the name , is Verizon Wireless .
+
+ Ver izon" False " Future"" event at the Verizon Center. Demi Lovato," 5 "[' Future', '""', ' event', ' at', ' the', ' Verizon']"
+369 102 The name of the CEO of x -1 The name of the CEO of Verizon Hans Vestberg Verizon "[' Wireless' ',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to'
+ ' the' ' name' ',' ' is' ' Verizon' ' Wireless' '.' '\n' '\n' 'Ver'
+ 'izon']" " Wireless , the company that owns the rights to the name , is Verizon Wireless .
+
+ Ver izon" False album since 2000. Verizon Wireless has gained 4 [' album', ' since', ' 2000', '.', ' Verizon']
+370 102 The name of the CEO of x -1 The name of the CEO of Verizon Hans Vestberg Verizon "[' Wireless' ',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to'
+ ' the' ' name' ',' ' is' ' Verizon' ' Wireless' '.' '\n' '\n' 'Ver'
+ 'izon']" " Wireless , the company that owns the rights to the name , is Verizon Wireless .
+
+ Ver izon" False Washington D.C. at the Verizon Center to celebrate 7 [' Washington', ' D', '.', 'C', '.', ' at', ' the', ' Verizon']
+371 104 The name of the CEO of x -1 The name of the CEO of Berkshire Hathaway Warren Buffett Berkshire Hathaway "[' is' ' Warren' ' Buffett' '.' ' He' ' is' ' the' ' world' ""'s"" ' most'
+ ' successful' ' investor' '.' ' He' ' is' ' also' ' the' ' world' ""'s""
+ ' most']" is Warren Buffett . He is the world 's most successful investor . He is also the world 's most True the creation on his Berkshire Hathaway conglomerate 6 [' the', ' creation', ' on', ' his', ' Berkshire', ' Hath', 'away']
+372 104 The name of the CEO of x -1 The name of the CEO of Berkshire Hathaway Warren Buffett Berkshire Hathaway "[' is' ' Warren' ' Buffett' '.' ' He' ' is' ' the' ' world' ""'s"" ' most'
+ ' successful' ' investor' '.' ' He' ' is' ' also' ' the' ' world' ""'s""
+ ' most']" is Warren Buffett . He is the world 's most successful investor . He is also the world 's most True shareholders. Berkshire Hathaway partially funded 4 [' shareholders', '.', ' Berkshire', ' Hath', 'away']
+373 104 The name of the CEO of x -1 The name of the CEO of Berkshire Hathaway Warren Buffett Berkshire Hathaway "[' is' ' Warren' ' Buffett' '.' ' He' ' is' ' the' ' world' ""'s"" ' most'
+ ' successful' ' investor' '.' ' He' ' is' ' also' ' the' ' world' ""'s""
+ ' most']" is Warren Buffett . He is the world 's most successful investor . He is also the world 's most True the creation on his Berkshire Hathaway conglomerate 6 [' the', ' creation', ' on', ' his', ' Berkshire', ' Hath', 'away']
+374 104 The name of the CEO of x -1 The name of the CEO of Berkshire Hathaway Warren Buffett Berkshire Hathaway "[' is' ' Warren' ' Buffett' '.' ' He' ' is' ' the' ' world' ""'s"" ' most'
+ ' successful' ' investor' '.' ' He' ' is' ' also' ' the' ' world' ""'s""
+ ' most']" is Warren Buffett . He is the world 's most successful investor . He is also the world 's most True shareholders. Berkshire Hathaway partially funded 4 [' shareholders', '.', ' Berkshire', ' Hath', 'away']
+375 104 The name of the CEO of x -1 The name of the CEO of Berkshire Hathaway Warren Buffett Berkshire Hathaway "[' is' ' Warren' ' Buffett' '.' ' He' ' is' ' the' ' world' ""'s"" ' most'
+ ' successful' ' investor' '.' ' He' ' is' ' also' ' the' ' world' ""'s""
+ ' most']" is Warren Buffett . He is the world 's most successful investor . He is also the world 's most True Hortons in 2014, Berkshire Hathaway bought into 7 [' Hort', 'ons', ' in', ' 2014', ',', ' Berkshire', ' Hath', 'away']
+376 105 The name of the CEO of x -1 The name of the CEO of JPMorgan Chase Jamie Dimon JPMorgan Chase "[',' ' Jamie' ' Dim' 'on' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.'
+ ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but']" , Jamie Dim on , is a bit of a mystery . He is a very smart guy , but True 1,002-foot (305 m) -tall JPMorgan Chase Tower (formerly 12 [' 1', ',', '002', '-', 'foot', ' (', '305', ' m', ')', ' -', 'tall', ' JPMorgan', ' Chase']
+377 105 The name of the CEO of x -1 The name of the CEO of JPMorgan Chase Jamie Dimon JPMorgan Chase "[',' ' Jamie' ' Dim' 'on' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.'
+ ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but']" , Jamie Dim on , is a bit of a mystery . He is a very smart guy , but True (Rick Wagoner), JPMorgan Chase (Steven Black), 7 [' (', 'Rick', ' W', 'agon', 'er', '),', ' JPMorgan', ' Chase']
+378 105 The name of the CEO of x -1 The name of the CEO of JPMorgan Chase Jamie Dimon JPMorgan Chase "[',' ' Jamie' ' Dim' 'on' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.'
+ ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but']" , Jamie Dim on , is a bit of a mystery . He is a very smart guy , but True 2000 to form JPMorgan Chase & Co. Throughout 4 [' 2000', ' to', ' form', ' JPMorgan', ' Chase']
+379 105 The name of the CEO of x -1 The name of the CEO of JPMorgan Chase Jamie Dimon JPMorgan Chase "[',' ' Jamie' ' Dim' 'on' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.'
+ ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but']" , Jamie Dim on , is a bit of a mystery . He is a very smart guy , but True Bank One merged with JPMorgan Chase in 2004, and the name 5 [' Bank', ' One', ' merged', ' with', ' JPMorgan', ' Chase']
+380 105 The name of the CEO of x -1 The name of the CEO of JPMorgan Chase Jamie Dimon JPMorgan Chase "[',' ' Jamie' ' Dim' 'on' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.'
+ ' He' ' is' ' a' ' very' ' smart' ' guy' ',' ' but']" , Jamie Dim on , is a bit of a mystery . He is a very smart guy , but True (Rick Wagoner), JPMorgan Chase (Steven Black), 7 [' (', 'Rick', ' W', 'agon', 'er', '),', ' JPMorgan', ' Chase']
+381 108 The name of the CEO of x -1 The name of the CEO of Citigroup Jane Fraser Citigroup "[',' ' the' ' bank' ' that' ' is' ' the' ' parent' ' company' ' of' ' Cit'
+ 'ib' 'ank' ',' ' is' ' Vik' 'ram' ' Pand' 'it' '.' '\n']" ", the bank that is the parent company of Cit ib ank , is Vik ram Pand it .
+" False 10 Hudson Yards, and Citigroup is considering moving 6 [' 10', ' Hudson', ' Yards', ',', ' and', ' Cit', 'igroup']
+382 108 The name of the CEO of x -1 The name of the CEO of Citigroup Jane Fraser Citigroup "[',' ' the' ' bank' ' that' ' is' ' the' ' parent' ' company' ' of' ' Cit'
+ 'ib' 'ank' ',' ' is' ' Vik' 'ram' ' Pand' 'it' '.' '\n']" ", the bank that is the parent company of Cit ib ank , is Vik ram Pand it .
+" False acquisition of the Citigroup Loan Portfolio (2008) 4 [' acquisition', ' of', ' the', ' Cit', 'igroup']
+383 108 The name of the CEO of x -1 The name of the CEO of Citigroup Jane Fraser Citigroup "[',' ' the' ' bank' ' that' ' is' ' the' ' parent' ' company' ' of' ' Cit'
+ 'ib' 'ank' ',' ' is' ' Vik' 'ram' ' Pand' 'it' '.' '\n']" ", the bank that is the parent company of Cit ib ank , is Vik ram Pand it .
+" False Brothers, a subsidiary of Citigroup housed in the building, 6 [' Brothers', ',', ' a', ' subsidiary', ' of', ' Cit', 'igroup']
+384 108 The name of the CEO of x -1 The name of the CEO of Citigroup Jane Fraser Citigroup "[',' ' the' ' bank' ' that' ' is' ' the' ' parent' ' company' ' of' ' Cit'
+ 'ib' 'ank' ',' ' is' ' Vik' 'ram' ' Pand' 'it' '.' '\n']" ", the bank that is the parent company of Cit ib ank , is Vik ram Pand it .
+" False statements issued by UBS, Citigroup and Goldman 7 [' statements', ' issued', ' by', ' U', 'BS', ',', ' Cit', 'igroup']
+385 108 The name of the CEO of x -1 The name of the CEO of Citigroup Jane Fraser Citigroup "[',' ' the' ' bank' ' that' ' is' ' the' ' parent' ' company' ' of' ' Cit'
+ 'ib' 'ank' ',' ' is' ' Vik' 'ram' ' Pand' 'it' '.' '\n']" ", the bank that is the parent company of Cit ib ank , is Vik ram Pand it .
+" False index announced by Citigroup in February 2011, 4 [' index', ' announced', ' by', ' Cit', 'igroup']
+386 109 The name of the CEO of x -1 The name of the CEO of Society of Jesus Arturo Sosa Abascal Society of Jesus "[',' ' Fr' '.' ' John' ' A' '.' ' Hard' 'on' ',' ' S' '.' 'J' '.' '\n'
+ '\n' 'The' ' Society' ' of' ' Jesus' ' is']" ", Fr . John A . Hard on , S . J .
+
+ The Society of Jesus is" False members of the Society of Jesus and the Récollets 5 [' members', ' of', ' the', ' Society', ' of', ' Jesus']
+387 109 The name of the CEO of x -1 The name of the CEO of Society of Jesus Arturo Sosa Abascal Society of Jesus "[',' ' Fr' '.' ' John' ' A' '.' ' Hard' 'on' ',' ' S' '.' 'J' '.' '\n'
+ '\n' 'The' ' Society' ' of' ' Jesus' ' is']" ", Fr . John A . Hard on , S . J .
+
+ The Society of Jesus is" False novitiate with the Society of Jesus in September 8 [' no', 'v', 'iti', 'ate', ' with', ' the', ' Society', ' of', ' Jesus']
+388 109 The name of the CEO of x -1 The name of the CEO of Society of Jesus Arturo Sosa Abascal Society of Jesus "[',' ' Fr' '.' ' John' ' A' '.' ' Hard' 'on' ',' ' S' '.' 'J' '.' '\n'
+ '\n' 'The' ' Society' ' of' ' Jesus' ' is']" ", Fr . John A . Hard on , S . J .
+
+ The Society of Jesus is" False novitiate with the Society of Jesus in September 1920 8 [' no', 'v', 'iti', 'ate', ' with', ' the', ' Society', ' of', ' Jesus']
+389 109 The name of the CEO of x -1 The name of the CEO of Society of Jesus Arturo Sosa Abascal Society of Jesus "[',' ' Fr' '.' ' John' ' A' '.' ' Hard' 'on' ',' ' S' '.' 'J' '.' '\n'
+ '\n' 'The' ' Society' ' of' ' Jesus' ' is']" ", Fr . John A . Hard on , S . J .
+
+ The Society of Jesus is" False artefacts from the Society of Jesus and English Catholicism. 6 [' arte', 'facts', ' from', ' the', ' Society', ' of', ' Jesus']
+390 109 The name of the CEO of x -1 The name of the CEO of Society of Jesus Arturo Sosa Abascal Society of Jesus "[',' ' Fr' '.' ' John' ' A' '.' ' Hard' 'on' ',' ' S' '.' 'J' '.' '\n'
+ '\n' 'The' ' Society' ' of' ' Jesus' ' is']" ", Fr . John A . Hard on , S . J .
+
+ The Society of Jesus is" False promise to destroy the Society of Jesus would be in violation 6 [' promise', ' to', ' destroy', ' the', ' Society', ' of', ' Jesus']
+391 112 The name of the CEO of x -1 The name of the CEO of Northrop Grumman Kathy J. Warden Northrop Grumman "[' is' ' Robert' ' J' '.' ' Stevens' '.' ' He' ' is' ' a' ' graduate'
+ ' of' ' the' ' University' ' of' ' Michigan' ' and' ' the' ' University'
+ ' of' ' Michigan']" is Robert J . Stevens . He is a graduate of the University of Michigan and the University of Michigan False Martin and Northrop Grumman were competing to connect 6 [' Martin', ' and', ' North', 'rop', ' Gr', 'um', 'man']
+392 112 The name of the CEO of x -1 The name of the CEO of Northrop Grumman Kathy J. Warden Northrop Grumman "[' is' ' Robert' ' J' '.' ' Stevens' '.' ' He' ' is' ' a' ' graduate'
+ ' of' ' the' ' University' ' of' ' Michigan' ' and' ' the' ' University'
+ ' of' ' Michigan']" is Robert J . Stevens . He is a graduate of the University of Michigan and the University of Michigan False was reclaimed by Northrop Grumman and used as 7 [' was', ' reclaimed', ' by', ' North', 'rop', ' Gr', 'um', 'man']
+393 112 The name of the CEO of x -1 The name of the CEO of Northrop Grumman Kathy J. Warden Northrop Grumman "[' is' ' Robert' ' J' '.' ' Stevens' '.' ' He' ' is' ' a' ' graduate'
+ ' of' ' the' ' University' ' of' ' Michigan' ' and' ' the' ' University'
+ ' of' ' Michigan']" is Robert J . Stevens . He is a graduate of the University of Michigan and the University of Michigan False / A-18 Hornets, Northrop Grumman EA-6 Prowlers 10 [' /', ' A', '-', '18', ' Hornets', ',', ' North', 'rop', ' Gr', 'um', 'man']
+394 112 The name of the CEO of x -1 The name of the CEO of Northrop Grumman Kathy J. Warden Northrop Grumman "[' is' ' Robert' ' J' '.' ' Stevens' '.' ' He' ' is' ' a' ' graduate'
+ ' of' ' the' ' University' ' of' ' Michigan' ' and' ' the' ' University'
+ ' of' ' Michigan']" is Robert J . Stevens . He is a graduate of the University of Michigan and the University of Michigan False 4 ['North', 'rop', ' Gr', 'um', 'man']
+395 112 The name of the CEO of x -1 The name of the CEO of Northrop Grumman Kathy J. Warden Northrop Grumman "[' is' ' Robert' ' J' '.' ' Stevens' '.' ' He' ' is' ' a' ' graduate'
+ ' of' ' the' ' University' ' of' ' Michigan' ' and' ' the' ' University'
+ ' of' ' Michigan']" is Robert J . Stevens . He is a graduate of the University of Michigan and the University of Michigan False the project for the Northrop Grumman B-2 Spirit stealth 8 [' the', ' project', ' for', ' the', ' North', 'rop', ' Gr', 'um', 'man']
+396 114 The name of the CEO of x -1 The name of the CEO of Toulouse-Blagnac Philippe Crébassa Toulouse-Blagnac "[' Airport' ' is' ' Mr' '.' ' Jean' '-' 'Pierre' ' Jou' 'yet' '.' ' He'
+ ' is' ' a' ' French' ' politician' ' and' ' a' ' member' ' of' ' the']" Airport is Mr . Jean - Pierre Jou yet . He is a French politician and a member of the False adjacent to Toulouse-Blagnac Airport in Colomiers; 8 [' adjacent', ' to', ' T', 'oul', 'ouse', '-', 'Bl', 'agn', 'ac']
+397 115 The name of the CEO of x -1 The name of the CEO of Universal Music Group Lucian Grainge Universal Music Group "[',' ' the' ' parent' ' company' ' of' ' Universal' ' Music' ' Group' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' owns'
+ ' the' ' rights']" , the parent company of Universal Music Group , is also the name of the company that owns the rights False " Music Group ===
+" 5 [' Music', ' Group', ' ===', 'Universal', ' Music', ' Group']
+398 115 The name of the CEO of x -1 The name of the CEO of Universal Music Group Lucian Grainge Universal Music Group "[',' ' the' ' parent' ' company' ' of' ' Universal' ' Music' ' Group' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' owns'
+ ' the' ' rights']" , the parent company of Universal Music Group , is also the name of the company that owns the rights False " litigation with Universal Music Group over their record contract.
+" 4 [' litigation', ' with', ' Universal', ' Music', ' Group']
+399 115 The name of the CEO of x -1 The name of the CEO of Universal Music Group Lucian Grainge Universal Music Group "[',' ' the' ' parent' ' company' ' of' ' Universal' ' Music' ' Group' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' owns'
+ ' the' ' rights']" , the parent company of Universal Music Group , is also the name of the company that owns the rights False Reznor condemned Universal Music Group — the parent company 5 [' Rez', 'nor', ' condemned', ' Universal', ' Music', ' Group']
+400 115 The name of the CEO of x -1 The name of the CEO of Universal Music Group Lucian Grainge Universal Music Group "[',' ' the' ' parent' ' company' ' of' ' Universal' ' Music' ' Group' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' owns'
+ ' the' ' rights']" , the parent company of Universal Music Group , is also the name of the company that owns the rights False reported that Universal Music Group and YouTube have 4 [' reported', ' that', ' Universal', ' Music', ' Group']
+401 115 The name of the CEO of x -1 The name of the CEO of Universal Music Group Lucian Grainge Universal Music Group "[',' ' the' ' parent' ' company' ' of' ' Universal' ' Music' ' Group' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' owns'
+ ' the' ' rights']" , the parent company of Universal Music Group , is also the name of the company that owns the rights False Interscope Records and Universal Music Group in order to expand 8 [' In', 'ters', 'c', 'ope', ' Records', ' and', ' Universal', ' Music', ' Group']
+402 116 The name of the CEO of x -1 The name of the CEO of Tencent Ma Huateng Tencent "[' is' ' also' ' a' ' Chinese' ' name' '.' '\n' '\n' '~~' '~' '\n' '\n'
+ '~~' '~' '\n' '\n' '~~' '~' '\n' '\n']" " is also a Chinese name .
+
+ ~~ ~
+
+ ~~ ~
+
+ ~~ ~
+
+" False partner with Qihoo and Tencent to help promote 7 [' partner', ' with', ' Q', 'ih', 'oo', ' and', ' Ten', 'cent']
+403 116 The name of the CEO of x -1 The name of the CEO of Tencent Ma Huateng Tencent "[' is' ' also' ' a' ' Chinese' ' name' '.' '\n' '\n' '~~' '~' '\n' '\n'
+ '~~' '~' '\n' '\n' '~~' '~' '\n' '\n']" " is also a Chinese name .
+
+ ~~ ~
+
+ ~~ ~
+
+ ~~ ~
+
+" False published by Tencent Games on November 3 [' published', ' by', ' Ten', 'cent']
+404 116 The name of the CEO of x -1 The name of the CEO of Tencent Ma Huateng Tencent "[' is' ' also' ' a' ' Chinese' ' name' '.' '\n' '\n' '~~' '~' '\n' '\n'
+ '~~' '~' '\n' '\n' '~~' '~' '\n' '\n']" " is also a Chinese name .
+
+ ~~ ~
+
+ ~~ ~
+
+ ~~ ~
+
+" False release supporting Tencent QQ. Communication with 3 [' release', ' supporting', ' Ten', 'cent']
+405 116 The name of the CEO of x -1 The name of the CEO of Tencent Ma Huateng Tencent "[' is' ' also' ' a' ' Chinese' ' name' '.' '\n' '\n' '~~' '~' '\n' '\n'
+ '~~' '~' '\n' '\n' '~~' '~' '\n' '\n']" " is also a Chinese name .
+
+ ~~ ~
+
+ ~~ ~
+
+ ~~ ~
+
+" False release supporting Tencent QQ. Communication 3 [' release', ' supporting', ' Ten', 'cent']
+406 116 The name of the CEO of x -1 The name of the CEO of Tencent Ma Huateng Tencent "[' is' ' also' ' a' ' Chinese' ' name' '.' '\n' '\n' '~~' '~' '\n' '\n'
+ '~~' '~' '\n' '\n' '~~' '~' '\n' '\n']" " is also a Chinese name .
+
+ ~~ ~
+
+ ~~ ~
+
+ ~~ ~
+
+" False published by Tencent Games on November 3 [' published', ' by', ' Ten', 'cent']
+407 117 The name of the CEO of x -1 The name of the CEO of Salesforce Marc Benioff Salesforce "['.' 'com' ' is' ' Marc' ' Ben' 'io' 'ff' '.' '\n' '\n' 'The' ' company'
+ ' is' ' a' ' cloud' ' computing' ' company' ' that' ' provides' ' a']" ". com is Marc Ben io ff .
+
+ The company is a cloud computing company that provides a" True Social Studio beyond Salesforce ’ s Marketing 4 [' Social', ' Studio', ' beyond', ' Sales', 'force']
+408 117 The name of the CEO of x -1 The name of the CEO of Salesforce Marc Benioff Salesforce "['.' 'com' ' is' ' Marc' ' Ben' 'io' 'ff' '.' '\n' '\n' 'The' ' company'
+ ' is' ' a' ' cloud' ' computing' ' company' ' that' ' provides' ' a']" ". com is Marc Ben io ff .
+
+ The company is a cloud computing company that provides a" True new division called Salesforce ExactTarget Marketing 4 [' new', ' division', ' called', ' Sales', 'force']
+409 117 The name of the CEO of x -1 The name of the CEO of Salesforce Marc Benioff Salesforce "['.' 'com' ' is' ' Marc' ' Ben' 'io' 'ff' '.' '\n' '\n' 'The' ' company'
+ ' is' ' a' ' cloud' ' computing' ' company' ' that' ' provides' ' a']" ". com is Marc Ben io ff .
+
+ The company is a cloud computing company that provides a" True 1 ['Sales', 'force']
+410 117 The name of the CEO of x -1 The name of the CEO of Salesforce Marc Benioff Salesforce "['.' 'com' ' is' ' Marc' ' Ben' 'io' 'ff' '.' '\n' '\n' 'The' ' company'
+ ' is' ' a' ' cloud' ' computing' ' company' ' that' ' provides' ' a']" ". com is Marc Ben io ff .
+
+ The company is a cloud computing company that provides a" True " Marketing Cloud =
+" 4 [' Marketing', ' Cloud', ' =', 'Sales', 'force']
+411 117 The name of the CEO of x -1 The name of the CEO of Salesforce Marc Benioff Salesforce "['.' 'com' ' is' ' Marc' ' Ben' 'io' 'ff' '.' '\n' '\n' 'The' ' company'
+ ' is' ' a' ' cloud' ' computing' ' company' ' that' ' provides' ' a']" ". com is Marc Ben io ff .
+
+ The company is a cloud computing company that provides a" True division called Salesforce ExactTarget Marketing 3 [' division', ' called', ' Sales', 'force']
+412 118 The name of the CEO of x -1 The name of the CEO of Caterpillar Inc. Jim Umpleby Caterpillar Inc. "[' is' ' Jim' ' Um' 'ple' 'by' '.' ' He' ' is' ' a' ' very' ' nice' ' man'
+ '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Jim Um ple by . He is a very nice man . He is a very nice man . True As of 2010, Caterpillar Inc. was the 229nd largest 7 [' As', ' of', ' 2010', ',', ' Cater', 'pillar', ' Inc', '.']
+413 118 The name of the CEO of x -1 The name of the CEO of Caterpillar Inc. Jim Umpleby Caterpillar Inc. "[' is' ' Jim' ' Um' 'ple' 'by' '.' ' He' ' is' ' a' ' very' ' nice' ' man'
+ '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Jim Um ple by . He is a very nice man . He is a very nice man . True As of 2010, Caterpillar Inc. was the 229th largest 7 [' As', ' of', ' 2010', ',', ' Cater', 'pillar', ' Inc', '.']
+414 118 The name of the CEO of x -1 The name of the CEO of Caterpillar Inc. Jim Umpleby Caterpillar Inc. "[' is' ' Jim' ' Um' 'ple' 'by' '.' ' He' ' is' ' a' ' very' ' nice' ' man'
+ '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Jim Um ple by . He is a very nice man . He is a very nice man . True distribution center by Caterpillar Inc. and a major Payless 6 [' distribution', ' center', ' by', ' Cater', 'pillar', ' Inc', '.']
+415 118 The name of the CEO of x -1 The name of the CEO of Caterpillar Inc. Jim Umpleby Caterpillar Inc. "[' is' ' Jim' ' Um' 'ple' 'by' '.' ' He' ' is' ' a' ' very' ' nice' ' man'
+ '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Jim Um ple by . He is a very nice man . He is a very nice man . True distribution center by Caterpillar Inc. and a major 6 [' distribution', ' center', ' by', ' Cater', 'pillar', ' Inc', '.']
+416 118 The name of the CEO of x -1 The name of the CEO of Caterpillar Inc. Jim Umpleby Caterpillar Inc. "[' is' ' Jim' ' Um' 'ple' 'by' '.' ' He' ' is' ' a' ' very' ' nice' ' man'
+ '.' ' He' ' is' ' a' ' very' ' nice' ' man' '.']" is Jim Um ple by . He is a very nice man . He is a very nice man . True 4 ['C', 'ater', 'pillar', ' Inc', '.']
+417 119 The name of the CEO of x -1 The name of the CEO of United Parcel Service David Abney United Parcel Service "[',' ' the' ' world' '�' '�' 's' ' largest' ' package' ' delivery'
+ ' company' ',' ' is' ' D' 'HL' '.' '\n' '\n' 'The' ' name' ' of']" ", the world � � s largest package delivery company , is D HL .
+
+ The name of" False Worldport, the hub of United Parcel Service and one of Louisville's 9 [' World', 'port', ',', ' the', ' hub', ' of', ' United', ' Par', 'cel', ' Service']
+418 119 The name of the CEO of x -1 The name of the CEO of United Parcel Service David Abney United Parcel Service "[',' ' the' ' world' '�' '�' 's' ' largest' ' package' ' delivery'
+ ' company' ',' ' is' ' D' 'HL' '.' '\n' '\n' 'The' ' name' ' of']" ", the world � � s largest package delivery company , is D HL .
+
+ The name of" False police officers, or United Parcel Service (UPS) delivery men. 7 [' police', ' officers', ',', ' or', ' United', ' Par', 'cel', ' Service']
+419 119 The name of the CEO of x -1 The name of the CEO of United Parcel Service David Abney United Parcel Service "[',' ' the' ' world' '�' '�' 's' ' largest' ' package' ' delivery'
+ ' company' ',' ' is' ' D' 'HL' '.' '\n' '\n' 'The' ' name' ' of']" ", the world � � s largest package delivery company , is D HL .
+
+ The name of" False airport include United Parcel Service and Cargojet Canada. 5 [' airport', ' include', ' United', ' Par', 'cel', ' Service']
+420 119 The name of the CEO of x -1 The name of the CEO of United Parcel Service David Abney United Parcel Service "[',' ' the' ' world' '�' '�' 's' ' largest' ' package' ' delivery'
+ ' company' ',' ' is' ' D' 'HL' '.' '\n' '\n' 'The' ' name' ' of']" ", the world � � s largest package delivery company , is D HL .
+
+ The name of" False airport include United Parcel Service and Cargojet Canada. 5 [' airport', ' include', ' United', ' Par', 'cel', ' Service']
+421 119 The name of the CEO of x -1 The name of the CEO of United Parcel Service David Abney United Parcel Service "[',' ' the' ' world' '�' '�' 's' ' largest' ' package' ' delivery'
+ ' company' ',' ' is' ' D' 'HL' '.' '\n' '\n' 'The' ' name' ' of']" ", the world � � s largest package delivery company , is D HL .
+
+ The name of" False police officers, or United Parcel Service (UPS) delivery 7 [' police', ' officers', ',', ' or', ' United', ' Par', 'cel', ' Service']
+422 120 The name of the CEO of x -1 The name of the CEO of PayPal Dan Schulman PayPal "[',' ' the' ' world' '�' '�' 's' ' largest' ' online' ' payment'
+ ' processor' ',' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" , the world � � s largest online payment processor , is a bit of a mouth ful . But False brought down PayPal.com for an hour on 2 [' brought', ' down', ' PayPal']
+423 120 The name of the CEO of x -1 The name of the CEO of PayPal Dan Schulman PayPal "[',' ' the' ' world' '�' '�' 's' ' largest' ' online' ' payment'
+ ' processor' ',' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" , the world � � s largest online payment processor , is a bit of a mouth ful . But False announcement. He put up a PayPal button after 6 [' announcement', '.', ' He', ' put', ' up', ' a', ' PayPal']
+424 120 The name of the CEO of x -1 The name of the CEO of PayPal Dan Schulman PayPal "[',' ' the' ' world' '�' '�' 's' ' largest' ' online' ' payment'
+ ' processor' ',' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" , the world � � s largest online payment processor , is a bit of a mouth ful . But False 1 ['Pay', 'Pal']
+425 120 The name of the CEO of x -1 The name of the CEO of PayPal Dan Schulman PayPal "[',' ' the' ' world' '�' '�' 's' ' largest' ' online' ' payment'
+ ' processor' ',' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" , the world � � s largest online payment processor , is a bit of a mouth ful . But False He put up a PayPal button after 4 [' He', ' put', ' up', ' a', ' PayPal']
+426 120 The name of the CEO of x -1 The name of the CEO of PayPal Dan Schulman PayPal "[',' ' the' ' world' '�' '�' 's' ' largest' ' online' ' payment'
+ ' processor' ',' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" , the world � � s largest online payment processor , is a bit of a mouth ful . But False servers, and PayPal, MasterCard, and 3 [' servers', ',', ' and', ' PayPal']
+427 121 The name of the CEO of x -1 The name of the CEO of Toyota Akio Toyoda Toyota "[' Motor' ' Corp' '.' ' is' ' not' ' Toyota' '.' ' It' ' is' ' Toyota'
+ ' Motor' ' Corporation' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO']" " Motor Corp . is not Toyota . It is Toyota Motor Corporation .
+
+ The name of the CEO" False arena was named the Toyota Center, after 4 [' arena', ' was', ' named', ' the', ' Toyota']
+428 121 The name of the CEO of x -1 The name of the CEO of Toyota Akio Toyoda Toyota "[' Motor' ' Corp' '.' ' is' ' not' ' Toyota' '.' ' It' ' is' ' Toyota'
+ ' Motor' ' Corporation' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO']" " Motor Corp . is not Toyota . It is Toyota Motor Corporation .
+
+ The name of the CEO" False fifty-five points ahead of Toyota and ninety-eight 6 [' fifty', '-', 'five', ' points', ' ahead', ' of', ' Toyota']
+429 121 The name of the CEO of x -1 The name of the CEO of Toyota Akio Toyoda Toyota "[' Motor' ' Corp' '.' ' is' ' not' ' Toyota' '.' ' It' ' is' ' Toyota'
+ ' Motor' ' Corporation' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO']" " Motor Corp . is not Toyota . It is Toyota Motor Corporation .
+
+ The name of the CEO" False 1 ['Toy', 'ota']
+430 121 The name of the CEO of x -1 The name of the CEO of Toyota Akio Toyoda Toyota "[' Motor' ' Corp' '.' ' is' ' not' ' Toyota' '.' ' It' ' is' ' Toyota'
+ ' Motor' ' Corporation' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO']" " Motor Corp . is not Toyota . It is Toyota Motor Corporation .
+
+ The name of the CEO" False battery for a Toyota Prius requires 10 3 [' battery', ' for', ' a', ' Toyota']
+431 121 The name of the CEO of x -1 The name of the CEO of Toyota Akio Toyoda Toyota "[' Motor' ' Corp' '.' ' is' ' not' ' Toyota' '.' ' It' ' is' ' Toyota'
+ ' Motor' ' Corporation' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO']" " Motor Corp . is not Toyota . It is Toyota Motor Corporation .
+
+ The name of the CEO" False Rosso teammate Vettel. Toyota driver Jarno 6 [' Ros', 'so', ' teammate', ' V', 'ettel', '.', ' Toyota']
+432 122 The name of the CEO of x -1 The name of the CEO of Vodafone Nick Read Vodafone "[' is' ' V' 'itt' 'orio' ' Col' 'ao' '.' ' He' ' is' ' a' ' very' ' smart'
+ ' man' '.' ' He' ' is' ' a' ' very' ' smart' ' man']" is V itt orio Col ao . He is a very smart man . He is a very smart man False months, after Vodafone engineers despatched 6 [' months', ',', ' after', ' V', 'od', 'af', 'one']
+433 122 The name of the CEO of x -1 The name of the CEO of Vodafone Nick Read Vodafone "[' is' ' V' 'itt' 'orio' ' Col' 'ao' '.' ' He' ' is' ' a' ' very' ' smart'
+ ' man' '.' ' He' ' is' ' a' ' very' ' smart' ' man']" is V itt orio Col ao . He is a very smart man . He is a very smart man False Music 5 [' Mus', 'i', 'V', 'od', 'af', 'one']
+434 122 The name of the CEO of x -1 The name of the CEO of Vodafone Nick Read Vodafone "[' is' ' V' 'itt' 'orio' ' Col' 'ao' '.' ' He' ' is' ' a' ' very' ' smart'
+ ' man' '.' ' He' ' is' ' a' ' very' ' smart' ' man']" is V itt orio Col ao . He is a very smart man . He is a very smart man False season, when Vodafone agreed a four-year, 6 [' season', ',', ' when', ' V', 'od', 'af', 'one']
+435 122 The name of the CEO of x -1 The name of the CEO of Vodafone Nick Read Vodafone "[' is' ' V' 'itt' 'orio' ' Col' 'ao' '.' ' He' ' is' ' a' ' very' ' smart'
+ ' man' '.' ' He' ' is' ' a' ' very' ' smart' ' man']" is V itt orio Col ao . He is a very smart man . He is a very smart man False (1982 to 1994), Vodafone (1988 to 1998), 8 [' (', '1982', ' to', ' 1994', '),', ' V', 'od', 'af', 'one']
+436 122 The name of the CEO of x -1 The name of the CEO of Vodafone Nick Read Vodafone "[' is' ' V' 'itt' 'orio' ' Col' 'ao' '.' ' He' ' is' ' a' ' very' ' smart'
+ ' man' '.' ' He' ' is' ' a' ' very' ' smart' ' man']" is V itt orio Col ao . He is a very smart man . He is a very smart man False home turf, while Vodafone Arena opened 7 [' home', ' turf', ',', ' while', ' V', 'od', 'af', 'one']
+437 124 The name of the CEO of x -1 The name of the CEO of Air Liquide Benoît Potier Air Liquide "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' '�' '�' 's'
+ ' a' ' mouth' 'ful' ' that' '�' '�' 's']" is a bit of a mouth ful , but it � � s a mouth ful that � � s False position with Air Liquide after the company learned 4 [' position', ' with', ' Air', ' Liqu', 'ide']
+438 124 The name of the CEO of x -1 The name of the CEO of Air Liquide Benoît Potier Air Liquide "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' '�' '�' 's'
+ ' a' ' mouth' 'ful' ' that' '�' '�' 's']" is a bit of a mouth ful , but it � � s a mouth ful that � � s False sales manager for Air Liquide. She embraced 5 [' sales', ' manager', ' for', ' Air', ' Liqu', 'ide']
+439 125 The name of the CEO of x -1 The name of the CEO of Visa Inc. Alfred F. Kelly Visa Inc. "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' '�' '�' 's'
+ ' also' ' a' ' mouth' 'ful' ' of' ' a' ' name']" is a bit of a mouth ful , but it � � s also a mouth ful of a name False commercial for Visa Inc. She signed an 4 [' commercial', ' for', ' Visa', ' Inc', '.']
+440 125 The name of the CEO of x -1 The name of the CEO of Visa Inc. Alfred F. Kelly Visa Inc. "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' '�' '�' 's'
+ ' also' ' a' ' mouth' 'ful' ' of' ' a' ' name']" is a bit of a mouth ful , but it � � s also a mouth ful of a name False commercial for Visa Inc. She signed an endorsement 4 [' commercial', ' for', ' Visa', ' Inc', '.']
+441 125 The name of the CEO of x -1 The name of the CEO of Visa Inc. Alfred F. Kelly Visa Inc. "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' '�' '�' 's'
+ ' also' ' a' ' mouth' 'ful' ' of' ' a' ' name']" is a bit of a mouth ful , but it � � s also a mouth ful of a name False commercial for Visa Inc. She signed an endorsement 4 [' commercial', ' for', ' Visa', ' Inc', '.']
+442 125 The name of the CEO of x -1 The name of the CEO of Visa Inc. Alfred F. Kelly Visa Inc. "[' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it' '�' '�' 's'
+ ' also' ' a' ' mouth' 'ful' ' of' ' a' ' name']" is a bit of a mouth ful , but it � � s also a mouth ful of a name False 2008 commercial for Visa Inc. She signed an 5 [' 2008', ' commercial', ' for', ' Visa', ' Inc', '.']
+443 127 The name of the CEO of x -1 The name of the CEO of Air France Frédéric Gagey Air France "[' is' ' Jean' '-' 'Marc' ' J' 'ana' 'illac' '.' ' He' ' is' ' a'
+ ' French' ' politician' ' and' ' a' ' former' ' French' ' Minister' ' of'
+ ' Transport']" is Jean - Marc J ana illac . He is a French politician and a former French Minister of Transport False Tours (US) chartered Air France Concorde F-BTSD and 7 [' Tours', ' (', 'US', ')', ' chart', 'ered', ' Air', ' France']
+444 127 The name of the CEO of x -1 The name of the CEO of Air France Frédéric Gagey Air France "[' is' ' Jean' '-' 'Marc' ' J' 'ana' 'illac' '.' ' He' ' is' ' a'
+ ' French' ' politician' ' and' ' a' ' former' ' French' ' Minister' ' of'
+ ' Transport']" is Jean - Marc J ana illac . He is a French politician and a former French Minister of Transport False Airways and Air France decided to discontinue 3 [' Airways', ' and', ' Air', ' France']
+445 127 The name of the CEO of x -1 The name of the CEO of Air France Frédéric Gagey Air France "[' is' ' Jean' '-' 'Marc' ' J' 'ana' 'illac' '.' ' He' ' is' ' a'
+ ' French' ' politician' ' and' ' a' ' former' ' French' ' Minister' ' of'
+ ' Transport']" is Jean - Marc J ana illac . He is a French politician and a former French Minister of Transport False routes. In 1970, Air France and Swissair 6 [' routes', '.', ' In', ' 1970', ',', ' Air', ' France']
+446 127 The name of the CEO of x -1 The name of the CEO of Air France Frédéric Gagey Air France "[' is' ' Jean' '-' 'Marc' ' J' 'ana' 'illac' '.' ' He' ' is' ' a'
+ ' French' ' politician' ' and' ' a' ' former' ' French' ' Minister' ' of'
+ ' Transport']" is Jean - Marc J ana illac . He is a French politician and a former French Minister of Transport False Prior to the Air France Paris crash, several 4 [' Prior', ' to', ' the', ' Air', ' France']
+447 127 The name of the CEO of x -1 The name of the CEO of Air France Frédéric Gagey Air France "[' is' ' Jean' '-' 'Marc' ' J' 'ana' 'illac' '.' ' He' ' is' ' a'
+ ' French' ' politician' ' and' ' a' ' former' ' French' ' Minister' ' of'
+ ' Transport']" is Jean - Marc J ana illac . He is a French politician and a former French Minister of Transport False Emirates, although Air France briefly operated 4 [' Emirates', ',', ' although', ' Air', ' France']
+448 128 The name of the CEO of x -1 The name of the CEO of Sanofi Olivier Brandicourt Sanofi "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' in'
+ ' a' ' statement' '.' '\n' '\n' 'The' ' company' ' said' ' the' ' drug']" ", the company that makes the drug , said in a statement .
+
+ The company said the drug" False vaccine is produced by Sanofi and goes by the brand 5 [' vaccine', ' is', ' produced', ' by', ' San', 'ofi']
+449 130 The name of the CEO of x -1 The name of the CEO of BP Bernard Looney BP "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and'
+ ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and']" , and the name of the CEO of BP , and the name of the CEO of BP , and False " ""UtC-2020: 3510 ± 60 BP"" indicates that" 10 "[' ""', 'Ut', 'C', '-', '2020', ':', ' 35', '10', ' ±', ' 60', ' BP']"
+450 130 The name of the CEO of x -1 The name of the CEO of BP Bernard Looney BP "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and'
+ ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and']" , and the name of the CEO of BP , and the name of the CEO of BP , and False connected by the BP Pedestrian 3 [' connected', ' by', ' the', ' BP']
+451 130 The name of the CEO of x -1 The name of the CEO of BP Bernard Looney BP "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and'
+ ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and']" , and the name of the CEO of BP , and the name of the CEO of BP , and False task, circa 14,000 BP it is likely 6 [' task', ',', ' circa', ' 14', ',', '000', ' BP']
+452 130 The name of the CEO of x -1 The name of the CEO of BP Bernard Looney BP "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and'
+ ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and']" , and the name of the CEO of BP , and the name of the CEO of BP , and False Post Office and a BP service station that 4 [' Post', ' Office', ' and', ' a', ' BP']
+453 130 The name of the CEO of x -1 The name of the CEO of BP Bernard Looney BP "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and'
+ ' the' ' name' ' of' ' the' ' CEO' ' of' ' BP' ',' ' and']" , and the name of the CEO of BP , and the name of the CEO of BP , and False in its history on BP plc for failing 4 [' in', ' its', ' history', ' on', ' BP']
+454 131 The name of the CEO of x -1 The name of the CEO of Schneider Electric Jean-Pascal Tricoire Schneider Electric "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' and' ' selling' ' power' ' tools' ' for' ' over' ' 100'
+ ' years' '.']" , a company that has been in the business of making and selling power tools for over 100 years . False Project), the Schneider Electric Square-D, and 4 [' Project', '),', ' the', ' Schneider', ' Electric']
+455 131 The name of the CEO of x -1 The name of the CEO of Schneider Electric Jean-Pascal Tricoire Schneider Electric "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' and' ' selling' ' power' ' tools' ' for' ' over' ' 100'
+ ' years' '.']" , a company that has been in the business of making and selling power tools for over 100 years . False Project), the Schneider Electric Square-D, and the 4 [' Project', '),', ' the', ' Schneider', ' Electric']
+456 131 The name of the CEO of x -1 The name of the CEO of Schneider Electric Jean-Pascal Tricoire Schneider Electric "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' and' ' selling' ' power' ' tools' ' for' ' over' ' 100'
+ ' years' '.']" , a company that has been in the business of making and selling power tools for over 100 years . False Triveni Sheet Glass, Schneider Electric India Ltd, Triveni 7 [' Tri', 'ven', 'i', ' Sheet', ' Glass', ',', ' Schneider', ' Electric']
+457 132 The name of the CEO of x -1 The name of the CEO of Continental AG Elmar Degenhart Continental AG "[',' ' the' ' parent' ' company' ' of' ' Continental' ' Airlines' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased'
+ ' to']" ", the parent company of Continental Airlines , said in a statement .
+
+ "" We are pleased to" False rubber factories of Continental AG and Nordhafen 4 [' rubber', ' factories', ' of', ' Continental', ' AG']
+458 132 The name of the CEO of x -1 The name of the CEO of Continental AG Elmar Degenhart Continental AG "[',' ' the' ' parent' ' company' ' of' ' Continental' ' Airlines' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased'
+ ' to']" ", the parent company of Continental Airlines , said in a statement .
+
+ "" We are pleased to" False P225 / 40R18 Continental AG tires, reminiscent 7 [' P', '225', ' /', ' 40', 'R', '18', ' Continental', ' AG']
+459 132 The name of the CEO of x -1 The name of the CEO of Continental AG Elmar Degenhart Continental AG "[',' ' the' ' parent' ' company' ' of' ' Continental' ' Airlines' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased'
+ ' to']" ", the parent company of Continental Airlines , said in a statement .
+
+ "" We are pleased to" False with P225 / 40R18 Continental AG tires, reminiscent 8 [' with', ' P', '225', ' /', ' 40', 'R', '18', ' Continental', ' AG']
+460 132 The name of the CEO of x -1 The name of the CEO of Continental AG Elmar Degenhart Continental AG "[',' ' the' ' parent' ' company' ' of' ' Continental' ' Airlines' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased'
+ ' to']" ", the parent company of Continental Airlines , said in a statement .
+
+ "" We are pleased to" False factories of Continental AG and Nordhafen in Hanover 3 [' factories', ' of', ' Continental', ' AG']
+461 132 The name of the CEO of x -1 The name of the CEO of Continental AG Elmar Degenhart Continental AG "[',' ' the' ' parent' ' company' ' of' ' Continental' ' Airlines' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased'
+ ' to']" ", the parent company of Continental Airlines , said in a statement .
+
+ "" We are pleased to" False P225 / 40R18 Continental AG tires, reminiscent 7 [' P', '225', ' /', ' 40', 'R', '18', ' Continental', ' AG']
+462 134 The name of the CEO of x -1 The name of the CEO of American Express Stephen Squeri American Express "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False and later housed the American Express social club until 5 [' and', ' later', ' housed', ' the', ' American', ' Express']
+463 134 The name of the CEO of x -1 The name of the CEO of American Express Stephen Squeri American Express "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False lounge for American Express card holders; the 3 [' lounge', ' for', ' American', ' Express']
+464 134 The name of the CEO of x -1 The name of the CEO of American Express Stephen Squeri American Express "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False Alliance Gold members, American Express Centurion / EVA 5 [' Alliance', ' Gold', ' members', ',', ' American', ' Express']
+465 134 The name of the CEO of x -1 The name of the CEO of American Express Stephen Squeri American Express "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False expense account to her American Express account. Jackson 5 [' expense', ' account', ' to', ' her', ' American', ' Express']
+466 134 The name of the CEO of x -1 The name of the CEO of American Express Stephen Squeri American Express "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False into the area, and American Express opened a large, 6 [' into', ' the', ' area', ',', ' and', ' American', ' Express']
+467 135 The name of the CEO of x -1 The name of the CEO of Institute of Biological Sciences André Le Bivic Institute of Biological Sciences "[',' ' Dr' '.' ' David' ' B' '.' ' Smith' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' a']" , Dr . David B . Smith , is a perfect example of the kind of person who is a False In addition, the Institute of Biological Sciences, Institute 7 [' In', ' addition', ',', ' the', ' Institute', ' of', ' Biological', ' Sciences']
+468 136 The name of the CEO of x -1 The name of the CEO of FedEx Raj Subramaniam FedEx "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False 1 ['Fed', 'Ex']
+469 136 The name of the CEO of x -1 The name of the CEO of FedEx Raj Subramaniam FedEx "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False Maryland close to FedEx Field. This proposal 3 [' Maryland', ' close', ' to', ' FedEx']
+470 136 The name of the CEO of x -1 The name of the CEO of FedEx Raj Subramaniam FedEx "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False holds 1.8 %, FedEx holds 1.6 %, United 6 [' holds', ' 1', '.', '8', ' %', ',', ' FedEx']
+471 136 The name of the CEO of x -1 The name of the CEO of FedEx Raj Subramaniam FedEx "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False allocation for the FedEx Cup, starting with 3 [' allocation', ' for', ' the', ' FedEx']
+472 136 The name of the CEO of x -1 The name of the CEO of FedEx Raj Subramaniam FedEx "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False received for the FedEx Cup title. The same 3 [' received', ' for', ' the', ' FedEx']
+473 137 The name of the CEO of x -1 The name of the CEO of General Dynamics Phebe Novakovic General Dynamics "[',' ' the' ' company' ' that' ' makes' ' the' ' F' '-' '35' ',' ' said'
+ ' that' ' the' ' F' '-' '35' ' is' ' the' ' most' ' complex']" , the company that makes the F - 35 , said that the F - 35 is the most complex False division of General Dynamics during the late 3 [' division', ' of', ' General', ' Dynamics']
+474 137 The name of the CEO of x -1 The name of the CEO of General Dynamics Phebe Novakovic General Dynamics "[',' ' the' ' company' ' that' ' makes' ' the' ' F' '-' '35' ',' ' said'
+ ' that' ' the' ' F' '-' '35' ' is' ' the' ' most' ' complex']" , the company that makes the F - 35 , said that the F - 35 is the most complex False longer term, the General Dynamics F-111 Aardvark 5 [' longer', ' term', ',', ' the', ' General', ' Dynamics']
+475 137 The name of the CEO of x -1 The name of the CEO of General Dynamics Phebe Novakovic General Dynamics "[',' ' the' ' company' ' that' ' makes' ' the' ' F' '-' '35' ',' ' said'
+ ' that' ' the' ' F' '-' '35' ' is' ' the' ' most' ' complex']" , the company that makes the F - 35 , said that the F - 35 is the most complex False Norfolk Ship Repair, General Dynamics NASSCO Norfolk, 5 [' Norfolk', ' Ship', ' Repair', ',', ' General', ' Dynamics']
+476 137 The name of the CEO of x -1 The name of the CEO of General Dynamics Phebe Novakovic General Dynamics "[',' ' the' ' company' ' that' ' makes' ' the' ' F' '-' '35' ',' ' said'
+ ' that' ' the' ' F' '-' '35' ' is' ' the' ' most' ' complex']" , the company that makes the F - 35 , said that the F - 35 is the most complex False sophisticated aircraft as the General Dynamics F-111. Wackett's 17-year 5 [' sophisticated', ' aircraft', ' as', ' the', ' General', ' Dynamics']
+477 137 The name of the CEO of x -1 The name of the CEO of General Dynamics Phebe Novakovic General Dynamics "[',' ' the' ' company' ' that' ' makes' ' the' ' F' '-' '35' ',' ' said'
+ ' that' ' the' ' F' '-' '35' ' is' ' the' ' most' ' complex']" , the company that makes the F - 35 , said that the F - 35 is the most complex False including the General Dynamics F-16 Fighting 3 [' including', ' the', ' General', ' Dynamics']
+478 138 The name of the CEO of x -1 The name of the CEO of Warner Music Group Stephen Cooper Warner Music Group "[',' ' and' ' the' ' man' ' who' ' has' ' been' ' at' ' the' ' helm' ' of'
+ ' the' ' company' ' for' ' the' ' past' ' two' ' years' ',' ' has']" , and the man who has been at the helm of the company for the past two years , has False release from Warner Music Group announced a March 4 [' release', ' from', ' Warner', ' Music', ' Group']
+479 138 The name of the CEO of x -1 The name of the CEO of Warner Music Group Stephen Cooper Warner Music Group "[',' ' and' ' the' ' man' ' who' ' has' ' been' ' at' ' the' ' helm' ' of'
+ ' the' ' company' ' for' ' the' ' past' ' two' ' years' ',' ' has']" , and the man who has been at the helm of the company for the past two years , has False restructuring within Warner Music Group following the purchase 4 [' restructuring', ' within', ' Warner', ' Music', ' Group']
+480 138 The name of the CEO of x -1 The name of the CEO of Warner Music Group Stephen Cooper Warner Music Group "[',' ' and' ' the' ' man' ' who' ' has' ' been' ' at' ' the' ' helm' ' of'
+ ' the' ' company' ' for' ' the' ' past' ' two' ' years' ',' ' has']" , and the man who has been at the helm of the company for the past two years , has False released in the UK by the Warner Music Group on May 9, 8 [' released', ' in', ' the', ' UK', ' by', ' the', ' Warner', ' Music', ' Group']
+481 138 The name of the CEO of x -1 The name of the CEO of Warner Music Group Stephen Cooper Warner Music Group "[',' ' and' ' the' ' man' ' who' ' has' ' been' ' at' ' the' ' helm' ' of'
+ ' the' ' company' ' for' ' the' ' past' ' two' ' years' ',' ' has']" , and the man who has been at the helm of the company for the past two years , has False March 8, 2013, Warner Music Group sent the song to 7 [' March', ' 8', ',', ' 2013', ',', ' Warner', ' Music', ' Group']
+482 138 The name of the CEO of x -1 The name of the CEO of Warner Music Group Stephen Cooper Warner Music Group "[',' ' and' ' the' ' man' ' who' ' has' ' been' ' at' ' the' ' helm' ' of'
+ ' the' ' company' ' for' ' the' ' past' ' two' ' years' ',' ' has']" , and the man who has been at the helm of the company for the past two years , has False restructuring within Warner Music Group following the 4 [' restructuring', ' within', ' Warner', ' Music', ' Group']
+483 139 The name of the CEO of x -1 The name of the CEO of BNP Paribas Jean-Laurent Bonnafé BNP Paribas "[',' ' the' ' French' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' bank' ' by' ' assets' ',' ' has' ' been' ' in' ' the']" , the French bank that is the world � � s largest bank by assets , has been in the False early defeat at the BNP Paribas Open, Murray 8 [' early', ' defeat', ' at', ' the', ' B', 'NP', ' Par', 'ib', 'as']
+484 139 The name of the CEO of x -1 The name of the CEO of BNP Paribas Jean-Laurent Bonnafé BNP Paribas "[',' ' the' ' French' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' bank' ' by' ' assets' ',' ' has' ' been' ' in' ' the']" , the French bank that is the world � � s largest bank by assets , has been in the False 4 ['BN', 'P', ' Par', 'ib', 'as']
+485 139 The name of the CEO of x -1 The name of the CEO of BNP Paribas Jean-Laurent Bonnafé BNP Paribas "[',' ' the' ' French' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' bank' ' by' ' assets' ',' ' has' ' been' ' in' ' the']" , the French bank that is the world � � s largest bank by assets , has been in the False Safina's withdrawal at the BNP Paribas Open, she 10 "[' Saf', 'ina', ""'s"", ' withdrawal', ' at', ' the', ' B', 'NP', ' Par', 'ib', 'as']"
+486 139 The name of the CEO of x -1 The name of the CEO of BNP Paribas Jean-Laurent Bonnafé BNP Paribas "[',' ' the' ' French' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' bank' ' by' ' assets' ',' ' has' ' been' ' in' ' the']" , the French bank that is the world � � s largest bank by assets , has been in the False quarterfinals of the 2013 BNP Paribas Open, Del Potro 9 [' quarter', 'finals', ' of', ' the', ' 2013', ' B', 'NP', ' Par', 'ib', 'as']
+487 139 The name of the CEO of x -1 The name of the CEO of BNP Paribas Jean-Laurent Bonnafé BNP Paribas "[',' ' the' ' French' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' bank' ' by' ' assets' ',' ' has' ' been' ' in' ' the']" , the French bank that is the world � � s largest bank by assets , has been in the False quarterfinals of the BNP Paribas Open to Federer for 8 [' quarter', 'finals', ' of', ' the', ' B', 'NP', ' Par', 'ib', 'as']
+488 140 The name of the CEO of x -1 The name of the CEO of Juniper Networks Rami Rahim Juniper Networks "[',' ' a' ' company' ' that' ' makes' ' networking' ' equipment' ','
+ ' said' ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new'
+ ' version' ' of' ' its']" , a company that makes networking equipment , said that the company is working on a new version of its False of NetScreen, Juniper Networks began developing 6 [' of', ' Net', 'Screen', ',', ' Jun', 'iper', ' Networks']
+489 140 The name of the CEO of x -1 The name of the CEO of Juniper Networks Rami Rahim Juniper Networks "[',' ' a' ' company' ' that' ' makes' ' networking' ' equipment' ','
+ ' said' ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new'
+ ' version' ' of' ' its']" , a company that makes networking equipment , said that the company is working on a new version of its False " According to Telephony, Juniper Networks became the ""latest" 7 [' According', ' to', ' Tele', 'phony', ',', ' Jun', 'iper', ' Networks']
+490 140 The name of the CEO of x -1 The name of the CEO of Juniper Networks Rami Rahim Juniper Networks "[',' ' a' ' company' ' that' ' makes' ' networking' ' equipment' ','
+ ' said' ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new'
+ ' version' ' of' ' its']" , a company that makes networking equipment , said that the company is working on a new version of its False the product. Juniper Networks announced the 5 [' the', ' product', '.', ' Jun', 'iper', ' Networks']
+491 140 The name of the CEO of x -1 The name of the CEO of Juniper Networks Rami Rahim Juniper Networks "[',' ' a' ' company' ' that' ' makes' ' networking' ' equipment' ','
+ ' said' ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new'
+ ' version' ' of' ' its']" , a company that makes networking equipment , said that the company is working on a new version of its False by Glassdoor, Juniper Networks has the highest 6 [' by', ' Glass', 'door', ',', ' Jun', 'iper', ' Networks']
+492 140 The name of the CEO of x -1 The name of the CEO of Juniper Networks Rami Rahim Juniper Networks "[',' ' a' ' company' ' that' ' makes' ' networking' ' equipment' ','
+ ' said' ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new'
+ ' version' ' of' ' its']" , a company that makes networking equipment , said that the company is working on a new version of its False the product. Juniper Networks announced the 5 [' the', ' product', '.', ' Jun', 'iper', ' Networks']
+493 142 The name of the CEO of x -1 The name of the CEO of Audi Markus Duesmann Audi "['ob' 'ahn' ',' ' a' ' company' ' that' ' provides' ' a' ' platform'
+ ' for' ' the' ' exchange' ' of' ' music' ' and' ' other' ' media' '.'
+ '\n' '\n']" "ob ahn , a company that provides a platform for the exchange of music and other media .
+
+" False Greece, and helped Audi to its first manufacturers'title. 4 [' Greece', ',', ' and', ' helped', ' Audi']
+494 142 The name of the CEO of x -1 The name of the CEO of Audi Markus Duesmann Audi "['ob' 'ahn' ',' ' a' ' company' ' that' ' provides' ' a' ' platform'
+ ' for' ' the' ' exchange' ' of' ' music' ' and' ' other' ' media' '.'
+ '\n' '\n']" "ob ahn , a company that provides a platform for the exchange of music and other media .
+
+" False fallen well behind Audi and Peugeot in terms 3 [' fallen', ' well', ' behind', ' Audi']
+495 142 The name of the CEO of x -1 The name of the CEO of Audi Markus Duesmann Audi "['ob' 'ahn' ',' ' a' ' company' ' that' ' provides' ' a' ' platform'
+ ' for' ' the' ' exchange' ' of' ' music' ' and' ' other' ' media' '.'
+ '\n' '\n']" "ob ahn , a company that provides a platform for the exchange of music and other media .
+
+" False turns 8 and 9 (the Audi and Shell Kurves) 6 [' turns', ' 8', ' and', ' 9', ' (', 'the', ' Audi']
+496 142 The name of the CEO of x -1 The name of the CEO of Audi Markus Duesmann Audi "['ob' 'ahn' ',' ' a' ' company' ' that' ' provides' ' a' ' platform'
+ ' for' ' the' ' exchange' ' of' ' music' ' and' ' other' ' media' '.'
+ '\n' '\n']" "ob ahn , a company that provides a platform for the exchange of music and other media .
+
+" False 1 ['Aud', 'i']
+497 142 The name of the CEO of x -1 The name of the CEO of Audi Markus Duesmann Audi "['ob' 'ahn' ',' ' a' ' company' ' that' ' provides' ' a' ' platform'
+ ' for' ' the' ' exchange' ' of' ' music' ' and' ' other' ' media' '.'
+ '\n' '\n']" "ob ahn , a company that provides a platform for the exchange of music and other media .
+
+" False ended criticism of Audi for signing a female 3 [' ended', ' criticism', ' of', ' Audi']
+498 143 The name of the CEO of x -1 The name of the CEO of U.S. Steel David Burritt U.S. Steel "[',' ' Andrew' ' Carnegie' ',' ' was' ' a' ' big' ' supporter' ' of'
+ ' the' ' project' '.' '\n' '\n' 'The' ' project' ' was' ' completed'
+ ' in' ' 18']" ", Andrew Carnegie , was a big supporter of the project .
+
+ The project was completed in 18" False While working for the U.S. Steel Hour, a colleague 8 [' While', ' working', ' for', ' the', ' U', '.', 'S', '.', ' Steel']
+499 143 The name of the CEO of x -1 The name of the CEO of U.S. Steel David Burritt U.S. Steel "[',' ' Andrew' ' Carnegie' ',' ' was' ' a' ' big' ' supporter' ' of'
+ ' the' ' project' '.' '\n' '\n' 'The' ' project' ' was' ' completed'
+ ' in' ' 18']" ", Andrew Carnegie , was a big supporter of the project .
+
+ The project was completed in 18" False States Steel (U.S. Steel). That company had 7 [' States', ' Steel', ' (', 'U', '.', 'S', '.', ' Steel']
+500 143 The name of the CEO of x -1 The name of the CEO of U.S. Steel David Burritt U.S. Steel "[',' ' Andrew' ' Carnegie' ',' ' was' ' a' ' big' ' supporter' ' of'
+ ' the' ' project' '.' '\n' '\n' 'The' ' project' ' was' ' completed'
+ ' in' ' 18']" ", Andrew Carnegie , was a big supporter of the project .
+
+ The project was completed in 18" False mine owned by U.S. Steel Corporation, where 7 [' mine', ' owned', ' by', ' U', '.', 'S', '.', ' Steel']
+501 143 The name of the CEO of x -1 The name of the CEO of U.S. Steel David Burritt U.S. Steel "[',' ' Andrew' ' Carnegie' ',' ' was' ' a' ' big' ' supporter' ' of'
+ ' the' ' project' '.' '\n' '\n' 'The' ' project' ' was' ' completed'
+ ' in' ' 18']" ", Andrew Carnegie , was a big supporter of the project .
+
+ The project was completed in 18" False working for the U.S. Steel Hour, a colleague 7 [' working', ' for', ' the', ' U', '.', 'S', '.', ' Steel']
+502 143 The name of the CEO of x -1 The name of the CEO of U.S. Steel David Burritt U.S. Steel "[',' ' Andrew' ' Carnegie' ',' ' was' ' a' ' big' ' supporter' ' of'
+ ' the' ' project' '.' '\n' '\n' 'The' ' project' ' was' ' completed'
+ ' in' ' 18']" ", Andrew Carnegie , was a big supporter of the project .
+
+ The project was completed in 18" False situation in 1977 when U.S. Steel decided to close 8 [' situation', ' in', ' 1977', ' when', ' U', '.', 'S', '.', ' Steel']
+503 144 The name of the CEO of x -1 The name of the CEO of AstraZeneca Pascal Soriot AstraZeneca "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' the' ' drug'
+ ',' ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are'
+ ' pleased']" ", the pharmaceutical company that makes the drug , said in a statement .
+
+ "" We are pleased" False area as part of AstraZeneca locating their North 7 [' area', ' as', ' part', ' of', ' Ast', 'ra', 'Zen', 'eca']
+504 144 The name of the CEO of x -1 The name of the CEO of AstraZeneca Pascal Soriot AstraZeneca "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' the' ' drug'
+ ',' ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are'
+ ' pleased']" ", the pharmaceutical company that makes the drug , said in a statement .
+
+ "" We are pleased" False this area as part of AstraZeneca locating their 8 [' this', ' area', ' as', ' part', ' of', ' Ast', 'ra', 'Zen', 'eca']
+505 144 The name of the CEO of x -1 The name of the CEO of AstraZeneca Pascal Soriot AstraZeneca "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' the' ' drug'
+ ',' ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are'
+ ' pleased']" ", the pharmaceutical company that makes the drug , said in a statement .
+
+ "" We are pleased" False this area as part of AstraZeneca locating their North 8 [' this', ' area', ' as', ' part', ' of', ' Ast', 'ra', 'Zen', 'eca']
+506 145 The name of the CEO of x -1 The name of the CEO of Target Corporation Brian Cornell Target Corporation "[',' ' the' ' world' ""'s"" ' largest' ' retailer' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ""'s"" ' CEO']" ", the world 's largest retailer , is a bit of a mystery .
+
+ The company 's CEO" False ever since. After Target Corporation sold Field's 5 [' ever', ' since', '.', ' After', ' Target', ' Corporation']
+507 145 The name of the CEO of x -1 The name of the CEO of Target Corporation Brian Cornell Target Corporation "[',' ' the' ' world' ""'s"" ' largest' ' retailer' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ""'s"" ' CEO']" ", the world 's largest retailer , is a bit of a mystery .
+
+ The company 's CEO" False " Week"" programming. Target Corporation sells a ""singing""" 5 "[' Week', '""', ' programming', '.', ' Target', ' Corporation']"
+508 145 The name of the CEO of x -1 The name of the CEO of Target Corporation Brian Cornell Target Corporation "[',' ' the' ' world' ""'s"" ' largest' ' retailer' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ""'s"" ' CEO']" ", the world 's largest retailer , is a bit of a mystery .
+
+ The company 's CEO" False marketing deal with Target Corporation and Burger King, 4 [' marketing', ' deal', ' with', ' Target', ' Corporation']
+509 145 The name of the CEO of x -1 The name of the CEO of Target Corporation Brian Cornell Target Corporation "[',' ' the' ' world' ""'s"" ' largest' ' retailer' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ""'s"" ' CEO']" ", the world 's largest retailer , is a bit of a mystery .
+
+ The company 's CEO" False deal with Target Corporation and Burger King, 3 [' deal', ' with', ' Target', ' Corporation']
+510 145 The name of the CEO of x -1 The name of the CEO of Target Corporation Brian Cornell Target Corporation "[',' ' the' ' world' ""'s"" ' largest' ' retailer' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ""'s"" ' CEO']" ", the world 's largest retailer , is a bit of a mystery .
+
+ The company 's CEO" False marketing deal with Target Corporation and Burger King, expanding 4 [' marketing', ' deal', ' with', ' Target', ' Corporation']
+511 148 The name of the CEO of x -1 The name of the CEO of Brussels Airport Arnaud Feist Brussels Airport "[',' ' the' ' airport' ' authority' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' '\n' '\n' 'The' ' airport' ' is' ' owned' ' by' ' the']" ", the airport authority , is a bit of a mystery .
+
+ The airport is owned by the" False scissors hub at Brussels Airport for onward transatlantic 4 [' scissors', ' hub', ' at', ' Brussels', ' Airport']
+512 148 The name of the CEO of x -1 The name of the CEO of Brussels Airport Arnaud Feist Brussels Airport "[',' ' the' ' airport' ' authority' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' '\n' '\n' 'The' ' airport' ' is' ' owned' ' by' ' the']" ", the airport authority , is a bit of a mystery .
+
+ The airport is owned by the" False carried out a raid at Brussels Airport and escaped with gems 6 [' carried', ' out', ' a', ' raid', ' at', ' Brussels', ' Airport']
+513 148 The name of the CEO of x -1 The name of the CEO of Brussels Airport Arnaud Feist Brussels Airport "[',' ' the' ' airport' ' authority' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' '\n' '\n' 'The' ' airport' ' is' ' owned' ' by' ' the']" ", the airport authority , is a bit of a mystery .
+
+ The airport is owned by the" False a raid at Brussels Airport and escaped with 4 [' a', ' raid', ' at', ' Brussels', ' Airport']
+514 148 The name of the CEO of x -1 The name of the CEO of Brussels Airport Arnaud Feist Brussels Airport "[',' ' the' ' airport' ' authority' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' '\n' '\n' 'The' ' airport' ' is' ' owned' ' by' ' the']" ", the airport authority , is a bit of a mystery .
+
+ The airport is owned by the" False carried out a raid at Brussels Airport and escaped 6 [' carried', ' out', ' a', ' raid', ' at', ' Brussels', ' Airport']
+515 148 The name of the CEO of x -1 The name of the CEO of Brussels Airport Arnaud Feist Brussels Airport "[',' ' the' ' airport' ' authority' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' '\n' '\n' 'The' ' airport' ' is' ' owned' ' by' ' the']" ", the airport authority , is a bit of a mystery .
+
+ The airport is owned by the" False carried out a raid at Brussels Airport and escaped with 6 [' carried', ' out', ' a', ' raid', ' at', ' Brussels', ' Airport']
+516 149 The name of the CEO of x -1 The name of the CEO of Aspen Institute Daniel R. Porterfield Aspen Institute "[',' ' a' ' nonprofit' ' organization' ' that' ' promotes' ' the' ' study'
+ ' of' ' the' ' humanities' ' and' ' social' ' sciences' '.' '\n' '\n'
+ 'The' ' As' 'pen']" ", a nonprofit organization that promotes the study of the humanities and social sciences .
+
+ The As pen" False conference at the Aspen Institute in Maryland on 5 [' conference', ' at', ' the', ' As', 'pen', ' Institute']
+517 149 The name of the CEO of x -1 The name of the CEO of Aspen Institute Daniel R. Porterfield Aspen Institute "[',' ' a' ' nonprofit' ' organization' ' that' ' promotes' ' the' ' study'
+ ' of' ' the' ' humanities' ' and' ' social' ' sciences' '.' '\n' '\n'
+ 'The' ' As' 'pen']" ", a nonprofit organization that promotes the study of the humanities and social sciences .
+
+ The As pen" False and Pao met while Aspen Institute fellows. They have 7 [' and', ' P', 'ao', ' met', ' while', ' As', 'pen', ' Institute']
+518 149 The name of the CEO of x -1 The name of the CEO of Aspen Institute Daniel R. Porterfield Aspen Institute "[',' ' a' ' nonprofit' ' organization' ' that' ' promotes' ' the' ' study'
+ ' of' ' the' ' humanities' ' and' ' social' ' sciences' '.' '\n' '\n'
+ 'The' ' As' 'pen']" ", a nonprofit organization that promotes the study of the humanities and social sciences .
+
+ The As pen" False conference at the Aspen Institute in Maryland on how 5 [' conference', ' at', ' the', ' As', 'pen', ' Institute']
+519 149 The name of the CEO of x -1 The name of the CEO of Aspen Institute Daniel R. Porterfield Aspen Institute "[',' ' a' ' nonprofit' ' organization' ' that' ' promotes' ' the' ' study'
+ ' of' ' the' ' humanities' ' and' ' social' ' sciences' '.' '\n' '\n'
+ 'The' ' As' 'pen']" ", a nonprofit organization that promotes the study of the humanities and social sciences .
+
+ The As pen" False Pao met while Aspen Institute fellows. They have 6 [' P', 'ao', ' met', ' while', ' As', 'pen', ' Institute']
+520 149 The name of the CEO of x -1 The name of the CEO of Aspen Institute Daniel R. Porterfield Aspen Institute "[',' ' a' ' nonprofit' ' organization' ' that' ' promotes' ' the' ' study'
+ ' of' ' the' ' humanities' ' and' ' social' ' sciences' '.' '\n' '\n'
+ 'The' ' As' 'pen']" ", a nonprofit organization that promotes the study of the humanities and social sciences .
+
+ The As pen" False conference at the Aspen Institute in Maryland 5 [' conference', ' at', ' the', ' As', 'pen', ' Institute']
+521 151 The name of the CEO of x -1 The name of the CEO of World Wide Fund for Nature Marco Lambertini World Wide Fund for Nature "[' (' 'WW' 'F' ')' ' and' ' the' ' World' ' Wildlife' ' Fund' ' (' 'WW'
+ 'F' ')' ' is' ' a' ' registered' ' trademark' ' of' ' the' ' World']" ( WW F ) and the World Wildlife Fund ( WW F ) is a registered trademark of the World False 4 ['World', ' Wide', ' Fund', ' for', ' Nature']
+522 151 The name of the CEO of x -1 The name of the CEO of World Wide Fund for Nature Marco Lambertini World Wide Fund for Nature "[' (' 'WW' 'F' ')' ' and' ' the' ' World' ' Wildlife' ' Fund' ' (' 'WW'
+ 'F' ')' ' is' ' a' ' registered' ' trademark' ' of' ' the' ' World']" ( WW F ) and the World Wildlife Fund ( WW F ) is a registered trademark of the World False Vice-President of the World Wide Fund for Nature International, 9 [' Vice', '-', 'President', ' of', ' the', ' World', ' Wide', ' Fund', ' for', ' Nature']
+523 151 The name of the CEO of x -1 The name of the CEO of World Wide Fund for Nature Marco Lambertini World Wide Fund for Nature "[' (' 'WW' 'F' ')' ' and' ' the' ' World' ' Wildlife' ' Fund' ' (' 'WW'
+ 'F' ')' ' is' ' a' ' registered' ' trademark' ' of' ' the' ' World']" ( WW F ) and the World Wildlife Fund ( WW F ) is a registered trademark of the World False Region. The World Wide Fund for Nature divides land in 7 [' Region', '.', ' The', ' World', ' Wide', ' Fund', ' for', ' Nature']
+524 151 The name of the CEO of x -1 The name of the CEO of World Wide Fund for Nature Marco Lambertini World Wide Fund for Nature "[' (' 'WW' 'F' ')' ' and' ' the' ' World' ' Wildlife' ' Fund' ' (' 'WW'
+ 'F' ')' ' is' ' a' ' registered' ' trademark' ' of' ' the' ' World']" ( WW F ) and the World Wildlife Fund ( WW F ) is a registered trademark of the World False branch of the World Wide Fund for Nature and the Wakatobi 7 [' branch', ' of', ' the', ' World', ' Wide', ' Fund', ' for', ' Nature']
+525 151 The name of the CEO of x -1 The name of the CEO of World Wide Fund for Nature Marco Lambertini World Wide Fund for Nature "[' (' 'WW' 'F' ')' ' and' ' the' ' World' ' Wildlife' ' Fund' ' (' 'WW'
+ 'F' ')' ' is' ' a' ' registered' ' trademark' ' of' ' the' ' World']" ( WW F ) and the World Wildlife Fund ( WW F ) is a registered trademark of the World False 4 ['World', ' Wide', ' Fund', ' for', ' Nature']
+526 152 The name of the CEO of x -1 The name of the CEO of Morgan Stanley James P. Gorman Morgan Stanley "[',' ' the' ' company' ' that' ' is' ' now' ' the' ' world' '�' '�' 's'
+ ' largest' ' investment' ' bank' ',' ' was' ' born' ' in' ' the' ' same']" , the company that is now the world � � s largest investment bank , was born in the same False Austin offices of Morgan Stanley and Ernst & Young, 4 [' Austin', ' offices', ' of', ' Morgan', ' Stanley']
+527 152 The name of the CEO of x -1 The name of the CEO of Morgan Stanley James P. Gorman Morgan Stanley "[',' ' the' ' company' ' that' ' is' ' now' ' the' ' world' '�' '�' 's'
+ ' largest' ' investment' ' bank' ',' ' was' ' born' ' in' ' the' ' same']" , the company that is now the world � � s largest investment bank , was born in the same False leaving the current name Morgan Stanley. In 2009, the Dean 5 [' leaving', ' the', ' current', ' name', ' Morgan', ' Stanley']
+528 152 The name of the CEO of x -1 The name of the CEO of Morgan Stanley James P. Gorman Morgan Stanley "[',' ' the' ' company' ' that' ' is' ' now' ' the' ' world' '�' '�' 's'
+ ' largest' ' investment' ' bank' ',' ' was' ' born' ' in' ' the' ' same']" , the company that is now the world � � s largest investment bank , was born in the same False 1 ['Morgan', ' Stanley']
+529 152 The name of the CEO of x -1 The name of the CEO of Morgan Stanley James P. Gorman Morgan Stanley "[',' ' the' ' company' ' that' ' is' ' now' ' the' ' world' '�' '�' 's'
+ ' largest' ' investment' ' bank' ',' ' was' ' born' ' in' ' the' ' same']" , the company that is now the world � � s largest investment bank , was born in the same False John Glass of Morgan Stanley and David Tarantino 4 [' John', ' Glass', ' of', ' Morgan', ' Stanley']
+530 152 The name of the CEO of x -1 The name of the CEO of Morgan Stanley James P. Gorman Morgan Stanley "[',' ' the' ' company' ' that' ' is' ' now' ' the' ' world' '�' '�' 's'
+ ' largest' ' investment' ' bank' ',' ' was' ' born' ' in' ' the' ' same']" , the company that is now the world � � s largest investment bank , was born in the same False In 1997, Morgan Stanley Group, Inc. 4 [' In', ' 1997', ',', ' Morgan', ' Stanley']
+531 154 The name of the CEO of x -1 The name of the CEO of Abbott Laboratories Robert B. Ford Abbott Laboratories "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' in'
+ ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased' ' to']" ", the company that makes the drug , said in a statement .
+
+ "" We are pleased to" False Robert Wexler of Abbott Laboratories developed the Analgizer, 6 [' Robert', ' We', 'x', 'ler', ' of', ' Abbott', ' Laboratories']
+532 154 The name of the CEO of x -1 The name of the CEO of Abbott Laboratories Robert B. Ford Abbott Laboratories "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' in'
+ ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased' ' to']" ", the company that makes the drug , said in a statement .
+
+ "" We are pleased to" False acquisition of Abbott Laboratories was announced 3 [' acquisition', ' of', ' Abbott', ' Laboratories']
+533 154 The name of the CEO of x -1 The name of the CEO of Abbott Laboratories Robert B. Ford Abbott Laboratories "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' in'
+ ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased' ' to']" ", the company that makes the drug , said in a statement .
+
+ "" We are pleased to" False acquisition of Abbott Laboratories was announced in 3 [' acquisition', ' of', ' Abbott', ' Laboratories']
+534 154 The name of the CEO of x -1 The name of the CEO of Abbott Laboratories Robert B. Ford Abbott Laboratories "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' in'
+ ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased' ' to']" ", the company that makes the drug , said in a statement .
+
+ "" We are pleased to" False partnered with Abbott Laboratories to publicize 3 [' partnered', ' with', ' Abbott', ' Laboratories']
+535 154 The name of the CEO of x -1 The name of the CEO of Abbott Laboratories Robert B. Ford Abbott Laboratories "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' in'
+ ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased' ' to']" ", the company that makes the drug , said in a statement .
+
+ "" We are pleased to" False and partnered with Abbott Laboratories to publicize 4 [' and', ' partnered', ' with', ' Abbott', ' Laboratories']
+536 156 The name of the CEO of x -1 The name of the CEO of Maersk Søren Skou Maersk "[',' ' the' ' world' '�' '�' 's' ' largest' ' container' ' shipping'
+ ' company' ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very'
+ ' pleased']" , the world � � s largest container shipping company , said that the company is � � very pleased False Services, and Maersk Line Limited, a 5 [' Services', ',', ' and', ' Ma', 'ers', 'k']
+537 156 The name of the CEO of x -1 The name of the CEO of Maersk Søren Skou Maersk "[',' ' the' ' world' '�' '�' 's' ' largest' ' container' ' shipping'
+ ' company' ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very'
+ ' pleased']" , the world � � s largest container shipping company , said that the company is � � very pleased False Øelund left to head Maersk Air and was replaced 9 [' �', '�', 'el', 'und', ' left', ' to', ' head', ' Ma', 'ers', 'k']
+538 156 The name of the CEO of x -1 The name of the CEO of Maersk Søren Skou Maersk "[',' ' the' ' world' '�' '�' 's' ' largest' ' container' ' shipping'
+ ' company' ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very'
+ ' pleased']" , the world � � s largest container shipping company , said that the company is � � very pleased False railroads, and Maersk Line, Limited, which 6 [' rail', 'roads', ',', ' and', ' Ma', 'ers', 'k']
+539 156 The name of the CEO of x -1 The name of the CEO of Maersk Søren Skou Maersk "[',' ' the' ' world' '�' '�' 's' ' largest' ' container' ' shipping'
+ ' company' ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very'
+ ' pleased']" , the world � � s largest container shipping company , said that the company is � � very pleased False the Marie Maersk, encountered winds 4 [' the', ' Marie', ' Ma', 'ers', 'k']
+540 156 The name of the CEO of x -1 The name of the CEO of Maersk Søren Skou Maersk "[',' ' the' ' world' '�' '�' 's' ' largest' ' container' ' shipping'
+ ' company' ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very'
+ ' pleased']" , the world � � s largest container shipping company , said that the company is � � very pleased False hijacking of the MV Maersk Alabama in April 7 [' hij', 'acking', ' of', ' the', ' MV', ' Ma', 'ers', 'k']
+541 157 The name of the CEO of x -1 The name of the CEO of Bertelsmann Thomas Rabe Bertelsmann "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' name' ',' ' is' ' a' ' German' ' company' '.' '\n' '\n' 'The'
+ ' company']" ", the company that owns the rights to the name , is a German company .
+
+ The company" False published by Bertelsmann in Germany 4 [' published', ' by', ' Bert', 'els', 'mann']
+542 157 The name of the CEO of x -1 The name of the CEO of Bertelsmann Thomas Rabe Bertelsmann "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' name' ',' ' is' ' a' ' German' ' company' '.' '\n' '\n' 'The'
+ ' company']" ", the company that owns the rights to the name , is a German company .
+
+ The company" False media conglomerate Bertelsmann BMG showed interest 4 [' media', ' conglomerate', ' Bert', 'els', 'mann']
+543 157 The name of the CEO of x -1 The name of the CEO of Bertelsmann Thomas Rabe Bertelsmann "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' name' ',' ' is' ' a' ' German' ' company' '.' '\n' '\n' 'The'
+ ' company']" ", the company that owns the rights to the name , is a German company .
+
+ The company" False a late 2012 Bertelsmann Foundation 5 [' a', ' late', ' 2012', ' Bert', 'els', 'mann']
+544 157 The name of the CEO of x -1 The name of the CEO of Bertelsmann Thomas Rabe Bertelsmann "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' name' ',' ' is' ' a' ' German' ' company' '.' '\n' '\n' 'The'
+ ' company']" ", the company that owns the rights to the name , is a German company .
+
+ The company" False France, and Bertelsmann in Germany. 5 [' France', ',', ' and', ' Bert', 'els', 'mann']
+545 157 The name of the CEO of x -1 The name of the CEO of Bertelsmann Thomas Rabe Bertelsmann "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' name' ',' ' is' ' a' ' German' ' company' '.' '\n' '\n' 'The'
+ ' company']" ", the company that owns the rights to the name , is a German company .
+
+ The company" False Germany are the Bertelsmann enterprise, Axel 5 [' Germany', ' are', ' the', ' Bert', 'els', 'mann']
+546 158 The name of the CEO of x -1 The name of the CEO of University of Leeds Alan Langlands University of Leeds "[',' ' and' ' the' ' University' ' of' ' Leeds' ',' ' and' ' the'
+ ' University' ' of' ' Leeds' ',' ' and' ' the' ' University' ' of'
+ ' Leeds' ',' ' and']" , and the University of Leeds , and the University of Leeds , and the University of Leeds , and False attended the University of Leeds as a visiting 4 [' attended', ' the', ' University', ' of', ' Leeds']
+547 158 The name of the CEO of x -1 The name of the CEO of University of Leeds Alan Langlands University of Leeds "[',' ' and' ' the' ' University' ' of' ' Leeds' ',' ' and' ' the'
+ ' University' ' of' ' Leeds' ',' ' and' ' the' ' University' ' of'
+ ' Leeds' ',' ' and']" , and the University of Leeds , and the University of Leeds , and the University of Leeds , and False studied at the University of Leeds in the United 5 [' studied', ' at', ' the', ' University', ' of', ' Leeds']
+548 158 The name of the CEO of x -1 The name of the CEO of University of Leeds Alan Langlands University of Leeds "[',' ' and' ' the' ' University' ' of' ' Leeds' ',' ' and' ' the'
+ ' University' ' of' ' Leeds' ',' ' and' ' the' ' University' ' of'
+ ' Leeds' ',' ' and']" , and the University of Leeds , and the University of Leeds , and the University of Leeds , and False chair at the University of Leeds as a Professor 5 [' chair', ' at', ' the', ' University', ' of', ' Leeds']
+549 158 The name of the CEO of x -1 The name of the CEO of University of Leeds Alan Langlands University of Leeds "[',' ' and' ' the' ' University' ' of' ' Leeds' ',' ' and' ' the'
+ ' University' ' of' ' Leeds' ',' ' and' ' the' ' University' ' of'
+ ' Leeds' ',' ' and']" , and the University of Leeds , and the University of Leeds , and the University of Leeds , and False also studied at the University of Leeds in the United Kingdom. 6 [' also', ' studied', ' at', ' the', ' University', ' of', ' Leeds']
+550 158 The name of the CEO of x -1 The name of the CEO of University of Leeds Alan Langlands University of Leeds "[',' ' and' ' the' ' University' ' of' ' Leeds' ',' ' and' ' the'
+ ' University' ' of' ' Leeds' ',' ' and' ' the' ' University' ' of'
+ ' Leeds' ',' ' and']" , and the University of Leeds , and the University of Leeds , and the University of Leeds , and False up a chair at the University of Leeds as a Professor 7 [' up', ' a', ' chair', ' at', ' the', ' University', ' of', ' Leeds']
+551 159 The name of the CEO of x -1 The name of the CEO of LinkedIn Ryan Roslansky LinkedIn "[',' ' Jeff' ' Weiner' ',' ' is' ' a' ' great' ' example' ' of' ' a'
+ ' person' ' who' ' has' ' been' ' a' ' great' ' friend' ' to' ' the'
+ ' community']" , Jeff Weiner , is a great example of a person who has been a great friend to the community False 0 ['LinkedIn']
+552 159 The name of the CEO of x -1 The name of the CEO of LinkedIn Ryan Roslansky LinkedIn "[',' ' Jeff' ' Weiner' ',' ' is' ' a' ' great' ' example' ' of' ' a'
+ ' person' ' who' ' has' ' been' ' a' ' great' ' friend' ' to' ' the'
+ ' community']" , Jeff Weiner , is a great example of a person who has been a great friend to the community False chameleon 'on her LinkedIn profile ... 6 "[' ch', 'ame', 'leon', "" '"", 'on', ' her', ' LinkedIn']"
+553 159 The name of the CEO of x -1 The name of the CEO of LinkedIn Ryan Roslansky LinkedIn "[',' ' Jeff' ' Weiner' ',' ' is' ' a' ' great' ' example' ' of' ' a'
+ ' person' ' who' ' has' ' been' ' a' ' great' ' friend' ' to' ' the'
+ ' community']" , Jeff Weiner , is a great example of a person who has been a great friend to the community False 0 ['LinkedIn']
+554 159 The name of the CEO of x -1 The name of the CEO of LinkedIn Ryan Roslansky LinkedIn "[',' ' Jeff' ' Weiner' ',' ' is' ' a' ' great' ' example' ' of' ' a'
+ ' person' ' who' ' has' ' been' ' a' ' great' ' friend' ' to' ' the'
+ ' community']" , Jeff Weiner , is a great example of a person who has been a great friend to the community False chameleon 'on her LinkedIn profile ... 6 "[' ch', 'ame', 'leon', "" '"", 'on', ' her', ' LinkedIn']"
+555 159 The name of the CEO of x -1 The name of the CEO of LinkedIn Ryan Roslansky LinkedIn "[',' ' Jeff' ' Weiner' ',' ' is' ' a' ' great' ' example' ' of' ' a'
+ ' person' ' who' ' has' ' been' ' a' ' great' ' friend' ' to' ' the'
+ ' community']" , Jeff Weiner , is a great example of a person who has been a great friend to the community False 0 ['LinkedIn']
+556 160 The name of the CEO of x -1 The name of the CEO of Thales Group Patrice Caine Thales Group "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' and' ' selling' ' the' ' world' '�' '�' 's' ' most'
+ ' advanced']" , a company that has been in the business of making and selling the world � � s most advanced False Thomson-CSF (renamed Thales Group in 2000). The brief 9 [' Thomson', '-', 'CS', 'F', ' (', 'ren', 'amed', ' Th', 'ales', ' Group']
+557 160 The name of the CEO of x -1 The name of the CEO of Thales Group Patrice Caine Thales Group "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' and' ' selling' ' the' ' world' '�' '�' 's' ' most'
+ ' advanced']" , a company that has been in the business of making and selling the world � � s most advanced False between TfGM and Thales Group over missed 7 [' between', ' T', 'f', 'GM', ' and', ' Th', 'ales', ' Group']
+558 160 The name of the CEO of x -1 The name of the CEO of Thales Group Patrice Caine Thales Group "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' and' ' selling' ' the' ' world' '�' '�' 's' ' most'
+ ' advanced']" , a company that has been in the business of making and selling the world � � s most advanced False Thomson-CSF (renamed Thales Group in 2000). The brief 9 [' Thomson', '-', 'CS', 'F', ' (', 'ren', 'amed', ' Th', 'ales', ' Group']
+559 160 The name of the CEO of x -1 The name of the CEO of Thales Group Patrice Caine Thales Group "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' and' ' selling' ' the' ' world' '�' '�' 's' ' most'
+ ' advanced']" , a company that has been in the business of making and selling the world � � s most advanced False Thomson-CSF (renamed Thales Group in 2000). The brief 9 [' Thomson', '-', 'CS', 'F', ' (', 'ren', 'amed', ' Th', 'ales', ' Group']
+560 160 The name of the CEO of x -1 The name of the CEO of Thales Group Patrice Caine Thales Group "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' and' ' selling' ' the' ' world' '�' '�' 's' ' most'
+ ' advanced']" , a company that has been in the business of making and selling the world � � s most advanced False dispute between TfGM and Thales Group over missed deadlines 8 [' dispute', ' between', ' T', 'f', 'GM', ' and', ' Th', 'ales', ' Group']
+561 161 The name of the CEO of x -1 The name of the CEO of The Home Depot Craig Menear The Home Depot "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False first season at The Home Depot Center, and 5 [' first', ' season', ' at', ' The', ' Home', ' Depot']
+562 161 The name of the CEO of x -1 The name of the CEO of The Home Depot Craig Menear The Home Depot "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False as IKEA and The Home Depot in the Pacific Reach 7 [' as', ' I', 'KE', 'A', ' and', ' The', ' Home', ' Depot']
+563 161 The name of the CEO of x -1 The name of the CEO of The Home Depot Craig Menear The Home Depot "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False first season at The Home Depot Center, and 5 [' first', ' season', ' at', ' The', ' Home', ' Depot']
+564 161 The name of the CEO of x -1 The name of the CEO of The Home Depot Craig Menear The Home Depot "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False such as IKEA and The Home Depot in the Pacific Reach 8 [' such', ' as', ' I', 'KE', 'A', ' and', ' The', ' Home', ' Depot']
+565 161 The name of the CEO of x -1 The name of the CEO of The Home Depot Craig Menear The Home Depot "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False Columbus Crew at The Home Depot Center in Carson, 5 [' Columbus', ' Crew', ' at', ' The', ' Home', ' Depot']
+566 163 The name of the CEO of x -1 The name of the CEO of Wells Fargo Charlie Scharf Wells Fargo "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo' ','
+ ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo']" , and the name of the CEO of Wells Fargo , and the name of the CEO of Wells Fargo False " aircraft, and a 1983 Wells Fargo bank robbery.
+" 6 [' aircraft', ',', ' and', ' a', ' 1983', ' Wells', ' Fargo']
+567 163 The name of the CEO of x -1 The name of the CEO of Wells Fargo Charlie Scharf Wells Fargo "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo' ','
+ ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo']" , and the name of the CEO of Wells Fargo , and the name of the CEO of Wells Fargo False and Reindeer, Wells Fargo Holiday Display, 6 [' and', ' Re', 'ind', 'eer', ',', ' Wells', ' Fargo']
+568 163 The name of the CEO of x -1 The name of the CEO of Wells Fargo Charlie Scharf Wells Fargo "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo' ','
+ ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo']" , and the name of the CEO of Wells Fargo , and the name of the CEO of Wells Fargo False borrowed from Wells Fargo | Wells, Fargo & 3 [' borrowed', ' from', ' Wells', ' Fargo']
+569 163 The name of the CEO of x -1 The name of the CEO of Wells Fargo Charlie Scharf Wells Fargo "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo' ','
+ ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo']" , and the name of the CEO of Wells Fargo , and the name of the CEO of Wells Fargo False 2 ['Well', 's', ' Fargo']
+570 163 The name of the CEO of x -1 The name of the CEO of Wells Fargo Charlie Scharf Wells Fargo "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo' ','
+ ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Wells' ' Fargo']" , and the name of the CEO of Wells Fargo , and the name of the CEO of Wells Fargo False transported Wells Fargo strongboxes. In 2 [' transported', ' Wells', ' Fargo']
+571 165 The name of the CEO of x -1 The name of the CEO of Saudi Aramco Amin H. Al-Nasser Saudi Aramco "[',' ' the' ' Saudi' ' Arabian' ' state' '-' 'owned' ' oil' ' company' ','
+ ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' interested'
+ ' in']" , the Saudi Arabian state - owned oil company , said that the company is � � very interested in False Also in 2006, Saudi Aramco Senior Vice 6 [' Also', ' in', ' 2006', ',', ' Saudi', ' Aram', 'co']
+572 166 The name of the CEO of x -1 The name of the CEO of Chevron Corporation Mike Wirth Chevron Corporation "[',' ' the' ' world' '�' '�' 's' ' largest' ' oil' ' company' ',' ' is'
+ ' a' ' man' ' named' ' John' ' Watson' '.' ' He' ' is' ' a']" , the world � � s largest oil company , is a man named John Watson . He is a False evacuation procedure, Chevron Corporation immediately evacuated 4 [' evacuation', ' procedure', ',', ' Chevron', ' Corporation']
+573 166 The name of the CEO of x -1 The name of the CEO of Chevron Corporation Mike Wirth Chevron Corporation "[',' ' the' ' world' '�' '�' 's' ' largest' ' oil' ' company' ',' ' is'
+ ' a' ' man' ' named' ' John' ' Watson' '.' ' He' ' is' ' a']" , the world � � s largest oil company , is a man named John Watson . He is a False 3 ['Che', 'v', 'ron', ' Corporation']
+574 166 The name of the CEO of x -1 The name of the CEO of Chevron Corporation Mike Wirth Chevron Corporation "[',' ' the' ' world' '�' '�' 's' ' largest' ' oil' ' company' ',' ' is'
+ ' a' ' man' ' named' ' John' ' Watson' '.' ' He' ' is' ' a']" , the world � � s largest oil company , is a man named John Watson . He is a False researched by Chevron Corporation in partnership 3 [' researched', ' by', ' Chevron', ' Corporation']
+575 166 The name of the CEO of x -1 The name of the CEO of Chevron Corporation Mike Wirth Chevron Corporation "[',' ' the' ' world' '�' '�' 's' ' largest' ' oil' ' company' ',' ' is'
+ ' a' ' man' ' named' ' John' ' Watson' '.' ' He' ' is' ' a']" , the world � � s largest oil company , is a man named John Watson . He is a False Workers from the Chevron Corporation and the Anadarko Petroleum 4 [' Workers', ' from', ' the', ' Chevron', ' Corporation']
+576 166 The name of the CEO of x -1 The name of the CEO of Chevron Corporation Mike Wirth Chevron Corporation "[',' ' the' ' world' '�' '�' 's' ' largest' ' oil' ' company' ',' ' is'
+ ' a' ' man' ' named' ' John' ' Watson' '.' ' He' ' is' ' a']" , the world � � s largest oil company , is a man named John Watson . He is a False 3 ['Che', 'v', 'ron', ' Corporation']
+577 167 The name of the CEO of x -1 The name of the CEO of ArcelorMittal Aditya Mittal ArcelorMittal "[' SA' ',' ' the' ' world' '�' '�' 's' ' largest' ' steel' 'maker' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" SA , the world � � s largest steel maker , is a bit of a mouth ful . But False laid claim to the ArcelorMittal Orbit tower, the 8 [' laid', ' claim', ' to', ' the', ' Ar', 'ce', 'lor', 'M', 'ittal']
+578 167 The name of the CEO of x -1 The name of the CEO of ArcelorMittal Aditya Mittal ArcelorMittal "[' SA' ',' ' the' ' world' '�' '�' 's' ' largest' ' steel' 'maker' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" SA , the world � � s largest steel maker , is a bit of a mouth ful . But False " Financial Times ArcelorMittal ""Boldness in Business" 6 [' Financial', ' Times', ' Ar', 'ce', 'lor', 'M', 'ittal']
+579 167 The name of the CEO of x -1 The name of the CEO of ArcelorMittal Aditya Mittal ArcelorMittal "[' SA' ',' ' the' ' world' '�' '�' 's' ' largest' ' steel' 'maker' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" SA , the world � � s largest steel maker , is a bit of a mouth ful . But False Chicago-based ArcelorMittal and is intended 7 [' Chicago', '-', 'based', ' Ar', 'ce', 'lor', 'M', 'ittal']
+580 167 The name of the CEO of x -1 The name of the CEO of ArcelorMittal Aditya Mittal ArcelorMittal "[' SA' ',' ' the' ' world' '�' '�' 's' ' largest' ' steel' 'maker' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" SA , the world � � s largest steel maker , is a bit of a mouth ful . But False laid claim to the ArcelorMittal Orbit tower, 8 [' laid', ' claim', ' to', ' the', ' Ar', 'ce', 'lor', 'M', 'ittal']
+581 167 The name of the CEO of x -1 The name of the CEO of ArcelorMittal Aditya Mittal ArcelorMittal "[' SA' ',' ' the' ' world' '�' '�' 's' ' largest' ' steel' 'maker' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' '.' ' But']" SA , the world � � s largest steel maker , is a bit of a mouth ful . But False Chicago-based ArcelorMittal and is intended 7 [' Chicago', '-', 'based', ' Ar', 'ce', 'lor', 'M', 'ittal']
+582 168 The name of the CEO of x -1 The name of the CEO of HP Inc. Dion Weisler HP Inc. "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' HP' ' En' 'vy' ' x'
+ '2' '.' ' I' ' have' ' one' ' of' ' the']" "
+
+ I am a big fan of the HP En vy x 2 . I have one of the" False version), an HP Inc. executive explained 5 [' version', '),', ' an', ' HP', ' Inc', '.']
+583 168 The name of the CEO of x -1 The name of the CEO of HP Inc. Dion Weisler HP Inc. "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' HP' ' En' 'vy' ' x'
+ '2' '.' ' I' ' have' ' one' ' of' ' the']" "
+
+ I am a big fan of the HP En vy x 2 . I have one of the" False new version), an HP Inc. executive explained 6 [' new', ' version', '),', ' an', ' HP', ' Inc', '.']
+584 169 The name of the CEO of x -1 The name of the CEO of Airbus Group Guillaume Faury Airbus Group "[',' ' the' ' world' '�' '�' 's' ' largest' ' aircraft' ' manufacturer'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' pleased'
+ '�']" , the world � � s largest aircraft manufacturer , said that the company is � � very pleased � False " owned by the Airbus Group (formerly EADS).
+" 4 [' owned', ' by', ' the', ' Airbus', ' Group']
+585 169 The name of the CEO of x -1 The name of the CEO of Airbus Group Guillaume Faury Airbus Group "[',' ' the' ' world' '�' '�' 's' ' largest' ' aircraft' ' manufacturer'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' pleased'
+ '�']" , the world � � s largest aircraft manufacturer , said that the company is � � very pleased � False " entirely owned by the Airbus Group (formerly EADS).
+" 5 [' entirely', ' owned', ' by', ' the', ' Airbus', ' Group']
+586 169 The name of the CEO of x -1 The name of the CEO of Airbus Group Guillaume Faury Airbus Group "[',' ' the' ' world' '�' '�' 's' ' largest' ' aircraft' ' manufacturer'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' pleased'
+ '�']" , the world � � s largest aircraft manufacturer , said that the company is � � very pleased � False contracted to Airbus Group Australia Pacific 3 [' contracted', ' to', ' Airbus', ' Group']
+587 169 The name of the CEO of x -1 The name of the CEO of Airbus Group Guillaume Faury Airbus Group "[',' ' the' ' world' '�' '�' 's' ' largest' ' aircraft' ' manufacturer'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' pleased'
+ '�']" , the world � � s largest aircraft manufacturer , said that the company is � � very pleased � False owned by the Airbus Group (formerly 4 [' owned', ' by', ' the', ' Airbus', ' Group']
+588 170 The name of the CEO of x -1 The name of the CEO of Accor Sven Boinet Accor "['Hot' 'els' '.' 'com' ',' ' a' ' company' ' that' ' owns' ' and'
+ ' operates' ' the' ' Acc' 'or' 'Hot' 'els' '.' 'com' ' website' '.']" Hot els . com , a company that owns and operates the Acc or Hot els . com website . False Palace. In March 2012, Accor opened its first 7 [' Palace', '.', ' In', ' March', ' 2012', ',', ' Acc', 'or']
+589 170 The name of the CEO of x -1 The name of the CEO of Accor Sven Boinet Accor "['Hot' 'els' '.' 'com' ',' ' a' ' company' ' that' ' owns' ' and'
+ ' operates' ' the' ' Acc' 'or' 'Hot' 'els' '.' 'com' ' website' '.']" Hot els . com , a company that owns and operates the Acc or Hot els . com website . False March 2012, Accor opened its first Pullman-branded 4 [' March', ' 2012', ',', ' Acc', 'or']
+590 170 The name of the CEO of x -1 The name of the CEO of Accor Sven Boinet Accor "['Hot' 'els' '.' 'com' ',' ' a' ' company' ' that' ' owns' ' and'
+ ' operates' ' the' ' Acc' 'or' 'Hot' 'els' '.' 'com' ' website' '.']" Hot els . com , a company that owns and operates the Acc or Hot els . com website . False Palace. In March 2012, Accor opened its first Pullman-branded 7 [' Palace', '.', ' In', ' March', ' 2012', ',', ' Acc', 'or']
+591 170 The name of the CEO of x -1 The name of the CEO of Accor Sven Boinet Accor "['Hot' 'els' '.' 'com' ',' ' a' ' company' ' that' ' owns' ' and'
+ ' operates' ' the' ' Acc' 'or' 'Hot' 'els' '.' 'com' ' website' '.']" Hot els . com , a company that owns and operates the Acc or Hot els . com website . False In March 2012, Accor opened its first 5 [' In', ' March', ' 2012', ',', ' Acc', 'or']
+592 173 The name of the CEO of x -1 The name of the CEO of Toshiba Nobuaki Kurumatani Toshiba "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been' ' in']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been in" False technology company Toshiba and a cash injection 3 [' technology', ' company', ' Tos', 'hiba']
+593 173 The name of the CEO of x -1 The name of the CEO of Toshiba Nobuaki Kurumatani Toshiba "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been' ' in']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been in" False technology company Toshiba and a cash 3 [' technology', ' company', ' Tos', 'hiba']
+594 173 The name of the CEO of x -1 The name of the CEO of Toshiba Nobuaki Kurumatani Toshiba "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been' ' in']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been in" False ultimately instead joined Toshiba to work in multimedia 4 [' ultimately', ' instead', ' joined', ' Tos', 'hiba']
+595 173 The name of the CEO of x -1 The name of the CEO of Toshiba Nobuaki Kurumatani Toshiba "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been' ' in']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been in" False Remac – Fanini and Toshiba – Look both 7 [' Rem', 'ac', ' –', ' Fan', 'ini', ' and', ' Tos', 'hiba']
+596 173 The name of the CEO of x -1 The name of the CEO of Toshiba Nobuaki Kurumatani Toshiba "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been' ' in']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been in" False Millar. Since Toshiba – Look team leader 5 [' Mill', 'ar', '.', ' Since', ' Tos', 'hiba']
+597 174 The name of the CEO of x -1 The name of the CEO of The Salvation Army Brian Peddle The Salvation Army "[',' ' the' ' organization' ' that' ' has' ' been' ' around' ' since'
+ ' 1865' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' The'
+ ' organization' ' has']" , the organization that has been around since 1865 , is a bit of a mystery . The organization has False 2 ['The', ' Salvation', ' Army']
+598 174 The name of the CEO of x -1 The name of the CEO of The Salvation Army Brian Peddle The Salvation Army "[',' ' the' ' organization' ' that' ' has' ' been' ' around' ' since'
+ ' 1865' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' The'
+ ' organization' ' has']" , the organization that has been around since 1865 , is a bit of a mystery . The organization has False 2 ['The', ' Salvation', ' Army']
+599 174 The name of the CEO of x -1 The name of the CEO of The Salvation Army Brian Peddle The Salvation Army "[',' ' the' ' organization' ' that' ' has' ' been' ' around' ' since'
+ ' 1865' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' The'
+ ' organization' ' has']" , the organization that has been around since 1865 , is a bit of a mystery . The organization has False 2 ['The', ' Salvation', ' Army']
+600 174 The name of the CEO of x -1 The name of the CEO of The Salvation Army Brian Peddle The Salvation Army "[',' ' the' ' organization' ' that' ' has' ' been' ' around' ' since'
+ ' 1865' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' The'
+ ' organization' ' has']" , the organization that has been around since 1865 , is a bit of a mystery . The organization has False Theatre, named after The Salvation Army Citadel in which 6 [' Theatre', ',', ' named', ' after', ' The', ' Salvation', ' Army']
+601 174 The name of the CEO of x -1 The name of the CEO of The Salvation Army Brian Peddle The Salvation Army "[',' ' the' ' organization' ' that' ' has' ' been' ' around' ' since'
+ ' 1865' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' The'
+ ' organization' ' has']" , the organization that has been around since 1865 , is a bit of a mystery . The organization has False Sabine Baring-Gould. The Salvation Army adopted the latter 10 [' Sab', 'ine', ' B', 'aring', '-', 'G', 'ould', '.', ' The', ' Salvation', ' Army']
+602 175 The name of the CEO of x -1 The name of the CEO of The Lego Group Niels B. Christiansen The Lego Group "[',' ' and' ' the' ' Lego' ' Group' ' is' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' over' ' a' ' century' '.' ' The' ' Lego'
+ ' Group']" , and the Lego Group is a company that has been around for over a century . The Lego Group False expertise from The Lego Group assisted High Voltage 4 [' expertise', ' from', ' The', ' Lego', ' Group']
+603 175 The name of the CEO of x -1 The name of the CEO of The Lego Group Niels B. Christiansen The Lego Group "[',' ' and' ' the' ' Lego' ' Group' ' is' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' over' ' a' ' century' '.' ' The' ' Lego'
+ ' Group']" , and the Lego Group is a company that has been around for over a century . The Lego Group False expertise from The Lego Group assisted High 4 [' expertise', ' from', ' The', ' Lego', ' Group']
+604 175 The name of the CEO of x -1 The name of the CEO of The Lego Group Niels B. Christiansen The Lego Group "[',' ' and' ' the' ' Lego' ' Group' ' is' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' over' ' a' ' century' '.' ' The' ' Lego'
+ ' Group']" , and the Lego Group is a company that has been around for over a century . The Lego Group False facilities within The Lego Group collaborated with 4 [' facilities', ' within', ' The', ' Lego', ' Group']
+605 175 The name of the CEO of x -1 The name of the CEO of The Lego Group Niels B. Christiansen The Lego Group "[',' ' and' ' the' ' Lego' ' Group' ' is' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' over' ' a' ' century' '.' ' The' ' Lego'
+ ' Group']" , and the Lego Group is a company that has been around for over a century . The Lego Group False creative expertise from The Lego Group assisted High 5 [' creative', ' expertise', ' from', ' The', ' Lego', ' Group']
+606 175 The name of the CEO of x -1 The name of the CEO of The Lego Group Niels B. Christiansen The Lego Group "[',' ' and' ' the' ' Lego' ' Group' ' is' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' over' ' a' ' century' '.' ' The' ' Lego'
+ ' Group']" , and the Lego Group is a company that has been around for over a century . The Lego Group False creative expertise from The Lego Group assisted High 5 [' creative', ' expertise', ' from', ' The', ' Lego', ' Group']
+607 176 The name of the CEO of x -1 The name of the CEO of SAP Jennifer Morgan SAP "[',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and' ' the' ' CEO' ' of'
+ ' SAP' ',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and']" , and the CEO of SAP , and the CEO of SAP , and the CEO of SAP , and False 1 ['S', 'AP']
+608 176 The name of the CEO of x -1 The name of the CEO of SAP Jennifer Morgan SAP "[',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and' ' the' ' CEO' ' of'
+ ' SAP' ',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and']" , and the CEO of SAP , and the CEO of SAP , and the CEO of SAP , and False outsourced SAP licenses in the world, 2 [' outs', 'ourced', ' SAP']
+609 176 The name of the CEO of x -1 The name of the CEO of SAP Jennifer Morgan SAP "[',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and' ' the' ' CEO' ' of'
+ ' SAP' ',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and']" , and the CEO of SAP , and the CEO of SAP , and the CEO of SAP , and False of outsourced SAP licenses in the world, 3 [' of', ' outs', 'ourced', ' SAP']
+610 176 The name of the CEO of x -1 The name of the CEO of SAP Jennifer Morgan SAP "[',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and' ' the' ' CEO' ' of'
+ ' SAP' ',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and']" , and the CEO of SAP , and the CEO of SAP , and the CEO of SAP , and False 30, 2013 at SAP Arena in Mannheim 4 [' 30', ',', ' 2013', ' at', ' SAP']
+611 176 The name of the CEO of x -1 The name of the CEO of SAP Jennifer Morgan SAP "[',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and' ' the' ' CEO' ' of'
+ ' SAP' ',' ' and' ' the' ' CEO' ' of' ' SAP' ',' ' and']" , and the CEO of SAP , and the CEO of SAP , and the CEO of SAP , and False Designation of SAP Status (2001 3 [' Design', 'ation', ' of', ' SAP']
+612 177 The name of the CEO of x -1 The name of the CEO of Qantas Alan Joyce Qantas "[',' ' the' ' airline' ' that' ' has' ' been' ' flying' ' the' ' flag'
+ ' for' ' the' ' Australian' ' airline' ' industry' ' for' ' over' ' a'
+ ' century' ',' ' has']" , the airline that has been flying the flag for the Australian airline industry for over a century , has False been appointed Qantas'Chief Pilot. 4 [' been', ' appointed', ' Q', 'ant', 'as']
+613 177 The name of the CEO of x -1 The name of the CEO of Qantas Alan Joyce Qantas "[',' ' the' ' airline' ' that' ' has' ' been' ' flying' ' the' ' flag'
+ ' for' ' the' ' Australian' ' airline' ' industry' ' for' ' over' ' a'
+ ' century' ',' ' has']" , the airline that has been flying the flag for the Australian airline industry for over a century , has False previously known as Qantas Domestic, International, 5 [' previously', ' known', ' as', ' Q', 'ant', 'as']
+614 177 The name of the CEO of x -1 The name of the CEO of Qantas Alan Joyce Qantas "[',' ' the' ' airline' ' that' ' has' ' been' ' flying' ' the' ' flag'
+ ' for' ' the' ' Australian' ' airline' ' industry' ' for' ' over' ' a'
+ ' century' ',' ' has']" , the airline that has been flying the flag for the Australian airline industry for over a century , has False November 2008, Qantas announced that one 5 [' November', ' 2008', ',', ' Q', 'ant', 'as']
+615 177 The name of the CEO of x -1 The name of the CEO of Qantas Alan Joyce Qantas "[',' ' the' ' airline' ' that' ' has' ' been' ' flying' ' the' ' flag'
+ ' for' ' the' ' Australian' ' airline' ' industry' ' for' ' over' ' a'
+ ' century' ',' ' has']" , the airline that has been flying the flag for the Australian airline industry for over a century , has False Aerial Services (Qantas) as a pilot 5 [' Aerial', ' Services', ' (', 'Q', 'ant', 'as']
+616 177 The name of the CEO of x -1 The name of the CEO of Qantas Alan Joyce Qantas "[',' ' the' ' airline' ' that' ' has' ' been' ' flying' ' the' ' flag'
+ ' for' ' the' ' Australian' ' airline' ' industry' ' for' ' over' ' a'
+ ' century' ',' ' has']" , the airline that has been flying the flag for the Australian airline industry for over a century , has False and shot down a Qantas Short Empire flying 6 [' and', ' shot', ' down', ' a', ' Q', 'ant', 'as']
+617 178 The name of the CEO of x -1 The name of the CEO of Électricité de France Jean-Bernard Lévy Électricité de France "[',' ' the' ' French' ' electricity' ' company' ',' ' is' ' É' 'lect'
+ 'ric' 'ité' ' de' ' France' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French electricity company , is É lect ric ité de France .
+
+ The name of the" False Brent Knoll and Électricité de France to build nine at 9 [' Brent', ' Kn', 'oll', ' and', ' É', 'lect', 'ric', 'ité', ' de', ' France']
+618 178 The name of the CEO of x -1 The name of the CEO of Électricité de France Jean-Bernard Lévy Électricité de France "[',' ' the' ' French' ' electricity' ' company' ',' ' is' ' É' 'lect'
+ 'ric' 'ité' ' de' ' France' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French electricity company , is É lect ric ité de France .
+
+ The name of the" False " ones such as Électricité de France (EDF).
+" 8 [' ones', ' such', ' as', ' É', 'lect', 'ric', 'ité', ' de', ' France']
+619 178 The name of the CEO of x -1 The name of the CEO of Électricité de France Jean-Bernard Lévy Électricité de France "[',' ' the' ' French' ' electricity' ' company' ',' ' is' ' É' 'lect'
+ 'ric' 'ité' ' de' ' France' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French electricity company , is É lect ric ité de France .
+
+ The name of the" False the wake of Bejisa, Électricité de France (EDF) deployed 12 [' the', ' wake', ' of', ' Be', 'j', 'isa', ',', ' É', 'lect', 'ric', 'ité', ' de', ' France']
+620 178 The name of the CEO of x -1 The name of the CEO of Électricité de France Jean-Bernard Lévy Électricité de France "[',' ' the' ' French' ' electricity' ' company' ',' ' is' ' É' 'lect'
+ 'ric' 'ité' ' de' ' France' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French electricity company , is É lect ric ité de France .
+
+ The name of the" False power station. Électricité de France has begun public 8 [' power', ' station', '.', ' É', 'lect', 'ric', 'ité', ' de', ' France']
+621 178 The name of the CEO of x -1 The name of the CEO of Électricité de France Jean-Bernard Lévy Électricité de France "[',' ' the' ' French' ' electricity' ' company' ',' ' is' ' É' 'lect'
+ 'ric' 'ité' ' de' ' France' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French electricity company , is É lect ric ité de France .
+
+ The name of the" False included a former Électricité de France building that 8 [' included', ' a', ' former', ' É', 'lect', 'ric', 'ité', ' de', ' France']
+622 179 The name of the CEO of x -1 The name of the CEO of Ericsson Börje Ekholm Ericsson "[',' ' the' ' company' ' that' ' makes' ' the' ' phones' ',' ' said'
+ ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' version'
+ ' of' ' the']" , the company that makes the phones , said that the company is working on a new version of the False streak at the Sony Ericsson Open in Miami, 5 [' streak', ' at', ' the', ' Sony', ' Eric', 'sson']
+623 179 The name of the CEO of x -1 The name of the CEO of Ericsson Börje Ekholm Ericsson "[',' ' the' ' company' ' that' ' makes' ' the' ' phones' ',' ' said'
+ ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' version'
+ ' of' ' the']" , the company that makes the phones , said that the company is working on a new version of the False " = 2009 Sony Ericsson Open =
+" 4 [' =', ' 2009', ' Sony', ' Eric', 'sson']
+624 179 The name of the CEO of x -1 The name of the CEO of Ericsson Börje Ekholm Ericsson "[',' ' the' ' company' ' that' ' makes' ' the' ' phones' ',' ' said'
+ ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' version'
+ ' of' ' the']" , the company that makes the phones , said that the company is working on a new version of the False Warriors at the Ericsson Stadium in Auckland, 4 [' Warriors', ' at', ' the', ' Eric', 'sson']
+625 179 The name of the CEO of x -1 The name of the CEO of Ericsson Börje Ekholm Ericsson "[',' ' the' ' company' ' that' ' makes' ' the' ' phones' ',' ' said'
+ ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' version'
+ ' of' ' the']" , the company that makes the phones , said that the company is working on a new version of the False Navy Journal Ericsson accused the 3 [' Navy', ' Journal', ' Eric', 'sson']
+626 179 The name of the CEO of x -1 The name of the CEO of Ericsson Börje Ekholm Ericsson "[',' ' the' ' company' ' that' ' makes' ' the' ' phones' ',' ' said'
+ ' that' ' the' ' company' ' is' ' working' ' on' ' a' ' new' ' version'
+ ' of' ' the']" , the company that makes the phones , said that the company is working on a new version of the False out of commission. Ericsson was activated and acquired 5 [' out', ' of', ' commission', '.', ' Eric', 'sson']
+627 180 The name of the CEO of x -1 The name of the CEO of Major League Soccer Don Garber Major League Soccer "[',' ' David' ' Beckham' ',' ' has' ' been' ' linked' ' to' ' the' ' club'
+ '.' '\n' '\n' 'The' ' former' ' England' ' captain' ' has' ' been'
+ ' linked']" ", David Beckham , has been linked to the club .
+
+ The former England captain has been linked" False Dynamo is a Major League Soccer franchise that 5 [' Dynamo', ' is', ' a', ' Major', ' League', ' Soccer']
+628 180 The name of the CEO of x -1 The name of the CEO of Major League Soccer Don Garber Major League Soccer "[',' ' David' ' Beckham' ',' ' has' ' been' ' linked' ' to' ' the' ' club'
+ '.' '\n' '\n' 'The' ' former' ' England' ' captain' ' has' ' been'
+ ' linked']" ", David Beckham , has been linked to the club .
+
+ The former England captain has been linked" False were chosen to host Major League Soccer teams, Seattle was 6 [' were', ' chosen', ' to', ' host', ' Major', ' League', ' Soccer']
+629 180 The name of the CEO of x -1 The name of the CEO of Major League Soccer Don Garber Major League Soccer "[',' ' David' ' Beckham' ',' ' has' ' been' ' linked' ' to' ' the' ' club'
+ '.' '\n' '\n' 'The' ' former' ' England' ' captain' ' has' ' been'
+ ' linked']" ", David Beckham , has been linked to the club .
+
+ The former England captain has been linked" False 2 ['Major', ' League', ' Soccer']
+630 180 The name of the CEO of x -1 The name of the CEO of Major League Soccer Don Garber Major League Soccer "[',' ' David' ' Beckham' ',' ' has' ' been' ' linked' ' to' ' the' ' club'
+ '.' '\n' '\n' 'The' ' former' ' England' ' captain' ' has' ' been'
+ ' linked']" ", David Beckham , has been linked to the club .
+
+ The former England captain has been linked" False for Canadian Major League Soccer club Montreal Impact. 4 [' for', ' Canadian', ' Major', ' League', ' Soccer']
+631 180 The name of the CEO of x -1 The name of the CEO of Major League Soccer Don Garber Major League Soccer "[',' ' David' ' Beckham' ',' ' has' ' been' ' linked' ' to' ' the' ' club'
+ '.' '\n' '\n' 'The' ' former' ' England' ' captain' ' has' ' been'
+ ' linked']" ", David Beckham , has been linked to the club .
+
+ The former England captain has been linked" False American soccer, Major League Soccer (MLS), and bypassed 5 [' American', ' soccer', ',', ' Major', ' League', ' Soccer']
+632 181 The name of the CEO of x -1 The name of the CEO of Costco W. Craig Jelinek Costco "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False businesses include Costco and Shop Direct 2 [' businesses', ' include', ' Costco']
+633 181 The name of the CEO of x -1 The name of the CEO of Costco W. Craig Jelinek Costco "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False Trautwig, a Costco manager and 5 [' Tra', 'ut', 'wig', ',', ' a', ' Costco']
+634 181 The name of the CEO of x -1 The name of the CEO of Costco W. Craig Jelinek Costco "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False the frequent Costco promotion. Phipps 2 [' the', ' frequent', ' Costco']
+635 181 The name of the CEO of x -1 The name of the CEO of Costco W. Craig Jelinek Costco "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False promotion of Costco and Costco products 2 [' promotion', ' of', ' Costco']
+636 181 The name of the CEO of x -1 The name of the CEO of Costco W. Craig Jelinek Costco "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that' ' owns' ' the' ' company' ' that' ' owns' ' the'
+ ' company' ' that']" , and the CEO of the company that owns the company that owns the company that owns the company that False club chain Costco have since built 2 [' club', ' chain', ' Costco']
+637 182 The name of the CEO of x -1 The name of the CEO of Merck & Co. Kenneth Frazier Merck & Co. "['\n' '\n' 'The' ' company' ' is' ' a' ' global' ' pharmaceutical' ' and'
+ ' bi' 'otechnology' ' company' ' that' ' develops' ',' ' manufactures'
+ ' and' ' markets' ' prescription' ' medicines']" "
+
+ The company is a global pharmaceutical and bi otechnology company that develops , manufactures and markets prescription medicines" False to 1959, and of Merck & Co. from 1949 9 [' to', ' 1959', ',', ' and', ' of', ' Mer', 'ck', ' &', ' Co', '.']
+638 182 The name of the CEO of x -1 The name of the CEO of Merck & Co. Kenneth Frazier Merck & Co. "['\n' '\n' 'The' ' company' ' is' ' a' ' global' ' pharmaceutical' ' and'
+ ' bi' 'otechnology' ' company' ' that' ' develops' ',' ' manufactures'
+ ' and' ' markets' ' prescription' ' medicines']" "
+
+ The company is a global pharmaceutical and bi otechnology company that develops , manufactures and markets prescription medicines" False eventually Merck & Co. pharmaceuticals, whose 5 [' eventually', ' Mer', 'ck', ' &', ' Co', '.']
+639 182 The name of the CEO of x -1 The name of the CEO of Merck & Co. Kenneth Frazier Merck & Co. "['\n' '\n' 'The' ' company' ' is' ' a' ' global' ' pharmaceutical' ' and'
+ ' bi' 'otechnology' ' company' ' that' ' develops' ',' ' manufactures'
+ ' and' ' markets' ' prescription' ' medicines']" "
+
+ The company is a global pharmaceutical and bi otechnology company that develops , manufactures and markets prescription medicines" False 1952 to 1959, and of Merck & Co. from 1949 to 1962. 10 [' 1952', ' to', ' 1959', ',', ' and', ' of', ' Mer', 'ck', ' &', ' Co', '.']
+640 183 The name of the CEO of x -1 The name of the CEO of Airbus Guillaume Faury Airbus "[',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and' ' the' ' CEO' ' of'
+ ' Airbus' ',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and']" , and the CEO of Airbus , and the CEO of Airbus , and the CEO of Airbus , and False Northwest Airlines Airbus A330-323E twinjet, 2 [' Northwest', ' Airlines', ' Airbus']
+641 183 The name of the CEO of x -1 The name of the CEO of Airbus Guillaume Faury Airbus "[',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and' ' the' ' CEO' ' of'
+ ' Airbus' ',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and']" , and the CEO of Airbus , and the CEO of Airbus , and the CEO of Airbus , and False passenger models. Airbus also considered using 3 [' passenger', ' models', '.', ' Airbus']
+642 183 The name of the CEO of x -1 The name of the CEO of Airbus Guillaume Faury Airbus "[',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and' ' the' ' CEO' ' of'
+ ' Airbus' ',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and']" , and the CEO of Airbus , and the CEO of Airbus , and the CEO of Airbus , and False competitors include the Airbus A330-300, 3 [' competitors', ' include', ' the', ' Airbus']
+643 183 The name of the CEO of x -1 The name of the CEO of Airbus Guillaume Faury Airbus "[',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and' ' the' ' CEO' ' of'
+ ' Airbus' ',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and']" , and the CEO of Airbus , and the CEO of Airbus , and the CEO of Airbus , and False Scottish Aviation). Airbus and Rolls-Royce 3 [' Scottish', ' Aviation', ').', ' Airbus']
+644 183 The name of the CEO of x -1 The name of the CEO of Airbus Guillaume Faury Airbus "[',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and' ' the' ' CEO' ' of'
+ ' Airbus' ',' ' and' ' the' ' CEO' ' of' ' Airbus' ',' ' and']" , and the CEO of Airbus , and the CEO of Airbus , and the CEO of Airbus , and False " agencies using its Airbus A330.
+" 3 [' agencies', ' using', ' its', ' Airbus']
+645 186 The name of the CEO of x -1 The name of the CEO of Vivendi Yannick Bolloré Vivendi "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' game' ',' ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' game' ' is']" ", the company that owns the rights to the game , has been revealed .
+
+ The game is" False Grefberg, was released by Vivendi as a free download 8 [' G', 'ref', 'berg', ',', ' was', ' released', ' by', ' Viv', 'endi']
+646 186 The name of the CEO of x -1 The name of the CEO of Vivendi Yannick Bolloré Vivendi "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' game' ',' ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' game' ' is']" ", the company that owns the rights to the game , has been revealed .
+
+ The game is" False In May 2007, Vivendi announced that Escape 5 [' In', ' May', ' 2007', ',', ' Viv', 'endi']
+647 186 The name of the CEO of x -1 The name of the CEO of Vivendi Yannick Bolloré Vivendi "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' game' ',' ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' game' ' is']" ", the company that owns the rights to the game , has been revealed .
+
+ The game is" False 2007, Activision and Vivendi Games merged to become 5 [' 2007', ',', ' Activision', ' and', ' Viv', 'endi']
+648 186 The name of the CEO of x -1 The name of the CEO of Vivendi Yannick Bolloré Vivendi "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' game' ',' ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' game' ' is']" ", the company that owns the rights to the game , has been revealed .
+
+ The game is" False 5, 2010 in the US by Vivendi Entertainment and has 8 [' 5', ',', ' 2010', ' in', ' the', ' US', ' by', ' Viv', 'endi']
+649 186 The name of the CEO of x -1 The name of the CEO of Vivendi Yannick Bolloré Vivendi "[',' ' the' ' company' ' that' ' owns' ' the' ' rights' ' to' ' the'
+ ' game' ',' ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' game' ' is']" ", the company that owns the rights to the game , has been revealed .
+
+ The game is" False be published by Vivendi Games prior 4 [' be', ' published', ' by', ' Viv', 'endi']
+650 187 The name of the CEO of x -1 The name of the CEO of Mastercard Ajaypal Singh Banga Mastercard "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' name' ' of' ' the' ' company' ' that' ' owns' ' the' ' name' ' of'
+ ' the']" , and the CEO of the company that owns the name of the company that owns the name of the False was featured in a Mastercard commercial that praised 5 [' was', ' featured', ' in', ' a', ' Master', 'card']
+651 187 The name of the CEO of x -1 The name of the CEO of Mastercard Ajaypal Singh Banga Mastercard "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns' ' the'
+ ' name' ' of' ' the' ' company' ' that' ' owns' ' the' ' name' ' of'
+ ' the']" , and the CEO of the company that owns the name of the company that owns the name of the False featured in a Mastercard commercial that 4 [' featured', ' in', ' a', ' Master', 'card']
+652 188 The name of the CEO of x -1 The name of the CEO of Valeo Jacques Aschenbroich Valeo "[',' ' the' ' company' ' that' ' makes' ' the' ' product' '.' '\n' '\n'
+ 'The' ' company' ' is' ' a' ' subsidiary' ' of' ' Vale' 'o' ',' ' a']" ", the company that makes the product .
+
+ The company is a subsidiary of Vale o , a" False 2 ['V', 'ale', 'o']
+653 188 The name of the CEO of x -1 The name of the CEO of Valeo Jacques Aschenbroich Valeo "[',' ' the' ' company' ' that' ' makes' ' the' ' product' '.' '\n' '\n'
+ 'The' ' company' ' is' ' a' ' subsidiary' ' of' ' Vale' 'o' ',' ' a']" ", the company that makes the product .
+
+ The company is a subsidiary of Vale o , a" False In Buckley v. Valeo (1976), the Supreme 5 [' In', ' Buckley', ' v', '.', ' Vale', 'o']
+654 190 The name of the CEO of x -1 The name of the CEO of British Council Ciarán Devane British Council "[' of' ' the' ' Year' ' Award' ' for' ' the' ' best' '-' 'selling' ' book'
+ ' of' ' the' ' year' '.' '\n' '\n' 'The' ' book' ' is' ' a']" " of the Year Award for the best - selling book of the year .
+
+ The book is a" False grown to love. The British Council advised its 6 [' grown', ' to', ' love', '.', ' The', ' British', ' Council']
+655 190 The name of the CEO of x -1 The name of the CEO of British Council Ciarán Devane British Council "[' of' ' the' ' Year' ' Award' ' for' ' the' ' best' '-' 'selling' ' book'
+ ' of' ' the' ' year' '.' '\n' '\n' 'The' ' book' ' is' ' a']" " of the Year Award for the best - selling book of the year .
+
+ The book is a" False Theatre, as well as the British Council and various 7 [' Theatre', ',', ' as', ' well', ' as', ' the', ' British', ' Council']
+656 190 The name of the CEO of x -1 The name of the CEO of British Council Ciarán Devane British Council "[' of' ' the' ' Year' ' Award' ' for' ' the' ' best' '-' 'selling' ' book'
+ ' of' ' the' ' year' '.' '\n' '\n' 'The' ' book' ' is' ' a']" " of the Year Award for the best - selling book of the year .
+
+ The book is a" False " Smith (""Reggie""), a British Council lecturer posted" 7 "[' Smith', ' (""', 'Re', 'ggie', '""),', ' a', ' British', ' Council']"
+657 190 The name of the CEO of x -1 The name of the CEO of British Council Ciarán Devane British Council "[' of' ' the' ' Year' ' Award' ' for' ' the' ' best' '-' 'selling' ' book'
+ ' of' ' the' ' year' '.' '\n' '\n' 'The' ' book' ' is' ' a']" " of the Year Award for the best - selling book of the year .
+
+ The book is a" False 1 ['British', ' Council']
+658 190 The name of the CEO of x -1 The name of the CEO of British Council Ciarán Devane British Council "[' of' ' the' ' Year' ' Award' ' for' ' the' ' best' '-' 'selling' ' book'
+ ' of' ' the' ' year' '.' '\n' '\n' 'The' ' book' ' is' ' a']" " of the Year Award for the best - selling book of the year .
+
+ The book is a" False 1 ['British', ' Council']
+659 191 The name of the CEO of x -1 The name of the CEO of Philip Morris International Jacek Olczak Philip Morris International "[',' ' the' ' world' '�' '�' 's' ' largest' ' tobacco' ' company' ','
+ ' is' ' Philip' ' Morris' ' International' '.' '\n' '\n' 'The' ' company'
+ ' is']" ", the world � � s largest tobacco company , is Philip Morris International .
+
+ The company is" False " ===
+" 4 [' ===', 'Phil', 'ip', ' Morris', ' International']
+660 191 The name of the CEO of x -1 The name of the CEO of Philip Morris International Jacek Olczak Philip Morris International "[',' ' the' ' world' '�' '�' 's' ' largest' ' tobacco' ' company' ','
+ ' is' ' Philip' ' Morris' ' International' '.' '\n' '\n' 'The' ' company'
+ ' is']" ", the world � � s largest tobacco company , is Philip Morris International .
+
+ The company is" False Cowboy Hat) to Philip Morris International to use free of charge. 6 [' Cowboy', ' Hat', ')', ' to', ' Philip', ' Morris', ' International']
+661 193 The name of the CEO of x -1 The name of the CEO of FC Bayern Munich Herbert Hainer FC Bayern Munich "[',' ' Karl' '-' 'He' 'in' 'z' ' Rum' 'men' 'ig' 'ge' ',' ' has' ' been'
+ ' linked' ' with' ' a' ' move' ' to' ' the' ' Premier']" , Karl - He in z Rum men ig ge , has been linked with a move to the Premier False Biggest defeat: FC Bayern Munich 5 – 1 Arsenal, 6 [' Big', 'gest', ' defeat', ':', ' FC', ' Bayern', ' Munich']
+662 193 The name of the CEO of x -1 The name of the CEO of FC Bayern Munich Herbert Hainer FC Bayern Munich "[',' ' Karl' '-' 'He' 'in' 'z' ' Rum' 'men' 'ig' 'ge' ',' ' has' ' been'
+ ' linked' ' with' ' a' ' move' ' to' ' the' ' Premier']" , Karl - He in z Rum men ig ge , has been linked with a move to the Premier False Saarbrücken asked FC Bayern Munich II to lose every 9 [' Sa', 'ar', 'br', 'ü', 'ck', 'en', ' asked', ' FC', ' Bayern', ' Munich']
+663 193 The name of the CEO of x -1 The name of the CEO of FC Bayern Munich Herbert Hainer FC Bayern Munich "[',' ' Karl' '-' 'He' 'in' 'z' ' Rum' 'men' 'ig' 'ge' ',' ' has' ' been'
+ ' linked' ' with' ' a' ' move' ' to' ' the' ' Premier']" , Karl - He in z Rum men ig ge , has been linked with a move to the Premier False 2 ['FC', ' Bayern', ' Munich']
+664 193 The name of the CEO of x -1 The name of the CEO of FC Bayern Munich Herbert Hainer FC Bayern Munich "[',' ' Karl' '-' 'He' 'in' 'z' ' Rum' 'men' 'ig' 'ge' ',' ' has' ' been'
+ ' linked' ' with' ' a' ' move' ' to' ' the' ' Premier']" , Karl - He in z Rum men ig ge , has been linked with a move to the Premier False president of the FC Bayern Munich football club between 5 [' president', ' of', ' the', ' FC', ' Bayern', ' Munich']
+665 193 The name of the CEO of x -1 The name of the CEO of FC Bayern Munich Herbert Hainer FC Bayern Munich "[',' ' Karl' '-' 'He' 'in' 'z' ' Rum' 'men' 'ig' 'ge' ',' ' has' ' been'
+ ' linked' ' with' ' a' ' move' ' to' ' the' ' Premier']" , Karl - He in z Rum men ig ge , has been linked with a move to the Premier False president of the FC Bayern Munich football club between 5 [' president', ' of', ' the', ' FC', ' Bayern', ' Munich']
+666 197 The name of the CEO of x -1 The name of the CEO of Activision Blizzard Bobby Kotick Activision Blizzard "[',' ' the' ' company' ' that' ' owns' ' Blizzard' ' Entertainment' ','
+ ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' company' ' is'
+ ' currently' ' in' ' the']" ", the company that owns Blizzard Entertainment , has been revealed .
+
+ The company is currently in the" False February 2009, Activision Blizzard had asserted 4 [' February', ' 2009', ',', ' Activision', ' Blizzard']
+667 197 The name of the CEO of x -1 The name of the CEO of Activision Blizzard Bobby Kotick Activision Blizzard "[',' ' the' ' company' ' that' ' owns' ' Blizzard' ' Entertainment' ','
+ ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' company' ' is'
+ ' currently' ' in' ' the']" ", the company that owns Blizzard Entertainment , has been revealed .
+
+ The company is currently in the" False was canceled by Activision Blizzard in 2011, as a result 4 [' was', ' canceled', ' by', ' Activision', ' Blizzard']
+668 197 The name of the CEO of x -1 The name of the CEO of Activision Blizzard Bobby Kotick Activision Blizzard "[',' ' the' ' company' ' that' ' owns' ' Blizzard' ' Entertainment' ','
+ ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' company' ' is'
+ ' currently' ' in' ' the']" ", the company that owns Blizzard Entertainment , has been revealed .
+
+ The company is currently in the" False canceled by Activision Blizzard in 2011, as 3 [' canceled', ' by', ' Activision', ' Blizzard']
+669 197 The name of the CEO of x -1 The name of the CEO of Activision Blizzard Bobby Kotick Activision Blizzard "[',' ' the' ' company' ' that' ' owns' ' Blizzard' ' Entertainment' ','
+ ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' company' ' is'
+ ' currently' ' in' ' the']" ", the company that owns Blizzard Entertainment , has been revealed .
+
+ The company is currently in the" False were then developed. Activision Blizzard is now the 5 [' were', ' then', ' developed', '.', ' Activision', ' Blizzard']
+670 197 The name of the CEO of x -1 The name of the CEO of Activision Blizzard Bobby Kotick Activision Blizzard "[',' ' the' ' company' ' that' ' owns' ' Blizzard' ' Entertainment' ','
+ ' has' ' been' ' revealed' '.' '\n' '\n' 'The' ' company' ' is'
+ ' currently' ' in' ' the']" ", the company that owns Blizzard Entertainment , has been revealed .
+
+ The company is currently in the" False Video game publisher Activision Blizzard posted a tribute 4 [' Video', ' game', ' publisher', ' Activision', ' Blizzard']
+671 198 The name of the CEO of x -1 The name of the CEO of KLM Pieter Elbers KLM "[',' ' the' ' Dutch' ' airline' ',' ' is' ' a' ' Dutch' ' word' ' for'
+ ' ""' 'I' ' love' ' you' '""' ' in' ' Dutch' '.' '\n' '\n']" ", the Dutch airline , is a Dutch word for "" I love you "" in Dutch .
+
+" False 1998 and early 1999, KLM and Braathens 6 [' 1998', ' and', ' early', ' 1999', ',', ' K', 'LM']
+672 198 The name of the CEO of x -1 The name of the CEO of KLM Pieter Elbers KLM "[',' ' the' ' Dutch' ' airline' ',' ' is' ' a' ' Dutch' ' word' ' for'
+ ' ""' 'I' ' love' ' you' '""' ' in' ' Dutch' '.' '\n' '\n']" ", the Dutch airline , is a Dutch word for "" I love you "" in Dutch .
+
+" False Icelandair (Reykjavík), KLM (Amsterdam), 11 [' Iceland', 'air', ' (', 'R', 'ey', 'kj', 'av', 'í', 'k', '),', ' K', 'LM']
+673 198 The name of the CEO of x -1 The name of the CEO of KLM Pieter Elbers KLM "[',' ' the' ' Dutch' ' airline' ',' ' is' ' a' ' Dutch' ' word' ' for'
+ ' ""' 'I' ' love' ' you' '""' ' in' ' Dutch' '.' '\n' '\n']" ", the Dutch airline , is a Dutch word for "" I love you "" in Dutch .
+
+" False brands similar to KLM and Air France in 4 [' brands', ' similar', ' to', ' K', 'LM']
+674 198 The name of the CEO of x -1 The name of the CEO of KLM Pieter Elbers KLM "[',' ' the' ' Dutch' ' airline' ',' ' is' ' a' ' Dutch' ' word' ' for'
+ ' ""' 'I' ' love' ' you' '""' ' in' ' Dutch' '.' '\n' '\n']" ", the Dutch airline , is a Dutch word for "" I love you "" in Dutch .
+
+" False the landing. A KLM aircraft had 5 [' the', ' landing', '.', ' A', ' K', 'LM']
+675 198 The name of the CEO of x -1 The name of the CEO of KLM Pieter Elbers KLM "[',' ' the' ' Dutch' ' airline' ',' ' is' ' a' ' Dutch' ' word' ' for'
+ ' ""' 'I' ' love' ' you' '""' ' in' ' Dutch' '.' '\n' '\n']" ", the Dutch airline , is a Dutch word for "" I love you "" in Dutch .
+
+" False named after her. A KLM McDonnell Douglas 6 [' named', ' after', ' her', '.', ' A', ' K', 'LM']
+676 199 The name of the CEO of x -1 The name of the CEO of Comcast Brian L. Roberts Comcast "[',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and' ' the' ' CEO' ' of'
+ ' Comcast' ',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and']" , and the CEO of Comcast , and the CEO of Comcast , and the CEO of Comcast , and False 1 ['Com', 'cast']
+677 199 The name of the CEO of x -1 The name of the CEO of Comcast Brian L. Roberts Comcast "[',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and' ' the' ' CEO' ' of'
+ ' Comcast' ',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and']" , and the CEO of Comcast , and the CEO of Comcast , and the CEO of Comcast , and False 1 ['Com', 'cast']
+678 199 The name of the CEO of x -1 The name of the CEO of Comcast Brian L. Roberts Comcast "[',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and' ' the' ' CEO' ' of'
+ ' Comcast' ',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and']" , and the CEO of Comcast , and the CEO of Comcast , and the CEO of Comcast , and False 1 ['Com', 'cast']
+679 199 The name of the CEO of x -1 The name of the CEO of Comcast Brian L. Roberts Comcast "[',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and' ' the' ' CEO' ' of'
+ ' Comcast' ',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and']" , and the CEO of Comcast , and the CEO of Comcast , and the CEO of Comcast , and False demand to 500,000 Comcast customers, 5 [' demand', ' to', ' 500', ',', '000', ' Comcast']
+680 199 The name of the CEO of x -1 The name of the CEO of Comcast Brian L. Roberts Comcast "[',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and' ' the' ' CEO' ' of'
+ ' Comcast' ',' ' and' ' the' ' CEO' ' of' ' Comcast' ',' ' and']" , and the CEO of Comcast , and the CEO of Comcast , and the CEO of Comcast , and False McDonald's, Comcast and radio station 3 "[' McDonald', ""'s"", ',', ' Comcast']"
+681 200 The name of the CEO of x -1 The name of the CEO of DR Maria Rørbye Rønn DR "['.' ' J' '.' 'J' '.' ' Abrams' ',' ' the' ' director' ' of' ' the' ' new'
+ ' Star' ' Wars' ' movie' ',' ' Star' ' Wars' ':' ' The']" . J . J . Abrams , the director of the new Star Wars movie , Star Wars : The False competition from the DR and the MR, 3 [' competition', ' from', ' the', ' DR']
+682 200 The name of the CEO of x -1 The name of the CEO of DR Maria Rørbye Rønn DR "['.' ' J' '.' 'J' '.' ' Abrams' ',' ' the' ' director' ' of' ' the' ' new'
+ ' Star' ' Wars' ' movie' ',' ' Star' ' Wars' ':' ' The']" . J . J . Abrams , the director of the new Star Wars movie , Star Wars : The False lines ran. The DR was extended 4 [' lines', ' ran', '.', ' The', ' DR']
+683 200 The name of the CEO of x -1 The name of the CEO of DR Maria Rørbye Rønn DR "['.' ' J' '.' 'J' '.' ' Abrams' ',' ' the' ' director' ' of' ' the' ' new'
+ ' Star' ' Wars' ' movie' ',' ' Star' ' Wars' ':' ' The']" . J . J . Abrams , the director of the new Star Wars movie , Star Wars : The False Act, 1910, The DR was not merged with 5 [' Act', ',', ' 1910', ',', ' The', ' DR']
+684 200 The name of the CEO of x -1 The name of the CEO of DR Maria Rørbye Rønn DR "['.' ' J' '.' 'J' '.' ' Abrams' ',' ' the' ' director' ' of' ' the' ' new'
+ ' Star' ' Wars' ' movie' ',' ' Star' ' Wars' ':' ' The']" . J . J . Abrams , the director of the new Star Wars movie , Star Wars : The False use the S & DR station. The Sunniside 4 [' use', ' the', ' S', ' &', ' DR']
+685 200 The name of the CEO of x -1 The name of the CEO of DR Maria Rørbye Rønn DR "['.' ' J' '.' 'J' '.' ' Abrams' ',' ' the' ' director' ' of' ' the' ' new'
+ ' Star' ' Wars' ' movie' ',' ' Star' ' Wars' ':' ' The']" . J . J . Abrams , the director of the new Star Wars movie , Star Wars : The False and Co, and the S & DR worked traffic 7 [' and', ' Co', ',', ' and', ' the', ' S', ' &', ' DR']
+686 202 The name of the CEO of x -1 The name of the CEO of Accenture Julie Sweet Accenture "[',' ' a' ' global' ' management' ' consulting' ' firm' ',' ' and' ' a'
+ ' former' ' U' '.' 'S' '.' ' Army' ' officer' ',' ' is' ' a' ' member']" , a global management consulting firm , and a former U . S . Army officer , is a member False consulting companies like Accenture and also including 4 [' consulting', ' companies', ' like', ' Acc', 'enture']
+687 202 The name of the CEO of x -1 The name of the CEO of Accenture Julie Sweet Accenture "[',' ' a' ' global' ' management' ' consulting' ' firm' ',' ' and' ' a'
+ ' former' ' U' '.' 'S' '.' ' Army' ' officer' ',' ' is' ' a' ' member']" , a global management consulting firm , and a former U . S . Army officer , is a member False 1 ['Acc', 'enture']
+688 202 The name of the CEO of x -1 The name of the CEO of Accenture Julie Sweet Accenture "[',' ' a' ' global' ' management' ' consulting' ' firm' ',' ' and' ' a'
+ ' former' ' U' '.' 'S' '.' ' Army' ' officer' ',' ' is' ' a' ' member']" , a global management consulting firm , and a former U . S . Army officer , is a member False companies like Accenture and also including 3 [' companies', ' like', ' Acc', 'enture']
+689 202 The name of the CEO of x -1 The name of the CEO of Accenture Julie Sweet Accenture "[',' ' a' ' global' ' management' ' consulting' ' firm' ',' ' and' ' a'
+ ' former' ' U' '.' 'S' '.' ' Army' ' officer' ',' ' is' ' a' ' member']" , a global management consulting firm , and a former U . S . Army officer , is a member False companies like Accenture and also including 3 [' companies', ' like', ' Acc', 'enture']
+690 204 The name of the CEO of x -1 The name of the CEO of AOL Tim Armstrong AOL "[',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and' ' the' ' CEO' ' of'
+ ' AOL' ',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and']" , and the CEO of AOL , and the CEO of AOL , and the CEO of AOL , and False " ""love-song."" Furthermore, AOL Music's Charley Rogulewski" 7 "[' ""', 'love', '-', 'song', '.""', ' Furthermore', ',', ' AOL']"
+691 204 The name of the CEO of x -1 The name of the CEO of AOL Tim Armstrong AOL "[',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and' ' the' ' CEO' ' of'
+ ' AOL' ',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and']" , and the CEO of AOL , and the CEO of AOL , and the CEO of AOL , and False launched an AOL Radio station dedicated 2 [' launched', ' an', ' AOL']
+692 204 The name of the CEO of x -1 The name of the CEO of AOL Tim Armstrong AOL "[',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and' ' the' ' CEO' ' of'
+ ' AOL' ',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and']" , and the CEO of AOL , and the CEO of AOL , and the CEO of AOL , and False Sara Anderson from AOL Radio noted that 3 [' Sara', ' Anderson', ' from', ' AOL']
+693 204 The name of the CEO of x -1 The name of the CEO of AOL Tim Armstrong AOL "[',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and' ' the' ' CEO' ' of'
+ ' AOL' ',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and']" , and the CEO of AOL , and the CEO of AOL , and the CEO of AOL , and False Shima Maya of AOL Radio Blog gave the 4 [' Sh', 'ima', ' Maya', ' of', ' AOL']
+694 204 The name of the CEO of x -1 The name of the CEO of AOL Tim Armstrong AOL "[',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and' ' the' ' CEO' ' of'
+ ' AOL' ',' ' and' ' the' ' CEO' ' of' ' AOL' ',' ' and']" , and the CEO of AOL , and the CEO of AOL , and the CEO of AOL , and False Maureen Ryan from AOL TV rated it with 4 [' Ma', 'ureen', ' Ryan', ' from', ' AOL']
+695 206 The name of the CEO of x -1 The name of the CEO of SpaceX Elon Musk SpaceX "[',' ' Elon' ' Musk' ',' ' has' ' been' ' a' ' bit' ' of' ' a'
+ ' controversial' ' figure' ' in' ' the' ' space' ' industry' '.' ' He'
+ ' has' ' been']" , Elon Musk , has been a bit of a controversial figure in the space industry . He has been True brought about by SpaceX lower launch 3 [' brought', ' about', ' by', ' SpaceX']
+696 206 The name of the CEO of x -1 The name of the CEO of SpaceX Elon Musk SpaceX "[',' ' Elon' ' Musk' ',' ' has' ' been' ' a' ' bit' ' of' ' a'
+ ' controversial' ' figure' ' in' ' the' ' space' ' industry' '.' ' He'
+ ' has' ' been']" , Elon Musk , has been a bit of a controversial figure in the space industry . He has been True 1 ['Space', 'X']
+697 206 The name of the CEO of x -1 The name of the CEO of SpaceX Elon Musk SpaceX "[',' ' Elon' ' Musk' ',' ' has' ' been' ' a' ' bit' ' of' ' a'
+ ' controversial' ' figure' ' in' ' the' ' space' ' industry' '.' ' He'
+ ' has' ' been']" , Elon Musk , has been a bit of a controversial figure in the space industry . He has been True Elon Musk announced SpaceX's plans to build a 3 [' Elon', ' Musk', ' announced', ' SpaceX']
+698 206 The name of the CEO of x -1 The name of the CEO of SpaceX Elon Musk SpaceX "[',' ' Elon' ' Musk' ',' ' has' ' been' ' a' ' bit' ' of' ' a'
+ ' controversial' ' figure' ' in' ' the' ' space' ' industry' '.' ' He'
+ ' has' ' been']" , Elon Musk , has been a bit of a controversial figure in the space industry . He has been True 1 ['Space', 'X']
+699 206 The name of the CEO of x -1 The name of the CEO of SpaceX Elon Musk SpaceX "[',' ' Elon' ' Musk' ',' ' has' ' been' ' a' ' bit' ' of' ' a'
+ ' controversial' ' figure' ' in' ' the' ' space' ' industry' '.' ' He'
+ ' has' ' been']" , Elon Musk , has been a bit of a controversial figure in the space industry . He has been True hover testing at the SpaceX Rocket Test 4 [' hover', ' testing', ' at', ' the', ' SpaceX']
+700 207 The name of the CEO of x -1 The name of the CEO of Finnair Pekka Vauramo Finnair "[',' ' the' ' airline' ""'s"" ' parent' ' company' ',' ' said' ' the'
+ ' airline' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' deal' '.'
+ '\n']" ", the airline 's parent company , said the airline is "" very pleased "" with the deal .
+" False domestic services, while Finnair served the airport 5 [' domestic', ' services', ',', ' while', ' Finn', 'air']
+701 207 The name of the CEO of x -1 The name of the CEO of Finnair Pekka Vauramo Finnair "[',' ' the' ' airline' ""'s"" ' parent' ' company' ',' ' said' ' the'
+ ' airline' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' deal' '.'
+ '\n']" ", the airline 's parent company , said the airline is "" very pleased "" with the deal .
+" False Japan and Europe. Finnair applied to join 5 [' Japan', ' and', ' Europe', '.', ' Finn', 'air']
+702 207 The name of the CEO of x -1 The name of the CEO of Finnair Pekka Vauramo Finnair "[',' ' the' ' airline' ""'s"" ' parent' ' company' ',' ' said' ' the'
+ ' airline' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' deal' '.'
+ '\n']" ", the airline 's parent company , said the airline is "" very pleased "" with the deal .
+" False Svalbard from Oslo. Finnair announced to begin 7 [' S', 'val', 'bard', ' from', ' Oslo', '.', ' Finn', 'air']
+703 207 The name of the CEO of x -1 The name of the CEO of Finnair Pekka Vauramo Finnair "[',' ' the' ' airline' ""'s"" ' parent' ' company' ',' ' said' ' the'
+ ' airline' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' deal' '.'
+ '\n']" ", the airline 's parent company , said the airline is "" very pleased "" with the deal .
+" False services, while Finnair served the airport 4 [' services', ',', ' while', ' Finn', 'air']
+704 207 The name of the CEO of x -1 The name of the CEO of Finnair Pekka Vauramo Finnair "[',' ' the' ' airline' ""'s"" ' parent' ' company' ',' ' said' ' the'
+ ' airline' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' deal' '.'
+ '\n']" ", the airline 's parent company , said the airline is "" very pleased "" with the deal .
+" False and Europe. Finnair applied to join 4 [' and', ' Europe', '.', ' Finn', 'air']
+705 208 The name of the CEO of x -1 The name of the CEO of Kroger Rodney McMullen Kroger "[',' ' the' ' world' '�' '�' 's' ' largest' ' grocery' ' store' ' chain'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' interested']" , the world � � s largest grocery store chain , said that the company is � � very interested False Corp., United Airlines, Kroger Co. and Xcel 6 [' Corp', '.,', ' United', ' Airlines', ',', ' Kro', 'ger']
+706 208 The name of the CEO of x -1 The name of the CEO of Kroger Rodney McMullen Kroger "[',' ' the' ' world' '�' '�' 's' ' largest' ' grocery' ' store' ' chain'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' interested']" , the world � � s largest grocery store chain , said that the company is � � very interested False companies in Ohio include Kroger (Cincinnati), 5 [' companies', ' in', ' Ohio', ' include', ' Kro', 'ger']
+707 208 The name of the CEO of x -1 The name of the CEO of Kroger Rodney McMullen Kroger "[',' ' the' ' world' '�' '�' 's' ' largest' ' grocery' ' store' ' chain'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' interested']" , the world � � s largest grocery store chain , said that the company is � � very interested False Peter and Helen Kroger were found 4 [' Peter', ' and', ' Helen', ' Kro', 'ger']
+708 208 The name of the CEO of x -1 The name of the CEO of Kroger Rodney McMullen Kroger "[',' ' the' ' world' '�' '�' 's' ' largest' ' grocery' ' store' ' chain'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' interested']" , the world � � s largest grocery store chain , said that the company is � � very interested False presenter Kroger Babb. Mom and Dad 2 [' presenter', ' Kro', 'ger']
+709 208 The name of the CEO of x -1 The name of the CEO of Kroger Rodney McMullen Kroger "[',' ' the' ' world' '�' '�' 's' ' largest' ' grocery' ' store' ' chain'
+ ',' ' said' ' that' ' the' ' company' ' is' ' �' '�' 'very' ' interested']" , the world � � s largest grocery store chain , said that the company is � � very interested False Ohio include Kroger (Cincinnati), 3 [' Ohio', ' include', ' Kro', 'ger']
+710 210 The name of the CEO of x -1 The name of the CEO of ConocoPhillips Ryan Lance ConocoPhillips "[',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' and' ' gas'
+ ' company' ',' ' is' ' a' ' member' ' of' ' the' ' board' ' of'
+ ' directors' ' of']" , the company that owns the oil and gas company , is a member of the board of directors of False Freeport-McMoRan to the ConocoPhillips refinery in Santa 12 [' Fre', 'eport', '-', 'Mc', 'Mo', 'R', 'an', ' to', ' the', ' Con', 'oco', 'Phill', 'ips']
+711 210 The name of the CEO of x -1 The name of the CEO of ConocoPhillips Ryan Lance ConocoPhillips "[',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' and' ' gas'
+ ' company' ',' ' is' ' a' ' member' ' of' ' the' ' board' ' of'
+ ' directors' ' of']" , the company that owns the oil and gas company , is a member of the board of directors of False those working for ConocoPhillips and ExxonMobil remained 6 [' those', ' working', ' for', ' Con', 'oco', 'Phill', 'ips']
+712 210 The name of the CEO of x -1 The name of the CEO of ConocoPhillips Ryan Lance ConocoPhillips "[',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' and' ' gas'
+ ' company' ',' ' is' ' a' ' member' ' of' ' the' ' board' ' of'
+ ' directors' ' of']" , the company that owns the oil and gas company , is a member of the board of directors of False funding provided by ConocoPhillips in 2003. NSTA 6 [' funding', ' provided', ' by', ' Con', 'oco', 'Phill', 'ips']
+713 210 The name of the CEO of x -1 The name of the CEO of ConocoPhillips Ryan Lance ConocoPhillips "[',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' and' ' gas'
+ ' company' ',' ' is' ' a' ' member' ' of' ' the' ' board' ' of'
+ ' directors' ' of']" , the company that owns the oil and gas company , is a member of the board of directors of False provided by ConocoPhillips in 2003. NSTA indicated 5 [' provided', ' by', ' Con', 'oco', 'Phill', 'ips']
+714 210 The name of the CEO of x -1 The name of the CEO of ConocoPhillips Ryan Lance ConocoPhillips "[',' ' the' ' company' ' that' ' owns' ' the' ' oil' ' and' ' gas'
+ ' company' ',' ' is' ' a' ' member' ' of' ' the' ' board' ' of'
+ ' directors' ' of']" , the company that owns the oil and gas company , is a member of the board of directors of False Freeport-McMoRan to the ConocoPhillips refinery in 12 [' Fre', 'eport', '-', 'Mc', 'Mo', 'R', 'an', ' to', ' the', ' Con', 'oco', 'Phill', 'ips']
+715 211 The name of the CEO of x -1 The name of the CEO of Lenovo Yang Yuanqing Lenovo "[',' ' Lenovo' ',' ' is' ' a' ' Chinese' ' company' ' that' ' makes'
+ ' computers' ',' ' phones' ',' ' tablets' ',' ' and' ' other'
+ ' electronics' '.' '\n']" ", Lenovo , is a Chinese company that makes computers , phones , tablets , and other electronics .
+" False such as Huawei and Lenovo have become 4 [' such', ' as', ' Huawei', ' and', ' Lenovo']
+716 211 The name of the CEO of x -1 The name of the CEO of Lenovo Yang Yuanqing Lenovo "[',' ' Lenovo' ',' ' is' ' a' ' Chinese' ' company' ' that' ' makes'
+ ' computers' ',' ' phones' ',' ' tablets' ',' ' and' ' other'
+ ' electronics' '.' '\n']" ", Lenovo , is a Chinese company that makes computers , phones , tablets , and other electronics .
+" False Ericsson Xperia X10. Lenovo announced the 6 [' Eric', 'sson', ' Xperia', ' X', '10', '.', ' Lenovo']
+717 211 The name of the CEO of x -1 The name of the CEO of Lenovo Yang Yuanqing Lenovo "[',' ' Lenovo' ',' ' is' ' a' ' Chinese' ' company' ' that' ' makes'
+ ' computers' ',' ' phones' ',' ' tablets' ',' ' and' ' other'
+ ' electronics' '.' '\n']" ", Lenovo , is a Chinese company that makes computers , phones , tablets , and other electronics .
+" False Panasonic, Renault India, Lenovo and the Spanish football 5 [' Panasonic', ',', ' Renault', ' India', ',', ' Lenovo']
+718 211 The name of the CEO of x -1 The name of the CEO of Lenovo Yang Yuanqing Lenovo "[',' ' Lenovo' ',' ' is' ' a' ' Chinese' ' company' ' that' ' makes'
+ ' computers' ',' ' phones' ',' ' tablets' ',' ' and' ' other'
+ ' electronics' '.' '\n']" ", Lenovo , is a Chinese company that makes computers , phones , tablets , and other electronics .
+" False Panasonic, Renault India, Lenovo and the Spanish football 5 [' Panasonic', ',', ' Renault', ' India', ',', ' Lenovo']
+719 211 The name of the CEO of x -1 The name of the CEO of Lenovo Yang Yuanqing Lenovo "[',' ' Lenovo' ',' ' is' ' a' ' Chinese' ' company' ' that' ' makes'
+ ' computers' ',' ' phones' ',' ' tablets' ',' ' and' ' other'
+ ' electronics' '.' '\n']" ", Lenovo , is a Chinese company that makes computers , phones , tablets , and other electronics .
+" False transforms into a Lenovo ThinkPad Edge 3 [' transforms', ' into', ' a', ' Lenovo']
+720 214 The name of the CEO of x -1 The name of the CEO of Viacom Robert Bakish Viacom "[',' ' the' ' parent' ' company' ' of' ' MTV' ',' ' Nickel' 'ode' 'on' ','
+ ' Comedy' ' Central' ',' ' and' ' V' 'H' '1' ',' ' is']" , the parent company of MTV , Nickel ode on , Comedy Central , and V H 1 , is False Rock Band helped Viacom to become the fifth-largest 5 [' Rock', ' Band', ' helped', ' V', 'iac', 'om']
+721 214 The name of the CEO of x -1 The name of the CEO of Viacom Robert Bakish Viacom "[',' ' the' ' parent' ' company' ' of' ' MTV' ',' ' Nickel' 'ode' 'on' ','
+ ' Comedy' ' Central' ',' ' and' ' V' 'H' '1' ',' ' is']" , the parent company of MTV , Nickel ode on , Comedy Central , and V H 1 , is False Rock Band helped Viacom to become the 5 [' Rock', ' Band', ' helped', ' V', 'iac', 'om']
+722 214 The name of the CEO of x -1 The name of the CEO of Viacom Robert Bakish Viacom "[',' ' the' ' parent' ' company' ' of' ' MTV' ',' ' Nickel' 'ode' 'on' ','
+ ' Comedy' ' Central' ',' ' and' ' V' 'H' '1' ',' ' is']" , the parent company of MTV , Nickel ode on , Comedy Central , and V H 1 , is False music channels owned by Viacom and CBS due 6 [' music', ' channels', ' owned', ' by', ' V', 'iac', 'om']
+723 214 The name of the CEO of x -1 The name of the CEO of Viacom Robert Bakish Viacom "[',' ' the' ' parent' ' company' ' of' ' MTV' ',' ' Nickel' 'ode' 'on' ','
+ ' Comedy' ' Central' ',' ' and' ' V' 'H' '1' ',' ' is']" , the parent company of MTV , Nickel ode on , Comedy Central , and V H 1 , is False systems. In late 2010, Viacom sold Harmonix 8 [' systems', '.', ' In', ' late', ' 2010', ',', ' V', 'iac', 'om']
+724 214 The name of the CEO of x -1 The name of the CEO of Viacom Robert Bakish Viacom "[',' ' the' ' parent' ' company' ' of' ' MTV' ',' ' Nickel' 'ode' 'on' ','
+ ' Comedy' ' Central' ',' ' and' ' V' 'H' '1' ',' ' is']" , the parent company of MTV , Nickel ode on , Comedy Central , and V H 1 , is False originally provided by Viacom, the parent 5 [' originally', ' provided', ' by', ' V', 'iac', 'om']
+725 215 The name of the CEO of x -1 The name of the CEO of Honda Takahiro Hachigō Honda "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Honda' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Honda' ' is']" ", and the name of the CEO of Honda .
+
+ The name of the CEO of Honda is" False Chevrolet Volt and Honda Insight. For its 3 [' Chevrolet', ' Volt', ' and', ' Honda']
+726 215 The name of the CEO of x -1 The name of the CEO of Honda Takahiro Hachigō Honda "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Honda' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Honda' ' is']" ", and the name of the CEO of Honda .
+
+ The name of the CEO of Honda is" False previous form, the Hondas of Barrichello and 4 [' previous', ' form', ',', ' the', ' Honda']
+727 215 The name of the CEO of x -1 The name of the CEO of Honda Takahiro Hachigō Honda "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Honda' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Honda' ' is']" ", and the name of the CEO of Honda .
+
+ The name of the CEO of Honda is" False second at the Honda Classic, with the 3 [' second', ' at', ' the', ' Honda']
+728 215 The name of the CEO of x -1 The name of the CEO of Honda Takahiro Hachigō Honda "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Honda' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Honda' ' is']" ", and the name of the CEO of Honda .
+
+ The name of the CEO of Honda is" False to a bed at Laguna Honda Hospital, a 6 [' to', ' a', ' bed', ' at', ' Lag', 'una', ' Honda']
+729 215 The name of the CEO of x -1 The name of the CEO of Honda Takahiro Hachigō Honda "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Honda' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' CEO' ' of' ' Honda' ' is']" ", and the name of the CEO of Honda .
+
+ The name of the CEO of Honda is" False Enemies Tour, the Honda Civic Tour, 4 [' Enemies', ' Tour', ',', ' the', ' Honda']
+730 218 The name of the CEO of x -1 The name of the CEO of Kyushu Railway Company Yōji Furumiya Kyushu Railway Company "[',' ' the' ' operator' ' of' ' the' ' Ky' 'ush' 'u' ' Sh' 'ink' 'ansen'
+ ',' ' said' ' that' ' the' ' company' ' is' ' considering' ' the'
+ ' possibility']" , the operator of the Ky ush u Sh ink ansen , said that the company is considering the possibility False December 23, 2002, the Kyushu Railway Company has displayed the 10 [' December', ' 23', ',', ' 2002', ',', ' the', ' Ky', 'ush', 'u', ' Railway', ' Company']
+731 219 The name of the CEO of x -1 The name of the CEO of Heineken Freddy Heineken Heineken "[',' ' the' ' beer' ' company' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' '\n' '\n' 'The' ' beer' ' is' ' a' ' dark' ',']" ", the beer company , is a bit of a mystery .
+
+ The beer is a dark ," False Target, IKEA, and Heineken operate distribution 9 [' Target', ',', ' I', 'KE', 'A', ',', ' and', ' He', 'ine', 'ken']
+732 219 The name of the CEO of x -1 The name of the CEO of Heineken Freddy Heineken Heineken "[',' ' the' ' beer' ' company' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' '\n' '\n' 'The' ' beer' ' is' ' a' ' dark' ',']" ", the beer company , is a bit of a mystery .
+
+ The beer is a dark ," False stage of the Heineken Cup unbeaten during 5 [' stage', ' of', ' the', ' He', 'ine', 'ken']
+733 219 The name of the CEO of x -1 The name of the CEO of Heineken Freddy Heineken Heineken "[',' ' the' ' beer' ' company' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' '\n' '\n' 'The' ' beer' ' is' ' a' ' dark' ',']" ", the beer company , is a bit of a mystery .
+
+ The beer is a dark ," False an interview with Heineken Music, on 27 January 5 [' an', ' interview', ' with', ' He', 'ine', 'ken']
+734 219 The name of the CEO of x -1 The name of the CEO of Heineken Freddy Heineken Heineken "[',' ' the' ' beer' ' company' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' '\n' '\n' 'The' ' beer' ' is' ' a' ' dark' ',']" ", the beer company , is a bit of a mystery .
+
+ The beer is a dark ," False included Ford Motors, Heineken Pilsener (which 6 [' included', ' Ford', ' Motors', ',', ' He', 'ine', 'ken']
+735 219 The name of the CEO of x -1 The name of the CEO of Heineken Freddy Heineken Heineken "[',' ' the' ' beer' ' company' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' '\n' '\n' 'The' ' beer' ' is' ' a' ' dark' ',']" ", the beer company , is a bit of a mystery .
+
+ The beer is a dark ," False large quantities of Heineken beer. On several 5 [' large', ' quantities', ' of', ' He', 'ine', 'ken']
+736 220 The name of the CEO of x -1 The name of the CEO of Amazon Web Services Andy Jassy Amazon Web Services "[',' ' Jeff' ' Bezos' ',' ' is' ' a' ' great' ' example' ' of' ' this' '.'
+ ' He' ' is' ' a' ' great' ' example' ' of' ' a' ' person' ' who']" , Jeff Bezos , is a great example of this . He is a great example of a person who False AWS SDK 6 [' AWS', ' S', 'DA', 'm', 'azon', ' Web', ' Services']
+737 220 The name of the CEO of x -1 The name of the CEO of Amazon Web Services Andy Jassy Amazon Web Services "[',' ' Jeff' ' Bezos' ',' ' is' ' a' ' great' ' example' ' of' ' this' '.'
+ ' He' ' is' ' a' ' great' ' example' ' of' ' a' ' person' ' who']" , Jeff Bezos , is a great example of this . He is a great example of a person who False the 3 [' th', 'Amazon', ' Web', ' Services']
+738 220 The name of the CEO of x -1 The name of the CEO of Amazon Web Services Andy Jassy Amazon Web Services "[',' ' Jeff' ' Bezos' ',' ' is' ' a' ' great' ' example' ' of' ' this' '.'
+ ' He' ' is' ' a' ' great' ' example' ' of' ' a' ' person' ' who']" , Jeff Bezos , is a great example of this . He is a great example of a person who False AWS SDK 6 [' AWS', ' S', 'DA', 'm', 'azon', ' Web', ' Services']
+739 221 The name of the CEO of x -1 The name of the CEO of Paramount Pictures Jim Gianopulos Paramount Pictures "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Paramount'
+ ' Pictures' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Paramount']" ", and the name of the CEO of Paramount Pictures .
+
+ The name of the CEO of Paramount" False 2 ['Par', 'amount', ' Pictures']
+740 221 The name of the CEO of x -1 The name of the CEO of Paramount Pictures Jim Gianopulos Paramount Pictures "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Paramount'
+ ' Pictures' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Paramount']" ", and the name of the CEO of Paramount Pictures .
+
+ The name of the CEO of Paramount" False 2 ['Par', 'amount', ' Pictures']
+741 221 The name of the CEO of x -1 The name of the CEO of Paramount Pictures Jim Gianopulos Paramount Pictures "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Paramount'
+ ' Pictures' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Paramount']" ", and the name of the CEO of Paramount Pictures .
+
+ The name of the CEO of Paramount" False Since its debut, Paramount Pictures has released 5 [' Since', ' its', ' debut', ',', ' Paramount', ' Pictures']
+742 221 The name of the CEO of x -1 The name of the CEO of Paramount Pictures Jim Gianopulos Paramount Pictures "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Paramount'
+ ' Pictures' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Paramount']" ", and the name of the CEO of Paramount Pictures .
+
+ The name of the CEO of Paramount" False own success led Paramount Pictures to purchase 4 [' own', ' success', ' led', ' Paramount', ' Pictures']
+743 221 The name of the CEO of x -1 The name of the CEO of Paramount Pictures Jim Gianopulos Paramount Pictures "[',' ' and' ' the' ' name' ' of' ' the' ' CEO' ' of' ' Paramount'
+ ' Pictures' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' CEO' ' of'
+ ' Paramount']" ", and the name of the CEO of Paramount Pictures .
+
+ The name of the CEO of Paramount" False Beauty (1999), Paramount Pictures became interested 5 [' Beauty', ' (', '1999', '),', ' Paramount', ' Pictures']
+744 222 The name of the CEO of x -1 The name of the CEO of Nortel Mike S. Zafirovski Nortel "[' Networks' ',' ' a' ' company' ' that' ' was' ' once' ' the' ' world'
+ ""'s"" ' largest' ' maker' ' of' ' telephone' ' switching' ' equipment' ','
+ ' was' ' acquired' ' by']" Networks , a company that was once the world 's largest maker of telephone switching equipment , was acquired by False devices, and with Nortel for fiber optic technology. 5 [' devices', ',', ' and', ' with', ' Nort', 'el']
+745 222 The name of the CEO of x -1 The name of the CEO of Nortel Mike S. Zafirovski Nortel "[' Networks' ',' ' a' ' company' ' that' ' was' ' once' ' the' ' world'
+ ""'s"" ' largest' ' maker' ' of' ' telephone' ' switching' ' equipment' ','
+ ' was' ' acquired' ' by']" Networks , a company that was once the world 's largest maker of telephone switching equipment , was acquired by False 2 ['N', 'ort', 'el']
+746 222 The name of the CEO of x -1 The name of the CEO of Nortel Mike S. Zafirovski Nortel "[' Networks' ',' ' a' ' company' ' that' ' was' ' once' ' the' ' world'
+ ""'s"" ' largest' ' maker' ' of' ' telephone' ' switching' ' equipment' ','
+ ' was' ' acquired' ' by']" Networks , a company that was once the world 's largest maker of telephone switching equipment , was acquired by False Siemens, Ericsson, Nortel and 3Com. Juniper 7 [' Siem', 'ens', ',', ' Eric', 'sson', ',', ' Nort', 'el']
+747 222 The name of the CEO of x -1 The name of the CEO of Nortel Mike S. Zafirovski Nortel "[' Networks' ',' ' a' ' company' ' that' ' was' ' once' ' the' ' world'
+ ""'s"" ' largest' ' maker' ' of' ' telephone' ' switching' ' equipment' ','
+ ' was' ' acquired' ' by']" Networks , a company that was once the world 's largest maker of telephone switching equipment , was acquired by False Siemens, Ericsson, Nortel and 3Com. Juniper 7 [' Siem', 'ens', ',', ' Eric', 'sson', ',', ' Nort', 'el']
+748 222 The name of the CEO of x -1 The name of the CEO of Nortel Mike S. Zafirovski Nortel "[' Networks' ',' ' a' ' company' ' that' ' was' ' once' ' the' ' world'
+ ""'s"" ' largest' ' maker' ' of' ' telephone' ' switching' ' equipment' ','
+ ' was' ' acquired' ' by']" Networks , a company that was once the world 's largest maker of telephone switching equipment , was acquired by False Siemens, Ericsson, Nortel and 3Com. Juniper 7 [' Siem', 'ens', ',', ' Eric', 'sson', ',', ' Nort', 'el']
+749 224 The name of the CEO of x -1 The name of the CEO of United Airlines Scott Kirby United Airlines "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' United' ' Airlines' ',' ' and' ' the' ' CEO' ' of' ' the' ' company'
+ ' that' ' owns']" , and the CEO of the company that owns United Airlines , and the CEO of the company that owns False 4, 1955, a United Airlines Douglas DC-6 named 6 [' 4', ',', ' 1955', ',', ' a', ' United', ' Airlines']
+750 224 The name of the CEO of x -1 The name of the CEO of United Airlines Scott Kirby United Airlines "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' United' ' Airlines' ',' ' and' ' the' ' CEO' ' of' ' the' ' company'
+ ' that' ' owns']" , and the CEO of the company that owns United Airlines , and the CEO of the company that owns False 08, the pilots of United Airlines Flight 175 verified 6 [' 08', ',', ' the', ' pilots', ' of', ' United', ' Airlines']
+751 224 The name of the CEO of x -1 The name of the CEO of United Airlines Scott Kirby United Airlines "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' United' ' Airlines' ',' ' and' ' the' ' CEO' ' of' ' the' ' company'
+ ' that' ' owns']" , and the CEO of the company that owns United Airlines , and the CEO of the company that owns False another DC-10, United Airlines Flight 232, ten 6 [' another', ' DC', '-', '10', ',', ' United', ' Airlines']
+752 224 The name of the CEO of x -1 The name of the CEO of United Airlines Scott Kirby United Airlines "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' United' ' Airlines' ',' ' and' ' the' ' CEO' ' of' ' the' ' company'
+ ' that' ' owns']" , and the CEO of the company that owns United Airlines , and the CEO of the company that owns False service with United Airlines on September 3 [' service', ' with', ' United', ' Airlines']
+753 224 The name of the CEO of x -1 The name of the CEO of United Airlines Scott Kirby United Airlines "[',' ' and' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' United' ' Airlines' ',' ' and' ' the' ' CEO' ' of' ' the' ' company'
+ ' that' ' owns']" , and the CEO of the company that owns United Airlines , and the CEO of the company that owns False propeller blades. United Airlines also stated they 5 [' prope', 'ller', ' blades', '.', ' United', ' Airlines']
+754 231 The name of the CEO of x -1 The name of the CEO of Eli Lilly and Company David A. Ricks Eli Lilly and Company "[',' ' the' ' world' ""'s"" ' largest' ' pharmaceutical' ' company' ','
+ ' is' ' Eli' ' Lilly' '.' '\n' '\n' 'The' ' company' ' is' ' based' ' in'
+ ' Indianapolis']" ", the world 's largest pharmaceutical company , is Eli Lilly .
+
+ The company is based in Indianapolis" False medicines. Eli Lilly and Company was one of the first 5 [' medicines', '.', ' Eli', ' Lilly', ' and', ' Company']
+755 231 The name of the CEO of x -1 The name of the CEO of Eli Lilly and Company David A. Ricks Eli Lilly and Company "[',' ' the' ' world' ""'s"" ' largest' ' pharmaceutical' ' company' ','
+ ' is' ' Eli' ' Lilly' '.' '\n' '\n' 'The' ' company' ' is' ' based' ' in'
+ ' Indianapolis']" ", the world 's largest pharmaceutical company , is Eli Lilly .
+
+ The company is based in Indianapolis" False founder of the Eli Lilly and Company pharmaceutical 6 [' founder', ' of', ' the', ' Eli', ' Lilly', ' and', ' Company']
+756 231 The name of the CEO of x -1 The name of the CEO of Eli Lilly and Company David A. Ricks Eli Lilly and Company "[',' ' the' ' world' ""'s"" ' largest' ' pharmaceutical' ' company' ','
+ ' is' ' Eli' ' Lilly' '.' '\n' '\n' 'The' ' company' ' is' ' based' ' in'
+ ' Indianapolis']" ", the world 's largest pharmaceutical company , is Eli Lilly .
+
+ The company is based in Indianapolis" False liquid medicines. Eli Lilly and Company was one of the first 6 [' liquid', ' medicines', '.', ' Eli', ' Lilly', ' and', ' Company']
+757 231 The name of the CEO of x -1 The name of the CEO of Eli Lilly and Company David A. Ricks Eli Lilly and Company "[',' ' the' ' world' ""'s"" ' largest' ' pharmaceutical' ' company' ','
+ ' is' ' Eli' ' Lilly' '.' '\n' '\n' 'The' ' company' ' is' ' based' ' in'
+ ' Indianapolis']" ", the world 's largest pharmaceutical company , is Eli Lilly .
+
+ The company is based in Indianapolis" False The manufacturer Eli Lilly and Company ran an aggressive 5 [' The', ' manufacturer', ' Eli', ' Lilly', ' and', ' Company']
+758 231 The name of the CEO of x -1 The name of the CEO of Eli Lilly and Company David A. Ricks Eli Lilly and Company "[',' ' the' ' world' ""'s"" ' largest' ' pharmaceutical' ' company' ','
+ ' is' ' Eli' ' Lilly' '.' '\n' '\n' 'The' ' company' ' is' ' based' ' in'
+ ' Indianapolis']" ", the world 's largest pharmaceutical company , is Eli Lilly .
+
+ The company is based in Indianapolis" False medicines. Eli Lilly and Company was one of the 5 [' medicines', '.', ' Eli', ' Lilly', ' and', ' Company']
+759 232 The name of the CEO of x -1 The name of the CEO of Hitachi Toshiaki Higashihara Hitachi "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been'
+ ' around']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been around" False territory in Hitachi Province. Having 3 [' territory', ' in', ' Hit', 'achi']
+760 232 The name of the CEO of x -1 The name of the CEO of Hitachi Toshiaki Higashihara Hitachi "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been'
+ ' around']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been around" False bombardment of the Hitachi area took 4 [' bombardment', ' of', ' the', ' Hit', 'achi']
+761 232 The name of the CEO of x -1 The name of the CEO of Hitachi Toshiaki Higashihara Hitachi "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been'
+ ' around']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been around" False 1 ['Hit', 'achi']
+762 232 The name of the CEO of x -1 The name of the CEO of Hitachi Toshiaki Higashihara Hitachi "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been'
+ ' around']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been around" False his group chose the Hitachi SH-4 processor architecture 5 [' his', ' group', ' chose', ' the', ' Hit', 'achi']
+763 232 The name of the CEO of x -1 The name of the CEO of Hitachi Toshiaki Higashihara Hitachi "[',' ' the' ' Japanese' ' electronics' ' giant' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company' ' has' ' been'
+ ' around']" ", the Japanese electronics giant , is a bit of a mystery .
+
+ The company has been around" False against the city of Hitachi during the 5 [' against', ' the', ' city', ' of', ' Hit', 'achi']
+764 234 The name of the CEO of x -1 The name of the CEO of Bristol-Myers Squibb Giovanni Caforio Bristol-Myers Squibb "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' the'
+ ' company' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' results'
+ ' of']" ", the company that makes the drug , said the company is "" very pleased "" with the results of" False a deal with Bristol-Myers Squibb to license this 9 [' a', ' deal', ' with', ' Bristol', '-', 'My', 'ers', ' Squ', 'ib', 'b']
+765 234 The name of the CEO of x -1 The name of the CEO of Bristol-Myers Squibb Giovanni Caforio Bristol-Myers Squibb "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' the'
+ ' company' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' results'
+ ' of']" ", the company that makes the drug , said the company is "" very pleased "" with the results of" False negotiated with the Bristol-Myers Squibb company and 9 [' negotiated', ' with', ' the', ' Bristol', '-', 'My', 'ers', ' Squ', 'ib', 'b']
+766 234 The name of the CEO of x -1 The name of the CEO of Bristol-Myers Squibb Giovanni Caforio Bristol-Myers Squibb "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' the'
+ ' company' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' results'
+ ' of']" ", the company that makes the drug , said the company is "" very pleased "" with the results of" False whereupon it was leased to Bristol-Myers Squibb for 25 years. Since 12 [' where', 'upon', ' it', ' was', ' leased', ' to', ' Bristol', '-', 'My', 'ers', ' Squ', 'ib', 'b']
+767 234 The name of the CEO of x -1 The name of the CEO of Bristol-Myers Squibb Giovanni Caforio Bristol-Myers Squibb "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' the'
+ ' company' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' results'
+ ' of']" ", the company that makes the drug , said the company is "" very pleased "" with the results of" False Butterworth, former Bristol-Myers Squibb CEO Peter R. 10 [' Butter', 'worth', ',', ' former', ' Bristol', '-', 'My', 'ers', ' Squ', 'ib', 'b']
+768 234 The name of the CEO of x -1 The name of the CEO of Bristol-Myers Squibb Giovanni Caforio Bristol-Myers Squibb "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' the'
+ ' company' ' is' ' ""' 'very' ' pleased' '""' ' with' ' the' ' results'
+ ' of']" ", the company that makes the drug , said the company is "" very pleased "" with the results of" False Butterworth, former Bristol-Myers Squibb CEO Peter R. 10 [' Butter', 'worth', ',', ' former', ' Bristol', '-', 'My', 'ers', ' Squ', 'ib', 'b']
+769 237 The name of the CEO of x -1 The name of the CEO of Lukoil Vagit Alekperov Lukoil "[',' ' the' ' Russian' ' oil' ' giant' ',' ' is' ' Vladimir' ' Putin' '.'
+ '\n' '\n' 'The' ' Russian' ' president' ' has' ' been' ' accused' ' of'
+ ' ordering']" ", the Russian oil giant , is Vladimir Putin .
+
+ The Russian president has been accused of ordering" False BP, ExxonMobil, Lukoil and Statoil. 6 [' BP', ',', ' Exxon', 'Mobil', ',', ' Luk', 'oil']
+770 237 The name of the CEO of x -1 The name of the CEO of Lukoil Vagit Alekperov Lukoil "[',' ' the' ' Russian' ' oil' ' giant' ',' ' is' ' Vladimir' ' Putin' '.'
+ '\n' '\n' 'The' ' Russian' ' president' ' has' ' been' ' accused' ' of'
+ ' ordering']" ", the Russian oil giant , is Vladimir Putin .
+
+ The Russian president has been accused of ordering" False ExxonMobil, Lukoil and Statoil. As Western 4 [' Exxon', 'Mobil', ',', ' Luk', 'oil']
+771 239 The name of the CEO of x -1 The name of the CEO of International Telecommunication Union Houlin Zhao International Telecommunication Union "[' (' 'IT' 'U' ')' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but'
+ ' it' '�' '�' 's' ' a' ' mouth' 'ful']" ( IT U ) is a bit of a mouth ful , but it � � s a mouth ful False Organization (ILO), the International Telecommunication Union (ITU), the United 9 [' Organization', ' (', 'IL', 'O', '),', ' the', ' International', ' Tele', 'communication', ' Union']
+772 239 The name of the CEO of x -1 The name of the CEO of International Telecommunication Union Houlin Zhao International Telecommunication Union "[' (' 'IT' 'U' ')' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but'
+ ' it' '�' '�' 's' ' a' ' mouth' 'ful']" ( IT U ) is a bit of a mouth ful , but it � � s a mouth ful False but later the International Telecommunication Union set a standard 6 [' but', ' later', ' the', ' International', ' Tele', 'communication', ' Union']
+773 239 The name of the CEO of x -1 The name of the CEO of International Telecommunication Union Houlin Zhao International Telecommunication Union "[' (' 'IT' 'U' ')' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but'
+ ' it' '�' '�' 's' ' a' ' mouth' 'ful']" ( IT U ) is a bit of a mouth ful , but it � � s a mouth ful False " defined by the International Telecommunication Union (ITU).
+" 6 [' defined', ' by', ' the', ' International', ' Tele', 'communication', ' Union']
+774 239 The name of the CEO of x -1 The name of the CEO of International Telecommunication Union Houlin Zhao International Telecommunication Union "[' (' 'IT' 'U' ')' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but'
+ ' it' '�' '�' 's' ' a' ' mouth' 'ful']" ( IT U ) is a bit of a mouth ful , but it � � s a mouth ful False established by the International Telecommunication Union (ITU), which 6 [' established', ' by', ' the', ' International', ' Tele', 'communication', ' Union']
+775 239 The name of the CEO of x -1 The name of the CEO of International Telecommunication Union Houlin Zhao International Telecommunication Union "[' (' 'IT' 'U' ')' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but'
+ ' it' '�' '�' 's' ' a' ' mouth' 'ful']" ( IT U ) is a bit of a mouth ful , but it � � s a mouth ful False " defined by the International Telecommunication Union (ITU).
+" 6 [' defined', ' by', ' the', ' International', ' Tele', 'communication', ' Union']
+776 240 The name of the CEO of x -1 The name of the CEO of Médecins Sans Frontières Christos Christou Médecins Sans Frontières "[' (' 'MS' 'F' ')' ' in' ' the' ' Congo' ',' ' Dr' '.' ' Jo' 'anne' ' Liu'
+ ',' ' said' ' that' ' the' ' organization' ' is' ' �']" ( MS F ) in the Congo , Dr . Jo anne Liu , said that the organization is � False affected the three Médecins Sans Frontières (Doctors Without 10 [' affected', ' the', ' three', ' Mé', 'dec', 'ins', ' Sans', ' Front', 'i', 'è', 'res']
+777 240 The name of the CEO of x -1 The name of the CEO of Médecins Sans Frontières Christos Christou Médecins Sans Frontières "[' (' 'MS' 'F' ')' ' in' ' the' ' Congo' ',' ' Dr' '.' ' Jo' 'anne' ' Liu'
+ ',' ' said' ' that' ' the' ' organization' ' is' ' �']" ( MS F ) in the Congo , Dr . Jo anne Liu , said that the organization is � False available from Médecins Sans Frontières in the next-door 9 [' available', ' from', ' Mé', 'dec', 'ins', ' Sans', ' Front', 'i', 'è', 'res']
+778 240 The name of the CEO of x -1 The name of the CEO of Médecins Sans Frontières Christos Christou Médecins Sans Frontières "[' (' 'MS' 'F' ')' ' in' ' the' ' Congo' ',' ' Dr' '.' ' Jo' 'anne' ' Liu'
+ ',' ' said' ' that' ' the' ' organization' ' is' ' �']" ( MS F ) in the Congo , Dr . Jo anne Liu , said that the organization is � False was available from Médecins Sans Frontières in the next-door 10 [' was', ' available', ' from', ' Mé', 'dec', 'ins', ' Sans', ' Front', 'i', 'è', 'res']
+779 240 The name of the CEO of x -1 The name of the CEO of Médecins Sans Frontières Christos Christou Médecins Sans Frontières "[' (' 'MS' 'F' ')' ' in' ' the' ' Congo' ',' ' Dr' '.' ' Jo' 'anne' ' Liu'
+ ',' ' said' ' that' ' the' ' organization' ' is' ' �']" ( MS F ) in the Congo , Dr . Jo anne Liu , said that the organization is � False department, and Médecins Sans Frontières distributes antiretroviral 10 [' department', ',', ' and', ' Mé', 'dec', 'ins', ' Sans', ' Front', 'i', 'è', 'res']
+780 240 The name of the CEO of x -1 The name of the CEO of Médecins Sans Frontières Christos Christou Médecins Sans Frontières "[' (' 'MS' 'F' ')' ' in' ' the' ' Congo' ',' ' Dr' '.' ' Jo' 'anne' ' Liu'
+ ',' ' said' ' that' ' the' ' organization' ' is' ' �']" ( MS F ) in the Congo , Dr . Jo anne Liu , said that the organization is � False " the medical charity Médecins Sans Frontières (MSF, or ""Doctors" 10 [' the', ' medical', ' charity', ' Mé', 'dec', 'ins', ' Sans', ' Front', 'i', 'è', 'res']
+781 241 The name of the CEO of x -1 The name of the CEO of Capgemini Paul Hermelin Capgemini "[',' ' a' ' French' ' multinational' ' company' ',' ' is' ' a' ' good'
+ ' example' ' of' ' a' ' company' ' that' ' has' ' been' ' able' ' to'
+ ' successfully' ' implement']" , a French multinational company , is a good example of a company that has been able to successfully implement False was designed by Capgemini and previously had 5 [' was', ' designed', ' by', ' Cap', 'gem', 'ini']
+782 241 The name of the CEO of x -1 The name of the CEO of Capgemini Paul Hermelin Capgemini "[',' ' a' ' French' ' multinational' ' company' ',' ' is' ' a' ' good'
+ ' example' ' of' ' a' ' company' ' that' ' has' ' been' ' able' ' to'
+ ' successfully' ' implement']" , a French multinational company , is a good example of a company that has been able to successfully implement False designed by Capgemini and previously 4 [' designed', ' by', ' Cap', 'gem', 'ini']
+783 241 The name of the CEO of x -1 The name of the CEO of Capgemini Paul Hermelin Capgemini "[',' ' a' ' French' ' multinational' ' company' ',' ' is' ' a' ' good'
+ ' example' ' of' ' a' ' company' ' that' ' has' ' been' ' able' ' to'
+ ' successfully' ' implement']" , a French multinational company , is a good example of a company that has been able to successfully implement False was designed by Capgemini and previously 5 [' was', ' designed', ' by', ' Cap', 'gem', 'ini']
+784 241 The name of the CEO of x -1 The name of the CEO of Capgemini Paul Hermelin Capgemini "[',' ' a' ' French' ' multinational' ' company' ',' ' is' ' a' ' good'
+ ' example' ' of' ' a' ' company' ' that' ' has' ' been' ' able' ' to'
+ ' successfully' ' implement']" , a French multinational company , is a good example of a company that has been able to successfully implement False was designed by Capgemini and previously had 5 [' was', ' designed', ' by', ' Cap', 'gem', 'ini']
+785 241 The name of the CEO of x -1 The name of the CEO of Capgemini Paul Hermelin Capgemini "[',' ' a' ' French' ' multinational' ' company' ',' ' is' ' a' ' good'
+ ' example' ' of' ' a' ' company' ' that' ' has' ' been' ' able' ' to'
+ ' successfully' ' implement']" , a French multinational company , is a good example of a company that has been able to successfully implement False government selected Capgemini to design the 4 [' government', ' selected', ' Cap', 'gem', 'ini']
+786 242 The name of the CEO of x -1 The name of the CEO of GitHub Nat Friedman GitHub "[',' ' and' ' the' ' name' ' of' ' the' ' company' '.' '\n' '\n' '------'
+ '\n' '\n' 'The' ' company' ' is' ' also' ' known' ' as' ' GitHub']" ", and the name of the company .
+
+ ------
+
+ The company is also known as GitHub" False tools such as GitHub and Coverity and speed 3 [' tools', ' such', ' as', ' GitHub']
+787 242 The name of the CEO of x -1 The name of the CEO of GitHub Nat Friedman GitHub "[',' ' and' ' the' ' name' ' of' ' the' ' company' '.' '\n' '\n' '------'
+ '\n' '\n' 'The' ' company' ' is' ' also' ' known' ' as' ' GitHub']" ", and the name of the company .
+
+ ------
+
+ The company is also known as GitHub" False development tools such as GitHub and Coverity 4 [' development', ' tools', ' such', ' as', ' GitHub']
+788 242 The name of the CEO of x -1 The name of the CEO of GitHub Nat Friedman GitHub "[',' ' and' ' the' ' name' ' of' ' the' ' company' '.' '\n' '\n' '------'
+ '\n' '\n' 'The' ' company' ' is' ' also' ' known' ' as' ' GitHub']" ", and the name of the company .
+
+ ------
+
+ The company is also known as GitHub" False tools such as GitHub and Coverity 3 [' tools', ' such', ' as', ' GitHub']
+789 246 The name of the CEO of x -1 The name of the CEO of Electrolux Jonas Samuelson Electrolux "[',' ' the' ' company' ' that' ' makes' ' the' ' vacuum' ' cleaner' ','
+ ' is' ' a' ' Swedish' ' company' '.' '\n' '\n' 'The' ' company' ' is'
+ ' based']" ", the company that makes the vacuum cleaner , is a Swedish company .
+
+ The company is based" False President and CEO of Electrolux and Alden Partridge, 5 [' President', ' and', ' CEO', ' of', ' Electro', 'lux']
+790 246 The name of the CEO of x -1 The name of the CEO of Electrolux Jonas Samuelson Electrolux "[',' ' the' ' company' ' that' ' makes' ' the' ' vacuum' ' cleaner' ','
+ ' is' ' a' ' Swedish' ' company' '.' '\n' '\n' 'The' ' company' ' is'
+ ' based']" ", the company that makes the vacuum cleaner , is a Swedish company .
+
+ The company is based" False Extras (2003 to 2005), Electrolux (2005 to 2008), 7 [' Extras', ' (', '2003', ' to', ' 2005', '),', ' Electro', 'lux']
+791 246 The name of the CEO of x -1 The name of the CEO of Electrolux Jonas Samuelson Electrolux "[',' ' the' ' company' ' that' ' makes' ' the' ' vacuum' ' cleaner' ','
+ ' is' ' a' ' Swedish' ' company' '.' '\n' '\n' 'The' ' company' ' is'
+ ' based']" ", the company that makes the vacuum cleaner , is a Swedish company .
+
+ The company is based" False President and CEO of Electrolux and Alden Partridge, 5 [' President', ' and', ' CEO', ' of', ' Electro', 'lux']
+792 246 The name of the CEO of x -1 The name of the CEO of Electrolux Jonas Samuelson Electrolux "[',' ' the' ' company' ' that' ' makes' ' the' ' vacuum' ' cleaner' ','
+ ' is' ' a' ' Swedish' ' company' '.' '\n' '\n' 'The' ' company' ' is'
+ ' based']" ", the company that makes the vacuum cleaner , is a Swedish company .
+
+ The company is based" False President and CEO of Electrolux and Alden Partridge, 5 [' President', ' and', ' CEO', ' of', ' Electro', 'lux']
+793 246 The name of the CEO of x -1 The name of the CEO of Electrolux Jonas Samuelson Electrolux "[',' ' the' ' company' ' that' ' makes' ' the' ' vacuum' ' cleaner' ','
+ ' is' ' a' ' Swedish' ' company' '.' '\n' '\n' 'The' ' company' ' is'
+ ' based']" ", the company that makes the vacuum cleaner , is a Swedish company .
+
+ The company is based" False Extras (2003 to 2005), Electrolux (2005 to 2008), Carbrini 7 [' Extras', ' (', '2003', ' to', ' 2005', '),', ' Electro', 'lux']
+794 249 The name of the CEO of x -1 The name of the CEO of Novartis Joseph Jimenez Novartis "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' in'
+ ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased' ' to']" ", the company that makes the drug , said in a statement .
+
+ "" We are pleased to" False subsidiary of Novartis Pharma AG, in conjunction 4 [' subsidiary', ' of', ' Nov', 'art', 'is']
+795 249 The name of the CEO of x -1 The name of the CEO of Novartis Joseph Jimenez Novartis "[',' ' the' ' company' ' that' ' makes' ' the' ' drug' ',' ' said' ' in'
+ ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are' ' pleased' ' to']" ", the company that makes the drug , said in a statement .
+
+ "" We are pleased to" False subsidiary of Novartis Pharma AG, 4 [' subsidiary', ' of', ' Nov', 'art', 'is']
+796 250 The name of the CEO of x -1 The name of the CEO of Yum! Brands Greg Creed Yum! Brands "[',' ' the' ' parent' ' company' ' of' ' K' 'FC' ',' ' Pizza' ' Hut' ','
+ ' and' ' Taco' ' Bell' ',' ' said' ' in' ' a' ' statement' '.']" , the parent company of K FC , Pizza Hut , and Taco Bell , said in a statement . False was renamed Yum! Brands in May 2002. In 5 [' was', ' renamed', ' Y', 'um', '!', ' Brands']
+797 250 The name of the CEO of x -1 The name of the CEO of Yum! Brands Greg Creed Yum! Brands "[',' ' the' ' parent' ' company' ' of' ' K' 'FC' ',' ' Pizza' ' Hut' ','
+ ' and' ' Taco' ' Bell' ',' ' said' ' in' ' a' ' statement' '.']" , the parent company of K FC , Pizza Hut , and Taco Bell , said in a statement . False Tricon was renamed Yum! Brands in May 2002. 7 [' Tr', 'icon', ' was', ' renamed', ' Y', 'um', '!', ' Brands']
+798 250 The name of the CEO of x -1 The name of the CEO of Yum! Brands Greg Creed Yum! Brands "[',' ' the' ' parent' ' company' ' of' ' K' 'FC' ',' ' Pizza' ' Hut' ','
+ ' and' ' Taco' ' Bell' ',' ' said' ' in' ' a' ' statement' '.']" , the parent company of K FC , Pizza Hut , and Taco Bell , said in a statement . False changed its name to Yum! Brands in 2002. Yum has proved 7 [' changed', ' its', ' name', ' to', ' Y', 'um', '!', ' Brands']
+799 250 The name of the CEO of x -1 The name of the CEO of Yum! Brands Greg Creed Yum! Brands "[',' ' the' ' parent' ' company' ' of' ' K' 'FC' ',' ' Pizza' ' Hut' ','
+ ' and' ' Taco' ' Bell' ',' ' said' ' in' ' a' ' statement' '.']" , the parent company of K FC , Pizza Hut , and Taco Bell , said in a statement . False Tricon was renamed Yum! Brands in May 2002. In 7 [' Tr', 'icon', ' was', ' renamed', ' Y', 'um', '!', ' Brands']
+800 250 The name of the CEO of x -1 The name of the CEO of Yum! Brands Greg Creed Yum! Brands "[',' ' the' ' parent' ' company' ' of' ' K' 'FC' ',' ' Pizza' ' Hut' ','
+ ' and' ' Taco' ' Bell' ',' ' said' ' in' ' a' ' statement' '.']" , the parent company of K FC , Pizza Hut , and Taco Bell , said in a statement . False overall after Yum! Brands (parent of KFC, Taco 5 [' overall', ' after', ' Y', 'um', '!', ' Brands']
+801 251 The name of the CEO of x -1 The name of the CEO of GlaxoSmithKline Emma Walmsley GlaxoSmithKline "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' the' ' drug'
+ ',' ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are'
+ ' pleased']" ", the pharmaceutical company that makes the drug , said in a statement .
+
+ "" We are pleased" False Icos, Pfizer, and GlaxoSmithKline spent a combined 12 [' I', 'cos', ',', ' Pf', 'izer', ',', ' and', ' Gl', 'ax', 'o', 'Smith', 'K', 'line']
+802 251 The name of the CEO of x -1 The name of the CEO of GlaxoSmithKline Emma Walmsley GlaxoSmithKline "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' the' ' drug'
+ ',' ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are'
+ ' pleased']" ", the pharmaceutical company that makes the drug , said in a statement .
+
+ "" We are pleased" False pharmaceutical company GlaxoSmithKline employed the injunction 7 [' pharmaceutical', ' company', ' Gl', 'ax', 'o', 'Smith', 'K', 'line']
+803 251 The name of the CEO of x -1 The name of the CEO of GlaxoSmithKline Emma Walmsley GlaxoSmithKline "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' the' ' drug'
+ ',' ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are'
+ ' pleased']" ", the pharmaceutical company that makes the drug , said in a statement .
+
+ "" We are pleased" False pharmaceutical company GlaxoSmithKline (GSK) later identified 7 [' pharmaceutical', ' company', ' Gl', 'ax', 'o', 'Smith', 'K', 'line']
+804 251 The name of the CEO of x -1 The name of the CEO of GlaxoSmithKline Emma Walmsley GlaxoSmithKline "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' the' ' drug'
+ ',' ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are'
+ ' pleased']" ", the pharmaceutical company that makes the drug , said in a statement .
+
+ "" We are pleased" False manufactured by GlaxoSmithKline in modern times), paid 7 [' manufactured', ' by', ' Gl', 'ax', 'o', 'Smith', 'K', 'line']
+805 251 The name of the CEO of x -1 The name of the CEO of GlaxoSmithKline Emma Walmsley GlaxoSmithKline "[',' ' the' ' pharmaceutical' ' company' ' that' ' makes' ' the' ' drug'
+ ',' ' said' ' in' ' a' ' statement' '.' '\n' '\n' '""' 'We' ' are'
+ ' pleased']" ", the pharmaceutical company that makes the drug , said in a statement .
+
+ "" We are pleased" False children: Rotarix by GlaxoSmithKline and RotaTeq 11 [' children', ':', ' Rot', 'ar', 'ix', ' by', ' Gl', 'ax', 'o', 'Smith', 'K', 'line']
+806 252 The name of the CEO of x -1 The name of the CEO of Solvay Ilham Kadri Solvay "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' chemicals' ' for' ' over' ' a' ' century' '.' '\n' '\n' 'The']" ", a company that has been in the business of making chemicals for over a century .
+
+ The" False conferences, the Solvay congress of 5 [' conferences', ',', ' the', ' Sol', 'v', 'ay']
+807 252 The name of the CEO of x -1 The name of the CEO of Solvay Ilham Kadri Solvay "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' chemicals' ' for' ' over' ' a' ' century' '.' '\n' '\n' 'The']" ", a company that has been in the business of making chemicals for over a century .
+
+ The" False meet Ernest Solvay to negotiate 4 [' meet', ' Ernest', ' Sol', 'v', 'ay']
+808 252 The name of the CEO of x -1 The name of the CEO of Solvay Ilham Kadri Solvay "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' chemicals' ' for' ' over' ' a' ' century' '.' '\n' '\n' 'The']" ", a company that has been in the business of making chemicals for over a century .
+
+ The" False making alkali by the Solvay process. As 7 [' making', ' alk', 'ali', ' by', ' the', ' Sol', 'v', 'ay']
+809 252 The name of the CEO of x -1 The name of the CEO of Solvay Ilham Kadri Solvay "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' chemicals' ' for' ' over' ' a' ' century' '.' '\n' '\n' 'The']" ", a company that has been in the business of making chemicals for over a century .
+
+ The" False coke ovens, 120 Solvay coke ovens, 15 red 8 [' co', 'ke', ' oven', 's', ',', ' 120', ' Sol', 'v', 'ay']
+810 252 The name of the CEO of x -1 The name of the CEO of Solvay Ilham Kadri Solvay "[',' ' a' ' company' ' that' ' has' ' been' ' in' ' the' ' business' ' of'
+ ' making' ' chemicals' ' for' ' over' ' a' ' century' '.' '\n' '\n' 'The']" ", a company that has been in the business of making chemicals for over a century .
+
+ The" False of Camillus and Solvay. Construction on 6 [' of', ' Cam', 'illus', ' and', ' Sol', 'v', 'ay']
+811 255 The name of the CEO of x -1 The name of the CEO of BlackRock Larry Fink BlackRock "[',' ' the' ' world' '�' '�' 's' ' largest' ' asset' ' manager' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'We']" ", the world � � s largest asset manager , said in a statement .
+
+ � � We" False Private Limited, BlackRock Inc and Norges 4 [' Private', ' Limited', ',', ' Black', 'Rock']
+812 255 The name of the CEO of x -1 The name of the CEO of BlackRock Larry Fink BlackRock "[',' ' the' ' world' '�' '�' 's' ' largest' ' asset' ' manager' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'We']" ", the world � � s largest asset manager , said in a statement .
+
+ � � We" False Industries, Inc., 1987), BlackRock (1988), Federated Department 7 [' Industries', ',', ' Inc', '.,', ' 1987', '),', ' Black', 'Rock']
+813 255 The name of the CEO of x -1 The name of the CEO of BlackRock Larry Fink BlackRock "[',' ' the' ' world' '�' '�' 's' ' largest' ' asset' ' manager' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'We']" ", the world � � s largest asset manager , said in a statement .
+
+ � � We" False Maheshwari, EVP at DSP BlackRock Investment 11 [' Ma', 'hes', 'hw', 'ari', ',', ' EV', 'P', ' at', ' D', 'SP', ' Black', 'Rock']
+814 255 The name of the CEO of x -1 The name of the CEO of BlackRock Larry Fink BlackRock "[',' ' the' ' world' '�' '�' 's' ' largest' ' asset' ' manager' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'We']" ", the world � � s largest asset manager , said in a statement .
+
+ � � We" False Private Limited, BlackRock Inc and Norges Bank 4 [' Private', ' Limited', ',', ' Black', 'Rock']
+815 255 The name of the CEO of x -1 The name of the CEO of BlackRock Larry Fink BlackRock "[',' ' the' ' world' '�' '�' 's' ' largest' ' asset' ' manager' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'We']" ", the world � � s largest asset manager , said in a statement .
+
+ � � We" False Maheshwari, EVP at DSP BlackRock Investment 11 [' Ma', 'hes', 'hw', 'ari', ',', ' EV', 'P', ' at', ' D', 'SP', ' Black', 'Rock']
+816 256 The name of the CEO of x -1 The name of the CEO of Food and Agriculture Organization Mohammed Shia' Al Sudani Food and Agriculture Organization "[' of' ' the' ' United' ' Nations' ' (' 'FA' 'O' ')' ' is' ' José' ' G'
+ 'raz' 'iano' ' da' ' Silva' '.' ' He' ' is' ' a' ' Brazilian']" of the United Nations ( FA O ) is José G raz iano da Silva . He is a Brazilian False food market. The Food and Agriculture Organization of the United 7 [' food', ' market', '.', ' The', ' Food', ' and', ' Agriculture', ' Organization']
+817 256 The name of the CEO of x -1 The name of the CEO of Food and Agriculture Organization Mohammed Shia' Al Sudani Food and Agriculture Organization "[' of' ' the' ' United' ' Nations' ' (' 'FA' 'O' ')' ' is' ' José' ' G'
+ 'raz' 'iano' ' da' ' Silva' '.' ' He' ' is' ' a' ' Brazilian']" of the United Nations ( FA O ) is José G raz iano da Silva . He is a Brazilian False breeding mares. The Food and Agriculture Organization considers it to be 8 [' breeding', ' m', 'ares', '.', ' The', ' Food', ' and', ' Agriculture', ' Organization']
+818 256 The name of the CEO of x -1 The name of the CEO of Food and Agriculture Organization Mohammed Shia' Al Sudani Food and Agriculture Organization "[' of' ' the' ' United' ' Nations' ' (' 'FA' 'O' ')' ' is' ' José' ' G'
+ 'raz' 'iano' ' da' ' Silva' '.' ' He' ' is' ' a' ' Brazilian']" of the United Nations ( FA O ) is José G raz iano da Silva . He is a Brazilian False " Chairman of the Food and Agriculture Organization Council.
+" 6 [' Chairman', ' of', ' the', ' Food', ' and', ' Agriculture', ' Organization']
+819 256 The name of the CEO of x -1 The name of the CEO of Food and Agriculture Organization Mohammed Shia' Al Sudani Food and Agriculture Organization "[' of' ' the' ' United' ' Nations' ' (' 'FA' 'O' ')' ' is' ' José' ' G'
+ 'raz' 'iano' ' da' ' Silva' '.' ' He' ' is' ' a' ' Brazilian']" of the United Nations ( FA O ) is José G raz iano da Silva . He is a Brazilian False Ambassador for the Food and Agriculture Organization (FAO) of the United 6 [' Ambassador', ' for', ' the', ' Food', ' and', ' Agriculture', ' Organization']
+820 256 The name of the CEO of x -1 The name of the CEO of Food and Agriculture Organization Mohammed Shia' Al Sudani Food and Agriculture Organization "[' of' ' the' ' United' ' Nations' ' (' 'FA' 'O' ')' ' is' ' José' ' G'
+ 'raz' 'iano' ' da' ' Silva' '.' ' He' ' is' ' a' ' Brazilian']" of the United Nations ( FA O ) is José G raz iano da Silva . He is a Brazilian False long-lines. The Food and Agriculture Organization of the United Nations 8 [' long', '-', 'lines', '.', ' The', ' Food', ' and', ' Agriculture', ' Organization']
+821 257 The name of the CEO of x -1 The name of the CEO of Simon & Schuster Richard E. Snyder Simon & Schuster "[',' ' the' ' publisher' ' of' ' the' ' book' ',' ' is' ' also' ' the'
+ ' name' ' of' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' the']" , the publisher of the book , is also the name of the CEO of the company that owns the False writers, in 1980 Simon & Schuster created a new 7 [' writers', ',', ' in', ' 1980', ' Simon', ' &', ' Sch', 'uster']
+822 257 The name of the CEO of x -1 The name of the CEO of Simon & Schuster Richard E. Snyder Simon & Schuster "[',' ' the' ' publisher' ' of' ' the' ' book' ',' ' is' ' also' ' the'
+ ' name' ' of' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' the']" , the publisher of the book , is also the name of the CEO of the company that owns the False published by Simon & Schuster and sold 30,000 5 [' published', ' by', ' Simon', ' &', ' Sch', 'uster']
+823 257 The name of the CEO of x -1 The name of the CEO of Simon & Schuster Richard E. Snyder Simon & Schuster "[',' ' the' ' publisher' ' of' ' the' ' book' ',' ' is' ' also' ' the'
+ ' name' ' of' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' the']" , the publisher of the book , is also the name of the CEO of the company that owns the False full-time for Simon & Schuster in Cambridge. 7 [' full', '-', 'time', ' for', ' Simon', ' &', ' Sch', 'uster']
+824 257 The name of the CEO of x -1 The name of the CEO of Simon & Schuster Richard E. Snyder Simon & Schuster "[',' ' the' ' publisher' ' of' ' the' ' book' ',' ' is' ' also' ' the'
+ ' name' ' of' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' the']" , the publisher of the book , is also the name of the CEO of the company that owns the False 3 ['Simon', ' &', ' Sch', 'uster']
+825 257 The name of the CEO of x -1 The name of the CEO of Simon & Schuster Richard E. Snyder Simon & Schuster "[',' ' the' ' publisher' ' of' ' the' ' book' ',' ' is' ' also' ' the'
+ ' name' ' of' ' the' ' CEO' ' of' ' the' ' company' ' that' ' owns'
+ ' the']" , the publisher of the book , is also the name of the CEO of the company that owns the False an audiobook by Simon & Schuster Audio, read by Megan 7 [' an', ' audi', 'obook', ' by', ' Simon', ' &', ' Sch', 'uster']
+826 258 The name of the CEO of x -1 The name of the CEO of Best Buy Hubert Joly Best Buy "[' is' ' Brian' ' Dunn' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' very' ' helpful' '.' ' He' ' is' ' very' ' knowledgeable']" is Brian Dunn . He is a very nice guy and he is very helpful . He is very knowledgeable False available from Best Buy and Napster. The album's 3 [' available', ' from', ' Best', ' Buy']
+827 258 The name of the CEO of x -1 The name of the CEO of Best Buy Hubert Joly Best Buy "[' is' ' Brian' ' Dunn' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' very' ' helpful' '.' ' He' ' is' ' very' ' knowledgeable']" is Brian Dunn . He is a very nice guy and he is very helpful . He is very knowledgeable False partnerships with Eaton and Best Buy for sales and 5 [' partnerships', ' with', ' Eaton', ' and', ' Best', ' Buy']
+828 258 The name of the CEO of x -1 The name of the CEO of Best Buy Hubert Joly Best Buy "[' is' ' Brian' ' Dunn' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' very' ' helpful' '.' ' He' ' is' ' very' ' knowledgeable']" is Brian Dunn . He is a very nice guy and he is very helpful . He is very knowledgeable False Season through Best Buy in the United States. 3 [' Season', ' through', ' Best', ' Buy']
+829 258 The name of the CEO of x -1 The name of the CEO of Best Buy Hubert Joly Best Buy "[' is' ' Brian' ' Dunn' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' very' ' helpful' '.' ' He' ' is' ' very' ' knowledgeable']" is Brian Dunn . He is a very nice guy and he is very helpful . He is very knowledgeable False drumsticks with the game. Best Buy provided a discount 7 [' drum', 'sticks', ' with', ' the', ' game', '.', ' Best', ' Buy']
+830 258 The name of the CEO of x -1 The name of the CEO of Best Buy Hubert Joly Best Buy "[' is' ' Brian' ' Dunn' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' very' ' helpful' '.' ' He' ' is' ' very' ' knowledgeable']" is Brian Dunn . He is a very nice guy and he is very helpful . He is very knowledgeable False ' Lakes, SuperValu, Best Buy and Valspar. 8 "["" '"", ' Lakes', ',', ' Super', 'Val', 'u', ',', ' Best', ' Buy']"
+831 259 The name of the CEO of x -1 The name of the CEO of Humana Bruce Broussard Humana "[',' ' the' ' largest' ' health' ' insurer' ' in' ' the' ' United'
+ ' States' ',' ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 50'
+ ' most' ' powerful']" , the largest health insurer in the United States , is on the list of the top 50 most powerful False transported to the Humana Northside Hospital. 4 [' transported', ' to', ' the', ' Hum', 'ana']
+832 259 The name of the CEO of x -1 The name of the CEO of Humana Bruce Broussard Humana "[',' ' the' ' largest' ' health' ' insurer' ' in' ' the' ' United'
+ ' States' ',' ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 50'
+ ' most' ' powerful']" , the largest health insurer in the United States , is on the list of the top 50 most powerful False transported to the Humana Northside Hospital. 4 [' transported', ' to', ' the', ' Hum', 'ana']
+833 259 The name of the CEO of x -1 The name of the CEO of Humana Bruce Broussard Humana "[',' ' the' ' largest' ' health' ' insurer' ' in' ' the' ' United'
+ ' States' ',' ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 50'
+ ' most' ' powerful']" , the largest health insurer in the United States , is on the list of the top 50 most powerful False victory in the Humana Distaff Handicap, 4 [' victory', ' in', ' the', ' Hum', 'ana']
+834 259 The name of the CEO of x -1 The name of the CEO of Humana Bruce Broussard Humana "[',' ' the' ' largest' ' health' ' insurer' ' in' ' the' ' United'
+ ' States' ',' ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 50'
+ ' most' ' powerful']" , the largest health insurer in the United States , is on the list of the top 50 most powerful False victory in the Humana Distaff Handicap, 4 [' victory', ' in', ' the', ' Hum', 'ana']
+835 259 The name of the CEO of x -1 The name of the CEO of Humana Bruce Broussard Humana "[',' ' the' ' largest' ' health' ' insurer' ' in' ' the' ' United'
+ ' States' ',' ' is' ' on' ' the' ' list' ' of' ' the' ' top' ' 50'
+ ' most' ' powerful']" , the largest health insurer in the United States , is on the list of the top 50 most powerful False victory in the Humana Distaff Handicap, 4 [' victory', ' in', ' the', ' Hum', 'ana']
+836 261 The name of the CEO of x -1 The name of the CEO of T-Mobile US Mike Sievert T-Mobile US "[',' ' John' ' Le' 'ge' 're' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' '�' '�' 's' ' a' ' bit' ' of']" , John Le ge re , is a bit of a mystery . He � � s a bit of False Wireless and T-Mobile US were among the 5 [' Wireless', ' and', ' T', '-', 'Mobile', ' US']
+837 261 The name of the CEO of x -1 The name of the CEO of T-Mobile US Mike Sievert T-Mobile US "[',' ' John' ' Le' 'ge' 're' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' '�' '�' 's' ' a' ' bit' ' of']" , John Le ge re , is a bit of a mystery . He � � s a bit of False Verizon Wireless and T-Mobile US were among the first 6 [' Verizon', ' Wireless', ' and', ' T', '-', 'Mobile', ' US']
+838 261 The name of the CEO of x -1 The name of the CEO of T-Mobile US Mike Sievert T-Mobile US "[',' ' John' ' Le' 'ge' 're' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' '�' '�' 's' ' a' ' bit' ' of']" , John Le ge re , is a bit of a mystery . He � � s a bit of False America in Redmond, T-Mobile US in Bellevue, 7 [' America', ' in', ' Redmond', ',', ' T', '-', 'Mobile', ' US']
+839 261 The name of the CEO of x -1 The name of the CEO of T-Mobile US Mike Sievert T-Mobile US "[',' ' John' ' Le' 'ge' 're' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' '�' '�' 's' ' a' ' bit' ' of']" , John Le ge re , is a bit of a mystery . He � � s a bit of False Verizon Wireless and T-Mobile US were among the 6 [' Verizon', ' Wireless', ' and', ' T', '-', 'Mobile', ' US']
+840 261 The name of the CEO of x -1 The name of the CEO of T-Mobile US Mike Sievert T-Mobile US "[',' ' John' ' Le' 'ge' 're' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' '�' '�' 's' ' a' ' bit' ' of']" , John Le ge re , is a bit of a mystery . He � � s a bit of False in Redmond, T-Mobile US in Bellevue, Expedia 6 [' in', ' Redmond', ',', ' T', '-', 'Mobile', ' US']
+841 263 The name of the CEO of x -1 The name of the CEO of Broderbund Doug Carlston Broderbund "[',' ' the' ' company' ' that' ' created' ' the' ' original' ' game' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' created'
+ ' the' ' game']" , the company that created the original game , is also the name of the company that created the game False were sold to Broderbund, which was owned 6 [' were', ' sold', ' to', ' Bro', 'der', 'b', 'und']
+842 263 The name of the CEO of x -1 The name of the CEO of Broderbund Doug Carlston Broderbund "[',' ' the' ' company' ' that' ' created' ' the' ' original' ' game' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' created'
+ ' the' ' game']" , the company that created the original game , is also the name of the company that created the game False Mack, the Broderbund executive who 6 [' Mack', ',', ' the', ' Bro', 'der', 'b', 'und']
+843 263 The name of the CEO of x -1 The name of the CEO of Broderbund Doug Carlston Broderbund "[',' ' the' ' company' ' that' ' created' ' the' ' original' ' game' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' created'
+ ' the' ' game']" , the company that created the original game , is also the name of the company that created the game False by Robert Cook and Broderbund in which the player 7 [' by', ' Robert', ' Cook', ' and', ' Bro', 'der', 'b', 'und']
+844 263 The name of the CEO of x -1 The name of the CEO of Broderbund Doug Carlston Broderbund "[',' ' the' ' company' ' that' ' created' ' the' ' original' ' game' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' created'
+ ' the' ' game']" , the company that created the original game , is also the name of the company that created the game False were sold to Broderbund, which was owned 6 [' were', ' sold', ' to', ' Bro', 'der', 'b', 'und']
+845 263 The name of the CEO of x -1 The name of the CEO of Broderbund Doug Carlston Broderbund "[',' ' the' ' company' ' that' ' created' ' the' ' original' ' game' ','
+ ' is' ' also' ' the' ' name' ' of' ' the' ' company' ' that' ' created'
+ ' the' ' game']" , the company that created the original game , is also the name of the company that created the game False publication by Broderbund in 1983. Broderbund 5 [' publication', ' by', ' Bro', 'der', 'b', 'und']
+846 264 The name of the CEO of x -1 The name of the CEO of Hasbro Brian Goldner Hasbro "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' is' ' based' ' in' ' Paw' 't' 'ucket' ',' ' Rhode' ' Island']" " is a bit of a mystery .
+
+ The company is based in Paw t ucket , Rhode Island" False In January 2014, Hasbro announced a 5 [' In', ' January', ' 2014', ',', ' Has', 'bro']
+847 264 The name of the CEO of x -1 The name of the CEO of Hasbro Brian Goldner Hasbro "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' is' ' based' ' in' ' Paw' 't' 'ucket' ',' ' Rhode' ' Island']" " is a bit of a mystery .
+
+ The company is based in Paw t ucket , Rhode Island" False Imaginations after Hasbro ceased production 4 [' Imag', 'inations', ' after', ' Has', 'bro']
+848 264 The name of the CEO of x -1 The name of the CEO of Hasbro Brian Goldner Hasbro "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' is' ' based' ' in' ' Paw' 't' 'ucket' ',' ' Rhode' ' Island']" " is a bit of a mystery .
+
+ The company is based in Paw t ucket , Rhode Island" False signed a deal with Hasbro to make Friendship 5 [' signed', ' a', ' deal', ' with', ' Has', 'bro']
+849 264 The name of the CEO of x -1 The name of the CEO of Hasbro Brian Goldner Hasbro "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' is' ' based' ' in' ' Paw' 't' 'ucket' ',' ' Rhode' ' Island']" " is a bit of a mystery .
+
+ The company is based in Paw t ucket , Rhode Island" False " survey conducted by Hasbro found that ""two" 4 [' survey', ' conducted', ' by', ' Has', 'bro']
+850 264 The name of the CEO of x -1 The name of the CEO of Hasbro Brian Goldner Hasbro "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' is' ' based' ' in' ' Paw' 't' 'ucket' ',' ' Rhode' ' Island']" " is a bit of a mystery .
+
+ The company is based in Paw t ucket , Rhode Island" False two. In 1991, Hasbro became Play-Doh's 6 [' two', '.', ' In', ' 1991', ',', ' Has', 'bro']
+851 266 The name of the CEO of x -1 The name of the CEO of Nederlandse Spoorwegen Marjan Rintel Nederlandse Spoorwegen "[',' ' the' ' Dutch' ' national' ' railway' ' company' ',' ' is' ' Jan'
+ ' H' 'enn' 'eman' '.' ' He' ' is' ' a' ' man' ' of' ' few' ' words']" , the Dutch national railway company , is Jan H enn eman . He is a man of few words False (1868 – 1937), Nederlandse Spoorwegen (1938 – 2000), 13 [' (', '18', '68', ' –', ' 1937', '),', ' N', 'eder', 'land', 'se', ' Sp', 'oor', 'we', 'gen']
+852 267 The name of the CEO of x -1 The name of the CEO of NBCUniversal Steve Burke NBCUniversal "[',' ' Jeff' ' Zucker' ',' ' is' ' a' ' perfect' ' example' ' of' ' this'
+ '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Jeff Zucker , is a perfect example of this . He is a man who has been in the False 1 ['NBC', 'Universal']
+853 267 The name of the CEO of x -1 The name of the CEO of NBCUniversal Steve Burke NBCUniversal "[',' ' Jeff' ' Zucker' ',' ' is' ' a' ' perfect' ' example' ' of' ' this'
+ '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Jeff Zucker , is a perfect example of this . He is a man who has been in the False 1 ['NBC', 'Universal']
+854 267 The name of the CEO of x -1 The name of the CEO of NBCUniversal Steve Burke NBCUniversal "[',' ' Jeff' ' Zucker' ',' ' is' ' a' ' perfect' ' example' ' of' ' this'
+ '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Jeff Zucker , is a perfect example of this . He is a man who has been in the False 1 ['NBC', 'Universal']
+855 267 The name of the CEO of x -1 The name of the CEO of NBCUniversal Steve Burke NBCUniversal "[',' ' Jeff' ' Zucker' ',' ' is' ' a' ' perfect' ' example' ' of' ' this'
+ '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Jeff Zucker , is a perfect example of this . He is a man who has been in the False 1 ['NBC', 'Universal']
+856 267 The name of the CEO of x -1 The name of the CEO of NBCUniversal Steve Burke NBCUniversal "[',' ' Jeff' ' Zucker' ',' ' is' ' a' ' perfect' ' example' ' of' ' this'
+ '.' ' He' ' is' ' a' ' man' ' who' ' has' ' been' ' in' ' the']" , Jeff Zucker , is a perfect example of this . He is a man who has been in the False the merger of NBCUniversal with Kabletown, 4 [' the', ' merger', ' of', ' NBC', 'Universal']
+857 270 The name of the CEO of x -1 The name of the CEO of Valve Corporation Gabe Newell Valve Corporation "[' is' ' Gabe' ' New' 'ell' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Valve' ' and' ' the' ' creator' ' of' ' the' ' Half' '-' 'Life'
+ ' series']" is Gabe New ell . He is the founder of Valve and the creator of the Half - Life series True " ""powerfully atmospheric"" by Valve Corporation writer Marc Laidlaw.
+" 7 "[' ""', 'power', 'fully', ' atmospheric', '""', ' by', ' Valve', ' Corporation']"
+858 270 The name of the CEO of x -1 The name of the CEO of Valve Corporation Gabe Newell Valve Corporation "[' is' ' Gabe' ' New' 'ell' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Valve' ' and' ' the' ' creator' ' of' ' the' ' Half' '-' 'Life'
+ ' series']" is Gabe New ell . He is the founder of Valve and the creator of the Half - Life series True RapidFire engine. Valve Corporation approached Schatz 5 [' Rapid', 'Fire', ' engine', '.', ' Valve', ' Corporation']
+859 270 The name of the CEO of x -1 The name of the CEO of Valve Corporation Gabe Newell Valve Corporation "[' is' ' Gabe' ' New' 'ell' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Valve' ' and' ' the' ' creator' ' of' ' the' ' Half' '-' 'Life'
+ ' series']" is Gabe New ell . He is the founder of Valve and the creator of the Half - Life series True secured permission from Valve Corporation to cast McLain in 4 [' secured', ' permission', ' from', ' Valve', ' Corporation']
+860 270 The name of the CEO of x -1 The name of the CEO of Valve Corporation Gabe Newell Valve Corporation "[' is' ' Gabe' ' New' 'ell' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Valve' ' and' ' the' ' creator' ' of' ' the' ' Half' '-' 'Life'
+ ' series']" is Gabe New ell . He is the founder of Valve and the creator of the Half - Life series True PC Magazine, Valve Corporation CEO and co-founder 4 [' PC', ' Magazine', ',', ' Valve', ' Corporation']
+861 270 The name of the CEO of x -1 The name of the CEO of Valve Corporation Gabe Newell Valve Corporation "[' is' ' Gabe' ' New' 'ell' '.' ' He' ' is' ' the' ' founder' ' of'
+ ' Valve' ' and' ' the' ' creator' ' of' ' the' ' Half' '-' 'Life'
+ ' series']" is Gabe New ell . He is the founder of Valve and the creator of the Half - Life series True " ""powerfully atmospheric"" by Valve Corporation writer Marc" 7 "[' ""', 'power', 'fully', ' atmospheric', '""', ' by', ' Valve', ' Corporation']"
+862 271 The name of the CEO of x -1 The name of the CEO of Tesco Ken Murph Tesco "[' is' ' not' ' a' ' name' ' that' ' is' ' well' ' known' ' in' ' the'
+ ' world' ' of' ' business' '.' ' But' ',' ' if' ' you' ' are' ' a']" is not a name that is well known in the world of business . But , if you are a False site in 2014, but Tesco announced early 6 [' site', ' in', ' 2014', ',', ' but', ' Tes', 'co']
+863 271 The name of the CEO of x -1 The name of the CEO of Tesco Ken Murph Tesco "[' is' ' not' ' a' ' name' ' that' ' is' ' well' ' known' ' in' ' the'
+ ' world' ' of' ' business' '.' ' But' ',' ' if' ' you' ' are' ' a']" is not a name that is well known in the world of business . But , if you are a False proposed opening of a Tesco Express store on 5 [' proposed', ' opening', ' of', ' a', ' Tes', 'co']
+864 271 The name of the CEO of x -1 The name of the CEO of Tesco Ken Murph Tesco "[' is' ' not' ' a' ' name' ' that' ' is' ' well' ' known' ' in' ' the'
+ ' world' ' of' ' business' '.' ' But' ',' ' if' ' you' ' are' ' a']" is not a name that is well known in the world of business . But , if you are a False is next door to the Tesco supermarket. Also 6 [' is', ' next', ' door', ' to', ' the', ' Tes', 'co']
+865 271 The name of the CEO of x -1 The name of the CEO of Tesco Ken Murph Tesco "[' is' ' not' ' a' ' name' ' that' ' is' ' well' ' known' ' in' ' the'
+ ' world' ' of' ' business' '.' ' But' ',' ' if' ' you' ' are' ' a']" is not a name that is well known in the world of business . But , if you are a False by a large 24-hour Tesco Extra superstore 7 [' by', ' a', ' large', ' 24', '-', 'hour', ' Tes', 'co']
+866 271 The name of the CEO of x -1 The name of the CEO of Tesco Ken Murph Tesco "[' is' ' not' ' a' ' name' ' that' ' is' ' well' ' known' ' in' ' the'
+ ' world' ' of' ' business' '.' ' But' ',' ' if' ' you' ' are' ' a']" is not a name that is well known in the world of business . But , if you are a False with one provider, Tesco Personal Finance, 5 [' with', ' one', ' provider', ',', ' Tes', 'co']
+867 273 The name of the CEO of x -1 The name of the CEO of PricewaterhouseCoopers Samuel DiPiazza PricewaterhouseCoopers "[',' ' the' ' world' '�' '�' 's' ' largest' ' professional' ' services'
+ ' firm' ',' ' is' ' Bob' ' Mor' 'itz' '.' ' He' ' is' ' a' ' man']" , the world � � s largest professional services firm , is Bob Mor itz . He is a man False According to a 2011 PricewaterhouseCoopers report, India's 8 [' According', ' to', ' a', ' 2011', ' Pric', 'ewater', 'house', 'Co', 'opers']
+868 273 The name of the CEO of x -1 The name of the CEO of PricewaterhouseCoopers Samuel DiPiazza PricewaterhouseCoopers "[',' ' the' ' world' '�' '�' 's' ' largest' ' professional' ' services'
+ ' firm' ',' ' is' ' Bob' ' Mor' 'itz' '.' ' He' ' is' ' a' ' man']" , the world � � s largest professional services firm , is Bob Mor itz . He is a man False 2000 report from PricewaterhouseCoopers concluded 7 [' 2000', ' report', ' from', ' Pric', 'ewater', 'house', 'Co', 'opers']
+869 273 The name of the CEO of x -1 The name of the CEO of PricewaterhouseCoopers Samuel DiPiazza PricewaterhouseCoopers "[',' ' the' ' world' '�' '�' 's' ' largest' ' professional' ' services'
+ ' firm' ',' ' is' ' Bob' ' Mor' 'itz' '.' ' He' ' is' ' a' ' man']" , the world � � s largest professional services firm , is Bob Mor itz . He is a man False engaged accounting firm PricewaterhouseCoopers (PwC) to give substance 7 [' engaged', ' accounting', ' firm', ' Pric', 'ewater', 'house', 'Co', 'opers']
+870 273 The name of the CEO of x -1 The name of the CEO of PricewaterhouseCoopers Samuel DiPiazza PricewaterhouseCoopers "[',' ' the' ' world' '�' '�' 's' ' largest' ' professional' ' services'
+ ' firm' ',' ' is' ' Bob' ' Mor' 'itz' '.' ' He' ' is' ' a' ' man']" , the world � � s largest professional services firm , is Bob Mor itz . He is a man False & Young or PricewaterhouseCoopers to complete 7 [' &', ' Young', ' or', ' Pric', 'ewater', 'house', 'Co', 'opers']
+871 273 The name of the CEO of x -1 The name of the CEO of PricewaterhouseCoopers Samuel DiPiazza PricewaterhouseCoopers "[',' ' the' ' world' '�' '�' 's' ' largest' ' professional' ' services'
+ ' firm' ',' ' is' ' Bob' ' Mor' 'itz' '.' ' He' ' is' ' a' ' man']" , the world � � s largest professional services firm , is Bob Mor itz . He is a man False According to a 2011 PricewaterhouseCoopers report, India's 8 [' According', ' to', ' a', ' 2011', ' Pric', 'ewater', 'house', 'Co', 'opers']
+872 274 The name of the CEO of x -1 The name of the CEO of UBS Sergio Ermotti UBS "[',' ' the' ' Swiss' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' wealth' ' manager' ',' ' is' ' U' 'BS' ' AG' '.']" , the Swiss bank that is the world � � s largest wealth manager , is U BS AG . False mid-1990s, UBS came under fire from 6 [' mid', '-', '1990', 's', ',', ' U', 'BS']
+873 274 The name of the CEO of x -1 The name of the CEO of UBS Sergio Ermotti UBS "[',' ' the' ' Swiss' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' wealth' ' manager' ',' ' is' ' U' 'BS' ' AG' '.']" , the Swiss bank that is the world � � s largest wealth manager , is U BS AG . False 1 ['U', 'BS']
+874 274 The name of the CEO of x -1 The name of the CEO of UBS Sergio Ermotti UBS "[',' ' the' ' Swiss' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' wealth' ' manager' ',' ' is' ' U' 'BS' ' AG' '.']" , the Swiss bank that is the world � � s largest wealth manager , is U BS AG . False 26 October 2010, UBS announced that its 5 [' 26', ' October', ' 2010', ',', ' U', 'BS']
+875 274 The name of the CEO of x -1 The name of the CEO of UBS Sergio Ermotti UBS "[',' ' the' ' Swiss' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' wealth' ' manager' ',' ' is' ' U' 'BS' ' AG' '.']" , the Swiss bank that is the world � � s largest wealth manager , is U BS AG . False demonstrated that UBS likely took active 3 [' demonstrated', ' that', ' U', 'BS']
+876 274 The name of the CEO of x -1 The name of the CEO of UBS Sergio Ermotti UBS "[',' ' the' ' Swiss' ' bank' ' that' ' is' ' the' ' world' '�' '�' 's'
+ ' largest' ' wealth' ' manager' ',' ' is' ' U' 'BS' ' AG' '.']" , the Swiss bank that is the world � � s largest wealth manager , is U BS AG . False the spring of 2009, UBS announced another management 6 [' the', ' spring', ' of', ' 2009', ',', ' U', 'BS']
+877 275 The name of the CEO of x -1 The name of the CEO of Opera Software Jon Stephenson von Tetzchner Opera Software "[',' ' the' ' company' ' that' ' makes' ' the' ' Opera' ' browser' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company']" ", the company that makes the Opera browser , is a bit of a mystery .
+
+ The company" False technical officer of Opera Software and creator 4 [' technical', ' officer', ' of', ' Opera', ' Software']
+878 275 The name of the CEO of x -1 The name of the CEO of Opera Software Jon Stephenson von Tetzchner Opera Software "[',' ' the' ' company' ' that' ' makes' ' the' ' Opera' ' browser' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company']" ", the company that makes the Opera browser , is a bit of a mystery .
+
+ The company" False browser developed by Opera Software and Nintendo. 4 [' browser', ' developed', ' by', ' Opera', ' Software']
+879 275 The name of the CEO of x -1 The name of the CEO of Opera Software Jon Stephenson von Tetzchner Opera Software "[',' ' the' ' company' ' that' ' makes' ' the' ' Opera' ' browser' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company']" ", the company that makes the Opera browser , is a bit of a mystery .
+
+ The company" False browser developed by Opera Software and Nintendo. 4 [' browser', ' developed', ' by', ' Opera', ' Software']
+880 275 The name of the CEO of x -1 The name of the CEO of Opera Software Jon Stephenson von Tetzchner Opera Software "[',' ' the' ' company' ' that' ' makes' ' the' ' Opera' ' browser' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company']" ", the company that makes the Opera browser , is a bit of a mystery .
+
+ The company" False technical officer of Opera Software and creator of 4 [' technical', ' officer', ' of', ' Opera', ' Software']
+881 275 The name of the CEO of x -1 The name of the CEO of Opera Software Jon Stephenson von Tetzchner Opera Software "[',' ' the' ' company' ' that' ' makes' ' the' ' Opera' ' browser' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company']" ", the company that makes the Opera browser , is a bit of a mystery .
+
+ The company" False developed by Opera Software and Nintendo. A DSiWare 3 [' developed', ' by', ' Opera', ' Software']
+882 277 The name of the CEO of x -1 The name of the CEO of BT Group Philip Jansen BT Group "[',' ' the' ' parent' ' company' ' of' ' BT' ',' ' is' ' Robert' ' J' '.'
+ ' Brown' '.' '\n' '\n' 'BT' ' Group' ' is' ' a' ' British']" ", the parent company of BT , is Robert J . Brown .
+
+ BT Group is a British" False collaborate, supported by BT Group (BT), which upgraded 5 [' collaborate', ',', ' supported', ' by', ' BT', ' Group']
+883 277 The name of the CEO of x -1 The name of the CEO of BT Group Philip Jansen BT Group "[',' ' the' ' parent' ' company' ' of' ' BT' ',' ' is' ' Robert' ' J' '.'
+ ' Brown' '.' '\n' '\n' 'BT' ' Group' ' is' ' a' ' British']" ", the parent company of BT , is Robert J . Brown .
+
+ BT Group is a British" False 2001, for example, BT Group agreed to 6 [' 2001', ',', ' for', ' example', ',', ' BT', ' Group']
+884 277 The name of the CEO of x -1 The name of the CEO of BT Group Philip Jansen BT Group "[',' ' the' ' parent' ' company' ' of' ' BT' ',' ' is' ' Robert' ' J' '.'
+ ' Brown' '.' '\n' '\n' 'BT' ' Group' ' is' ' a' ' British']" ", the parent company of BT , is Robert J . Brown .
+
+ BT Group is a British" False began when BT Group introduced Cleanfeed, 3 [' began', ' when', ' BT', ' Group']
+885 277 The name of the CEO of x -1 The name of the CEO of BT Group Philip Jansen BT Group "[',' ' the' ' parent' ' company' ' of' ' BT' ',' ' is' ' Robert' ' J' '.'
+ ' Brown' '.' '\n' '\n' 'BT' ' Group' ' is' ' a' ' British']" ", the parent company of BT , is Robert J . Brown .
+
+ BT Group is a British" False collaborate, supported by BT Group (BT), which upgraded 5 [' collaborate', ',', ' supported', ' by', ' BT', ' Group']
+886 277 The name of the CEO of x -1 The name of the CEO of BT Group Philip Jansen BT Group "[',' ' the' ' parent' ' company' ' of' ' BT' ',' ' is' ' Robert' ' J' '.'
+ ' Brown' '.' '\n' '\n' 'BT' ' Group' ' is' ' a' ' British']" ", the parent company of BT , is Robert J . Brown .
+
+ BT Group is a British" False for example, BT Group agreed to sell 4 [' for', ' example', ',', ' BT', ' Group']
+887 279 The name of the CEO of x -1 The name of the CEO of Cognizant Francisco D'Souza Cognizant "[' is' ' Mr' '.' ' Ram' 'aling' 'a' ' Raj' 'u' '.' ' He' ' is' ' the'
+ ' founder' ' of' ' the' ' company' '.' ' He' ' is' ' a']" is Mr . Ram aling a Raj u . He is the founder of the company . He is a False 3 ['C', 'ogn', 'iz', 'ant']
+888 279 The name of the CEO of x -1 The name of the CEO of Cognizant Francisco D'Souza Cognizant "[' is' ' Mr' '.' ' Ram' 'aling' 'a' ' Raj' 'u' '.' ' He' ' is' ' the'
+ ' founder' ' of' ' the' ' company' '.' ' He' ' is' ' a']" is Mr . Ram aling a Raj u . He is the founder of the company . He is a False minister of state. Cognizant of the president's 6 [' minister', ' of', ' state', '.', ' Cogn', 'iz', 'ant']
+889 279 The name of the CEO of x -1 The name of the CEO of Cognizant Francisco D'Souza Cognizant "[' is' ' Mr' '.' ' Ram' 'aling' 'a' ' Raj' 'u' '.' ' He' ' is' ' the'
+ ' founder' ' of' ' the' ' company' '.' ' He' ' is' ' a']" is Mr . Ram aling a Raj u . He is the founder of the company . He is a False (MoU) with IT company Cognizant Technology 9 [' (', 'Mo', 'U', ')', ' with', ' IT', ' company', ' Cogn', 'iz', 'ant']
+890 279 The name of the CEO of x -1 The name of the CEO of Cognizant Francisco D'Souza Cognizant "[' is' ' Mr' '.' ' Ram' 'aling' 'a' ' Raj' 'u' '.' ' He' ' is' ' the'
+ ' founder' ' of' ' the' ' company' '.' ' He' ' is' ' a']" is Mr . Ram aling a Raj u . He is the founder of the company . He is a False 3 ['C', 'ogn', 'iz', 'ant']
+891 279 The name of the CEO of x -1 The name of the CEO of Cognizant Francisco D'Souza Cognizant "[' is' ' Mr' '.' ' Ram' 'aling' 'a' ' Raj' 'u' '.' ' He' ' is' ' the'
+ ' founder' ' of' ' the' ' company' '.' ' He' ' is' ' a']" is Mr . Ram aling a Raj u . He is the founder of the company . He is a False Territory from Spain. Cognizant of Jefferson's 6 [' Territory', ' from', ' Spain', '.', ' Cogn', 'iz', 'ant']
+892 280 The name of the CEO of x -1 The name of the CEO of BlackBerry John S. Chen BlackBerry "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' in' ' the' ' news' ' lately' ' for' ' all' ' the']" " is a bit of a mystery .
+
+ The company has been in the news lately for all the" False " Research In Motion's BlackBerry devices.
+" 4 "[' Research', ' In', ' Motion', ""'s"", ' BlackBerry']"
+893 280 The name of the CEO of x -1 The name of the CEO of BlackBerry John S. Chen BlackBerry "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' in' ' the' ' news' ' lately' ' for' ' all' ' the']" " is a bit of a mystery .
+
+ The company has been in the news lately for all the" False hosted by BlackBerry (then known 2 [' hosted', ' by', ' BlackBerry']
+894 280 The name of the CEO of x -1 The name of the CEO of BlackBerry John S. Chen BlackBerry "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' in' ' the' ' news' ' lately' ' for' ' all' ' the']" " is a bit of a mystery .
+
+ The company has been in the news lately for all the" False " apps for iOS, BlackBerry and Android devices.
+" 4 [' apps', ' for', ' iOS', ',', ' BlackBerry']
+895 280 The name of the CEO of x -1 The name of the CEO of BlackBerry John S. Chen BlackBerry "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' in' ' the' ' news' ' lately' ' for' ' all' ' the']" " is a bit of a mystery .
+
+ The company has been in the news lately for all the" False (after Symbian, BlackBerry OS, Android and 5 [' (', 'after', ' Symb', 'ian', ',', ' BlackBerry']
+896 280 The name of the CEO of x -1 The name of the CEO of BlackBerry John S. Chen BlackBerry "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' in' ' the' ' news' ' lately' ' for' ' all' ' the']" " is a bit of a mystery .
+
+ The company has been in the news lately for all the" False models for the BlackBerry line-up. The first 3 [' models', ' for', ' the', ' BlackBerry']
+897 281 The name of the CEO of x -1 The name of the CEO of Wolters Kluwer Nancy McKinstry Wolters Kluwer "[' Health' ' is' ' a' ' bit' ' of' ' a' ' mouth' 'ful' ',' ' but' ' it'
+ '�' '�' 's' ' a' ' mouth' 'ful' ' that' '�' '�']" Health is a bit of a mouth ful , but it � � s a mouth ful that � � False multimedia company Wolters Kluwer moved into 75,000 5 [' multimedia', ' company', ' Wol', 'ters', ' Klu', 'wer']
+898 282 The name of the CEO of x -1 The name of the CEO of Yle Merja Ylä-Anttila Yle "['em' ' is' ' a' ' combination' ' of' ' the' ' words' ' ""' 'Y' 'ale' '""'
+ ' and' ' ""' 'Em' 'ma' '"".' '\n' '\n' 'Y' 'le']" "em is a combination of the words "" Y ale "" and "" Em ma "".
+
+ Y le" False concert aired on Yle TV1 at the Christmas 4 [' concert', ' aired', ' on', ' Y', 'le']
+899 282 The name of the CEO of x -1 The name of the CEO of Yle Merja Ylä-Anttila Yle "['em' ' is' ' a' ' combination' ' of' ' the' ' words' ' ""' 'Y' 'ale' '""'
+ ' and' ' ""' 'Em' 'ma' '"".' '\n' '\n' 'Y' 'le']" "em is a combination of the words "" Y ale "" and "" Em ma "".
+
+ Y le" False public broadcaster, Yle, provides two Swedish-language 4 [' public', ' broadcaster', ',', ' Y', 'le']
+900 282 The name of the CEO of x -1 The name of the CEO of Yle Merja Ylä-Anttila Yle "['em' ' is' ' a' ' combination' ' of' ' the' ' words' ' ""' 'Y' 'ale' '""'
+ ' and' ' ""' 'Em' 'ma' '"".' '\n' '\n' 'Y' 'le']" "em is a combination of the words "" Y ale "" and "" Em ma "".
+
+ Y le" False concert aired on Yle TV1 at the Christmas 4 [' concert', ' aired', ' on', ' Y', 'le']
+901 282 The name of the CEO of x -1 The name of the CEO of Yle Merja Ylä-Anttila Yle "['em' ' is' ' a' ' combination' ' of' ' the' ' words' ' ""' 'Y' 'ale' '""'
+ ' and' ' ""' 'Em' 'ma' '"".' '\n' '\n' 'Y' 'le']" "em is a combination of the words "" Y ale "" and "" Em ma "".
+
+ Y le" False public broadcaster, Yle, provides two 4 [' public', ' broadcaster', ',', ' Y', 'le']
+902 282 The name of the CEO of x -1 The name of the CEO of Yle Merja Ylä-Anttila Yle "['em' ' is' ' a' ' combination' ' of' ' the' ' words' ' ""' 'Y' 'ale' '""'
+ ' and' ' ""' 'Em' 'ma' '"".' '\n' '\n' 'Y' 'le']" "em is a combination of the words "" Y ale "" and "" Em ma "".
+
+ Y le" False concert aired on Yle TV1 at the Christmas 4 [' concert', ' aired', ' on', ' Y', 'le']
+903 284 The name of the CEO of x -1 The name of the CEO of Bethesda Softworks Todd Howard Bethesda Softworks "[',' ' Pete' ' H' 'ines' ',' ' has' ' been' ' making' ' the' ' rounds'
+ ' on' ' the' ' internet' ' today' ',' ' and' ' he' ' has' ' been'
+ ' talking']" , Pete H ines , has been making the rounds on the internet today , and he has been talking False published by Bethesda Softworks and Ubisoft. It is 4 [' published', ' by', ' Bethesda', ' Soft', 'works']
+904 284 The name of the CEO of x -1 The name of the CEO of Bethesda Softworks Todd Howard Bethesda Softworks "[',' ' Pete' ' H' 'ines' ',' ' has' ' been' ' making' ' the' ' rounds'
+ ' on' ' the' ' internet' ' today' ',' ' and' ' he' ' has' ' been'
+ ' talking']" , Pete H ines , has been making the rounds on the internet today , and he has been talking False ESRB reported that Bethesda Softworks would promptly notify 7 [' E', 'SR', 'B', ' reported', ' that', ' Bethesda', ' Soft', 'works']
+905 284 The name of the CEO of x -1 The name of the CEO of Bethesda Softworks Todd Howard Bethesda Softworks "[',' ' Pete' ' H' 'ines' ',' ' has' ' been' ' making' ' the' ' rounds'
+ ' on' ' the' ' internet' ' today' ',' ' and' ' he' ' has' ' been'
+ ' talking']" , Pete H ines , has been making the rounds on the internet today , and he has been talking False property to Bethesda Softworks and the sale of 4 [' property', ' to', ' Bethesda', ' Soft', 'works']
+906 284 The name of the CEO of x -1 The name of the CEO of Bethesda Softworks Todd Howard Bethesda Softworks "[',' ' Pete' ' H' 'ines' ',' ' has' ' been' ' making' ' the' ' rounds'
+ ' on' ' the' ' internet' ' today' ',' ' and' ' he' ' has' ' been'
+ ' talking']" , Pete H ines , has been making the rounds on the internet today , and he has been talking False 3 ['Bet', 'hesda', ' Soft', 'works']
+907 284 The name of the CEO of x -1 The name of the CEO of Bethesda Softworks Todd Howard Bethesda Softworks "[',' ' Pete' ' H' 'ines' ',' ' has' ' been' ' making' ' the' ' rounds'
+ ' on' ' the' ' internet' ' today' ',' ' and' ' he' ' has' ' been'
+ ' talking']" , Pete H ines , has been making the rounds on the internet today , and he has been talking False 3 ['Bet', 'hesda', ' Soft', 'works']
+908 286 The name of the CEO of x -1 The name of the CEO of Aeroflot Sergei Alexandrovsky Aeroflot "[' is' ' Vladimir' ' Yak' 'unin' '.' ' He' ' is' ' a' ' Russian'
+ ' politician' ' and' ' former' ' Soviet' ' politician' '.' ' He' ' is'
+ ' the' ' former' ' head']" is Vladimir Yak unin . He is a Russian politician and former Soviet politician . He is the former head False the early 1980s, Aeroflot had experienced 7 [' the', ' early', ' 1980', 's', ',', ' Aer', 'of', 'lot']
+909 286 The name of the CEO of x -1 The name of the CEO of Aeroflot Sergei Alexandrovsky Aeroflot "[' is' ' Vladimir' ' Yak' 'unin' '.' ' He' ' is' ' a' ' Russian'
+ ' politician' ' and' ' former' ' Soviet' ' politician' '.' ' He' ' is'
+ ' the' ' former' ' head']" is Vladimir Yak unin . He is a Russian politician and former Soviet politician . He is the former head False Volga car and Aeroflot airlines. 6 [' Vol', 'ga', ' car', ' and', ' Aer', 'of', 'lot']
+910 286 The name of the CEO of x -1 The name of the CEO of Aeroflot Sergei Alexandrovsky Aeroflot "[' is' ' Vladimir' ' Yak' 'unin' '.' ' He' ' is' ' a' ' Russian'
+ ' politician' ' and' ' former' ' Soviet' ' politician' '.' ' He' ' is'
+ ' the' ' former' ' head']" is Vladimir Yak unin . He is a Russian politician and former Soviet politician . He is the former head False higher priority and Aeroflot had no requirement 5 [' higher', ' priority', ' and', ' Aer', 'of', 'lot']
+911 286 The name of the CEO of x -1 The name of the CEO of Aeroflot Sergei Alexandrovsky Aeroflot "[' is' ' Vladimir' ' Yak' 'unin' '.' ' He' ' is' ' a' ' Russian'
+ ' politician' ' and' ' former' ' Soviet' ' politician' '.' ' He' ' is'
+ ' the' ' former' ' head']" is Vladimir Yak unin . He is a Russian politician and former Soviet politician . He is the former head False revoke the license of Aeroflot Soviet Airlines to 6 [' revoke', ' the', ' license', ' of', ' Aer', 'of', 'lot']
+912 286 The name of the CEO of x -1 The name of the CEO of Aeroflot Sergei Alexandrovsky Aeroflot "[' is' ' Vladimir' ' Yak' 'unin' '.' ' He' ' is' ' a' ' Russian'
+ ' politician' ' and' ' former' ' Soviet' ' politician' '.' ' He' ' is'
+ ' the' ' former' ' head']" is Vladimir Yak unin . He is a Russian politician and former Soviet politician . He is the former head False service by Aeroflot from Murmansk 4 [' service', ' by', ' Aer', 'of', 'lot']
+913 288 The name of the CEO of x -1 The name of the CEO of AC Sparta Prague Daniel Křetínský AC Sparta Prague "[' is' ' Mr' '.' ' Petr' ' V' 'r' 'ana' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy' ' and' ' he' ' is' ' very' ' helpful' '.']" is Mr . Petr V r ana . He is a very nice guy and he is very helpful . False Germany), AC Sparta Prague (Czechoslovakia), Vitória 5 [' Germany', '),', ' AC', ' Spart', 'a', ' Prague']
+914 288 The name of the CEO of x -1 The name of the CEO of AC Sparta Prague Daniel Křetínský AC Sparta Prague "[' is' ' Mr' '.' ' Petr' ' V' 'r' 'ana' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy' ' and' ' he' ' is' ' very' ' helpful' '.']" is Mr . Petr V r ana . He is a very nice guy and he is very helpful . False Republic football team AC Sparta Prague before retiring and 6 [' Republic', ' football', ' team', ' AC', ' Spart', 'a', ' Prague']
+915 288 The name of the CEO of x -1 The name of the CEO of AC Sparta Prague Daniel Křetínský AC Sparta Prague "[' is' ' Mr' '.' ' Petr' ' V' 'r' 'ana' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy' ' and' ' he' ' is' ' very' ' helpful' '.']" is Mr . Petr V r ana . He is a very nice guy and he is very helpful . False Germany), AC Sparta Prague (Czechoslovakia), 5 [' Germany', '),', ' AC', ' Spart', 'a', ' Prague']
+916 288 The name of the CEO of x -1 The name of the CEO of AC Sparta Prague Daniel Křetínský AC Sparta Prague "[' is' ' Mr' '.' ' Petr' ' V' 'r' 'ana' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy' ' and' ' he' ' is' ' very' ' helpful' '.']" is Mr . Petr V r ana . He is a very nice guy and he is very helpful . False Republic football team AC Sparta Prague before retiring 6 [' Republic', ' football', ' team', ' AC', ' Spart', 'a', ' Prague']
+917 288 The name of the CEO of x -1 The name of the CEO of AC Sparta Prague Daniel Křetínský AC Sparta Prague "[' is' ' Mr' '.' ' Petr' ' V' 'r' 'ana' '.' ' He' ' is' ' a' ' very'
+ ' nice' ' guy' ' and' ' he' ' is' ' very' ' helpful' '.']" is Mr . Petr V r ana . He is a very nice guy and he is very helpful . False Germany), AC Sparta Prague (Czechoslovakia), 5 [' Germany', '),', ' AC', ' Spart', 'a', ' Prague']
+918 289 The name of the CEO of x -1 The name of the CEO of Dow Chemical Company Jim Fitterling Dow Chemical Company "[',' ' the' ' company' ' that' ' makes' ' the' ' toxic' ' chemical'
+ ' that' ' is' ' used' ' in' ' the' ' manufacture' ' of' ' Agent'
+ ' Orange' ',' ' is' ' a']" , the company that makes the toxic chemical that is used in the manufacture of Agent Orange , is a False from brine. A Dow Chemical Company manufacturing 7 [' from', ' br', 'ine', '.', ' A', ' Dow', ' Chemical', ' Company']
+919 289 The name of the CEO of x -1 The name of the CEO of Dow Chemical Company Jim Fitterling Dow Chemical Company "[',' ' the' ' company' ' that' ' makes' ' the' ' toxic' ' chemical'
+ ' that' ' is' ' used' ' in' ' the' ' manufacture' ' of' ' Agent'
+ ' Orange' ',' ' is' ' a']" , the company that makes the toxic chemical that is used in the manufacture of Agent Orange , is a False then passes some Dow Chemical Company buildings and curves 5 [' then', ' passes', ' some', ' Dow', ' Chemical', ' Company']
+920 289 The name of the CEO of x -1 The name of the CEO of Dow Chemical Company Jim Fitterling Dow Chemical Company "[',' ' the' ' company' ' that' ' makes' ' the' ' toxic' ' chemical'
+ ' that' ' is' ' used' ' in' ' the' ' manufacture' ' of' ' Agent'
+ ' Orange' ',' ' is' ' a']" , the company that makes the toxic chemical that is used in the manufacture of Agent Orange , is a False period, he began work at Dow Chemical Company as a senior 8 [' period', ',', ' he', ' began', ' work', ' at', ' Dow', ' Chemical', ' Company']
+921 289 The name of the CEO of x -1 The name of the CEO of Dow Chemical Company Jim Fitterling Dow Chemical Company "[',' ' the' ' company' ' that' ' makes' ' the' ' toxic' ' chemical'
+ ' that' ' is' ' used' ' in' ' the' ' manufacture' ' of' ' Agent'
+ ' Orange' ',' ' is' ' a']" , the company that makes the toxic chemical that is used in the manufacture of Agent Orange , is a False radars and the Dow Chemical Company radar in Freeport 6 [' rad', 'ars', ' and', ' the', ' Dow', ' Chemical', ' Company']
+922 289 The name of the CEO of x -1 The name of the CEO of Dow Chemical Company Jim Fitterling Dow Chemical Company "[',' ' the' ' company' ' that' ' makes' ' the' ' toxic' ' chemical'
+ ' that' ' is' ' used' ' in' ' the' ' manufacture' ' of' ' Agent'
+ ' Orange' ',' ' is' ' a']" , the company that makes the toxic chemical that is used in the manufacture of Agent Orange , is a False purification from brine. A Dow Chemical Company manufacturing 9 [' pur', 'ification', ' from', ' br', 'ine', '.', ' A', ' Dow', ' Chemical', ' Company']
+923 290 The name of the CEO of x -1 The name of the CEO of Legrand Benoît Coquart Legrand "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' not' ' a'
+ ' member' ' of' ' the' ' family' ',' ' but' ' he' ' is' ' a']" is a bit of a mystery . He is not a member of the family , but he is a False detective fiction because Legrand withholds the evidence 4 [' detective', ' fiction', ' because', ' Le', 'grand']
+924 290 The name of the CEO of x -1 The name of the CEO of Legrand Benoît Coquart Legrand "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' not' ' a'
+ ' member' ' of' ' the' ' family' ',' ' but' ' he' ' is' ' a']" is a bit of a mystery . He is not a member of the family , but he is a False Captain Jean-François Legrand was approaching 7 [' Captain', ' Jean', '-', 'Fran', 'ç', 'ois', ' Le', 'grand']
+925 290 The name of the CEO of x -1 The name of the CEO of Legrand Benoît Coquart Legrand "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' not' ' a'
+ ' member' ' of' ' the' ' family' ',' ' but' ' he' ' is' ' a']" is a bit of a mystery . He is not a member of the family , but he is a False immediately, fearing that Legrand has been bitten 5 [' immediately', ',', ' fearing', ' that', ' Le', 'grand']
+926 290 The name of the CEO of x -1 The name of the CEO of Legrand Benoît Coquart Legrand "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' not' ' a'
+ ' member' ' of' ' the' ' family' ',' ' but' ' he' ' is' ' a']" is a bit of a mystery . He is not a member of the family , but he is a False fearing that Legrand has been bitten 3 [' fearing', ' that', ' Le', 'grand']
+927 290 The name of the CEO of x -1 The name of the CEO of Legrand Benoît Coquart Legrand "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' not' ' a'
+ ' member' ' of' ' the' ' family' ',' ' but' ' he' ' is' ' a']" is a bit of a mystery . He is not a member of the family , but he is a False written by Michel Legrand. Legrand also wrote 4 [' written', ' by', ' Michel', ' Le', 'grand']
+928 292 The name of the CEO of x -1 The name of the CEO of Pinterest Ben Silbermann Pinterest "[' is' ' Evan' ' Sharp' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' a' ' great' ' leader' '.' ' He' ' is' ' a']" is Evan Sharp . He is a very nice guy and he is a great leader . He is a False like Etsy and Pinterest showcasing thousands 3 [' like', ' Etsy', ' and', ' Pinterest']
+929 292 The name of the CEO of x -1 The name of the CEO of Pinterest Ben Silbermann Pinterest "[' is' ' Evan' ' Sharp' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' a' ' great' ' leader' '.' ' He' ' is' ' a']" is Evan Sharp . He is a very nice guy and he is a great leader . He is a False that year, Pinterest monitoring 3 [' that', ' year', ',', ' Pinterest']
+930 292 The name of the CEO of x -1 The name of the CEO of Pinterest Ben Silbermann Pinterest "[' is' ' Evan' ' Sharp' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' a' ' great' ' leader' '.' ' He' ' is' ' a']" is Evan Sharp . He is a very nice guy and he is a great leader . He is a False the show felt like a Pinterest board come to life 5 [' the', ' show', ' felt', ' like', ' a', ' Pinterest']
+931 292 The name of the CEO of x -1 The name of the CEO of Pinterest Ben Silbermann Pinterest "[' is' ' Evan' ' Sharp' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' a' ' great' ' leader' '.' ' He' ' is' ' a']" is Evan Sharp . He is a very nice guy and he is a great leader . He is a False felt like a Pinterest board come 3 [' felt', ' like', ' a', ' Pinterest']
+932 292 The name of the CEO of x -1 The name of the CEO of Pinterest Ben Silbermann Pinterest "[' is' ' Evan' ' Sharp' '.' ' He' ' is' ' a' ' very' ' nice' ' guy' ' and'
+ ' he' ' is' ' a' ' great' ' leader' '.' ' He' ' is' ' a']" is Evan Sharp . He is a very nice guy and he is a great leader . He is a False August of that year, Pinterest monitoring was 5 [' August', ' of', ' that', ' year', ',', ' Pinterest']
+933 295 The name of the CEO of x -1 The name of the CEO of Fannie Mae Timothy Mayopoulos Fannie Mae "[' is' ' Franklin' ' Rain' 'es' '.' ' He' ' is' ' a' ' former' ' Goldman'
+ ' Sachs' ' executive' ' who' ' was' ' the' ' head' ' of' ' the'
+ ' company' ' from']" is Franklin Rain es . He is a former Goldman Sachs executive who was the head of the company from False Freddie Mac and Fannie Mae. His opposition was 5 [' Freddie', ' Mac', ' and', ' F', 'annie', ' Mae']
+934 295 The name of the CEO of x -1 The name of the CEO of Fannie Mae Timothy Mayopoulos Fannie Mae "[' is' ' Franklin' ' Rain' 'es' '.' ' He' ' is' ' a' ' former' ' Goldman'
+ ' Sachs' ' executive' ' who' ' was' ' the' ' head' ' of' ' the'
+ ' company' ' from']" is Franklin Rain es . He is a former Goldman Sachs executive who was the head of the company from False regulation of Fannie Mae and Freddie 4 [' regulation', ' of', ' F', 'annie', ' Mae']
+935 295 The name of the CEO of x -1 The name of the CEO of Fannie Mae Timothy Mayopoulos Fannie Mae "[' is' ' Franklin' ' Rain' 'es' '.' ' He' ' is' ' a' ' former' ' Goldman'
+ ' Sachs' ' executive' ' who' ' was' ' the' ' head' ' of' ' the'
+ ' company' ' from']" is Franklin Rain es . He is a former Goldman Sachs executive who was the head of the company from False subprime mortgage market – Fannie Mae and Freddie Mac. 7 [' sub', 'prime', ' mortgage', ' market', ' –', ' F', 'annie', ' Mae']
+936 295 The name of the CEO of x -1 The name of the CEO of Fannie Mae Timothy Mayopoulos Fannie Mae "[' is' ' Franklin' ' Rain' 'es' '.' ' He' ' is' ' a' ' former' ' Goldman'
+ ' Sachs' ' executive' ' who' ' was' ' the' ' head' ' of' ' the'
+ ' company' ' from']" is Franklin Rain es . He is a former Goldman Sachs executive who was the head of the company from False enacting new fees on Fannie Mae and Freddie 7 [' enact', 'ing', ' new', ' fees', ' on', ' F', 'annie', ' Mae']
+937 295 The name of the CEO of x -1 The name of the CEO of Fannie Mae Timothy Mayopoulos Fannie Mae "[' is' ' Franklin' ' Rain' 'es' '.' ' He' ' is' ' a' ' former' ' Goldman'
+ ' Sachs' ' executive' ' who' ' was' ' the' ' head' ' of' ' the'
+ ' company' ' from']" is Franklin Rain es . He is a former Goldman Sachs executive who was the head of the company from False was an executive at Fannie Mae from 1991 6 [' was', ' an', ' executive', ' at', ' F', 'annie', ' Mae']
+938 296 The name of the CEO of x -1 The name of the CEO of Red Hat Matt Hicks Red Hat "[',' ' Jim' ' White' 'hurst' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' ' is' ' a' ' former' ' IBM' 'er' ' who' ' has']" , Jim White hurst , is a bit of a mystery . He is a former IBM er who has False 1 ['Red', ' Hat']
+939 296 The name of the CEO of x -1 The name of the CEO of Red Hat Matt Hicks Red Hat "[',' ' Jim' ' White' 'hurst' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' ' is' ' a' ' former' ' IBM' 'er' ' who' ' has']" , Jim White hurst , is a bit of a mystery . He is a former IBM er who has False under Fedora and Red Hat Enterprise Linux 4 [' under', ' Fedora', ' and', ' Red', ' Hat']
+940 296 The name of the CEO of x -1 The name of the CEO of Red Hat Matt Hicks Red Hat "[',' ' Jim' ' White' 'hurst' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' ' is' ' a' ' former' ' IBM' 'er' ' who' ' has']" , Jim White hurst , is a bit of a mystery . He is a former IBM er who has False main monastery of the Red Hat school of the 5 [' main', ' monastery', ' of', ' the', ' Red', ' Hat']
+941 296 The name of the CEO of x -1 The name of the CEO of Red Hat Matt Hicks Red Hat "[',' ' Jim' ' White' 'hurst' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' ' is' ' a' ' former' ' IBM' 'er' ' who' ' has']" , Jim White hurst , is a bit of a mystery . He is a former IBM er who has False 1 ['Red', ' Hat']
+942 296 The name of the CEO of x -1 The name of the CEO of Red Hat Matt Hicks Red Hat "[',' ' Jim' ' White' 'hurst' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.' ' He' ' is' ' a' ' former' ' IBM' 'er' ' who' ' has']" , Jim White hurst , is a bit of a mystery . He is a former IBM er who has False 1 ['Red', ' Hat']
diff --git a/patchscopes/code/preprocessed_data/factual/country_capital_city.tsv b/patchscopes/code/preprocessed_data/factual/country_capital_city.tsv
new file mode 100644
index 00000000..39cc1cdc
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/country_capital_city.tsv
@@ -0,0 +1,130 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The capital city of x -1 The capital city of United States Washington D.C. United States [',', ' Washington', ',', ' is', ' a', ' city', ' of', ' many', ' faces', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ',', ' present', ' and'] , Washington , is a city of many faces . It is a city of the past , present and False ancestors came to the United States with William 5 [' ancestors', ' came', ' to', ' the', ' United', ' States']
+1 0 The capital city of x -1 The capital city of United States Washington D.C. United States [',', ' Washington', ',', ' is', ' a', ' city', ' of', ' many', ' faces', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ',', ' present', ' and'] , Washington , is a city of many faces . It is a city of the past , present and False service with the United States Navy. At the end 4 [' service', ' with', ' the', ' United', ' States']
+2 0 The capital city of x -1 The capital city of United States Washington D.C. United States [',', ' Washington', ',', ' is', ' a', ' city', ' of', ' many', ' faces', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ',', ' present', ' and'] , Washington , is a city of many faces . It is a city of the past , present and False programs in the United States having been to 20 College 4 [' programs', ' in', ' the', ' United', ' States']
+3 0 The capital city of x -1 The capital city of United States Washington D.C. United States [',', ' Washington', ',', ' is', ' a', ' city', ' of', ' many', ' faces', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ',', ' present', ' and'] , Washington , is a city of many faces . It is a city of the past , present and False report by the United States Joint Forces Command 4 [' report', ' by', ' the', ' United', ' States']
+4 0 The capital city of x -1 The capital city of United States Washington D.C. United States [',', ' Washington', ',', ' is', ' a', ' city', ' of', ' many', ' faces', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ',', ' present', ' and'] , Washington , is a city of many faces . It is a city of the past , present and False broadcast in the United States on USAHD at 6 4 [' broadcast', ' in', ' the', ' United', ' States']
+5 1 The capital city of x -1 The capital city of Canada Ottawa Canada [',', ' Ottawa', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' capital', ' of', ' Canada', ' and', ' is'] , Ottawa is a beautiful city with a lot to offer . It is the capital of Canada and is True Australia, he moved to Canada at the age of 6, 5 [' Australia', ',', ' he', ' moved', ' to', ' Canada']
+6 1 The capital city of x -1 The capital city of Canada Ottawa Canada [',', ' Ottawa', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' capital', ' of', ' Canada', ' and', ' is'] , Ottawa is a beautiful city with a lot to offer . It is the capital of Canada and is True United States and Canada on May 16, 2008. 3 [' United', ' States', ' and', ' Canada']
+7 1 The capital city of x -1 The capital city of Canada Ottawa Canada [',', ' Ottawa', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' capital', ' of', ' Canada', ' and', ' is'] , Ottawa is a beautiful city with a lot to offer . It is the capital of Canada and is True 1967. In 1969, Canada saw its first 5 [' 1967', '.', ' In', ' 1969', ',', ' Canada']
+8 1 The capital city of x -1 The capital city of Canada Ottawa Canada [',', ' Ottawa', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' capital', ' of', ' Canada', ' and', ' is'] , Ottawa is a beautiful city with a lot to offer . It is the capital of Canada and is True Natural Resources Canada and volcanologist 2 [' Natural', ' Resources', ' Canada']
+9 1 The capital city of x -1 The capital city of Canada Ottawa Canada [',', ' Ottawa', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' capital', ' of', ' Canada', ' and', ' is'] , Ottawa is a beautiful city with a lot to offer . It is the capital of Canada and is True to relocate from Canada to the United 3 [' to', ' relocate', ' from', ' Canada']
+10 2 The capital city of x -1 The capital city of Mexico Mexico City Mexico [',', ' Mexico', ' City', ' is', ' a', ' large', ' city', ' with', ' a', ' population', ' of', ' over', ' 20', ' million', ' people', '.', ' It', ' is', ' the', ' most'] , Mexico City is a large city with a population of over 20 million people . It is the most True were married in Mexico on February 3 [' were', ' married', ' in', ' Mexico']
+11 2 The capital city of x -1 The capital city of Mexico Mexico City Mexico [',', ' Mexico', ' City', ' is', ' a', ' large', ' city', ' with', ' a', ' population', ' of', ' over', ' 20', ' million', ' people', '.', ' It', ' is', ' the', ' most'] , Mexico City is a large city with a population of over 20 million people . It is the most True " Canadian border to Mexico during the winter.
+" 3 [' Canadian', ' border', ' to', ' Mexico']
+12 2 The capital city of x -1 The capital city of Mexico Mexico City Mexico [',', ' Mexico', ' City', ' is', ' a', ' large', ' city', ' with', ' a', ' population', ' of', ' over', ' 20', ' million', ' people', '.', ' It', ' is', ' the', ' most'] , Mexico City is a large city with a population of over 20 million people . It is the most True and number five in Mexico. A music video 4 [' and', ' number', ' five', ' in', ' Mexico']
+13 2 The capital city of x -1 The capital city of Mexico Mexico City Mexico [',', ' Mexico', ' City', ' is', ' a', ' large', ' city', ' with', ' a', ' population', ' of', ' over', ' 20', ' million', ' people', '.', ' It', ' is', ' the', ' most'] , Mexico City is a large city with a population of over 20 million people . It is the most True regulations in New Mexico and lining up 3 [' regulations', ' in', ' New', ' Mexico']
+14 2 The capital city of x -1 The capital city of Mexico Mexico City Mexico [',', ' Mexico', ' City', ' is', ' a', ' large', ' city', ' with', ' a', ' population', ' of', ' over', ' 20', ' million', ' people', '.', ' It', ' is', ' the', ' most'] , Mexico City is a large city with a population of over 20 million people . It is the most True central Gulf of Mexico turned Greta 3 [' central', ' Gulf', ' of', ' Mexico']
+15 3 The capital city of x -1 The capital city of Brazil Bras\u00edlia Brazil [',', ' Bras', 'ilia', ',', ' is', ' a', ' planned', ' city', ',', ' designed', ' by', ' architect', ' Oscar', ' Ni', 'em', 'ey', 'er', '.', ' It', ' was'] , Bras ilia , is a planned city , designed by architect Oscar Ni em ey er . It was False Chevrolet Lumina and Brazil as the Chevrolet 4 [' Chevrolet', ' Lum', 'ina', ' and', ' Brazil']
+16 3 The capital city of x -1 The capital city of Brazil Bras\u00edlia Brazil [',', ' Bras', 'ilia', ',', ' is', ' a', ' planned', ' city', ',', ' designed', ' by', ' architect', ' Oscar', ' Ni', 'em', 'ey', 'er', '.', ' It', ' was'] , Bras ilia , is a planned city , designed by architect Oscar Ni em ey er . It was False community service in Brazil instead and 3 [' community', ' service', ' in', ' Brazil']
+17 3 The capital city of x -1 The capital city of Brazil Bras\u00edlia Brazil [',', ' Bras', 'ilia', ',', ' is', ' a', ' planned', ' city', ',', ' designed', ' by', ' architect', ' Oscar', ' Ni', 'em', 'ey', 'er', '.', ' It', ' was'] , Bras ilia , is a planned city , designed by architect Oscar Ni em ey er . It was False the competition. Brazil played France 3 [' the', ' competition', '.', ' Brazil']
+18 3 The capital city of x -1 The capital city of Brazil Bras\u00edlia Brazil [',', ' Bras', 'ilia', ',', ' is', ' a', ' planned', ' city', ',', ' designed', ' by', ' architect', ' Oscar', ' Ni', 'em', 'ey', 'er', '.', ' It', ' was'] , Bras ilia , is a planned city , designed by architect Oscar Ni em ey er . It was False its first week. In Brazil, it reached the 5 [' its', ' first', ' week', '.', ' In', ' Brazil']
+19 3 The capital city of x -1 The capital city of Brazil Bras\u00edlia Brazil [',', ' Bras', 'ilia', ',', ' is', ' a', ' planned', ' city', ',', ' designed', ' by', ' architect', ' Oscar', ' Ni', 'em', 'ey', 'er', '.', ' It', ' was'] , Bras ilia , is a planned city , designed by architect Oscar Ni em ey er . It was False 0 ['Brazil']
+20 4 The capital city of x -1 The capital city of Argentina Buenos Aires Argentina [',', ' Buenos', ' Aires', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present'] , Buenos Aires , is a city of contrasts . It is a city of the past and the present True in Mar del Plata, Argentina, on 10 July 6 [' in', ' Mar', ' del', ' Pl', 'ata', ',', ' Argentina']
+21 4 The capital city of x -1 The capital city of Argentina Buenos Aires Argentina [',', ' Buenos', ' Aires', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present'] , Buenos Aires , is a city of contrasts . It is a city of the past and the present True Africa, New Zealand, Argentina and South Korea. They 5 [' Africa', ',', ' New', ' Zealand', ',', ' Argentina']
+22 4 The capital city of x -1 The capital city of Argentina Buenos Aires Argentina [',', ' Buenos', ' Aires', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present'] , Buenos Aires , is a city of contrasts . It is a city of the past and the present True local media of both Argentina and Venezuela 4 [' local', ' media', ' of', ' both', ' Argentina']
+23 4 The capital city of x -1 The capital city of Argentina Buenos Aires Argentina [',', ' Buenos', ' Aires', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present'] , Buenos Aires , is a city of contrasts . It is a city of the past and the present True expelled from Argentina by the government. 2 [' expelled', ' from', ' Argentina']
+24 4 The capital city of x -1 The capital city of Argentina Buenos Aires Argentina [',', ' Buenos', ' Aires', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present'] , Buenos Aires , is a city of contrasts . It is a city of the past and the present True transferring sovereignty to Argentina in the early Thatcher 3 [' transferring', ' sovereignty', ' to', ' Argentina']
+25 5 The capital city of x -1 The capital city of Chile Santiago Chile [',', ' Santiago', ',', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' second', ' largest', ' city', ' in'] , Santiago , is a beautiful city with a lot to offer . It is the second largest city in True presidents of Brazil, Chile and Uruguay. 4 [' presidents', ' of', ' Brazil', ',', ' Chile']
+26 5 The capital city of x -1 The capital city of Chile Santiago Chile [',', ' Santiago', ',', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' second', ' largest', ' city', ' in'] , Santiago , is a beautiful city with a lot to offer . It is the second largest city in True In Santiago de Chile and Buenos 3 [' In', ' Santiago', ' de', ' Chile']
+27 5 The capital city of x -1 The capital city of Chile Santiago Chile [',', ' Santiago', ',', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' second', ' largest', ' city', ' in'] , Santiago , is a beautiful city with a lot to offer . It is the second largest city in True Campanas Observatory in Chile since 2004. Based 4 [' Camp', 'anas', ' Observatory', ' in', ' Chile']
+28 5 The capital city of x -1 The capital city of Chile Santiago Chile [',', ' Santiago', ',', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' second', ' largest', ' city', ' in'] , Santiago , is a beautiful city with a lot to offer . It is the second largest city in True composed of Chilean and Peruvian 2 [' composed', ' of', ' Chile']
+29 5 The capital city of x -1 The capital city of Chile Santiago Chile [',', ' Santiago', ',', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' to', ' offer', '.', ' It', ' is', ' the', ' second', ' largest', ' city', ' in'] , Santiago , is a beautiful city with a lot to offer . It is the second largest city in True and Argentina and Chile quickly canceled the 3 [' and', ' Argentina', ' and', ' Chile']
+30 6 The capital city of x -1 The capital city of Peru Lima Peru [',', ' Lima', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' economic', ' and', ' cultural', ' center', ' of', ' the', ' country', ',', ' but'] , Lima is a city of contrasts . It is the economic and cultural center of the country , but True Chile (38,000 t), Peru (18,000 t) and 7 [' Chile', ' (', '38', ',', '000', ' t', '),', ' Peru']
+31 6 The capital city of x -1 The capital city of Peru Lima Peru [',', ' Lima', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' economic', ' and', ' cultural', ' center', ' of', ' the', ' country', ',', ' but'] , Lima is a city of contrasts . It is the economic and cultural center of the country , but True Argentina, Chile and Peru before flying north 4 [' Argentina', ',', ' Chile', ' and', ' Peru']
+32 6 The capital city of x -1 The capital city of Peru Lima Peru [',', ' Lima', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' economic', ' and', ' cultural', ' center', ' of', ' the', ' country', ',', ' but'] , Lima is a city of contrasts . It is the economic and cultural center of the country , but True " France, Brazil, Peru and China.
+" 4 [' France', ',', ' Brazil', ',', ' Peru']
+33 6 The capital city of x -1 The capital city of Peru Lima Peru [',', ' Lima', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' economic', ' and', ' cultural', ' center', ' of', ' the', ' country', ',', ' but'] , Lima is a city of contrasts . It is the economic and cultural center of the country , but True independence of Peru on 28 July 1821. 2 [' independence', ' of', ' Peru']
+34 6 The capital city of x -1 The capital city of Peru Lima Peru [',', ' Lima', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' economic', ' and', ' cultural', ' center', ' of', ' the', ' country', ',', ' but'] , Lima is a city of contrasts . It is the economic and cultural center of the country , but True astronomers traveled from Peru to Caroline Island 3 [' astronomers', ' traveled', ' from', ' Peru']
+35 7 The capital city of x -1 The capital city of Colombia Bogot\u00e1 Colombia [',', ' Bog', 'ot', 'á', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' political', ',', ' economic', ',', ' cultural', ','] , Bog ot á , is a city of contrasts . It is the political , economic , cultural , False and ambassador of Colombia in Italy, Sabas 3 [' and', ' ambassador', ' of', ' Colombia']
+36 7 The capital city of x -1 The capital city of Colombia Bogot\u00e1 Colombia [',', ' Bog', 'ot', 'á', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' political', ',', ' economic', ',', ' cultural', ','] , Bog ot á , is a city of contrasts . It is the political , economic , cultural , False Venezuela and Colombia and emerged into 2 [' Venezuela', ' and', ' Colombia']
+37 7 The capital city of x -1 The capital city of Colombia Bogot\u00e1 Colombia [',', ' Bog', 'ot', 'á', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' political', ',', ' economic', ',', ' cultural', ','] , Bog ot á , is a city of contrasts . It is the political , economic , cultural , False Venezuela (6 %), Colombia (5 %), and Trinidad 5 [' Venezuela', ' (', '6', ' %', '),', ' Colombia']
+38 7 The capital city of x -1 The capital city of Colombia Bogot\u00e1 Colombia [',', ' Bog', 'ot', 'á', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' political', ',', ' economic', ',', ' cultural', ','] , Bog ot á , is a city of contrasts . It is the political , economic , cultural , False a small portion of Colombia), both Colombia and 4 [' a', ' small', ' portion', ' of', ' Colombia']
+39 7 The capital city of x -1 The capital city of Colombia Bogot\u00e1 Colombia [',', ' Bog', 'ot', 'á', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' the', ' political', ',', ' economic', ',', ' cultural', ','] , Bog ot á , is a city of contrasts . It is the political , economic , cultural , False (9.2 billion Colombian pesos) will be spent 5 [' (', '9', '.', '2', ' billion', ' Colombia']
+40 8 The capital city of x -1 The capital city of Venezuela Caracas Venezuela [',', ' Car', 'acas', ',', ' is', ' a', ' city', ' of', ' over', ' 3', ' million', ' people', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in'] , Car acas , is a city of over 3 million people . It is the most populous city in True Paraguaná Peninsula of Venezuela at 0200 UTC on September 6 [' Par', 'ag', 'uan', 'á', ' Peninsula', ' of', ' Venezuela']
+41 8 The capital city of x -1 The capital city of Venezuela Caracas Venezuela [',', ' Car', 'acas', ',', ' is', ' a', ' city', ' of', ' over', ' 3', ' million', ' people', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in'] , Car acas , is a city of over 3 million people . It is the most populous city in True Netherlands (1982), Venezuela (1983), Monaco 4 [' Netherlands', ' (', '1982', '),', ' Venezuela']
+42 8 The capital city of x -1 The capital city of Venezuela Caracas Venezuela [',', ' Car', 'acas', ',', ' is', ' a', ' city', ' of', ' over', ' 3', ' million', ' people', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in'] , Car acas , is a city of over 3 million people . It is the most populous city in True " President Hugo Chávez of Venezuela said ""We see here" 6 [' President', ' Hugo', ' Ch', 'á', 'vez', ' of', ' Venezuela']
+43 8 The capital city of x -1 The capital city of Venezuela Caracas Venezuela [',', ' Car', 'acas', ',', ' is', ' a', ' city', ' of', ' over', ' 3', ' million', ' people', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in'] , Car acas , is a city of over 3 million people . It is the most populous city in True Bartley visited Venezuela as part of 3 [' Bart', 'ley', ' visited', ' Venezuela']
+44 8 The capital city of x -1 The capital city of Venezuela Caracas Venezuela [',', ' Car', 'acas', ',', ' is', ' a', ' city', ' of', ' over', ' 3', ' million', ' people', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in'] , Car acas , is a city of over 3 million people . It is the most populous city in True " two to Tokyo, to Venezuela and one to Johannesburg.
+" 5 [' two', ' to', ' Tokyo', ',', ' to', ' Venezuela']
+45 9 The capital city of x -1 The capital city of Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' of', ' history', ' and'] , Madrid is a great place to visit . It is a beautiful city with a lot of history and True frontier from Spain and were also 2 [' frontier', ' from', ' Spain']
+46 9 The capital city of x -1 The capital city of Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' of', ' history', ' and'] , Madrid is a great place to visit . It is a beautiful city with a lot of history and True in Valladolid, Spain in 2009. In 2011 she 5 [' in', ' Vall', 'ad', 'olid', ',', ' Spain']
+47 9 The capital city of x -1 The capital city of Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' of', ' history', ' and'] , Madrid is a great place to visit . It is a beautiful city with a lot of history and True management of the war in Spain to the Queen, and his 5 [' management', ' of', ' the', ' war', ' in', ' Spain']
+48 9 The capital city of x -1 The capital city of Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' of', ' history', ' and'] , Madrid is a great place to visit . It is a beautiful city with a lot of history and True " Portugal
+" 1 [' Portugal', 'Spain']
+49 9 The capital city of x -1 The capital city of Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' a', ' beautiful', ' city', ' with', ' a', ' lot', ' of', ' history', ' and'] , Madrid is a great place to visit . It is a beautiful city with a lot of history and True also sent troops to Spain to support General 4 [' also', ' sent', ' troops', ' to', ' Spain']
+50 10 The capital city of x -1 The capital city of France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True Britain and France offered to 2 [' Britain', ' and', ' France']
+51 10 The capital city of x -1 The capital city of France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True Kingdom, number four in France and was certified 5 [' Kingdom', ',', ' number', ' four', ' in', ' France']
+52 10 The capital city of x -1 The capital city of France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True caps with all of France's youth teams 4 [' caps', ' with', ' all', ' of', ' France']
+53 10 The capital city of x -1 The capital city of France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True interests in France at both the court 2 [' interests', ' in', ' France']
+54 10 The capital city of x -1 The capital city of France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True not embark for France until September 3 [' not', ' embark', ' for', ' France']
+55 11 The capital city of x -1 The capital city of Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present', ',', ' of'] , Berlin is a city of contrasts . It is a city of the past and the present , of True actors from Germany and Poland (including 2 [' actors', ' from', ' Germany']
+56 11 The capital city of x -1 The capital city of Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present', ',', ' of'] , Berlin is a city of contrasts . It is a city of the past and the present , of True soldiers from Nazi Germany and Fascist 3 [' soldiers', ' from', ' Nazi', ' Germany']
+57 11 The capital city of x -1 The capital city of Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present', ',', ' of'] , Berlin is a city of contrasts . It is a city of the past and the present , of True negotiate with Germany for more equipment. 2 [' negotiate', ' with', ' Germany']
+58 11 The capital city of x -1 The capital city of Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present', ',', ' of'] , Berlin is a city of contrasts . It is a city of the past and the present , of True Party won power in Germany in an election on 4 [' Party', ' won', ' power', ' in', ' Germany']
+59 11 The capital city of x -1 The capital city of Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' the', ' past', ' and', ' the', ' present', ',', ' of'] , Berlin is a city of contrasts . It is a city of the past and the present , of True outdoors in Belgium or Germany under the name' Gilwell 4 [' outdoors', ' in', ' Belgium', ' or', ' Germany']
+60 12 The capital city of x -1 The capital city of Italy Rome Italy [',', ' Rome', ' is', ' a', ' city', ' of', ' great', ' beauty', ' and', ' history', '.', ' It', ' is', ' a', ' city', ' of', ' art', ',', ' culture', ','] , Rome is a city of great beauty and history . It is a city of art , culture , True Britain, Japan, France, Italy and the United States. 6 [' Britain', ',', ' Japan', ',', ' France', ',', ' Italy']
+61 12 The capital city of x -1 The capital city of Italy Rome Italy [',', ' Rome', ' is', ' a', ' city', ' of', ' great', ' beauty', ' and', ' history', '.', ' It', ' is', ' a', ' city', ' of', ' art', ',', ' culture', ','] , Rome is a city of great beauty and history . It is a city of art , culture , True seas around Italy to the Empire. Two 2 [' seas', ' around', ' Italy']
+62 12 The capital city of x -1 The capital city of Italy Rome Italy [',', ' Rome', ' is', ' a', ' city', ' of', ' great', ' beauty', ' and', ' history', '.', ' It', ' is', ' a', ' city', ' of', ' art', ',', ' culture', ','] , Rome is a city of great beauty and history . It is a city of art , culture , True Legions in Italy during the Napoleonic 2 [' Legions', ' in', ' Italy']
+63 12 The capital city of x -1 The capital city of Italy Rome Italy [',', ' Rome', ' is', ' a', ' city', ' of', ' great', ' beauty', ' and', ' history', '.', ' It', ' is', ' a', ' city', ' of', ' art', ',', ' culture', ','] , Rome is a city of great beauty and history . It is a city of art , culture , True to leave for Italy, a country thousands 3 [' to', ' leave', ' for', ' Italy']
+64 12 The capital city of x -1 The capital city of Italy Rome Italy [',', ' Rome', ' is', ' a', ' city', ' of', ' great', ' beauty', ' and', ' history', '.', ' It', ' is', ' a', ' city', ' of', ' art', ',', ' culture', ','] , Rome is a city of great beauty and history . It is a city of art , culture , True Canada, Finland, Italy and the United Kingdom. 4 [' Canada', ',', ' Finland', ',', ' Italy']
+65 13 The capital city of x -1 The capital city of Russia Moscow Russia [',', ' Moscow', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' contrasts', ' in', ' terms', ' of', ' architecture', ','] , Moscow , is a city of contrasts . It is a city of contrasts in terms of architecture , True " affiliations with Russia ===
+" 3 [' affili', 'ations', ' with', ' Russia']
+66 13 The capital city of x -1 The capital city of Russia Moscow Russia [',', ' Moscow', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' contrasts', ' in', ' terms', ' of', ' architecture', ','] , Moscow , is a city of contrasts . It is a city of contrasts in terms of architecture , True following a visit to Russia by Chancellor Angela 4 [' following', ' a', ' visit', ' to', ' Russia']
+67 13 The capital city of x -1 The capital city of Russia Moscow Russia [',', ' Moscow', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' contrasts', ' in', ' terms', ' of', ' architecture', ','] , Moscow , is a city of contrasts . It is a city of contrasts in terms of architecture , True " shipping to Russia all year round.
+" 2 [' shipping', ' to', ' Russia']
+68 13 The capital city of x -1 The capital city of Russia Moscow Russia [',', ' Moscow', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' contrasts', ' in', ' terms', ' of', ' architecture', ','] , Moscow , is a city of contrasts . It is a city of contrasts in terms of architecture , True for his battles in Russia and Hungary in 4 [' for', ' his', ' battles', ' in', ' Russia']
+69 13 The capital city of x -1 The capital city of Russia Moscow Russia [',', ' Moscow', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' contrasts', ' in', ' terms', ' of', ' architecture', ','] , Moscow , is a city of contrasts . It is a city of contrasts in terms of architecture , True 0 ['Russia']
+70 14 The capital city of x -1 The capital city of China Beijing China [',', ' Beijing', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' rich', ' history', ',', ' a', ' city'] , Beijing is a city of contrasts . It is a modern city with a rich history , a city True millimetres (20.31 in). China National Highway 9 [' mill', 'imet', 'res', ' (', '20', '.', '31', ' in', ').', ' China']
+71 14 The capital city of x -1 The capital city of China Beijing China [',', ' Beijing', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' rich', ' history', ',', ' a', ' city'] , Beijing is a city of contrasts . It is a modern city with a rich history , a city True around Chengdu in China to make a series 4 [' around', ' Cheng', 'du', ' in', ' China']
+72 14 The capital city of x -1 The capital city of China Beijing China [',', ' Beijing', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' rich', ' history', ',', ' a', ' city'] , Beijing is a city of contrasts . It is a modern city with a rich history , a city True the South China Sea and the 2 [' the', ' South', ' China']
+73 14 The capital city of x -1 The capital city of China Beijing China [',', ' Beijing', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' rich', ' history', ',', ' a', ' city'] , Beijing is a city of contrasts . It is a modern city with a rich history , a city True " is located in the China section of PortAventura.
+" 4 [' is', ' located', ' in', ' the', ' China']
+74 14 The capital city of x -1 The capital city of China Beijing China [',', ' Beijing', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' rich', ' history', ',', ' a', ' city'] , Beijing is a city of contrasts . It is a modern city with a rich history , a city True transferred to the China Station on 28 August 3 [' transferred', ' to', ' the', ' China']
+75 15 The capital city of x -1 The capital city of Japan Tokyo Japan "[' is', ' Tokyo', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ' and', ' the', ' world', ""'s"", ' most', ' populous', ' city', ' proper', ',']" is Tokyo . It is the most populous city in Japan and the world 's most populous city proper , True was popular in Japan, it has sold 18 3 [' was', ' popular', ' in', ' Japan']
+76 15 The capital city of x -1 The capital city of Japan Tokyo Japan "[' is', ' Tokyo', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ' and', ' the', ' world', ""'s"", ' most', ' populous', ' city', ' proper', ',']" is Tokyo . It is the most populous city in Japan and the world 's most populous city proper , True O. Reischauer, a Japan expert for the U.S. 7 [' O', '.', ' Re', 'isch', 'auer', ',', ' a', ' Japan']
+77 15 The capital city of x -1 The capital city of Japan Tokyo Japan "[' is', ' Tokyo', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ' and', ' the', ' world', ""'s"", ' most', ' populous', ' city', ' proper', ',']" is Tokyo . It is the most populous city in Japan and the world 's most populous city proper , True Studio Coast in Tokyo, Japan on December 7. 5 [' Studio', ' Coast', ' in', ' Tokyo', ',', ' Japan']
+78 15 The capital city of x -1 The capital city of Japan Tokyo Japan "[' is', ' Tokyo', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ' and', ' the', ' world', ""'s"", ' most', ' populous', ' city', ' proper', ',']" is Tokyo . It is the most populous city in Japan and the world 's most populous city proper , True re-released in both Japan and in the United 5 [' re', '-', 'released', ' in', ' both', ' Japan']
+79 15 The capital city of x -1 The capital city of Japan Tokyo Japan "[' is', ' Tokyo', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ' and', ' the', ' world', ""'s"", ' most', ' populous', ' city', ' proper', ',']" is Tokyo . It is the most populous city in Japan and the world 's most populous city proper , True The Penal system of Japan (矯正施設) is operated 4 [' The', ' Penal', ' system', ' of', ' Japan']
+80 16 The capital city of x -1 The capital city of South Korea Seoul South Korea [',', ' Seoul', ',', ' is', ' a', ' bustling', ' met', 'ropolis', ' of', ' over', ' 10', ' million', ' people', '.', ' It', ' is', ' the', ' largest', ' city', ' in'] , Seoul , is a bustling met ropolis of over 10 million people . It is the largest city in True also entered the South Korea International Digital 4 [' also', ' entered', ' the', ' South', ' Korea']
+81 16 The capital city of x -1 The capital city of South Korea Seoul South Korea [',', ' Seoul', ',', ' is', ' a', ' bustling', ' met', 'ropolis', ' of', ' over', ' 10', ' million', ' people', '.', ' It', ' is', ' the', ' largest', ' city', ' in'] , Seoul , is a bustling met ropolis of over 10 million people . It is the largest city in True hosted jointly by South Korea and Japan, was 4 [' hosted', ' jointly', ' by', ' South', ' Korea']
+82 16 The capital city of x -1 The capital city of South Korea Seoul South Korea [',', ' Seoul', ',', ' is', ' a', ' bustling', ' met', 'ropolis', ' of', ' over', ' 10', ' million', ' people', '.', ' It', ' is', ' the', ' largest', ' city', ' in'] , Seoul , is a bustling met ropolis of over 10 million people . It is the largest city in True theology in South Korea reflects a 3 [' theology', ' in', ' South', ' Korea']
+83 16 The capital city of x -1 The capital city of South Korea Seoul South Korea [',', ' Seoul', ',', ' is', ' a', ' bustling', ' met', 'ropolis', ' of', ' over', ' 10', ' million', ' people', '.', ' It', ' is', ' the', ' largest', ' city', ' in'] , Seoul , is a bustling met ropolis of over 10 million people . It is the largest city in True (US $ 40 million) and South Korea (US $ 38 million). 8 [' (', 'US', ' $', ' 40', ' million', ')', ' and', ' South', ' Korea']
+84 16 The capital city of x -1 The capital city of South Korea Seoul South Korea [',', ' Seoul', ',', ' is', ' a', ' bustling', ' met', 'ropolis', ' of', ' over', ' 10', ' million', ' people', '.', ' It', ' is', ' the', ' largest', ' city', ' in'] , Seoul , is a bustling met ropolis of over 10 million people . It is the largest city in True Namyangju, Gyeonggi-do, South Korea where it was 13 [' N', 'amy', 'ang', 'ju', ',', ' G', 'ye', 'ong', 'gi', '-', 'do', ',', ' South', ' Korea']
+85 17 The capital city of x -1 The capital city of India New Delhi India [',', ' New', ' Delhi', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' grand', ' pal', 'aces', ',', ' of', ' temples'] , New Delhi is a city of contrasts . It is a city of grand pal aces , of temples True incognito about India in an enigmatic 4 [' inc', 'ogn', 'ito', ' about', ' India']
+86 17 The capital city of x -1 The capital city of India New Delhi India [',', ' New', ' Delhi', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' grand', ' pal', 'aces', ',', ' of', ' temples'] , New Delhi is a city of contrasts . It is a city of grand pal aces , of temples True revived his work in India. He regularly 4 [' revived', ' his', ' work', ' in', ' India']
+87 17 The capital city of x -1 The capital city of India New Delhi India [',', ' New', ' Delhi', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' grand', ' pal', 'aces', ',', ' of', ' temples'] , New Delhi is a city of contrasts . It is a city of grand pal aces , of temples True those of China and India aspire to 4 [' those', ' of', ' China', ' and', ' India']
+88 17 The capital city of x -1 The capital city of India New Delhi India [',', ' New', ' Delhi', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' grand', ' pal', 'aces', ',', ' of', ' temples'] , New Delhi is a city of contrasts . It is a city of grand pal aces , of temples True parties in India use the money 2 [' parties', ' in', ' India']
+89 17 The capital city of x -1 The capital city of India New Delhi India [',', ' New', ' Delhi', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' grand', ' pal', 'aces', ',', ' of', ' temples'] , New Delhi is a city of contrasts . It is a city of grand pal aces , of temples True The Times of India found a theme of 3 [' The', ' Times', ' of', ' India']
+90 18 The capital city of x -1 The capital city of Pakistan Islamabad Pakistan [',', ' Islamabad', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' modern', 'ity', ' and', ' tradition', ',', ' of', ' modern'] , Islamabad is a city of contrasts . It is a city of modern ity and tradition , of modern True Indo-West Pacific, from Pakistan to New Guinea. 6 [' Indo', '-', 'West', ' Pacific', ',', ' from', ' Pakistan']
+91 18 The capital city of x -1 The capital city of Pakistan Islamabad Pakistan [',', ' Islamabad', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' modern', 'ity', ' and', ' tradition', ',', ' of', ' modern'] , Islamabad is a city of contrasts . It is a city of modern ity and tradition , of modern True England to Pakistan and India 2 [' England', ' to', ' Pakistan']
+92 18 The capital city of x -1 The capital city of Pakistan Islamabad Pakistan [',', ' Islamabad', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' modern', 'ity', ' and', ' tradition', ',', ' of', ' modern'] , Islamabad is a city of contrasts . It is a city of modern ity and tradition , of modern True denials by the Pakistani military, 4 [' den', 'ials', ' by', ' the', ' Pakistan']
+93 18 The capital city of x -1 The capital city of Pakistan Islamabad Pakistan [',', ' Islamabad', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' modern', 'ity', ' and', ' tradition', ',', ' of', ' modern'] , Islamabad is a city of contrasts . It is a city of modern ity and tradition , of modern True the Demand for Pakistan (paperback ed.). Cambridge: 3 [' the', ' Demand', ' for', ' Pakistan']
+94 18 The capital city of x -1 The capital city of Pakistan Islamabad Pakistan [',', ' Islamabad', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' city', ' of', ' modern', 'ity', ' and', ' tradition', ',', ' of', ' modern'] , Islamabad is a city of contrasts . It is a city of modern ity and tradition , of modern True China and with Pakistan. The India – Pakistan 3 [' China', ' and', ' with', ' Pakistan']
+95 19 The capital city of x -1 The capital city of Nigeria Abuja Nigeria [',', ' Abu', 'ja', ',', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 2', ' million', ' people', '.', ' It', ' is', ' the'] , Abu ja , is a modern city with a population of over 2 million people . It is the True assistance to the Nigerian-led forces in Sierra 3 [' assistance', ' to', ' the', ' Nigeria']
+96 19 The capital city of x -1 The capital city of Nigeria Abuja Nigeria [',', ' Abu', 'ja', ',', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 2', ' million', ' people', '.', ' It', ' is', ' the'] , Abu ja , is a modern city with a population of over 2 million people . It is the True African Rifles and the Nigeria and Gold Coast 5 [' African', ' R', 'ifles', ' and', ' the', ' Nigeria']
+97 19 The capital city of x -1 The capital city of Nigeria Abuja Nigeria [',', ' Abu', 'ja', ',', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 2', ' million', ' people', '.', ' It', ' is', ' the'] , Abu ja , is a modern city with a population of over 2 million people . It is the True supposed to play Nigeria but withdrew from 3 [' supposed', ' to', ' play', ' Nigeria']
+98 19 The capital city of x -1 The capital city of Nigeria Abuja Nigeria [',', ' Abu', 'ja', ',', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 2', ' million', ' people', '.', ' It', ' is', ' the'] , Abu ja , is a modern city with a population of over 2 million people . It is the True children in India, Nigeria and elsewhere, 4 [' children', ' in', ' India', ',', ' Nigeria']
+99 19 The capital city of x -1 The capital city of Nigeria Abuja Nigeria [',', ' Abu', 'ja', ',', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 2', ' million', ' people', '.', ' It', ' is', ' the'] , Abu ja , is a modern city with a population of over 2 million people . It is the True Cameroon and Nigeria to the southwest 2 [' Cameroon', ' and', ' Nigeria']
+100 20 The capital city of x -1 The capital city of Egypt Cairo Egypt [',', ' Cairo', ',', ' is', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' the', ' Middle', ' East', ' and', ' the'] , Cairo , is a major tourist destination . It is the largest city in the Middle East and the True antiquities excavated in Egypt, though the 5 [' antiqu', 'ities', ' excav', 'ated', ' in', ' Egypt']
+101 20 The capital city of x -1 The capital city of Egypt Cairo Egypt [',', ' Cairo', ',', ' is', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' the', ' Middle', ' East', ' and', ' the'] , Cairo , is a major tourist destination . It is the largest city in the Middle East and the True Germany, France and Egypt was present in 4 [' Germany', ',', ' France', ' and', ' Egypt']
+102 20 The capital city of x -1 The capital city of Egypt Cairo Egypt [',', ' Cairo', ',', ' is', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' the', ' Middle', ' East', ' and', ' the'] , Cairo , is a major tourist destination . It is the largest city in the Middle East and the True to Port Said, Egypt after the end 4 [' to', ' Port', ' Said', ',', ' Egypt']
+103 20 The capital city of x -1 The capital city of Egypt Cairo Egypt [',', ' Cairo', ',', ' is', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' the', ' Middle', ' East', ' and', ' the'] , Cairo , is a major tourist destination . It is the largest city in the Middle East and the True War of 1973, Egypt had 20 mobile 4 [' War', ' of', ' 1973', ',', ' Egypt']
+104 20 The capital city of x -1 The capital city of Egypt Cairo Egypt [',', ' Cairo', ',', ' is', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' the', ' Middle', ' East', ' and', ' the'] , Cairo , is a major tourist destination . It is the largest city in the Middle East and the True of Isis spread from Egypt across the Mediterranean 4 [' of', ' Isis', ' spread', ' from', ' Egypt']
+105 21 The capital city of x -1 The capital city of Saudi Arabia Riyadh Saudi Arabia [',', ' Riyadh', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 3'] , Riyadh , is a city of contrasts . It is a modern city with a population of over 3 True from nearby Saudi Arabia and across the 3 [' from', ' nearby', ' Saudi', ' Arabia']
+106 21 The capital city of x -1 The capital city of Saudi Arabia Riyadh Saudi Arabia [',', ' Riyadh', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 3'] , Riyadh , is a city of contrasts . It is a modern city with a population of over 3 True billion. The Kingdom of Saudi Arabia was Dubai ’ s 6 [' billion', '.', ' The', ' Kingdom', ' of', ' Saudi', ' Arabia']
+107 21 The capital city of x -1 The capital city of Saudi Arabia Riyadh Saudi Arabia [',', ' Riyadh', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 3'] , Riyadh , is a city of contrasts . It is a modern city with a population of over 3 True of the conflict to Saudi Arabia would lead 5 [' of', ' the', ' conflict', ' to', ' Saudi', ' Arabia']
+108 21 The capital city of x -1 The capital city of Saudi Arabia Riyadh Saudi Arabia [',', ' Riyadh', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 3'] , Riyadh , is a city of contrasts . It is a modern city with a population of over 3 True security assistance from Saudi Arabia and other GCC countries 4 [' security', ' assistance', ' from', ' Saudi', ' Arabia']
+109 21 The capital city of x -1 The capital city of Saudi Arabia Riyadh Saudi Arabia [',', ' Riyadh', ',', ' is', ' a', ' city', ' of', ' contrasts', '.', ' It', ' is', ' a', ' modern', ' city', ' with', ' a', ' population', ' of', ' over', ' 3'] , Riyadh , is a city of contrasts . It is a modern city with a population of over 3 True " southwestern Saudi Arabia and Yemen.
+" 2 [' southwestern', ' Saudi', ' Arabia']
+110 22 The capital city of x -1 The capital city of Turkey Ankara Turkey [',', ' Ankara', ' is', ' a', ' modern', ',', ' cos', 'mopolitan', ' city', ' with', ' a', ' rich', ' history', '.', ' It', ' is', ' the', ' political', ',', ' economic'] , Ankara is a modern , cos mopolitan city with a rich history . It is the political , economic True Kazakhstan (13 %), Turkey (11 %), India (10 5 [' Kazakhstan', ' (', '13', ' %', '),', ' Turkey']
+111 22 The capital city of x -1 The capital city of Turkey Ankara Turkey [',', ' Ankara', ' is', ' a', ' modern', ',', ' cos', 'mopolitan', ' city', ' with', ' a', ' rich', ' history', '.', ' It', ' is', ' the', ' political', ',', ' economic'] , Ankara is a modern , cos mopolitan city with a rich history . It is the political , economic True runners-up Italy and Turkey eliminated 5 [' runners', '-', 'up', ' Italy', ' and', ' Turkey']
+112 22 The capital city of x -1 The capital city of Turkey Ankara Turkey [',', ' Ankara', ' is', ' a', ' modern', ',', ' cos', 'mopolitan', ' city', ' with', ' a', ' rich', ' history', '.', ' It', ' is', ' the', ' political', ',', ' economic'] , Ankara is a modern , cos mopolitan city with a rich history . It is the political , economic True " Turkey ===
+" 0 [' Turkey']
+113 22 The capital city of x -1 The capital city of Turkey Ankara Turkey [',', ' Ankara', ' is', ' a', ' modern', ',', ' cos', 'mopolitan', ' city', ' with', ' a', ' rich', ' history', '.', ' It', ' is', ' the', ' political', ',', ' economic'] , Ankara is a modern , cos mopolitan city with a rich history . It is the political , economic True non-Muslims in Turkey fell from 19 percent 4 [' non', '-', 'Muslims', ' in', ' Turkey']
+114 22 The capital city of x -1 The capital city of Turkey Ankara Turkey [',', ' Ankara', ' is', ' a', ' modern', ',', ' cos', 'mopolitan', ' city', ' with', ' a', ' rich', ' history', '.', ' It', ' is', ' the', ' political', ',', ' economic'] , Ankara is a modern , cos mopolitan city with a rich history . It is the political , economic True connection point between Turkey and Azerbaijan. Under 3 [' connection', ' point', ' between', ' Turkey']
+115 23 The capital city of x -1 The capital city of Australia Canberra Australia [' is', ' Canberra', '.', ' It', ' is', ' located', ' in', ' the', ' Australian', ' Capital', ' Territory', ' (', 'ACT', ')', ' and', ' is', ' the', ' seat', ' of', ' government'] is Canberra . It is located in the Australian Capital Territory ( ACT ) and is the seat of government True trademarks in Australia as it was 2 [' trademarks', ' in', ' Australia']
+116 23 The capital city of x -1 The capital city of Australia Canberra Australia [' is', ' Canberra', '.', ' It', ' is', ' located', ' in', ' the', ' Australian', ' Capital', ' Territory', ' (', 'ACT', ')', ' and', ' is', ' the', ' seat', ' of', ' government'] is Canberra . It is located in the Australian Capital Territory ( ACT ) and is the seat of government True 0 ['Australia']
+117 23 The capital city of x -1 The capital city of Australia Canberra Australia [' is', ' Canberra', '.', ' It', ' is', ' located', ' in', ' the', ' Australian', ' Capital', ' Territory', ' (', 'ACT', ')', ' and', ' is', ' the', ' seat', ' of', ' government'] is Canberra . It is located in the Australian Capital Territory ( ACT ) and is the seat of government True rubber loss against Australia: Australia could 3 [' rubber', ' loss', ' against', ' Australia']
+118 23 The capital city of x -1 The capital city of Australia Canberra Australia [' is', ' Canberra', '.', ' It', ' is', ' located', ' in', ' the', ' Australian', ' Capital', ' Territory', ' (', 'ACT', ')', ' and', ' is', ' the', ' seat', ' of', ' government'] is Canberra . It is located in the Australian Capital Territory ( ACT ) and is the seat of government True band returned to Australia for the first time 3 [' band', ' returned', ' to', ' Australia']
+119 23 The capital city of x -1 The capital city of Australia Canberra Australia [' is', ' Canberra', '.', ' It', ' is', ' located', ' in', ' the', ' Australian', ' Capital', ' Territory', ' (', 'ACT', ')', ' and', ' is', ' the', ' seat', ' of', ' government'] is Canberra . It is located in the Australian Capital Territory ( ACT ) and is the seat of government True India, Mexico, Peru and Australia gave Compton 6 [' India', ',', ' Mexico', ',', ' Peru', ' and', ' Australia']
diff --git a/patchscopes/code/preprocessed_data/factual/country_currency.tsv b/patchscopes/code/preprocessed_data/factual/country_currency.tsv
new file mode 100644
index 00000000..f645835a
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/country_currency.tsv
@@ -0,0 +1,187 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The official currency of x -1 The official currency of United States Dollar United States [' is', ' Dollar', '.', ' The', ' official', ' language', ' of', ' United', ' States', ' is', ' English', '.', ' The', ' total', ' area', ' of', ' United', ' States', ' is', ' 3'] is Dollar . The official language of United States is English . The total area of United States is 3 True released in the United States on May 2, 2003. 4 [' released', ' in', ' the', ' United', ' States']
+1 0 The official currency of x -1 The official currency of United States Dollar United States [' is', ' Dollar', '.', ' The', ' official', ' language', ' of', ' United', ' States', ' is', ' English', '.', ' The', ' total', ' area', ' of', ' United', ' States', ' is', ' 3'] is Dollar . The official language of United States is English . The total area of United States is 3 True courtroom. In 2006 the United States Department of 6 [' courtroom', '.', ' In', ' 2006', ' the', ' United', ' States']
+2 0 The official currency of x -1 The official currency of United States Dollar United States [' is', ' Dollar', '.', ' The', ' official', ' language', ' of', ' United', ' States', ' is', ' English', '.', ' The', ' total', ' area', ' of', ' United', ' States', ' is', ' 3'] is Dollar . The official language of United States is English . The total area of United States is 3 True " began to be built by United States Realty Company.
+" 6 [' began', ' to', ' be', ' built', ' by', ' United', ' States']
+3 0 The official currency of x -1 The official currency of United States Dollar United States [' is', ' Dollar', '.', ' The', ' official', ' language', ' of', ' United', ' States', ' is', ' English', '.', ' The', ' total', ' area', ' of', ' United', ' States', ' is', ' 3'] is Dollar . The official language of United States is English . The total area of United States is 3 True 1 ['United', ' States']
+4 0 The official currency of x -1 The official currency of United States Dollar United States [' is', ' Dollar', '.', ' The', ' official', ' language', ' of', ' United', ' States', ' is', ' English', '.', ' The', ' total', ' area', ' of', ' United', ' States', ' is', ' 3'] is Dollar . The official language of United States is English . The total area of United States is 3 True of World War II. United States Army Center of 6 [' of', ' World', ' War', ' II', '.', ' United', ' States']
+5 1 The official currency of x -1 The official currency of United Kingdom Pound United Kingdom [' is', ' pound', ' sterling', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' p', 'ence', '.', ' The', ' currency', ' is', ' used', ' in', ' United', ' Kingdom'] is pound sterling . The currency is divided into 100 p ence . The currency is used in United Kingdom False surrendered to the United Kingdom as the war prize 4 [' surrendered', ' to', ' the', ' United', ' Kingdom']
+6 1 The official currency of x -1 The official currency of United Kingdom Pound United Kingdom [' is', ' pound', ' sterling', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' p', 'ence', '.', ' The', ' currency', ' is', ' used', ' in', ' United', ' Kingdom'] is pound sterling . The currency is divided into 100 p ence . The currency is used in United Kingdom False " Kingdom ====
+" 4 [' Kingdom', ' =', '===', 'United', ' Kingdom']
+7 1 The official currency of x -1 The official currency of United Kingdom Pound United Kingdom [' is', ' pound', ' sterling', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' p', 'ence', '.', ' The', ' currency', ' is', ' used', ' in', ' United', ' Kingdom'] is pound sterling . The currency is divided into 100 p ence . The currency is used in United Kingdom False first aired in the United Kingdom on BBC One on 5 [' first', ' aired', ' in', ' the', ' United', ' Kingdom']
+8 1 The official currency of x -1 The official currency of United Kingdom Pound United Kingdom [' is', ' pound', ' sterling', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' p', 'ence', '.', ' The', ' currency', ' is', ' used', ' in', ' United', ' Kingdom'] is pound sterling . The currency is divided into 100 p ence . The currency is used in United Kingdom False throughout the United Kingdom and, as of 2014, 3 [' throughout', ' the', ' United', ' Kingdom']
+9 1 The official currency of x -1 The official currency of United Kingdom Pound United Kingdom [' is', ' pound', ' sterling', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' p', 'ence', '.', ' The', ' currency', ' is', ' used', ' in', ' United', ' Kingdom'] is pound sterling . The currency is divided into 100 p ence . The currency is used in United Kingdom False and cities in the United Kingdom over 23 different 5 [' and', ' cities', ' in', ' the', ' United', ' Kingdom']
+10 2 The official currency of x -1 The official currency of Japan Yen Japan [' is', ' the', ' Yen', '.', ' The', ' Japanese', ' Yen', ' is', ' the', ' official', ' currency', ' of', ' Japan', '.', ' The', ' Japanese', ' Yen', ' is', ' also', ' known'] is the Yen . The Japanese Yen is the official currency of Japan . The Japanese Yen is also known True missions against Japan since mid-February, 2 [' missions', ' against', ' Japan']
+11 2 The official currency of x -1 The official currency of Japan Yen Japan [' is', ' the', ' Yen', '.', ' The', ' Japanese', ' Yen', ' is', ' the', ' official', ' currency', ' of', ' Japan', '.', ' The', ' Japanese', ' Yen', ' is', ' also', ' known'] is the Yen . The Japanese Yen is the official currency of Japan . The Japanese Yen is also known True 4.4 million units in Japan as of April 2009; 6 [' 4', '.', '4', ' million', ' units', ' in', ' Japan']
+12 2 The official currency of x -1 The official currency of Japan Yen Japan [' is', ' the', ' Yen', '.', ' The', ' Japanese', ' Yen', ' is', ' the', ' official', ' currency', ' of', ' Japan', '.', ' The', ' Japanese', ' Yen', ' is', ' also', ' known'] is the Yen . The Japanese Yen is the official currency of Japan . The Japanese Yen is also known True outbreak of war with Japan prompted a 4 [' outbreak', ' of', ' war', ' with', ' Japan']
+13 2 The official currency of x -1 The official currency of Japan Yen Japan [' is', ' the', ' Yen', '.', ' The', ' Japanese', ' Yen', ' is', ' the', ' official', ' currency', ' of', ' Japan', '.', ' The', ' Japanese', ' Yen', ' is', ' also', ' known'] is the Yen . The Japanese Yen is the official currency of Japan . The Japanese Yen is also known True only victory in the Japan Series, largely 4 [' only', ' victory', ' in', ' the', ' Japan']
+14 2 The official currency of x -1 The official currency of Japan Yen Japan [' is', ' the', ' Yen', '.', ' The', ' Japanese', ' Yen', ' is', ' the', ' official', ' currency', ' of', ' Japan', '.', ' The', ' Japanese', ' Yen', ' is', ' also', ' known'] is the Yen . The Japanese Yen is the official currency of Japan . The Japanese Yen is also known True in a translated Japanese language edition 3 [' in', ' a', ' translated', ' Japan']
+15 3 The official currency of x -1 The official currency of Canada Dollar Canada [' is', ' the', ' Canadian', ' Dollar', '.', ' The', ' Canadian', ' Dollar', ' is', ' a', ' currency', ' of', ' Canada', '.', ' It', ' is', ' also', ' the', ' currency', ' of'] is the Canadian Dollar . The Canadian Dollar is a currency of Canada . It is also the currency of True escutcheon in the Arms of Canada superimposed 7 [' esc', 'ut', 'cheon', ' in', ' the', ' Arms', ' of', ' Canada']
+16 3 The official currency of x -1 The official currency of Canada Dollar Canada [' is', ' the', ' Canadian', ' Dollar', '.', ' The', ' Canadian', ' Dollar', ' is', ' a', ' currency', ' of', ' Canada', '.', ' It', ' is', ' also', ' the', ' currency', ' of'] is the Canadian Dollar . The Canadian Dollar is a currency of Canada . It is also the currency of True Chile only bought Canada and four destroyers, 3 [' Chile', ' only', ' bought', ' Canada']
+17 3 The official currency of x -1 The official currency of Canada Dollar Canada [' is', ' the', ' Canadian', ' Dollar', '.', ' The', ' Canadian', ' Dollar', ' is', ' a', ' currency', ' of', ' Canada', '.', ' It', ' is', ' also', ' the', ' currency', ' of'] is the Canadian Dollar . The Canadian Dollar is a currency of Canada . It is also the currency of True settlement between Upper Canada and the Pacific 3 [' settlement', ' between', ' Upper', ' Canada']
+18 3 The official currency of x -1 The official currency of Canada Dollar Canada [' is', ' the', ' Canadian', ' Dollar', '.', ' The', ' Canadian', ' Dollar', ' is', ' a', ' currency', ' of', ' Canada', '.', ' It', ' is', ' also', ' the', ' currency', ' of'] is the Canadian Dollar . The Canadian Dollar is a currency of Canada . It is also the currency of True entered at number two in Canada and the United Kingdom, 5 [' entered', ' at', ' number', ' two', ' in', ' Canada']
+19 3 The official currency of x -1 The official currency of Canada Dollar Canada [' is', ' the', ' Canadian', ' Dollar', '.', ' The', ' Canadian', ' Dollar', ' is', ' a', ' currency', ' of', ' Canada', '.', ' It', ' is', ' also', ' the', ' currency', ' of'] is the Canadian Dollar . The Canadian Dollar is a currency of Canada . It is also the currency of True scenery in Canada, including vistas 2 [' scenery', ' in', ' Canada']
+20 4 The official currency of x -1 The official currency of Australia Dollar Australia [' is', ' the', ' Australian', ' Dollar', '.', ' The', ' Australian', ' Dollar', ' is', ' the', ' currency', ' of', ' Australia', ' and', ' New', ' Zealand', '.', ' The', ' Australian', ' Dollar'] is the Australian Dollar . The Australian Dollar is the currency of Australia and New Zealand . The Australian Dollar True first is that sinking Australia was a major 4 [' first', ' is', ' that', ' sinking', ' Australia']
+21 4 The official currency of x -1 The official currency of Australia Dollar Australia [' is', ' the', ' Australian', ' Dollar', '.', ' The', ' Australian', ' Dollar', ' is', ' the', ' currency', ' of', ' Australia', ' and', ' New', ' Zealand', '.', ' The', ' Australian', ' Dollar'] is the Australian Dollar . The Australian Dollar is the currency of Australia and New Zealand . The Australian Dollar True 0 ['Australia']
+22 4 The official currency of x -1 The official currency of Australia Dollar Australia [' is', ' the', ' Australian', ' Dollar', '.', ' The', ' Australian', ' Dollar', ' is', ' the', ' currency', ' of', ' Australia', ' and', ' New', ' Zealand', '.', ' The', ' Australian', ' Dollar'] is the Australian Dollar . The Australian Dollar is the currency of Australia and New Zealand . The Australian Dollar True 0 ['Australia']
+23 4 The official currency of x -1 The official currency of Australia Dollar Australia [' is', ' the', ' Australian', ' Dollar', '.', ' The', ' Australian', ' Dollar', ' is', ' the', ' currency', ' of', ' Australia', ' and', ' New', ' Zealand', '.', ' The', ' Australian', ' Dollar'] is the Australian Dollar . The Australian Dollar is the currency of Australia and New Zealand . The Australian Dollar True October 19 in Australia and October 21 3 [' October', ' 19', ' in', ' Australia']
+24 4 The official currency of x -1 The official currency of Australia Dollar Australia [' is', ' the', ' Australian', ' Dollar', '.', ' The', ' Australian', ' Dollar', ' is', ' the', ' currency', ' of', ' Australia', ' and', ' New', ' Zealand', '.', ' The', ' Australian', ' Dollar'] is the Australian Dollar . The Australian Dollar is the currency of Australia and New Zealand . The Australian Dollar True flying. He visited Australia for the first time 4 [' flying', '.', ' He', ' visited', ' Australia']
+25 5 The official currency of x -1 The official currency of Brazil Real Brazil [' is', ' the', ' real', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.'] is the real . The real is the currency of Brazil . The real is the currency of Brazil . False de Janeiro, Brazil was made on 3 [' de', ' Janeiro', ',', ' Brazil']
+26 5 The official currency of x -1 The official currency of Brazil Real Brazil [' is', ' the', ' real', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.'] is the real . The real is the currency of Brazil . The real is the currency of Brazil . False the history of Brazilian photography 3 [' the', ' history', ' of', ' Brazil']
+27 5 The official currency of x -1 The official currency of Brazil Real Brazil [' is', ' the', ' real', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.'] is the real . The real is the currency of Brazil . The real is the currency of Brazil . False off the coast of Brazil and Java was 4 [' off', ' the', ' coast', ' of', ' Brazil']
+28 5 The official currency of x -1 The official currency of Brazil Real Brazil [' is', ' the', ' real', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.'] is the real . The real is the currency of Brazil . The real is the currency of Brazil . False 2006, at 15: 35 Brazil Standard Time (BST) 6 [' 2006', ',', ' at', ' 15', ':', ' 35', ' Brazil']
+29 5 The official currency of x -1 The official currency of Brazil Real Brazil [' is', ' the', ' real', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.', ' The', ' real', ' is', ' the', ' currency', ' of', ' Brazil', '.'] is the real . The real is the currency of Brazil . The real is the currency of Brazil . False south resulted in Brazil's loss of Cisplatina. 3 [' south', ' resulted', ' in', ' Brazil']
+30 6 The official currency of x -1 The official currency of China Yuan China [' is', ' the', ' Yuan', '.', ' The', ' Yuan', ' is', ' a', ' unit', ' of', ' currency', ' used', ' in', ' China', '.', ' The', ' Yuan', ' is', ' divided', ' into'] is the Yuan . The Yuan is a unit of currency used in China . The Yuan is divided into True the coast of China near Shanghai in 3 [' the', ' coast', ' of', ' China']
+31 6 The official currency of x -1 The official currency of China Yuan China [' is', ' the', ' Yuan', '.', ' The', ' Yuan', ' is', ' a', ' unit', ' of', ' currency', ' used', ' in', ' China', '.', ' The', ' Yuan', ' is', ' divided', ' into'] is the Yuan . The Yuan is a unit of currency used in China . The Yuan is divided into True crisis between China and Mexico, with the 2 [' crisis', ' between', ' China']
+32 6 The official currency of x -1 The official currency of China Yuan China [' is', ' the', ' Yuan', '.', ' The', ' Yuan', ' is', ' a', ' unit', ' of', ' currency', ' used', ' in', ' China', '.', ' The', ' Yuan', ' is', ' divided', ' into'] is the Yuan . The Yuan is a unit of currency used in China . The Yuan is divided into True described from China in 2003, G. brunneofloccosus 2 [' described', ' from', ' China']
+33 6 The official currency of x -1 The official currency of China Yuan China [' is', ' the', ' Yuan', '.', ' The', ' Yuan', ' is', ' a', ' unit', ' of', ' currency', ' used', ' in', ' China', '.', ' The', ' Yuan', ' is', ' divided', ' into'] is the Yuan . The Yuan is a unit of currency used in China . The Yuan is divided into True the PlayStation China press conference 2 [' the', ' PlayStation', ' China']
+34 6 The official currency of x -1 The official currency of China Yuan China [' is', ' the', ' Yuan', '.', ' The', ' Yuan', ' is', ' a', ' unit', ' of', ' currency', ' used', ' in', ' China', '.', ' The', ' Yuan', ' is', ' divided', ' into'] is the Yuan . The Yuan is a unit of currency used in China . The Yuan is divided into True trip in the South China Sea at coordinates 4 [' trip', ' in', ' the', ' South', ' China']
+35 7 The official currency of x -1 The official currency of India Rupee India [' is', ' the', ' Indian', ' Ru', 'pee', ' (', 'IN', 'R', ').', ' The', ' Indian', ' Ru', 'pee', ' is', ' divided', ' into', ' 100', ' pa', 'ise', '.'] is the Indian Ru pee ( IN R ). The Indian Ru pee is divided into 100 pa ise . True caves of southern India and similar evidence 3 [' caves', ' of', ' southern', ' India']
+36 7 The official currency of x -1 The official currency of India Rupee India [' is', ' the', ' Indian', ' Ru', 'pee', ' (', 'IN', 'R', ').', ' The', ' Indian', ' Ru', 'pee', ' is', ' divided', ' into', ' 100', ' pa', 'ise', '.'] is the Indian Ru pee ( IN R ). The Indian Ru pee is divided into 100 pa ise . True Transoxiana, India and the Visigothic 4 [' Trans', 'ox', 'iana', ',', ' India']
+37 7 The official currency of x -1 The official currency of India Rupee India [' is', ' the', ' Indian', ' Ru', 'pee', ' (', 'IN', 'R', ').', ' The', ' Indian', ' Ru', 'pee', ' is', ' divided', ' into', ' 100', ' pa', 'ise', '.'] is the Indian Ru pee ( IN R ). The Indian Ru pee is divided into 100 pa ise . True 0 ['India']
+38 7 The official currency of x -1 The official currency of India Rupee India [' is', ' the', ' Indian', ' Ru', 'pee', ' (', 'IN', 'R', ').', ' The', ' Indian', ' Ru', 'pee', ' is', ' divided', ' into', ' 100', ' pa', 'ise', '.'] is the Indian Ru pee ( IN R ). The Indian Ru pee is divided into 100 pa ise . True important cities in India. According to the 3 [' important', ' cities', ' in', ' India']
+39 7 The official currency of x -1 The official currency of India Rupee India [' is', ' the', ' Indian', ' Ru', 'pee', ' (', 'IN', 'R', ').', ' The', ' Indian', ' Ru', 'pee', ' is', ' divided', ' into', ' 100', ' pa', 'ise', '.'] is the Indian Ru pee ( IN R ). The Indian Ru pee is divided into 100 pa ise . True time, both in India and abroad, setting 4 [' time', ',', ' both', ' in', ' India']
+40 8 The official currency of x -1 The official currency of Russia Ruble Russia [' is', ' the', ' Russian', ' rub', 'le', '.', ' The', ' Russian', ' rub', 'le', ' is', ' a', ' currency', ' that', ' is', ' issued', ' by', ' the', ' Central', ' Bank'] is the Russian rub le . The Russian rub le is a currency that is issued by the Central Bank False include those with Russia and China. On February 3 [' include', ' those', ' with', ' Russia']
+41 8 The official currency of x -1 The official currency of Russia Ruble Russia [' is', ' the', ' Russian', ' rub', 'le', '.', ' The', ' Russian', ' rub', 'le', ' is', ' a', ' currency', ' that', ' is', ' issued', ' by', ' the', ' Central', ' Bank'] is the Russian rub le . The Russian rub le is a currency that is issued by the Central Bank False entry of Austria and Russia into the war 4 [' entry', ' of', ' Austria', ' and', ' Russia']
+42 8 The official currency of x -1 The official currency of Russia Ruble Russia [' is', ' the', ' Russian', ' rub', 'le', '.', ' The', ' Russian', ' rub', 'le', ' is', ' a', ' currency', ' that', ' is', ' issued', ' by', ' the', ' Central', ' Bank'] is the Russian rub le . The Russian rub le is a currency that is issued by the Central Bank False overtime loss to Russia. Returning for the 3 [' overtime', ' loss', ' to', ' Russia']
+43 8 The official currency of x -1 The official currency of Russia Ruble Russia [' is', ' the', ' Russian', ' rub', 'le', '.', ' The', ' Russian', ' rub', 'le', ' is', ' a', ' currency', ' that', ' is', ' issued', ' by', ' the', ' Central', ' Bank'] is the Russian rub le . The Russian rub le is a currency that is issued by the Central Bank False the Empress of Russia and the Princess 3 [' the', ' Empress', ' of', ' Russia']
+44 8 The official currency of x -1 The official currency of Russia Ruble Russia [' is', ' the', ' Russian', ' rub', 'le', '.', ' The', ' Russian', ' rub', 'le', ' is', ' a', ' currency', ' that', ' is', ' issued', ' by', ' the', ' Central', ' Bank'] is the Russian rub le . The Russian rub le is a currency that is issued by the Central Bank False City outside the Russian consulate on 3 [' City', ' outside', ' the', ' Russia']
+45 9 The official currency of x -1 The official currency of South Africa Rand South Africa [' is', ' the', ' Rand', '.', ' The', ' Rand', ' is', ' divided', ' into', ' 100', ' cents', '.', ' The', ' Rand', ' is', ' the', ' currency', ' of', ' South', ' Africa'] is the Rand . The Rand is divided into 100 cents . The Rand is the currency of South Africa True the public. As South African captain Jack 5 [' the', ' public', '.', ' As', ' South', ' Africa']
+46 9 The official currency of x -1 The official currency of South Africa Rand South Africa [' is', ' the', ' Rand', '.', ' The', ' Rand', ' is', ' divided', ' into', ' 100', ' cents', '.', ' The', ' Rand', ' is', ' the', ' currency', ' of', ' South', ' Africa'] is the Rand . The Rand is divided into 100 cents . The Rand is the currency of South Africa True England captaincy — South Africa played a Test series 5 [' England', ' captain', 'cy', ' —', ' South', ' Africa']
+47 9 The official currency of x -1 The official currency of South Africa Rand South Africa [' is', ' the', ' Rand', '.', ' The', ' Rand', ' is', ' divided', ' into', ' 100', ' cents', '.', ' The', ' Rand', ' is', ' the', ' currency', ' of', ' South', ' Africa'] is the Rand . The Rand is divided into 100 cents . The Rand is the currency of South Africa True from the Cape of South Africa – including ericas, 5 [' from', ' the', ' Cape', ' of', ' South', ' Africa']
+48 9 The official currency of x -1 The official currency of South Africa Rand South Africa [' is', ' the', ' Rand', '.', ' The', ' Rand', ' is', ' divided', ' into', ' 100', ' cents', '.', ' The', ' Rand', ' is', ' the', ' currency', ' of', ' South', ' Africa'] is the Rand . The Rand is divided into 100 cents . The Rand is the currency of South Africa True " South Africa ===
+" 1 [' South', ' Africa']
+49 9 The official currency of x -1 The official currency of South Africa Rand South Africa [' is', ' the', ' Rand', '.', ' The', ' Rand', ' is', ' divided', ' into', ' 100', ' cents', '.', ' The', ' Rand', ' is', ' the', ' currency', ' of', ' South', ' Africa'] is the Rand . The Rand is divided into 100 cents . The Rand is the currency of South Africa True built ten schools in South Africa after apartheid ended, 5 [' built', ' ten', ' schools', ' in', ' South', ' Africa']
+50 10 The official currency of x -1 The official currency of Mexico Peso Mexico [' is', ' the', ' Pes', 'o', '.', ' The', ' Pes', 'o', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.', '\n', '\n', 'The', ' official'] " is the Pes o . The Pes o is divided into 100 Cent av os .
+
+ The official" True Santa Fe, New Mexico opened an 4 [' Santa', ' Fe', ',', ' New', ' Mexico']
+51 10 The official currency of x -1 The official currency of Mexico Peso Mexico [' is', ' the', ' Pes', 'o', '.', ' The', ' Pes', 'o', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.', '\n', '\n', 'The', ' official'] " is the Pes o . The Pes o is divided into 100 Cent av os .
+
+ The official" True while over southeast Mexico, and upon moving 3 [' while', ' over', ' southeast', ' Mexico']
+52 10 The official currency of x -1 The official currency of Mexico Peso Mexico [' is', ' the', ' Pes', 'o', '.', ' The', ' Pes', 'o', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.', '\n', '\n', 'The', ' official'] " is the Pes o . The Pes o is divided into 100 Cent av os .
+
+ The official" True expansion of the Mexico City Metro 12 3 [' expansion', ' of', ' the', ' Mexico']
+53 10 The official currency of x -1 The official currency of Mexico Peso Mexico [' is', ' the', ' Pes', 'o', '.', ' The', ' Pes', 'o', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.', '\n', '\n', 'The', ' official'] " is the Pes o . The Pes o is divided into 100 Cent av os .
+
+ The official" True the Gulf of Mexico as the August hurricane. 3 [' the', ' Gulf', ' of', ' Mexico']
+54 10 The official currency of x -1 The official currency of Mexico Peso Mexico [' is', ' the', ' Pes', 'o', '.', ' The', ' Pes', 'o', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.', '\n', '\n', 'The', ' official'] " is the Pes o . The Pes o is divided into 100 Cent av os .
+
+ The official" True Summer Olympics in Mexico City, two American 3 [' Summer', ' Olympics', ' in', ' Mexico']
+55 11 The official currency of x -1 The official currency of New Zealand Dollar New Zealand [' is', ' the', ' New', ' Zealand', ' Dollar', ' (', 'NZ', 'D', ').', ' The', ' NZ', 'D', ' is', ' divided', ' into', ' 100', ' cents', '.', '\n', '\n'] " is the New Zealand Dollar ( NZ D ). The NZ D is divided into 100 cents .
+
+" True Australia and New Zealand (1948 – 51), 3 [' Australia', ' and', ' New', ' Zealand']
+56 11 The official currency of x -1 The official currency of New Zealand Dollar New Zealand [' is', ' the', ' New', ' Zealand', ' Dollar', ' (', 'NZ', 'D', ').', ' The', ' NZ', 'D', ' is', ' divided', ' into', ' 100', ' cents', '.', '\n', '\n'] " is the New Zealand Dollar ( NZ D ). The NZ D is divided into 100 cents .
+
+" True weeks, along with a New Zealand regiment. At 6 [' weeks', ',', ' along', ' with', ' a', ' New', ' Zealand']
+57 11 The official currency of x -1 The official currency of New Zealand Dollar New Zealand [' is', ' the', ' New', ' Zealand', ' Dollar', ' (', 'NZ', 'D', ').', ' The', ' NZ', 'D', ' is', ' divided', ' into', ' 100', ' cents', '.', '\n', '\n'] " is the New Zealand Dollar ( NZ D ). The NZ D is divided into 100 cents .
+
+" True Trevett of The New Zealand Herald originally 6 [' Tre', 've', 'tt', ' of', ' The', ' New', ' Zealand']
+58 11 The official currency of x -1 The official currency of New Zealand Dollar New Zealand [' is', ' the', ' New', ' Zealand', ' Dollar', ' (', 'NZ', 'D', ').', ' The', ' NZ', 'D', ' is', ' divided', ' into', ' 100', ' cents', '.', '\n', '\n'] " is the New Zealand Dollar ( NZ D ). The NZ D is divided into 100 cents .
+
+" True the reverse of the New Zealand fifty-cent coin, 5 [' the', ' reverse', ' of', ' the', ' New', ' Zealand']
+59 11 The official currency of x -1 The official currency of New Zealand Dollar New Zealand [' is', ' the', ' New', ' Zealand', ' Dollar', ' (', 'NZ', 'D', ').', ' The', ' NZ', 'D', ' is', ' divided', ' into', ' 100', ' cents', '.', '\n', '\n'] " is the New Zealand Dollar ( NZ D ). The NZ D is divided into 100 cents .
+
+" True reached number one in New Zealand and opened inside 5 [' reached', ' number', ' one', ' in', ' New', ' Zealand']
+60 12 The official currency of x -1 The official currency of South Korea Won South Korea [' is', ' the', ' won', ' (', 'KR', 'W', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate', ' is', ' 1'] is the won ( KR W ). The currency is divided into 100 sen . The exchange rate is 1 False surprise invasion of South Korea in June 1950, the North 4 [' surprise', ' invasion', ' of', ' South', ' Korea']
+61 12 The official currency of x -1 The official currency of South Korea Won South Korea [' is', ' the', ' won', ' (', 'KR', 'W', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate', ' is', ' 1'] is the won ( KR W ). The currency is divided into 100 sen . The exchange rate is 1 False and North and South Korea created a delicate 4 [' and', ' North', ' and', ' South', ' Korea']
+62 12 The official currency of x -1 The official currency of South Korea Won South Korea [' is', ' the', ' won', ' (', 'KR', 'W', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate', ' is', ' 1'] is the won ( KR W ). The currency is divided into 100 sen . The exchange rate is 1 False " was set by South Korea at 357.
+" 4 [' was', ' set', ' by', ' South', ' Korea']
+63 12 The official currency of x -1 The official currency of South Korea Won South Korea [' is', ' the', ' won', ' (', 'KR', 'W', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate', ' is', ' 1'] is the won ( KR W ). The currency is divided into 100 sen . The exchange rate is 1 False Canada, Scotland, South Korea and the United 5 [' Canada', ',', ' Scotland', ',', ' South', ' Korea']
+64 12 The official currency of x -1 The official currency of South Korea Won South Korea [' is', ' the', ' won', ' (', 'KR', 'W', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate', ' is', ' 1'] is the won ( KR W ). The currency is divided into 100 sen . The exchange rate is 1 False actively used in South Korea today, mostly for 4 [' actively', ' used', ' in', ' South', ' Korea']
+65 13 The official currency of x -1 The official currency of Switzerland Franc Switzerland [' is', ' the', ' Swiss', ' Franc', '.', ' The', ' Swiss', ' Franc', ' is', ' divided', ' into', ' 100', ' cent', 'imes', '.', ' The', ' Swiss', ' Franc', ' is', ' the'] is the Swiss Franc . The Swiss Franc is divided into 100 cent imes . The Swiss Franc is the True Zoo Basel in Switzerland holds the international 4 [' Zoo', ' Bas', 'el', ' in', ' Switzerland']
+66 13 The official currency of x -1 The official currency of Switzerland Franc Switzerland [' is', ' the', ' Swiss', ' Franc', '.', ' The', ' Swiss', ' Franc', ' is', ' divided', ' into', ' 100', ' cent', 'imes', '.', ' The', ' Swiss', ' Franc', ' is', ' the'] is the Swiss Franc . The Swiss Franc is divided into 100 cent imes . The Swiss Franc is the True Solti went to Switzerland to seek out Toscanini, 4 [' Sol', 'ti', ' went', ' to', ' Switzerland']
+67 13 The official currency of x -1 The official currency of Switzerland Franc Switzerland [' is', ' the', ' Swiss', ' Franc', '.', ' The', ' Swiss', ' Franc', ' is', ' divided', ' into', ' 100', ' cent', 'imes', '.', ' The', ' Swiss', ' Franc', ' is', ' the'] is the Swiss Franc . The Swiss Franc is divided into 100 cent imes . The Swiss Franc is the True his way to Switzerland in April 1943. Neame 3 [' his', ' way', ' to', ' Switzerland']
+68 13 The official currency of x -1 The official currency of Switzerland Franc Switzerland [' is', ' the', ' Swiss', ' Franc', '.', ' The', ' Swiss', ' Franc', ' is', ' divided', ' into', ' 100', ' cent', 'imes', '.', ' The', ' Swiss', ' Franc', ' is', ' the'] is the Swiss Franc . The Swiss Franc is divided into 100 cent imes . The Swiss Franc is the True 1 ['Sw', 'itzerland']
+69 13 The official currency of x -1 The official currency of Switzerland Franc Switzerland [' is', ' the', ' Swiss', ' Franc', '.', ' The', ' Swiss', ' Franc', ' is', ' divided', ' into', ' 100', ' cent', 'imes', '.', ' The', ' Swiss', ' Franc', ' is', ' the'] is the Swiss Franc . The Swiss Franc is divided into 100 cent imes . The Swiss Franc is the True the Union Bank of Switzerland (precursor of 4 [' the', ' Union', ' Bank', ' of', ' Switzerland']
+70 14 The official currency of x -1 The official currency of Turkey Lira Turkey [' is', ' the', ' Turkish', ' L', 'ira', ' (', 'TRY', ').', ' The', ' Turkish', ' L', 'ira', ' is', ' a', ' currency', ' that', ' is', ' pegged', ' to', ' the'] is the Turkish L ira ( TRY ). The Turkish L ira is a currency that is pegged to the True Ukraine, and Wales) Turkey (Iğdır Province), 5 [' Ukraine', ',', ' and', ' Wales', ')', ' Turkey']
+71 14 The official currency of x -1 The official currency of Turkey Lira Turkey [' is', ' the', ' Turkish', ' L', 'ira', ' (', 'TRY', ').', ' The', ' Turkish', ' L', 'ira', ' is', ' a', ' currency', ' that', ' is', ' pegged', ' to', ' the'] is the Turkish L ira ( TRY ). The Turkish L ira is a currency that is pegged to the True " Hill Road Meadow
+" 3 [' Hill', ' Road', ' Meadow', 'Turkey']
+72 14 The official currency of x -1 The official currency of Turkey Lira Turkey [' is', ' the', ' Turkish', ' L', 'ira', ' (', 'TRY', ').', ' The', ' Turkish', ' L', 'ira', ' is', ' a', ' currency', ' that', ' is', ' pegged', ' to', ' the'] is the Turkish L ira ( TRY ). The Turkish L ira is a currency that is pegged to the True annexed by Turkey with French 2 [' annexed', ' by', ' Turkey']
+73 14 The official currency of x -1 The official currency of Turkey Lira Turkey [' is', ' the', ' Turkish', ' L', 'ira', ' (', 'TRY', ').', ' The', ' Turkish', ' L', 'ira', ' is', ' a', ' currency', ' that', ' is', ' pegged', ' to', ' the'] is the Turkish L ira ( TRY ). The Turkish L ira is a currency that is pegged to the True " East Institute in Turkey between 2014 and 2015.
+" 3 [' East', ' Institute', ' in', ' Turkey']
+74 14 The official currency of x -1 The official currency of Turkey Lira Turkey [' is', ' the', ' Turkish', ' L', 'ira', ' (', 'TRY', ').', ' The', ' Turkish', ' L', 'ira', ' is', ' a', ' currency', ' that', ' is', ' pegged', ' to', ' the'] is the Turkish L ira ( TRY ). The Turkish L ira is a currency that is pegged to the True fled towards Turkey and Australia was 2 [' fled', ' towards', ' Turkey']
+75 15 The official currency of x -1 The official currency of Argentina Peso Argentina [' is', ' the', ' Pes', 'o', ' Argent', 'ino', '.', ' The', ' Pes', 'o', ' Argent', 'ino', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.'] is the Pes o Argent ino . The Pes o Argent ino is divided into 100 Cent av os . True 2 ['Ar', 'gent', 'ina']
+76 15 The official currency of x -1 The official currency of Argentina Peso Argentina [' is', ' the', ' Pes', 'o', ' Argent', 'ino', '.', ' The', ' Pes', 'o', ' Argent', 'ino', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.'] is the Pes o Argent ino . The Pes o Argent ino is divided into 100 Cent av os . True never released in Argentina. After ten months, 3 [' never', ' released', ' in', ' Argentina']
+77 15 The official currency of x -1 The official currency of Argentina Peso Argentina [' is', ' the', ' Pes', 'o', ' Argent', 'ino', '.', ' The', ' Pes', 'o', ' Argent', 'ino', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.'] is the Pes o Argent ino . The Pes o Argent ino is divided into 100 Cent av os . True were introduced to Argentina in the early 20th 3 [' were', ' introduced', ' to', ' Argentina']
+78 15 The official currency of x -1 The official currency of Argentina Peso Argentina [' is', ' the', ' Pes', 'o', ' Argent', 'ino', '.', ' The', ' Pes', 'o', ' Argent', 'ino', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.'] is the Pes o Argent ino . The Pes o Argent ino is divided into 100 Cent av os . True children across Argentina studied Berni's 2 [' children', ' across', ' Argentina']
+79 15 The official currency of x -1 The official currency of Argentina Peso Argentina [' is', ' the', ' Pes', 'o', ' Argent', 'ino', '.', ' The', ' Pes', 'o', ' Argent', 'ino', ' is', ' divided', ' into', ' 100', ' Cent', 'av', 'os', '.'] is the Pes o Argent ino . The Pes o Argent ino is divided into 100 Cent av os . True four group winners, Argentina, Yugoslavia, Uruguay 4 [' four', ' group', ' winners', ',', ' Argentina']
+80 16 The official currency of x -1 The official currency of Norway Krone Norway [' is', ' the', ' Norwegian', ' k', 'rone', ' (', 'N', 'OK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Norwegian k rone ( N OK ). The k rone is divided into 100 � � re . False were stationed in Norway at various times 3 [' were', ' stationed', ' in', ' Norway']
+81 16 The official currency of x -1 The official currency of Norway Krone Norway [' is', ' the', ' Norwegian', ' k', 'rone', ' (', 'N', 'OK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Norwegian k rone ( N OK ). The k rone is divided into 100 � � re . False was also flown into Norway towards the end 4 [' was', ' also', ' flown', ' into', ' Norway']
+82 16 The official currency of x -1 The official currency of Norway Krone Norway [' is', ' the', ' Norwegian', ' k', 'rone', ' (', 'N', 'OK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Norwegian k rone ( N OK ). The k rone is divided into 100 � � re . False purchased from Norway at the beginning 2 [' purchased', ' from', ' Norway']
+83 16 The official currency of x -1 The official currency of Norway Krone Norway [' is', ' the', ' Norwegian', ' k', 'rone', ' (', 'N', 'OK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Norwegian k rone ( N OK ). The k rone is divided into 100 � � re . False 1 ['Nor', 'way']
+84 16 The official currency of x -1 The official currency of Norway Krone Norway [' is', ' the', ' Norwegian', ' k', 'rone', ' (', 'N', 'OK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Norwegian k rone ( N OK ). The k rone is divided into 100 � � re . False manager in Stavanger, Norway, who administers 6 [' manager', ' in', ' St', 'av', 'anger', ',', ' Norway']
+85 17 The official currency of x -1 The official currency of Sweden Krona Sweden [' is', ' the', ' Swedish', ' k', 'ron', 'a', ' (', 'SE', 'K', ').', ' The', ' k', 'ron', 'a', ' is', ' divided', ' into', ' 100', ' �', '�'] is the Swedish k ron a ( SE K ). The k ron a is divided into 100 � � False Germany met Sweden in the World 2 [' Germany', ' met', ' Sweden']
+86 17 The official currency of x -1 The official currency of Sweden Krona Sweden [' is', ' the', ' Swedish', ' k', 'ron', 'a', ' (', 'SE', 'K', ').', ' The', ' k', 'ron', 'a', ' is', ' divided', ' into', ' 100', ' �', '�'] is the Swedish k ron a ( SE K ). The k ron a is divided into 100 � � False Government of Sweden was aware 2 [' Government', ' of', ' Sweden']
+87 17 The official currency of x -1 The official currency of Sweden Krona Sweden [' is', ' the', ' Swedish', ' k', 'ron', 'a', ' (', 'SE', 'K', ').', ' The', ' k', 'ron', 'a', ' is', ' divided', ' into', ' 100', ' �', '�'] is the Swedish k ron a ( SE K ). The k ron a is divided into 100 � � False and coal train from Sweden lost its braking 4 [' and', ' coal', ' train', ' from', ' Sweden']
+88 17 The official currency of x -1 The official currency of Sweden Krona Sweden [' is', ' the', ' Swedish', ' k', 'ron', 'a', ' (', 'SE', 'K', ').', ' The', ' k', 'ron', 'a', ' is', ' divided', ' into', ' 100', ' �', '�'] is the Swedish k ron a ( SE K ). The k ron a is divided into 100 � � False " medal beating Sweden 3-0.
+" 2 [' medal', ' beating', ' Sweden']
+89 17 The official currency of x -1 The official currency of Sweden Krona Sweden [' is', ' the', ' Swedish', ' k', 'ron', 'a', ' (', 'SE', 'K', ').', ' The', ' k', 'ron', 'a', ' is', ' divided', ' into', ' 100', ' �', '�'] is the Swedish k ron a ( SE K ). The k ron a is divided into 100 � � False Storglaciären in Sweden has undergone the 7 [' St', 'org', 'l', 'aci', 'ä', 'ren', ' in', ' Sweden']
+90 18 The official currency of x -1 The official currency of Denmark Krone Denmark [' is', ' the', ' Danish', ' k', 'rone', ' (', 'D', 'KK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Danish k rone ( D KK ). The k rone is divided into 100 � � re . False Sweden, Norway, and Denmark (1796), she would 5 [' Sweden', ',', ' Norway', ',', ' and', ' Denmark']
+91 18 The official currency of x -1 The official currency of Denmark Krone Denmark [' is', ' the', ' Danish', ' k', 'rone', ' (', 'D', 'KK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Danish k rone ( D KK ). The k rone is divided into 100 � � re . False – 60) and Denmark – Norway (1657 4 [' –', ' 60', ')', ' and', ' Denmark']
+92 18 The official currency of x -1 The official currency of Denmark Krone Denmark [' is', ' the', ' Danish', ' k', 'rone', ' (', 'D', 'KK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Danish k rone ( D KK ). The k rone is divided into 100 � � re . False the charts in Denmark for three weeks, 3 [' the', ' charts', ' in', ' Denmark']
+93 18 The official currency of x -1 The official currency of Denmark Krone Denmark [' is', ' the', ' Danish', ' k', 'rone', ' (', 'D', 'KK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Danish k rone ( D KK ). The k rone is divided into 100 � � re . False went to Lego's Denmark headquarters and received 4 "[' went', ' to', ' Lego', ""'s"", ' Denmark']"
+94 18 The official currency of x -1 The official currency of Denmark Krone Denmark [' is', ' the', ' Danish', ' k', 'rone', ' (', 'D', 'KK', ').', ' The', ' k', 'rone', ' is', ' divided', ' into', ' 100', ' �', '�', 're', '.'] is the Danish k rone ( D KK ). The k rone is divided into 100 � � re . False coaling at Skagen, Denmark on 7 October. Early 6 [' co', 'aling', ' at', ' Sk', 'agen', ',', ' Denmark']
+95 19 The official currency of x -1 The official currency of Poland Zloty Poland [' is', ' the', ' Polish', ' z', 'lot', 'y', ' (', 'PL', 'N', ').', ' The', ' z', 'lot', 'y', ' is', ' divided', ' into', ' 100', ' g', 'ros'] is the Polish z lot y ( PL N ). The z lot y is divided into 100 g ros False to the Duke of Poland, to whom they were 4 [' to', ' the', ' Duke', ' of', ' Poland']
+96 19 The official currency of x -1 The official currency of Poland Zloty Poland [' is', ' the', ' Polish', ' z', 'lot', 'y', ' (', 'PL', 'N', ').', ' The', ' z', 'lot', 'y', ' is', ' divided', ' into', ' 100', ' g', 'ros'] is the Polish z lot y ( PL N ). The z lot y is divided into 100 g ros False " economy and burials in Poland ===
+" 5 [' economy', ' and', ' bur', 'ials', ' in', ' Poland']
+97 19 The official currency of x -1 The official currency of Poland Zloty Poland [' is', ' the', ' Polish', ' z', 'lot', 'y', ' (', 'PL', 'N', ').', ' The', ' z', 'lot', 'y', ' is', ' divided', ' into', ' 100', ' g', 'ros'] is the Polish z lot y ( PL N ). The z lot y is divided into 100 g ros False demarcation line between Poland and Lithuania on 5 [' dem', 'arc', 'ation', ' line', ' between', ' Poland']
+98 19 The official currency of x -1 The official currency of Poland Zloty Poland [' is', ' the', ' Polish', ' z', 'lot', 'y', ' (', 'PL', 'N', ').', ' The', ' z', 'lot', 'y', ' is', ' divided', ' into', ' 100', ' g', 'ros'] is the Polish z lot y ( PL N ). The z lot y is divided into 100 g ros False disloyalty to Poland during the Soviet 4 [' disl', 'oyal', 'ty', ' to', ' Poland']
+99 19 The official currency of x -1 The official currency of Poland Zloty Poland [' is', ' the', ' Polish', ' z', 'lot', 'y', ' (', 'PL', 'N', ').', ' The', ' z', 'lot', 'y', ' is', ' divided', ' into', ' 100', ' g', 'ros'] is the Polish z lot y ( PL N ). The z lot y is divided into 100 g ros False Minister of Poland who rose out of 2 [' Minister', ' of', ' Poland']
+100 20 The official currency of x -1 The official currency of Hungary Forint Hungary [' is', ' the', ' Hungarian', ' For', 'int', ' (', 'H', 'UF', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' lip', 'é', ' (', 'le', 'i'] is the Hungarian For int ( H UF ). The currency is divided into 100 lip é ( le i True France, Belgium, Hungary and Norway. 4 [' France', ',', ' Belgium', ',', ' Hungary']
+101 20 The official currency of x -1 The official currency of Hungary Forint Hungary [' is', ' the', ' Hungarian', ' For', 'int', ' (', 'H', 'UF', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' lip', 'é', ' (', 'le', 'i'] is the Hungarian For int ( H UF ). The currency is divided into 100 lip é ( le i True Ferdinand as King of Hungary. On 14 February 4 [' Ferdinand', ' as', ' King', ' of', ' Hungary']
+102 20 The official currency of x -1 The official currency of Hungary Forint Hungary [' is', ' the', ' Hungarian', ' For', 'int', ' (', 'H', 'UF', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' lip', 'é', ' (', 'le', 'i'] is the Hungarian For int ( H UF ). The currency is divided into 100 lip é ( le i True common borders of Hungary and the Holy 3 [' common', ' borders', ' of', ' Hungary']
+103 20 The official currency of x -1 The official currency of Hungary Forint Hungary [' is', ' the', ' Hungarian', ' For', 'int', ' (', 'H', 'UF', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' lip', 'é', ' (', 'le', 'i'] is the Hungarian For int ( H UF ). The currency is divided into 100 lip é ( le i True made to induce Hungary and Bulgaria 3 [' made', ' to', ' induce', ' Hungary']
+104 20 The official currency of x -1 The official currency of Hungary Forint Hungary [' is', ' the', ' Hungarian', ' For', 'int', ' (', 'H', 'UF', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' lip', 'é', ' (', 'le', 'i'] is the Hungarian For int ( H UF ). The currency is divided into 100 lip é ( le i True " Mary crowned ""King"" of Hungary only seven days" 6 "[' Mary', ' crowned', ' ""', 'King', '""', ' of', ' Hungary']"
+105 21 The official currency of x -1 The official currency of Czech Republic Koruna Czech Republic [' is', ' the', ' Czech', ' Crown', ' (', 'K', 'or', 'una', ').', ' The', ' Czech', ' Crown', ' is', ' divided', ' into', ' 100', ' K', 'č', '.', ' The'] is the Czech Crown ( K or una ). The Czech Crown is divided into 100 K č . The True games in Prague, Czech Republic and Stockholm, 5 [' games', ' in', ' Prague', ',', ' Czech', ' Republic']
+106 21 The official currency of x -1 The official currency of Czech Republic Koruna Czech Republic [' is', ' the', ' Czech', ' Crown', ' (', 'K', 'or', 'una', ').', ' The', ' Czech', ' Crown', ' is', ' divided', ' into', ' 100', ' K', 'č', '.', ' The'] is the Czech Crown ( K or una ). The Czech Crown is divided into 100 K č . The True Norway, Slovakia, the Czech Republic and Belgium (Wallonia). 6 [' Norway', ',', ' Slovakia', ',', ' the', ' Czech', ' Republic']
+107 21 The official currency of x -1 The official currency of Czech Republic Koruna Czech Republic [' is', ' the', ' Czech', ' Crown', ' (', 'K', 'or', 'una', ').', ' The', ' Czech', ' Crown', ' is', ' divided', ' into', ' 100', ' K', 'č', '.', ' The'] is the Czech Crown ( K or una ). The Czech Crown is divided into 100 K č . The True Norway, Slovakia, the Czech Republic and Belgium 6 [' Norway', ',', ' Slovakia', ',', ' the', ' Czech', ' Republic']
+108 21 The official currency of x -1 The official currency of Czech Republic Koruna Czech Republic [' is', ' the', ' Czech', ' Crown', ' (', 'K', 'or', 'una', ').', ' The', ' Czech', ' Crown', ' is', ' divided', ' into', ' 100', ' K', 'č', '.', ' The'] is the Czech Crown ( K or una ). The Czech Crown is divided into 100 K č . The True Grimm entirely in the Czech Republic over budget constraints. 5 [' Grimm', ' entirely', ' in', ' the', ' Czech', ' Republic']
+109 21 The official currency of x -1 The official currency of Czech Republic Koruna Czech Republic [' is', ' the', ' Czech', ' Crown', ' (', 'K', 'or', 'una', ').', ' The', ' Czech', ' Crown', ' is', ' divided', ' into', ' 100', ' K', 'č', '.', ' The'] is the Czech Crown ( K or una ). The Czech Crown is divided into 100 K č . The True filmed in the Czech Republic because Sheffield's 4 [' filmed', ' in', ' the', ' Czech', ' Republic']
+110 22 The official currency of x -1 The official currency of Israel Shekel Israel [' is', ' the', ' Israeli', ' She', 'kel', ' (', 'ILS', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' N', 'IS', '.', ' The', ' exchange', ' rate'] is the Israeli She kel ( ILS ). The currency is divided into 100 N IS . The exchange rate True the State of Israel, as enunciated by 3 [' the', ' State', ' of', ' Israel']
+111 22 The official currency of x -1 The official currency of Israel Shekel Israel [' is', ' the', ' Israeli', ' She', 'kel', ' (', 'ILS', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' N', 'IS', '.', ' The', ' exchange', ' rate'] is the Israeli She kel ( ILS ). The currency is divided into 100 N IS . The exchange rate True government refusal, Israeli Prime Minister 3 [' government', ' refusal', ',', ' Israel']
+112 22 The official currency of x -1 The official currency of Israel Shekel Israel [' is', ' the', ' Israeli', ' She', 'kel', ' (', 'ILS', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' N', 'IS', '.', ' The', ' exchange', ' rate'] is the Israeli She kel ( ILS ). The currency is divided into 100 N IS . The exchange rate True the children of Israel by their fathers' houses, 3 [' the', ' children', ' of', ' Israel']
+113 22 The official currency of x -1 The official currency of Israel Shekel Israel [' is', ' the', ' Israeli', ' She', 'kel', ' (', 'ILS', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' N', 'IS', '.', ' The', ' exchange', ' rate'] is the Israeli She kel ( ILS ). The currency is divided into 100 N IS . The exchange rate True hospital. When Israel withdrew from the 3 [' hospital', '.', ' When', ' Israel']
+114 22 The official currency of x -1 The official currency of Israel Shekel Israel [' is', ' the', ' Israeli', ' She', 'kel', ' (', 'ILS', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' N', 'IS', '.', ' The', ' exchange', ' rate'] is the Israeli She kel ( ILS ). The currency is divided into 100 N IS . The exchange rate True the independence of Israel and other free nations. 3 [' the', ' independence', ' of', ' Israel']
+115 23 The official currency of x -1 The official currency of Saudi Arabia Riyal Saudi Arabia [' is', ' the', ' Riy', 'al', '.', ' The', ' Riy', 'al', ' is', ' divided', ' into', ' 100', ' Nak', 'fs', '.', ' The', ' Nak', 'fs', ' are', ' divided'] is the Riy al . The Riy al is divided into 100 Nak fs . The Nak fs are divided True with Egypt and Saudi Arabia to confront 4 [' with', ' Egypt', ' and', ' Saudi', ' Arabia']
+116 23 The official currency of x -1 The official currency of Saudi Arabia Riyal Saudi Arabia [' is', ' the', ' Riy', 'al', '.', ' The', ' Riy', 'al', ' is', ' divided', ' into', ' 100', ' Nak', 'fs', '.', ' The', ' Nak', 'fs', ' are', ' divided'] is the Riy al . The Riy al is divided into 100 Nak fs . The Nak fs are divided True allied with Saudi Arabia and Sunni 3 [' allied', ' with', ' Saudi', ' Arabia']
+117 23 The official currency of x -1 The official currency of Saudi Arabia Riyal Saudi Arabia [' is', ' the', ' Riy', 'al', '.', ' The', ' Riy', 'al', ' is', ' divided', ' into', ' 100', ' Nak', 'fs', '.', ' The', ' Nak', 'fs', ' are', ' divided'] is the Riy al . The Riy al is divided into 100 Nak fs . The Nak fs are divided True attack in Aqaba. In Saudi Arabia two women were 8 [' attack', ' in', ' A', 'q', 'aba', '.', ' In', ' Saudi', ' Arabia']
+118 23 The official currency of x -1 The official currency of Saudi Arabia Riyal Saudi Arabia [' is', ' the', ' Riy', 'al', '.', ' The', ' Riy', 'al', ' is', ' divided', ' into', ' 100', ' Nak', 'fs', '.', ' The', ' Nak', 'fs', ' are', ' divided'] is the Riy al . The Riy al is divided into 100 Nak fs . The Nak fs are divided True international organizations. Saudi Arabia was the largest aid 4 [' international', ' organizations', '.', ' Saudi', ' Arabia']
+119 23 The official currency of x -1 The official currency of Saudi Arabia Riyal Saudi Arabia [' is', ' the', ' Riy', 'al', '.', ' The', ' Riy', 'al', ' is', ' divided', ' into', ' 100', ' Nak', 'fs', '.', ' The', ' Nak', 'fs', ' are', ' divided'] is the Riy al . The Riy al is divided into 100 Nak fs . The Nak fs are divided True storm, neighboring Saudi Arabia sent assistance 4 [' storm', ',', ' neighboring', ' Saudi', ' Arabia']
+120 24 The official currency of x -1 The official currency of United Arab Emirates Dirham United Arab Emirates [' is', ' the', ' dir', 'ham', ' (', 'D', 'h', ').', ' The', ' dir', 'ham', ' is', ' divided', ' into', ' 100', ' fil', 's', '.', ' The', ' fil'] is the dir ham ( D h ). The dir ham is divided into 100 fil s . The fil False personnel to the United Arab Emirates in preparation 5 [' personnel', ' to', ' the', ' United', ' Arab', ' Emirates']
+121 24 The official currency of x -1 The official currency of United Arab Emirates Dirham United Arab Emirates [' is', ' the', ' dir', 'ham', ' (', 'D', 'h', ').', ' The', ' dir', 'ham', ' is', ' divided', ' into', ' 100', ' fil', 's', '.', ' The', ' fil'] is the dir ham ( D h ). The dir ham is divided into 100 fil s . The fil False supplies from the United Arab Emirates. Relief distribution 5 [' supplies', ' from', ' the', ' United', ' Arab', ' Emirates']
+122 24 The official currency of x -1 The official currency of United Arab Emirates Dirham United Arab Emirates [' is', ' the', ' dir', 'ham', ' (', 'D', 'h', ').', ' The', ' dir', 'ham', ' is', ' divided', ' into', ' 100', ' fil', 's', '.', ' The', ' fil'] is the dir ham ( D h ). The dir ham is divided into 100 fil s . The fil False debuted in the United Arab Emirates on April 9, 2015. 5 [' debuted', ' in', ' the', ' United', ' Arab', ' Emirates']
+123 24 The official currency of x -1 The official currency of United Arab Emirates Dirham United Arab Emirates [' is', ' the', ' dir', 'ham', ' (', 'D', 'h', ').', ' The', ' dir', 'ham', ' is', ' divided', ' into', ' 100', ' fil', 's', '.', ' The', ' fil'] is the dir ham ( D h ). The dir ham is divided into 100 fil s . The fil False protected. The United Arab Emirates has banned all hunting 5 [' protected', '.', ' The', ' United', ' Arab', ' Emirates']
+124 24 The official currency of x -1 The official currency of United Arab Emirates Dirham United Arab Emirates [' is', ' the', ' dir', 'ham', ' (', 'D', 'h', ').', ' The', ' dir', 'ham', ' is', ' divided', ' into', ' 100', ' fil', 's', '.', ' The', ' fil'] is the dir ham ( D h ). The dir ham is divided into 100 fil s . The fil False Air Base in the United Arab Emirates to support 6 [' Air', ' Base', ' in', ' the', ' United', ' Arab', ' Emirates']
+125 25 The official currency of x -1 The official currency of Singapore Dollar Singapore [' is', ' the', ' Singapore', ' Dollar', '.', ' The', ' Singapore', ' Dollar', ' is', ' pegged', ' to', ' the', ' US', ' Dollar', ' at', ' a', ' rate', ' of', ' 1', ' USD'] is the Singapore Dollar . The Singapore Dollar is pegged to the US Dollar at a rate of 1 USD True 1 ['Sing', 'apore']
+126 25 The official currency of x -1 The official currency of Singapore Dollar Singapore [' is', ' the', ' Singapore', ' Dollar', '.', ' The', ' Singapore', ' Dollar', ' is', ' pegged', ' to', ' the', ' US', ' Dollar', ' at', ' a', ' rate', ' of', ' 1', ' USD'] is the Singapore Dollar . The Singapore Dollar is pegged to the US Dollar at a rate of 1 USD True 1 ['Sing', 'apore']
+127 25 The official currency of x -1 The official currency of Singapore Dollar Singapore [' is', ' the', ' Singapore', ' Dollar', '.', ' The', ' Singapore', ' Dollar', ' is', ' pegged', ' to', ' the', ' US', ' Dollar', ' at', ' a', ' rate', ' of', ' 1', ' USD'] is the Singapore Dollar . The Singapore Dollar is pegged to the US Dollar at a rate of 1 USD True his fortune in Singapore and worked illegally 3 [' his', ' fortune', ' in', ' Singapore']
+128 25 The official currency of x -1 The official currency of Singapore Dollar Singapore [' is', ' the', ' Singapore', ' Dollar', '.', ' The', ' Singapore', ' Dollar', ' is', ' pegged', ' to', ' the', ' US', ' Dollar', ' at', ' a', ' rate', ' of', ' 1', ' USD'] is the Singapore Dollar . The Singapore Dollar is pegged to the US Dollar at a rate of 1 USD True 1 ['Sing', 'apore']
+129 25 The official currency of x -1 The official currency of Singapore Dollar Singapore [' is', ' the', ' Singapore', ' Dollar', '.', ' The', ' Singapore', ' Dollar', ' is', ' pegged', ' to', ' the', ' US', ' Dollar', ' at', ' a', ' rate', ' of', ' 1', ' USD'] is the Singapore Dollar . The Singapore Dollar is pegged to the US Dollar at a rate of 1 USD True 1 ['Sing', 'apore']
+130 26 The official currency of x -1 The official currency of Malaysia Ringgit Malaysia [' is', ' the', ' Malaysian', ' Ring', 'git', ' (', 'MY', 'R', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate'] is the Malaysian Ring git ( MY R ). The currency is divided into 100 sen . The exchange rate True for the 2008 Malaysian Games (Malay: 3 [' for', ' the', ' 2008', ' Malaysia']
+131 26 The official currency of x -1 The official currency of Malaysia Ringgit Malaysia [' is', ' the', ' Malaysian', ' Ring', 'git', ' (', 'MY', 'R', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate'] is the Malaysian Ring git ( MY R ). The currency is divided into 100 sen . The exchange rate True as Tamil Nesan, Malaysia Nanban and Makkal 5 [' as', ' Tamil', ' N', 'esan', ',', ' Malaysia']
+132 26 The official currency of x -1 The official currency of Malaysia Ringgit Malaysia [' is', ' the', ' Malaysian', ' Ring', 'git', ' (', 'MY', 'R', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate'] is the Malaysian Ring git ( MY R ). The currency is divided into 100 sen . The exchange rate True released in South Asia, Malaysia and Hong Kong, 5 [' released', ' in', ' South', ' Asia', ',', ' Malaysia']
+133 26 The official currency of x -1 The official currency of Malaysia Ringgit Malaysia [' is', ' the', ' Malaysian', ' Ring', 'git', ' (', 'MY', 'R', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate'] is the Malaysian Ring git ( MY R ). The currency is divided into 100 sen . The exchange rate True the crash of Malaysia Airlines Flight 3 [' the', ' crash', ' of', ' Malaysia']
+134 26 The official currency of x -1 The official currency of Malaysia Ringgit Malaysia [' is', ' the', ' Malaysian', ' Ring', 'git', ' (', 'MY', 'R', ').', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' exchange', ' rate'] is the Malaysian Ring git ( MY R ). The currency is divided into 100 sen . The exchange rate True Minister of Malaysia, Mahathir Mohamad 2 [' Minister', ' of', ' Malaysia']
+135 27 The official currency of x -1 The official currency of Indonesia Rupiah Indonesia [' is', ' the', ' Indonesian', ' R', 'up', 'iah', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' currency', ' is', ' divided', ' into'] is the Indonesian R up iah . The currency is divided into 100 sen . The currency is divided into True " at Sinematek Indonesia in Jakarta.
+" 5 [' at', ' S', 'inem', 'ate', 'k', ' Indonesia']
+136 27 The official currency of x -1 The official currency of Indonesia Rupiah Indonesia [' is', ' the', ' Indonesian', ' R', 'up', 'iah', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' currency', ' is', ' divided', ' into'] is the Indonesian R up iah . The currency is divided into 100 sen . The currency is divided into True Japanese-occupied Indonesia in advance of Allied 3 [' Japanese', '-', 'occupied', ' Indonesia']
+137 27 The official currency of x -1 The official currency of Indonesia Rupiah Indonesia [' is', ' the', ' Indonesian', ' R', 'up', 'iah', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' currency', ' is', ' divided', ' into'] is the Indonesian R up iah . The currency is divided into 100 sen . The currency is divided into True Konfrontasi between Indonesia and Malaysia, 4 [' Kon', 'front', 'asi', ' between', ' Indonesia']
+138 27 The official currency of x -1 The official currency of Indonesia Rupiah Indonesia [' is', ' the', ' Indonesian', ' R', 'up', 'iah', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' currency', ' is', ' divided', ' into'] is the Indonesian R up iah . The currency is divided into 100 sen . The currency is divided into True number of Indonesian aviation accidents 2 [' number', ' of', ' Indonesia']
+139 27 The official currency of x -1 The official currency of Indonesia Rupiah Indonesia [' is', ' the', ' Indonesian', ' R', 'up', 'iah', '.', ' The', ' currency', ' is', ' divided', ' into', ' 100', ' sen', '.', ' The', ' currency', ' is', ' divided', ' into'] is the Indonesian R up iah . The currency is divided into 100 sen . The currency is divided into True including Australia, Indonesia and Vietnam. Apsaras 3 [' including', ' Australia', ',', ' Indonesia']
+140 28 The official currency of x -1 The official currency of Thailand Baht Thailand [' is', ' the', ' Thai', ' b', 'ah', 't', '.', ' The', ' Thai', ' b', 'ah', 't', ' is', ' divided', ' into', ' 100', ' sat', 'ang', '.', '\n'] " is the Thai b ah t . The Thai b ah t is divided into 100 sat ang .
+" False in northeastern Thailand operating under 2 [' in', ' northeastern', ' Thailand']
+141 28 The official currency of x -1 The official currency of Thailand Baht Thailand [' is', ' the', ' Thai', ' b', 'ah', 't', '.', ' The', ' Thai', ' b', 'ah', 't', ' is', ' divided', ' into', ' 100', ' sat', 'ang', '.', '\n'] " is the Thai b ah t . The Thai b ah t is divided into 100 sat ang .
+" False parts of northern Thailand and occurred 3 [' parts', ' of', ' northern', ' Thailand']
+142 28 The official currency of x -1 The official currency of Thailand Baht Thailand [' is', ' the', ' Thai', ' b', 'ah', 't', '.', ' The', ' Thai', ' b', 'ah', 't', ' is', ' divided', ' into', ' 100', ' sat', 'ang', '.', '\n'] " is the Thai b ah t . The Thai b ah t is divided into 100 sat ang .
+" False chair Airports of Thailand and TOT, he purged 4 [' chair', ' Air', 'ports', ' of', ' Thailand']
+143 28 The official currency of x -1 The official currency of Thailand Baht Thailand [' is', ' the', ' Thai', ' b', 'ah', 't', '.', ' The', ' Thai', ' b', 'ah', 't', ' is', ' divided', ' into', ' 100', ' sat', 'ang', '.', '\n'] " is the Thai b ah t . The Thai b ah t is divided into 100 sat ang .
+" False teaching job in Thailand and he leaves 3 [' teaching', ' job', ' in', ' Thailand']
+144 28 The official currency of x -1 The official currency of Thailand Baht Thailand [' is', ' the', ' Thai', ' b', 'ah', 't', '.', ' The', ' Thai', ' b', 'ah', 't', ' is', ' divided', ' into', ' 100', ' sat', 'ang', '.', '\n'] " is the Thai b ah t . The Thai b ah t is divided into 100 sat ang .
+" False Airports of Thailand (AoT) and 3 [' Air', 'ports', ' of', ' Thailand']
+145 29 The official currency of x -1 The official currency of Philippines Peso Philippines [' is', ' the', ' Pes', 'o', '.', ' The', ' official', ' language', ' is', ' Filipino', '.', ' The', ' official', ' religion', ' is', ' Roman', ' Catholicism', '.', ' The', ' official'] is the Pes o . The official language is Filipino . The official religion is Roman Catholicism . The official True (颱風龍王), known in the Philippines as Typhoon Maring, 13 [' (', '�', '�', '�', '�', '�', '�', '龍', '王', '),', ' known', ' in', ' the', ' Philippines']
+146 29 The official currency of x -1 The official currency of Philippines Peso Philippines [' is', ' the', ' Pes', 'o', '.', ' The', ' official', ' language', ' is', ' Filipino', '.', ' The', ' official', ' religion', ' is', ' Roman', ' Catholicism', '.', ' The', ' official'] is the Pes o . The official language is Filipino . The official religion is Roman Catholicism . The official True Museum of the Philippines and partly in 3 [' Museum', ' of', ' the', ' Philippines']
+147 29 The official currency of x -1 The official currency of Philippines Peso Philippines [' is', ' the', ' Pes', 'o', '.', ' The', ' official', ' language', ' is', ' Filipino', '.', ' The', ' official', ' religion', ' is', ' Roman', ' Catholicism', '.', ' The', ' official'] is the Pes o . The official language is Filipino . The official religion is Roman Catholicism . The official True Atlantic, and the Philippines and Guam in 4 [' Atlantic', ',', ' and', ' the', ' Philippines']
+148 29 The official currency of x -1 The official currency of Philippines Peso Philippines [' is', ' the', ' Pes', 'o', '.', ' The', ' official', ' language', ' is', ' Filipino', '.', ' The', ' official', ' religion', ' is', ' Roman', ' Catholicism', '.', ' The', ' official'] is the Pes o . The official language is Filipino . The official religion is Roman Catholicism . The official True " Philippines ===
+" 0 [' Philippines']
+149 29 The official currency of x -1 The official currency of Philippines Peso Philippines [' is', ' the', ' Pes', 'o', '.', ' The', ' official', ' language', ' is', ' Filipino', '.', ' The', ' official', ' religion', ' is', ' Roman', ' Catholicism', '.', ' The', ' official'] is the Pes o . The official language is Filipino . The official religion is Roman Catholicism . The official True " Philippines ===
+" 0 [' Philippines']
diff --git a/patchscopes/code/preprocessed_data/factual/country_largest_city.tsv b/patchscopes/code/preprocessed_data/factual/country_largest_city.tsv
new file mode 100644
index 00000000..426af773
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/country_largest_city.tsv
@@ -0,0 +1,131 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The largest city in x -1 The largest city in United States New York City United States [',', ' New', ' York', ' is', ' a', ' major', ' center', ' of', ' commerce', ',', ' finance', ',', ' media', ',', ' fashion', ',', ' education', ',', ' and', ' culture'] , New York is a major center of commerce , finance , media , fashion , education , and culture False upset loss to the United States at the 1980 Winter 5 [' upset', ' loss', ' to', ' the', ' United', ' States']
+1 0 The largest city in x -1 The largest city in United States New York City United States [',', ' New', ' York', ' is', ' a', ' major', ' center', ' of', ' commerce', ',', ' finance', ',', ' media', ',', ' fashion', ',', ' education', ',', ' and', ' culture'] , New York is a major center of commerce , finance , media , fashion , education , and culture False cooperation with the United States Navy, to bombard 4 [' cooperation', ' with', ' the', ' United', ' States']
+2 0 The largest city in x -1 The largest city in United States New York City United States [',', ' New', ' York', ' is', ' a', ' major', ' center', ' of', ' commerce', ',', ' finance', ',', ' media', ',', ' fashion', ',', ' education', ',', ' and', ' culture'] , New York is a major center of commerce , finance , media , fashion , education , and culture False supporting tours of the United States while traveling internationally 5 [' supporting', ' tours', ' of', ' the', ' United', ' States']
+3 0 The largest city in x -1 The largest city in United States New York City United States [',', ' New', ' York', ' is', ' a', ' major', ' center', ' of', ' commerce', ',', ' finance', ',', ' media', ',', ' fashion', ',', ' education', ',', ' and', ' culture'] , New York is a major center of commerce , finance , media , fashion , education , and culture False the Southern United States, and William P. 3 [' the', ' Southern', ' United', ' States']
+4 0 The largest city in x -1 The largest city in United States New York City United States [',', ' New', ' York', ' is', ' a', ' major', ' center', ' of', ' commerce', ',', ' finance', ',', ' media', ',', ' fashion', ',', ' education', ',', ' and', ' culture'] , New York is a major center of commerce , finance , media , fashion , education , and culture False 1 ['United', ' States']
+5 1 The largest city in x -1 The largest city in China Shanghai China [',', ' Shanghai', ' is', ' a', ' major', ' international', ' financial', ' center', ' and', ' a', ' major', ' manufacturing', ' center', '.', ' It', ' is', ' also', ' a', ' major', ' tourist'] , Shanghai is a major international financial center and a major manufacturing center . It is also a major tourist True misconceptions about China as well as criticising 2 [' misconceptions', ' about', ' China']
+6 1 The largest city in x -1 The largest city in China Shanghai China [',', ' Shanghai', ' is', ' a', ' major', ' international', ' financial', ' center', ' and', ' a', ' major', ' manufacturing', ' center', '.', ' It', ' is', ' also', ' a', ' major', ' tourist'] , Shanghai is a major international financial center and a major manufacturing center . It is also a major tourist True entering the South China Sea, the JTWC 3 [' entering', ' the', ' South', ' China']
+7 1 The largest city in x -1 The largest city in China Shanghai China [',', ' Shanghai', ' is', ' a', ' major', ' international', ' financial', ' center', ' and', ' a', ' major', ' manufacturing', ' center', '.', ' It', ' is', ' also', ' a', ' major', ' tourist'] , Shanghai is a major international financial center and a major manufacturing center . It is also a major tourist True agencies, Malaysia, and China to plan and carry 5 [' agencies', ',', ' Malaysia', ',', ' and', ' China']
+8 1 The largest city in x -1 The largest city in China Shanghai China [',', ' Shanghai', ' is', ' a', ' major', ' international', ' financial', ' center', ' and', ' a', ' major', ' manufacturing', ' center', '.', ' It', ' is', ' also', ' a', ' major', ' tourist'] , Shanghai is a major international financial center and a major manufacturing center . It is also a major tourist True missionaries in China, but Johannes Kepler 2 [' missionaries', ' in', ' China']
+9 1 The largest city in x -1 The largest city in China Shanghai China [',', ' Shanghai', ' is', ' a', ' major', ' international', ' financial', ' center', ' and', ' a', ' major', ' manufacturing', ' center', '.', ' It', ' is', ' also', ' a', ' major', ' tourist'] , Shanghai is a major international financial center and a major manufacturing center . It is also a major tourist True Bohlinia entered China and northern 4 [' Boh', 'lin', 'ia', ' entered', ' China']
+10 2 The largest city in x -1 The largest city in Japan Tokyo Japan [' is', ' Tokyo', ',', ' and', ' it', ' is', ' the', ' capital', ' of', ' Japan', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ','] is Tokyo , and it is the capital of Japan . It is the most populous city in Japan , True Switzerland and Japan to push down 2 [' Switzerland', ' and', ' Japan']
+11 2 The largest city in x -1 The largest city in Japan Tokyo Japan [' is', ' Tokyo', ',', ' and', ' it', ' is', ' the', ' capital', ' of', ' Japan', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ','] is Tokyo , and it is the capital of Japan . It is the most populous city in Japan , True cooperation with the Japan Aerospace Exploration 3 [' cooperation', ' with', ' the', ' Japan']
+12 2 The largest city in x -1 The largest city in Japan Tokyo Japan [' is', ' Tokyo', ',', ' and', ' it', ' is', ' the', ' capital', ' of', ' Japan', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ','] is Tokyo , and it is the capital of Japan . It is the most populous city in Japan , True reached # 2 on the Japan Hot 100 chart 5 [' reached', ' #', ' 2', ' on', ' the', ' Japan']
+13 2 The largest city in x -1 The largest city in Japan Tokyo Japan [' is', ' Tokyo', ',', ' and', ' it', ' is', ' the', ' capital', ' of', ' Japan', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ','] is Tokyo , and it is the capital of Japan . It is the most populous city in Japan , True Following the surrender of Japan in August 1945, 4 [' Following', ' the', ' surrender', ' of', ' Japan']
+14 2 The largest city in x -1 The largest city in Japan Tokyo Japan [' is', ' Tokyo', ',', ' and', ' it', ' is', ' the', ' capital', ' of', ' Japan', '.', ' It', ' is', ' the', ' most', ' populous', ' city', ' in', ' Japan', ','] is Tokyo , and it is the capital of Japan . It is the most populous city in Japan , True " ==== Arguments from Japan ====
+" 5 [' =', '===', ' Arg', 'uments', ' from', ' Japan']
+15 3 The largest city in x -1 The largest city in Russia Moscow Russia [',', ' Moscow', ' is', ' a', ' major', ' cultural', ',', ' political', ',', ' and', ' economic', ' center', ' of', ' Russia', '.', ' It', ' is', ' the', ' capital', ' of'] , Moscow is a major cultural , political , and economic center of Russia . It is the capital of True 181 points, behind Russia (194 points) 4 [' 181', ' points', ',', ' behind', ' Russia']
+16 3 The largest city in x -1 The largest city in Russia Moscow Russia [',', ' Moscow', ' is', ' a', ' major', ' cultural', ',', ' political', ',', ' and', ' economic', ' center', ' of', ' Russia', '.', ' It', ' is', ' the', ' capital', ' of'] , Moscow is a major cultural , political , and economic center of Russia . It is the capital of True suggested that Russia was too big 2 [' suggested', ' that', ' Russia']
+17 3 The largest city in x -1 The largest city in Russia Moscow Russia [',', ' Moscow', ' is', ' a', ' major', ' cultural', ',', ' political', ',', ' and', ' economic', ' center', ' of', ' Russia', '.', ' It', ' is', ' the', ' capital', ' of'] , Moscow is a major cultural , political , and economic center of Russia . It is the capital of True total of 372 points. Russia finished in second 5 [' total', ' of', ' 372', ' points', '.', ' Russia']
+18 3 The largest city in x -1 The largest city in Russia Moscow Russia [',', ' Moscow', ' is', ' a', ' major', ' cultural', ',', ' political', ',', ' and', ' economic', ' center', ' of', ' Russia', '.', ' It', ' is', ' the', ' capital', ' of'] , Moscow is a major cultural , political , and economic center of Russia . It is the capital of True symphonies in Europe, Russia and America, the 6 [' sym', 'ph', 'onies', ' in', ' Europe', ',', ' Russia']
+19 3 The largest city in x -1 The largest city in Russia Moscow Russia [',', ' Moscow', ' is', ' a', ' major', ' cultural', ',', ' political', ',', ' and', ' economic', ' center', ' of', ' Russia', '.', ' It', ' is', ' the', ' capital', ' of'] , Moscow is a major cultural , political , and economic center of Russia . It is the capital of True " publicity"". A Russia vs Germany match" 3 "[' publicity', '"".', ' A', ' Russia']"
+20 4 The largest city in x -1 The largest city in India Mumbai India [',', ' Mumbai', ' is', ' a', ' major', ' financial', ' and', ' commercial', ' center', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' India', ','] , Mumbai is a major financial and commercial center . It is also the most populous city in India , True The situation in India continued to 3 [' The', ' situation', ' in', ' India']
+21 4 The largest city in x -1 The largest city in India Mumbai India [',', ' Mumbai', ' is', ' a', ' major', ' financial', ' and', ' commercial', ' center', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' India', ','] , Mumbai is a major financial and commercial center . It is also the most populous city in India , True " on the Parliament of India on 13 December 2001.
+" 4 [' on', ' the', ' Parliament', ' of', ' India']
+22 4 The largest city in x -1 The largest city in India Mumbai India [',', ' Mumbai', ' is', ' a', ' major', ' financial', ' and', ' commercial', ' center', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' India', ','] , Mumbai is a major financial and commercial center . It is also the most populous city in India , True sub-Saharan Africa and in India (where an endangered 6 [' sub', '-', 'Saharan', ' Africa', ' and', ' in', ' India']
+23 4 The largest city in x -1 The largest city in India Mumbai India [',', ' Mumbai', ' is', ' a', ' major', ' financial', ' and', ' commercial', ' center', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' India', ','] , Mumbai is a major financial and commercial center . It is also the most populous city in India , True few actors in India who garners pan-Indian 3 [' few', ' actors', ' in', ' India']
+24 4 The largest city in x -1 The largest city in India Mumbai India [',', ' Mumbai', ' is', ' a', ' major', ' financial', ' and', ' commercial', ' center', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' India', ','] , Mumbai is a major financial and commercial center . It is also the most populous city in India , True was posted to both India and Java. After the 4 [' was', ' posted', ' to', ' both', ' India']
+25 5 The largest city in x -1 The largest city in Brazil São Paulo Brazil [',', ' S', 'ão', ' Paulo', ' is', ' the', ' capital', ' of', ' the', ' state', ' of', ' S', 'ão', ' Paulo', '.', ' It', ' is', ' the', ' most', ' populous'] , S ão Paulo is the capital of the state of S ão Paulo . It is the most populous True entered the Brazilian market 2 [' entered', ' the', ' Brazil']
+26 5 The largest city in x -1 The largest city in Brazil São Paulo Brazil [',', ' S', 'ão', ' Paulo', ' is', ' the', ' capital', ' of', ' the', ' state', ' of', ' S', 'ão', ' Paulo', '.', ' It', ' is', ' the', ' most', ' populous'] , S ão Paulo is the capital of the state of S ão Paulo . It is the most populous True Americas, including Brazil and the United 3 [' Americas', ',', ' including', ' Brazil']
+27 5 The largest city in x -1 The largest city in Brazil São Paulo Brazil [',', ' S', 'ão', ' Paulo', ' is', ' the', ' capital', ' of', ' the', ' state', ' of', ' S', 'ão', ' Paulo', '.', ' It', ' is', ' the', ' most', ' populous'] , S ão Paulo is the capital of the state of S ão Paulo . It is the most populous True the game against Brazil and two goals 3 [' the', ' game', ' against', ' Brazil']
+28 5 The largest city in x -1 The largest city in Brazil São Paulo Brazil [',', ' S', 'ão', ' Paulo', ' is', ' the', ' capital', ' of', ' the', ' state', ' of', ' S', 'ão', ' Paulo', '.', ' It', ' is', ' the', ' most', ' populous'] , S ão Paulo is the capital of the state of S ão Paulo . It is the most populous True while directing Brazil toward policies 2 [' while', ' directing', ' Brazil']
+29 5 The largest city in x -1 The largest city in Brazil São Paulo Brazil [',', ' S', 'ão', ' Paulo', ' is', ' the', ' capital', ' of', ' the', ' state', ' of', ' S', 'ão', ' Paulo', '.', ' It', ' is', ' the', ' most', ' populous'] , S ão Paulo is the capital of the state of S ão Paulo . It is the most populous True plants operating in Brazil by July 2008, 126 dedicated 3 [' plants', ' operating', ' in', ' Brazil']
+30 6 The largest city in x -1 The largest city in Australia Sydney Australia [' is', ' Sydney', ',', ' which', ' is', ' located', ' in', ' the', ' state', ' of', ' New', ' South', ' Wales', '.', ' Sydney', ' is', ' the', ' most', ' populous', ' city'] is Sydney , which is located in the state of New South Wales . Sydney is the most populous city True Grainger left Australia at the age of 13 3 [' Gra', 'inger', ' left', ' Australia']
+31 6 The largest city in x -1 The largest city in Australia Sydney Australia [' is', ' Sydney', ',', ' which', ' is', ' located', ' in', ' the', ' state', ' of', ' New', ' South', ' Wales', '.', ' Sydney', ' is', ' the', ' most', ' populous', ' city'] is Sydney , which is located in the state of New South Wales . Sydney is the most populous city True 115 minutes. Australia declared at 5 / 549 3 [' 115', ' minutes', '.', ' Australia']
+32 6 The largest city in x -1 The largest city in Australia Sydney Australia [' is', ' Sydney', ',', ' which', ' is', ' located', ' in', ' the', ' state', ' of', ' New', ' South', ' Wales', '.', ' Sydney', ' is', ' the', ' most', ' populous', ' city'] is Sydney , which is located in the state of New South Wales . Sydney is the most populous city True bring a team to Australia during the English 4 [' bring', ' a', ' team', ' to', ' Australia']
+33 6 The largest city in x -1 The largest city in Australia Sydney Australia [' is', ' Sydney', ',', ' which', ' is', ' located', ' in', ' the', ' state', ' of', ' New', ' South', ' Wales', '.', ' Sydney', ' is', ' the', ' most', ' populous', ' city'] is Sydney , which is located in the state of New South Wales . Sydney is the most populous city True States, Canada, and Australia — peaking at number 5 [' States', ',', ' Canada', ',', ' and', ' Australia']
+34 6 The largest city in x -1 The largest city in Australia Sydney Australia [' is', ' Sydney', ',', ' which', ' is', ' located', ' in', ' the', ' state', ' of', ' New', ' South', ' Wales', '.', ' Sydney', ' is', ' the', ' most', ' populous', ' city'] is Sydney , which is located in the state of New South Wales . Sydney is the most populous city True missionary in Western Australia in country towns. 3 [' missionary', ' in', ' Western', ' Australia']
+35 7 The largest city in x -1 The largest city in Canada Toronto Canada [',', ' Toronto', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Ontario', ' and', ' the', ' largest', ' city', ' in'] , Toronto is a great place to visit . It is the capital of Ontario and the largest city in True Queen's Own Rifles of Canada — took heavy casualties 6 "[' Queen', ""'s"", ' Own', ' R', 'ifles', ' of', ' Canada']"
+36 7 The largest city in x -1 The largest city in Canada Toronto Canada [',', ' Toronto', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Ontario', ' and', ' the', ' largest', ' city', ' in'] , Toronto is a great place to visit . It is the capital of Ontario and the largest city in True organizations of Canada ran a full-page 2 [' organizations', ' of', ' Canada']
+37 7 The largest city in x -1 The largest city in Canada Toronto Canada [',', ' Toronto', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Ontario', ' and', ' the', ' largest', ' city', ' in'] , Toronto is a great place to visit . It is the capital of Ontario and the largest city in True 0 ['Canada']
+38 7 The largest city in x -1 The largest city in Canada Toronto Canada [',', ' Toronto', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Ontario', ' and', ' the', ' largest', ' city', ' in'] , Toronto is a great place to visit . It is the capital of Ontario and the largest city in True major junior ranks in Canada or staying in the 4 [' major', ' junior', ' ranks', ' in', ' Canada']
+39 7 The largest city in x -1 The largest city in Canada Toronto Canada [',', ' Toronto', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Ontario', ' and', ' the', ' largest', ' city', ' in'] , Toronto is a great place to visit . It is the capital of Ontario and the largest city in True General of Canada and the Canadian Prime 2 [' General', ' of', ' Canada']
+40 8 The largest city in x -1 The largest city in United Kingdom London United Kingdom [',', ' London', ' is', ' a', ' major', ' global', ' city', ',', ' and', ' the', ' largest', ' metropolitan', ' area', ' in', ' the', ' European', ' Union', '.', ' It', ' is'] , London is a major global city , and the largest metropolitan area in the European Union . It is True Channel 4 in the United Kingdom premiered on June 5 [' Channel', ' 4', ' in', ' the', ' United', ' Kingdom']
+41 8 The largest city in x -1 The largest city in United Kingdom London United Kingdom [',', ' London', ' is', ' a', ' major', ' global', ' city', ',', ' and', ' the', ' largest', ' metropolitan', ' area', ' in', ' the', ' European', ' Union', '.', ' It', ' is'] , London is a major global city , and the largest metropolitan area in the European Union . It is True 1 ['United', ' Kingdom']
+42 8 The largest city in x -1 The largest city in United Kingdom London United Kingdom [',', ' London', ' is', ' a', ' major', ' global', ' city', ',', ' and', ' the', ' largest', ' metropolitan', ' area', ' in', ' the', ' European', ' Union', '.', ' It', ' is'] , London is a major global city , and the largest metropolitan area in the European Union . It is True BBC Radio 1 in the United Kingdom on 2 July 2012 and 6 [' BBC', ' Radio', ' 1', ' in', ' the', ' United', ' Kingdom']
+43 8 The largest city in x -1 The largest city in United Kingdom London United Kingdom [',', ' London', ' is', ' a', ' major', ' global', ' city', ',', ' and', ' the', ' largest', ' metropolitan', ' area', ' in', ' the', ' European', ' Union', '.', ' It', ' is'] , London is a major global city , and the largest metropolitan area in the European Union . It is True refit in the United Kingdom in 1937, later 5 [' ref', 'it', ' in', ' the', ' United', ' Kingdom']
+44 8 The largest city in x -1 The largest city in United Kingdom London United Kingdom [',', ' London', ' is', ' a', ' major', ' global', ' city', ',', ' and', ' the', ' largest', ' metropolitan', ' area', ' in', ' the', ' European', ' Union', '.', ' It', ' is'] , London is a major global city , and the largest metropolitan area in the European Union . It is True Reviewers in the United Kingdom were generally 5 [' Review', 'ers', ' in', ' the', ' United', ' Kingdom']
+45 9 The largest city in x -1 The largest city in France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True President of France in 1981, he 2 [' President', ' of', ' France']
+46 9 The largest city in x -1 The largest city in France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True accused of leading France into a civil war because 3 [' accused', ' of', ' leading', ' France']
+47 9 The largest city in x -1 The largest city in France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True was dispatched to France in July 1918. Minton 3 [' was', ' dispatched', ' to', ' France']
+48 9 The largest city in x -1 The largest city in France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True against class or rank; France awarded the Légion 5 [' against', ' class', ' or', ' rank', ';', ' France']
+49 9 The largest city in x -1 The largest city in France Paris France [',', ' Paris', ' is', ' a', ' city', ' of', ' culture', ',', ' history', ',', ' and', ' romance', '.', ' It', ' is', ' also', ' a', ' city', ' of', ' great'] , Paris is a city of culture , history , and romance . It is also a city of great True King of England and France, Defender of the 4 [' King', ' of', ' England', ' and', ' France']
+50 10 The largest city in x -1 The largest city in Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Germany', ' and', ' the', ' largest', ' city', ' in'] , Berlin is a great place to visit . It is the capital of Germany and the largest city in True going Gold in Germany and reaching 3 [' going', ' Gold', ' in', ' Germany']
+51 10 The largest city in x -1 The largest city in Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Germany', ' and', ' the', ' largest', ' city', ' in'] , Berlin is a great place to visit . It is the capital of Germany and the largest city in True 0 ['Germany']
+52 10 The largest city in x -1 The largest city in Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Germany', ' and', ' the', ' largest', ' city', ' in'] , Berlin is a great place to visit . It is the capital of Germany and the largest city in True describe life under Nazi Germany and to expose 4 [' describe', ' life', ' under', ' Nazi', ' Germany']
+53 10 The largest city in x -1 The largest city in Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Germany', ' and', ' the', ' largest', ' city', ' in'] , Berlin is a great place to visit . It is the capital of Germany and the largest city in True " Push across Germany and victory ===
+" 2 [' Push', ' across', ' Germany']
+54 10 The largest city in x -1 The largest city in Germany Berlin Germany [',', ' Berlin', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Germany', ' and', ' the', ' largest', ' city', ' in'] , Berlin is a great place to visit . It is the capital of Germany and the largest city in True Alliance partners — Germany and Austria-Hungary 3 [' Alliance', ' partners', ' —', ' Germany']
+55 11 The largest city in x -1 The largest city in Italy Rome Italy [',', ' Rome', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Italy', ' and', ' the', ' largest', ' city', ' in'] , Rome is a great place to visit . It is the capital of Italy and the largest city in True opera all over Italy and Europe fully 3 [' opera', ' all', ' over', ' Italy']
+56 11 The largest city in x -1 The largest city in Italy Rome Italy [',', ' Rome', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Italy', ' and', ' the', ' largest', ' city', ' in'] , Rome is a great place to visit . It is the capital of Italy and the largest city in True Caporetto. In fear that Italy might be put 7 [' Cap', 'ore', 'tto', '.', ' In', ' fear', ' that', ' Italy']
+57 11 The largest city in x -1 The largest city in Italy Rome Italy [',', ' Rome', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Italy', ' and', ' the', ' largest', ' city', ' in'] , Rome is a great place to visit . It is the capital of Italy and the largest city in True expansion within Italy. Despite their 2 [' expansion', ' within', ' Italy']
+58 11 The largest city in x -1 The largest city in Italy Rome Italy [',', ' Rome', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Italy', ' and', ' the', ' largest', ' city', ' in'] , Rome is a great place to visit . It is the capital of Italy and the largest city in True won the gold medal, Italy picked up the silver, 5 [' won', ' the', ' gold', ' medal', ',', ' Italy']
+59 11 The largest city in x -1 The largest city in Italy Rome Italy [',', ' Rome', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Italy', ' and', ' the', ' largest', ' city', ' in'] , Rome is a great place to visit . It is the capital of Italy and the largest city in True 0 ['Italy']
+60 12 The largest city in x -1 The largest city in Mexico Mexico City Mexico [',', ' Gu', 'adal', 'aj', 'ara', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' state', ' of'] , Gu adal aj ara is a great place to visit . It is the capital of the state of False Harbor, after which New Mexico sailed to the island 5 [' Harbor', ',', ' after', ' which', ' New', ' Mexico']
+61 12 The largest city in x -1 The largest city in Mexico Mexico City Mexico [',', ' Gu', 'adal', 'aj', 'ara', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' state', ' of'] , Gu adal aj ara is a great place to visit . It is the capital of the state of False treaty with Mexico guaranteeing that 2 [' treaty', ' with', ' Mexico']
+62 12 The largest city in x -1 The largest city in Mexico Mexico City Mexico [',', ' Gu', 'adal', 'aj', 'ara', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' state', ' of'] , Gu adal aj ara is a great place to visit . It is the capital of the state of False " two dates in Mexico City.
+" 3 [' two', ' dates', ' in', ' Mexico']
+63 12 The largest city in x -1 The largest city in Mexico Mexico City Mexico [',', ' Gu', 'adal', 'aj', 'ara', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' state', ' of'] , Gu adal aj ara is a great place to visit . It is the capital of the state of False km) south of Mexico from an area of 4 [' km', ')', ' south', ' of', ' Mexico']
+64 12 The largest city in x -1 The largest city in Mexico Mexico City Mexico [',', ' Gu', 'adal', 'aj', 'ara', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' state', ' of'] , Gu adal aj ara is a great place to visit . It is the capital of the state of False participation at Mexico City, but he was 2 [' participation', ' at', ' Mexico']
+65 13 The largest city in x -1 The largest city in South Korea Seoul South Korea [',', ' Seoul', ' is', ' a', ' bustling', ' met', 'ropolis', ' with', ' a', ' population', ' of', ' over', ' 10', ' million', '.', ' It', ' is', ' the', ' capital', ' of'] , Seoul is a bustling met ropolis with a population of over 10 million . It is the capital of True travelled back to South Korea for further training, 4 [' travelled', ' back', ' to', ' South', ' Korea']
+66 13 The largest city in x -1 The largest city in South Korea Seoul South Korea [',', ' Seoul', ' is', ' a', ' bustling', ' met', 'ropolis', ' with', ' a', ' population', ' of', ' over', ' 10', ' million', '.', ' It', ' is', ' the', ' capital', ' of'] , Seoul is a bustling met ropolis with a population of over 10 million . It is the capital of True " sports in Japan, South Korea and Taiwan.
+" 5 [' sports', ' in', ' Japan', ',', ' South', ' Korea']
+67 13 The largest city in x -1 The largest city in South Korea Seoul South Korea [',', ' Seoul', ' is', ' a', ' bustling', ' met', 'ropolis', ' with', ' a', ' population', ' of', ' over', ' 10', ' million', '.', ' It', ' is', ' the', ' capital', ' of'] , Seoul is a bustling met ropolis with a population of over 10 million . It is the capital of True during which South Korea had no relations with 3 [' during', ' which', ' South', ' Korea']
+68 13 The largest city in x -1 The largest city in South Korea Seoul South Korea [',', ' Seoul', ' is', ' a', ' bustling', ' met', 'ropolis', ' with', ' a', ' population', ' of', ' over', ' 10', ' million', '.', ' It', ' is', ' the', ' capital', ' of'] , Seoul is a bustling met ropolis with a population of over 10 million . It is the capital of True " South Korea ====
+" 1 [' South', ' Korea']
+69 13 The largest city in x -1 The largest city in South Korea Seoul South Korea [',', ' Seoul', ' is', ' a', ' bustling', ' met', 'ropolis', ' with', ' a', ' population', ' of', ' over', ' 10', ' million', '.', ' It', ' is', ' the', ' capital', ' of'] , Seoul is a bustling met ropolis with a population of over 10 million . It is the capital of True Korean invasion of South Korea in June 1950, 4 [' Korean', ' invasion', ' of', ' South', ' Korea']
+70 14 The largest city in x -1 The largest city in Turkey Istanbul Turkey [',', ' Istanbul', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Turkey', ' and', ' the', ' largest', ' city', ' in'] , Istanbul is a great place to visit . It is the capital of Turkey and the largest city in True 0 ['Turkey']
+71 14 The largest city in x -1 The largest city in Turkey Istanbul Turkey [',', ' Istanbul', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Turkey', ' and', ' the', ' largest', ' city', ' in'] , Istanbul is a great place to visit . It is the capital of Turkey and the largest city in True areas such as India, Turkey and Cyprus, just 5 [' areas', ' such', ' as', ' India', ',', ' Turkey']
+72 14 The largest city in x -1 The largest city in Turkey Istanbul Turkey [',', ' Istanbul', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Turkey', ' and', ' the', ' largest', ' city', ' in'] , Istanbul is a great place to visit . It is the capital of Turkey and the largest city in True south-eastern Turkey to the Sinai 4 [' south', '-', 'e', 'astern', ' Turkey']
+73 14 The largest city in x -1 The largest city in Turkey Istanbul Turkey [',', ' Istanbul', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Turkey', ' and', ' the', ' largest', ' city', ' in'] , Istanbul is a great place to visit . It is the capital of Turkey and the largest city in True qualification victory over Turkey on 28 March 2009, 3 [' qualification', ' victory', ' over', ' Turkey']
+74 14 The largest city in x -1 The largest city in Turkey Istanbul Turkey [',', ' Istanbul', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Turkey', ' and', ' the', ' largest', ' city', ' in'] , Istanbul is a great place to visit . It is the capital of Turkey and the largest city in True world — from Turkey to Prague to Atlanta 3 [' world', ' —', ' from', ' Turkey']
+75 15 The largest city in x -1 The largest city in Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Spain', ' and', ' the', ' largest', ' city', ' in'] , Madrid is a great place to visit . It is the capital of Spain and the largest city in True allied with Spain in the Peninsular 2 [' allied', ' with', ' Spain']
+76 15 The largest city in x -1 The largest city in Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Spain', ' and', ' the', ' largest', ' city', ' in'] , Madrid is a great place to visit . It is the capital of Spain and the largest city in True Itza kingdom to Spain was a critical turning 4 [' It', 'za', ' kingdom', ' to', ' Spain']
+77 15 The largest city in x -1 The largest city in Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Spain', ' and', ' the', ' largest', ' city', ' in'] , Madrid is a great place to visit . It is the capital of Spain and the largest city in True games, losing 2 – 1 to Spain and 3 – 0 to Brazil. 7 [' games', ',', ' losing', ' 2', ' –', ' 1', ' to', ' Spain']
+78 15 The largest city in x -1 The largest city in Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Spain', ' and', ' the', ' largest', ' city', ' in'] , Madrid is a great place to visit . It is the capital of Spain and the largest city in True World Cup in Spain under Stein's 3 [' World', ' Cup', ' in', ' Spain']
+79 15 The largest city in x -1 The largest city in Spain Madrid Spain [',', ' Madrid', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Spain', ' and', ' the', ' largest', ' city', ' in'] , Madrid is a great place to visit . It is the capital of Spain and the largest city in True death in 1833 that Spain finally abandoned 5 [' death', ' in', ' 18', '33', ' that', ' Spain']
+80 16 The largest city in x -1 The largest city in Argentina Buenos Aires Argentina [' is', ' Buenos', ' Aires', ',', ' which', ' is', ' located', ' in', ' the', ' province', ' of', ' Buenos', ' Aires', '.', ' The', ' city', ' is', ' the', ' capital', ' of'] is Buenos Aires , which is located in the province of Buenos Aires . The city is the capital of True 2 ['Ar', 'gent', 'ina']
+81 16 The largest city in x -1 The largest city in Argentina Buenos Aires Argentina [' is', ' Buenos', ' Aires', ',', ' which', ' is', ' located', ' in', ' the', ' province', ' of', ' Buenos', ' Aires', '.', ' The', ' city', ' is', ' the', ' capital', ' of'] is Buenos Aires , which is located in the province of Buenos Aires . The city is the capital of True local media of both Argentina and Venezuela 4 [' local', ' media', ' of', ' both', ' Argentina']
+82 16 The largest city in x -1 The largest city in Argentina Buenos Aires Argentina [' is', ' Buenos', ' Aires', ',', ' which', ' is', ' located', ' in', ' the', ' province', ' of', ' Buenos', ' Aires', '.', ' The', ' city', ' is', ' the', ' capital', ' of'] is Buenos Aires , which is located in the province of Buenos Aires . The city is the capital of True " Cry for Me Argentina =
+" 3 [' Cry', ' for', ' Me', ' Argentina']
+83 16 The largest city in x -1 The largest city in Argentina Buenos Aires Argentina [' is', ' Buenos', ' Aires', ',', ' which', ' is', ' located', ' in', ' the', ' province', ' of', ' Buenos', ' Aires', '.', ' The', ' city', ' is', ' the', ' capital', ' of'] is Buenos Aires , which is located in the province of Buenos Aires . The city is the capital of True of the Casa Rosada, Argentina's government 7 [' of', ' the', ' Cas', 'a', ' Ros', 'ada', ',', ' Argentina']
+84 16 The largest city in x -1 The largest city in Argentina Buenos Aires Argentina [' is', ' Buenos', ' Aires', ',', ' which', ' is', ' located', ' in', ' the', ' province', ' of', ' Buenos', ' Aires', '.', ' The', ' city', ' is', ' the', ' capital', ' of'] is Buenos Aires , which is located in the province of Buenos Aires . The city is the capital of True countries. Newspapers in Argentina including La 5 [' countries', '.', ' Newsp', 'apers', ' in', ' Argentina']
+85 17 The largest city in x -1 The largest city in South Africa Johannesburg South Africa [',', ' Johannes', 'burg', ' is', ' a', ' major', ' economic', ' hub', ' and', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' also', ' the', ' country', '�'] , Johannes burg is a major economic hub and a major tourist destination . It is also the country � True his Queen's South Africa Medal to a 4 "[' his', ' Queen', ""'s"", ' South', ' Africa']"
+86 17 The largest city in x -1 The largest city in South Africa Johannesburg South Africa [',', ' Johannes', 'burg', ' is', ' a', ' major', ' economic', ' hub', ' and', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' also', ' the', ' country', '�'] , Johannes burg is a major economic hub and a major tourist destination . It is also the country � True localities in South Africa and in freshwater 4 [' local', 'ities', ' in', ' South', ' Africa']
+87 17 The largest city in x -1 The largest city in South Africa Johannesburg South Africa [',', ' Johannes', 'burg', ' is', ' a', ' major', ' economic', ' hub', ' and', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' also', ' the', ' country', '�'] , Johannes burg is a major economic hub and a major tourist destination . It is also the country � True filmed in Cape Town, South Africa could instead 6 [' filmed', ' in', ' Cape', ' Town', ',', ' South', ' Africa']
+88 17 The largest city in x -1 The largest city in South Africa Johannesburg South Africa [',', ' Johannes', 'burg', ' is', ' a', ' major', ' economic', ' hub', ' and', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' also', ' the', ' country', '�'] , Johannes burg is a major economic hub and a major tourist destination . It is also the country � True place with South Africa until cricket 3 [' place', ' with', ' South', ' Africa']
+89 17 The largest city in x -1 The largest city in South Africa Johannesburg South Africa [',', ' Johannes', 'burg', ' is', ' a', ' major', ' economic', ' hub', ' and', ' a', ' major', ' tourist', ' destination', '.', ' It', ' is', ' also', ' the', ' country', '�'] , Johannes burg is a major economic hub and a major tourist destination . It is also the country � True " Mithyane from South Africa commented, ""The" 5 [' Mith', 'y', 'ane', ' from', ' South', ' Africa']
+90 18 The largest city in x -1 The largest city in Poland Warsaw Poland [',', ' Warsaw', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Poland', ' and', ' the', ' largest', ' city', ' in'] , Warsaw is a great place to visit . It is the capital of Poland and the largest city in True Anti-communist resistance in Poland was also bolstered, 6 [' Anti', '-', 'commun', 'ist', ' resistance', ' in', ' Poland']
+91 18 The largest city in x -1 The largest city in Poland Warsaw Poland [',', ' Warsaw', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Poland', ' and', ' the', ' largest', ' city', ' in'] , Warsaw is a great place to visit . It is the capital of Poland and the largest city in True Kosciuszko in Poland at Kraków 5 [' Kos', 'cius', 'z', 'ko', ' in', ' Poland']
+92 18 The largest city in x -1 The largest city in Poland Warsaw Poland [',', ' Warsaw', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Poland', ' and', ' the', ' largest', ' city', ' in'] , Warsaw is a great place to visit . It is the capital of Poland and the largest city in True Nazis invaded Poland and forced 2 [' Nazis', ' invaded', ' Poland']
+93 18 The largest city in x -1 The largest city in Poland Warsaw Poland [',', ' Warsaw', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Poland', ' and', ' the', ' largest', ' city', ' in'] , Warsaw is a great place to visit . It is the capital of Poland and the largest city in True Łódź and across Poland over the next few months, 8 [' �', '�', 'ó', 'd', '�', '�', ' and', ' across', ' Poland']
+94 18 The largest city in x -1 The largest city in Poland Warsaw Poland [',', ' Warsaw', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' Poland', ' and', ' the', ' largest', ' city', ' in'] , Warsaw is a great place to visit . It is the capital of Poland and the largest city in True Governorate to autonomous Poland were reprinted in 4 [' Governor', 'ate', ' to', ' autonomous', ' Poland']
+95 19 The largest city in x -1 The largest city in Nigeria Lagos Nigeria [',', ' Lag', 'os', ' is', ' a', ' major', ' commercial', ' and', ' industrial', ' centre', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' Africa'] , Lag os is a major commercial and industrial centre . It is also the most populous city in Africa True southeastern region of Nigeria broke away to form 3 [' southeastern', ' region', ' of', ' Nigeria']
+96 19 The largest city in x -1 The largest city in Nigeria Lagos Nigeria [',', ' Lag', 'os', ' is', ' a', ' major', ' commercial', ' and', ' industrial', ' centre', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' Africa'] , Lag os is a major commercial and industrial centre . It is also the most populous city in Africa True agents and leaders in Nigeria and Cameroon, 4 [' agents', ' and', ' leaders', ' in', ' Nigeria']
+97 19 The largest city in x -1 The largest city in Nigeria Lagos Nigeria [',', ' Lag', 'os', ' is', ' a', ' major', ' commercial', ' and', ' industrial', ' centre', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' Africa'] , Lag os is a major commercial and industrial centre . It is also the most populous city in Africa True in present-day Nigeria and increased encounters 4 [' in', ' present', '-', 'day', ' Nigeria']
+98 19 The largest city in x -1 The largest city in Nigeria Lagos Nigeria [',', ' Lag', 'os', ' is', ' a', ' major', ' commercial', ' and', ' industrial', ' centre', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' Africa'] , Lag os is a major commercial and industrial centre . It is also the most populous city in Africa True Rifles and the Nigeria and Gold Coast 4 [' R', 'ifles', ' and', ' the', ' Nigeria']
+99 19 The largest city in x -1 The largest city in Nigeria Lagos Nigeria [',', ' Lag', 'os', ' is', ' a', ' major', ' commercial', ' and', ' industrial', ' centre', '.', ' It', ' is', ' also', ' the', ' most', ' populous', ' city', ' in', ' Africa'] , Lag os is a major commercial and industrial centre . It is also the most populous city in Africa True far eastern Nigeria (British Cameroons) 2 [' far', ' eastern', ' Nigeria']
+100 20 The largest city in x -1 The largest city in New Zealand Auckland New Zealand [' is', ' Auckland', ',', ' which', ' is', ' located', ' on', ' the', ' North', ' Island', '.', ' Auckland', ' is', ' the', ' most', ' populous', ' city', ' in', ' New', ' Zealand'] is Auckland , which is located on the North Island . Auckland is the most populous city in New Zealand True and the Maori of New Zealand display similar 6 [' and', ' the', ' Ma', 'ori', ' of', ' New', ' Zealand']
+101 20 The largest city in x -1 The largest city in New Zealand Auckland New Zealand [' is', ' Auckland', ',', ' which', ' is', ' located', ' on', ' the', ' North', ' Island', '.', ' Auckland', ' is', ' the', ' most', ' populous', ' city', ' in', ' New', ' Zealand'] is Auckland , which is located on the North Island . Auckland is the most populous city in New Zealand True and later sailed for New Zealand to join HMAS Psyche 5 [' and', ' later', ' sailed', ' for', ' New', ' Zealand']
+102 20 The largest city in x -1 The largest city in New Zealand Auckland New Zealand [' is', ' Auckland', ',', ' which', ' is', ' located', ' on', ' the', ' North', ' Island', '.', ' Auckland', ' is', ' the', ' most', ' populous', ' city', ' in', ' New', ' Zealand'] is Auckland , which is located on the North Island . Auckland is the most populous city in New Zealand True series win in New Zealand since 1981 and 4 [' series', ' win', ' in', ' New', ' Zealand']
+103 20 The largest city in x -1 The largest city in New Zealand Auckland New Zealand [' is', ' Auckland', ',', ' which', ' is', ' located', ' on', ' the', ' North', ' Island', '.', ' Auckland', ' is', ' the', ' most', ' populous', ' city', ' in', ' New', ' Zealand'] is Auckland , which is located on the North Island . Auckland is the most populous city in New Zealand True attacks. The New Zealand artillery again 4 [' attacks', '.', ' The', ' New', ' Zealand']
+104 20 The largest city in x -1 The largest city in New Zealand Auckland New Zealand [' is', ' Auckland', ',', ' which', ' is', ' located', ' on', ' the', ' North', ' Island', '.', ' Auckland', ' is', ' the', ' most', ' populous', ' city', ' in', ' New', ' Zealand'] is Auckland , which is located on the North Island . Auckland is the most populous city in New Zealand True Promising Group' at the New Zealand Music Awards. 7 "[' Prom', 'ising', ' Group', ""'"", ' at', ' the', ' New', ' Zealand']"
+105 21 The largest city in x -1 The largest city in Switzerland Zurich Switzerland [',', ' Geneva', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' cant', 'on', ' of', ' Geneva', ' and'] , Geneva is a great place to visit . It is the capital of the cant on of Geneva and False the Embassy of Switzerland in Tehran, which 3 [' the', ' Embassy', ' of', ' Switzerland']
+106 21 The largest city in x -1 The largest city in Switzerland Zurich Switzerland [',', ' Geneva', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' cant', 'on', ' of', ' Geneva', ' and'] , Geneva is a great place to visit . It is the capital of the cant on of Geneva and False club travelled to Switzerland on the 13th, 3 [' club', ' travelled', ' to', ' Switzerland']
+107 21 The largest city in x -1 The largest city in Switzerland Zurich Switzerland [',', ' Geneva', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' cant', 'on', ' of', ' Geneva', ' and'] , Geneva is a great place to visit . It is the capital of the cant on of Geneva and False " Norway, Spain, Sweden, Switzerland and Germany.
+" 6 [' Norway', ',', ' Spain', ',', ' Sweden', ',', ' Switzerland']
+108 21 The largest city in x -1 The largest city in Switzerland Zurich Switzerland [',', ' Geneva', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' cant', 'on', ' of', ' Geneva', ' and'] , Geneva is a great place to visit . It is the capital of the cant on of Geneva and False basing themselves in Switzerland during the war (including 4 [' bas', 'ing', ' themselves', ' in', ' Switzerland']
+109 21 The largest city in x -1 The largest city in Switzerland Zurich Switzerland [',', ' Geneva', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' cant', 'on', ' of', ' Geneva', ' and'] , Geneva is a great place to visit . It is the capital of the cant on of Geneva and False The Army of Switzerland and portions of the 3 [' The', ' Army', ' of', ' Switzerland']
+110 22 The largest city in x -1 The largest city in Netherlands Amsterdam Netherlands [',', ' Amsterdam', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' country', ' and', ' the', ' largest', ' city'] , Amsterdam is a great place to visit . It is the capital of the country and the largest city True prosperity for the Netherlands and a time when its 3 [' prosperity', ' for', ' the', ' Netherlands']
+111 22 The largest city in x -1 The largest city in Netherlands Amsterdam Netherlands [',', ' Amsterdam', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' country', ' and', ' the', ' largest', ' city'] , Amsterdam is a great place to visit . It is the capital of the country and the largest city True transferred back to the Netherlands and set up 4 [' transferred', ' back', ' to', ' the', ' Netherlands']
+112 22 The largest city in x -1 The largest city in Netherlands Amsterdam Netherlands [',', ' Amsterdam', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' country', ' and', ' the', ' largest', ' city'] , Amsterdam is a great place to visit . It is the capital of the country and the largest city True " Netherlands ===
+" 0 [' Netherlands']
+113 22 The largest city in x -1 The largest city in Netherlands Amsterdam Netherlands [',', ' Amsterdam', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' country', ' and', ' the', ' largest', ' city'] , Amsterdam is a great place to visit . It is the capital of the country and the largest city True exile in the Netherlands as the war concluded. 3 [' exile', ' in', ' the', ' Netherlands']
+114 22 The largest city in x -1 The largest city in Netherlands Amsterdam Netherlands [',', ' Amsterdam', ' is', ' a', ' great', ' place', ' to', ' visit', '.', ' It', ' is', ' the', ' capital', ' of', ' the', ' country', ' and', ' the', ' largest', ' city'] , Amsterdam is a great place to visit . It is the capital of the country and the largest city True annexed the Austrian Netherlands (modern Belgium), 3 [' annexed', ' the', ' Austrian', ' Netherlands']
+115 23 The largest city in x -1 The largest city in Pakistan Karachi Pakistan [',', ' Karachi', ' is', ' a', ' major', ' port', ' and', ' commercial', ' center', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' Pakistan', ' and', ' the', ' third'] , Karachi is a major port and commercial center . It is the largest city in Pakistan and the third True " common breeder in Pakistan and Kashmir.
+" 4 [' common', ' bre', 'eder', ' in', ' Pakistan']
+116 23 The largest city in x -1 The largest city in Pakistan Karachi Pakistan [',', ' Karachi', ' is', ' a', ' major', ' port', ' and', ' commercial', ' center', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' Pakistan', ' and', ' the', ' third'] , Karachi is a major port and commercial center . It is the largest city in Pakistan and the third True New Zealand, Pakistan (including 3 [' New', ' Zealand', ',', ' Pakistan']
+117 23 The largest city in x -1 The largest city in Pakistan Karachi Pakistan [',', ' Karachi', ' is', ' a', ' major', ' port', ' and', ' commercial', ' center', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' Pakistan', ' and', ' the', ' third'] , Karachi is a major port and commercial center . It is the largest city in Pakistan and the third True dismantled state in Pakistan. They flew for 3 [' dismantled', ' state', ' in', ' Pakistan']
+118 23 The largest city in x -1 The largest city in Pakistan Karachi Pakistan [',', ' Karachi', ' is', ' a', ' major', ' port', ' and', ' commercial', ' center', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' Pakistan', ' and', ' the', ' third'] , Karachi is a major port and commercial center . It is the largest city in Pakistan and the third True Indies then moved to Pakistan; in a non first-class 4 [' Indies', ' then', ' moved', ' to', ' Pakistan']
+119 23 The largest city in x -1 The largest city in Pakistan Karachi Pakistan [',', ' Karachi', ' is', ' a', ' major', ' port', ' and', ' commercial', ' center', '.', ' It', ' is', ' the', ' largest', ' city', ' in', ' Pakistan', ' and', ' the', ' third'] , Karachi is a major port and commercial center . It is the largest city in Pakistan and the third True Muslim community from Pakistan and India from 3 [' Muslim', ' community', ' from', ' Pakistan']
diff --git a/patchscopes/code/preprocessed_data/factual/food_from_country.tsv b/patchscopes/code/preprocessed_data/factual/food_from_country.tsv
new file mode 100644
index 00000000..0c7fcf79
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/food_from_country.tsv
@@ -0,0 +1,206 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The country of origin for x -1 The country of origin for Pizza Italy Pizza [' Hut', ' is', ' Italy', '.', '\n', '\n', 'P', 'izza', ' Hut', ' is', ' a', ' global', ' pizza', ' chain', ' that', ' was', ' founded', ' in', ' 1958', ' in'] " Hut is Italy .
+
+ P izza Hut is a global pizza chain that was founded in 1958 in" True after moving into Pizza Hut Park in 2005. 3 [' after', ' moving', ' into', ' Pizza']
+1 0 The country of origin for x -1 The country of origin for Pizza Italy Pizza [' Hut', ' is', ' Italy', '.', '\n', '\n', 'P', 'izza', ' Hut', ' is', ' a', ' global', ' pizza', ' chain', ' that', ' was', ' founded', ' in', ' 1958', ' in'] " Hut is Italy .
+
+ P izza Hut is a global pizza chain that was founded in 1958 in" True MOSPEADA or Samurai Pizza Cats, despite 6 [' M', 'OS', 'PE', 'ADA', ' or', ' Samurai', ' Pizza']
+2 0 The country of origin for x -1 The country of origin for Pizza Italy Pizza [' Hut', ' is', ' Italy', '.', '\n', '\n', 'P', 'izza', ' Hut', ' is', ' a', ' global', ' pizza', ' chain', ' that', ' was', ' founded', ' in', ' 1958', ' in'] " Hut is Italy .
+
+ P izza Hut is a global pizza chain that was founded in 1958 in" True campaigns by Sony and Pizza Hut. The Japanese version 4 [' campaigns', ' by', ' Sony', ' and', ' Pizza']
+3 0 The country of origin for x -1 The country of origin for Pizza Italy Pizza [' Hut', ' is', ' Italy', '.', '\n', '\n', 'P', 'izza', ' Hut', ' is', ' a', ' global', ' pizza', ' chain', ' that', ' was', ' founded', ' in', ' 1958', ' in'] " Hut is Italy .
+
+ P izza Hut is a global pizza chain that was founded in 1958 in" True Ozzfest at the Pizza Hut Park on 9 August. 5 [' O', 'zz', 'fest', ' at', ' the', ' Pizza']
+4 0 The country of origin for x -1 The country of origin for Pizza Italy Pizza [' Hut', ' is', ' Italy', '.', '\n', '\n', 'P', 'izza', ' Hut', ' is', ' a', ' global', ' pizza', ' chain', ' that', ' was', ' founded', ' in', ' 1958', ' in'] " Hut is Italy .
+
+ P izza Hut is a global pizza chain that was founded in 1958 in" True Richmond Kickers at Pizza Hut Park in Frisco, 4 [' Richmond', ' Kick', 'ers', ' at', ' Pizza']
+5 1 The country of origin for x -1 The country of origin for Sushi Japan Sushi [' is', ' Japan', '.', '\n', '\n', 'S', 'ushi', ' is', ' a', ' Japanese', ' dish', ' that', ' is', ' made', ' from', ' raw', ' fish', ',', ' usually', ' raw'] " is Japan .
+
+ S ushi is a Japanese dish that is made from raw fish , usually raw" True Jon Ledford of Sushi Arcade included her 5 [' Jon', ' Led', 'ford', ' of', ' S', 'ushi']
+6 1 The country of origin for x -1 The country of origin for Sushi Japan Sushi [' is', ' Japan', '.', '\n', '\n', 'S', 'ushi', ' is', ' a', ' Japanese', ' dish', ' that', ' is', ' made', ' from', ' raw', ' fish', ',', ' usually', ' raw'] " is Japan .
+
+ S ushi is a Japanese dish that is made from raw fish , usually raw" True 2013, Jon Ledford of Sushi Arcade included 7 [' 2013', ',', ' Jon', ' Led', 'ford', ' of', ' S', 'ushi']
+7 1 The country of origin for x -1 The country of origin for Sushi Japan Sushi [' is', ' Japan', '.', '\n', '\n', 'S', 'ushi', ' is', ' a', ' Japanese', ' dish', ' that', ' is', ' made', ' from', ' raw', ' fish', ',', ' usually', ' raw'] " is Japan .
+
+ S ushi is a Japanese dish that is made from raw fish , usually raw" True " Ledford of Arcade Sushi opined that ""in" 5 [' Led', 'ford', ' of', ' Arcade', ' S', 'ushi']
+8 2 The country of origin for x -1 The country of origin for Tacos Mexico Tacos [' is', ' Mexico', '.', '\n', '\n', 'The', ' country', ' of', ' origin', ' for', ' Tac', 'os', ' is', ' Mexico', '.', '\n', '\n', 'The', ' country', ' of'] " is Mexico .
+
+ The country of origin for Tac os is Mexico .
+
+ The country of" True at Ricky's Tacos speaks to Michael, 4 "[' at', ' Ricky', ""'s"", ' Tac', 'os']"
+9 2 The country of origin for x -1 The country of origin for Tacos Mexico Tacos [' is', ' Mexico', '.', '\n', '\n', 'The', ' country', ' of', ' origin', ' for', ' Tac', 'os', ' is', ' Mexico', '.', '\n', '\n', 'The', ' country', ' of'] " is Mexico .
+
+ The country of origin for Tac os is Mexico .
+
+ The country of" True microphone at Ricky's Tacos speaks to Michael, 5 "[' microphone', ' at', ' Ricky', ""'s"", ' Tac', 'os']"
+10 2 The country of origin for x -1 The country of origin for Tacos Mexico Tacos [' is', ' Mexico', '.', '\n', '\n', 'The', ' country', ' of', ' origin', ' for', ' Tac', 'os', ' is', ' Mexico', '.', '\n', '\n', 'The', ' country', ' of'] " is Mexico .
+
+ The country of origin for Tac os is Mexico .
+
+ The country of" True microphone at Ricky's Tacos speaks to 5 "[' microphone', ' at', ' Ricky', ""'s"", ' Tac', 'os']"
+11 3 The country of origin for x -1 The country of origin for Baguette France Baguette [' is', ' France', '.', '\n', '\n', 'B', 'agu', 'ette', ' is', ' a', ' French', ' bread', ' that', ' is', ' made', ' with', ' a', ' long', ',', ' thin'] " is France .
+
+ B agu ette is a French bread that is made with a long , thin" True 2 ['B', 'agu', 'ette']
+12 3 The country of origin for x -1 The country of origin for Baguette France Baguette [' is', ' France', '.', '\n', '\n', 'B', 'agu', 'ette', ' is', ' a', ' French', ' bread', ' that', ' is', ' made', ' with', ' a', ' long', ',', ' thin'] " is France .
+
+ B agu ette is a French bread that is made with a long , thin" True that the Chicken BLT Baguette sandwich, 7 [' that', ' the', ' Chicken', ' BL', 'T', ' B', 'agu', 'ette']
+13 3 The country of origin for x -1 The country of origin for Baguette France Baguette [' is', ' France', '.', '\n', '\n', 'B', 'agu', 'ette', ' is', ' a', ' French', ' bread', ' that', ' is', ' made', ' with', ' a', ' long', ',', ' thin'] " is France .
+
+ B agu ette is a French bread that is made with a long , thin" True BK introduced its BK Baguette line of sandwiches 8 [' B', 'K', ' introduced', ' its', ' B', 'K', ' B', 'agu', 'ette']
+14 3 The country of origin for x -1 The country of origin for Baguette France Baguette [' is', ' France', '.', '\n', '\n', 'B', 'agu', 'ette', ' is', ' a', ' French', ' bread', ' that', ' is', ' made', ' with', ' a', ' long', ',', ' thin'] " is France .
+
+ B agu ette is a French bread that is made with a long , thin" True introduced its BK Baguette line of sandwiches 6 [' introduced', ' its', ' B', 'K', ' B', 'agu', 'ette']
+15 3 The country of origin for x -1 The country of origin for Baguette France Baguette [' is', ' France', '.', '\n', '\n', 'B', 'agu', 'ette', ' is', ' a', ' French', ' bread', ' that', ' is', ' made', ' with', ' a', ' long', ',', ' thin'] " is France .
+
+ B agu ette is a French bread that is made with a long , thin" True introduced its BK Baguette line of sandwiches 6 [' introduced', ' its', ' B', 'K', ' B', 'agu', 'ette']
+16 5 The country of origin for x -1 The country of origin for Paella Spain Paella [' is', ' Spain', '.', '\n', '\n', 'The', ' Spanish', ' word', ' for', ' pa', 'ella', ' is', ' pa', 'ella', '.', '\n', '\n', 'The', ' Spanish', ' word'] " is Spain .
+
+ The Spanish word for pa ella is pa ella .
+
+ The Spanish word" True Balearic islands. Paella is a characteristic 6 [' Bal', 'ear', 'ic', ' islands', '.', ' Pa', 'ella']
+17 5 The country of origin for x -1 The country of origin for Paella Spain Paella [' is', ' Spain', '.', '\n', '\n', 'The', ' Spanish', ' word', ' for', ' pa', 'ella', ' is', ' pa', 'ella', '.', '\n', '\n', 'The', ' Spanish', ' word'] " is Spain .
+
+ The Spanish word for pa ella is pa ella .
+
+ The Spanish word" True Balearic islands. Paella is a characteristic 6 [' Bal', 'ear', 'ic', ' islands', '.', ' Pa', 'ella']
+18 5 The country of origin for x -1 The country of origin for Paella Spain Paella [' is', ' Spain', '.', '\n', '\n', 'The', ' Spanish', ' word', ' for', ' pa', 'ella', ' is', ' pa', 'ella', '.', '\n', '\n', 'The', ' Spanish', ' word'] " is Spain .
+
+ The Spanish word for pa ella is pa ella .
+
+ The Spanish word" True Balearic islands. Paella is a characteristic 6 [' Bal', 'ear', 'ic', ' islands', '.', ' Pa', 'ella']
+19 5 The country of origin for x -1 The country of origin for Paella Spain Paella [' is', ' Spain', '.', '\n', '\n', 'The', ' Spanish', ' word', ' for', ' pa', 'ella', ' is', ' pa', 'ella', '.', '\n', '\n', 'The', ' Spanish', ' word'] " is Spain .
+
+ The Spanish word for pa ella is pa ella .
+
+ The Spanish word" True Balearic islands. Paella is a characteristic 6 [' Bal', 'ear', 'ic', ' islands', '.', ' Pa', 'ella']
+20 7 The country of origin for x -1 The country of origin for Baklava Turkey Baklava "[' is', ' Greece', '.', ' The', ' name', ' Bak', 'l', 'ava', ' comes', ' from', ' the', ' Turkish', ' word', ' ""', 'b', 'ak', 'l', 'ava', '""', ' which']" " is Greece . The name Bak l ava comes from the Turkish word "" b ak l ava "" which" False 3 ['B', 'ak', 'l', 'ava']
+21 8 The country of origin for x -1 The country of origin for Feijoada Brazil Feijoada [' is', ' Brazil', ',', ' but', ' it', ' is', ' a', ' dish', ' that', ' has', ' been', ' adopted', ' by', ' many', ' other', ' countries', '.', ' It', ' is', ' a'] is Brazil , but it is a dish that has been adopted by many other countries . It is a True Other dishes include Feijoada Poveira, made with 6 [' Other', ' dishes', ' include', ' Fe', 'ij', 'oad', 'a']
+22 8 The country of origin for x -1 The country of origin for Feijoada Brazil Feijoada [' is', ' Brazil', ',', ' but', ' it', ' is', ' a', ' dish', ' that', ' has', ' been', ' adopted', ' by', ' many', ' other', ' countries', '.', ' It', ' is', ' a'] is Brazil , but it is a dish that has been adopted by many other countries . It is a True Other dishes include Feijoada Poveira, made 6 [' Other', ' dishes', ' include', ' Fe', 'ij', 'oad', 'a']
+23 8 The country of origin for x -1 The country of origin for Feijoada Brazil Feijoada [' is', ' Brazil', ',', ' but', ' it', ' is', ' a', ' dish', ' that', ' has', ' been', ' adopted', ' by', ' many', ' other', ' countries', '.', ' It', ' is', ' a'] is Brazil , but it is a dish that has been adopted by many other countries . It is a True dishes include Feijoada Poveira, made 5 [' dishes', ' include', ' Fe', 'ij', 'oad', 'a']
+24 8 The country of origin for x -1 The country of origin for Feijoada Brazil Feijoada [' is', ' Brazil', ',', ' but', ' it', ' is', ' a', ' dish', ' that', ' has', ' been', ' adopted', ' by', ' many', ' other', ' countries', '.', ' It', ' is', ' a'] is Brazil , but it is a dish that has been adopted by many other countries . It is a True Other dishes include Feijoada Poveira, made with 6 [' Other', ' dishes', ' include', ' Fe', 'ij', 'oad', 'a']
+25 9 The country of origin for x -1 The country of origin for Borscht Ukraine Borscht [' is', ' Russia', ',', ' and', ' the', ' name', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' beet', 'root', '.', ' The', ' beet', 'root', ' is'] is Russia , and the name is derived from the Russian word for beet root . The beet root is False borscht garnishes. Borscht based on bone 8 [' b', 'ors', 'cht', ' garn', 'ishes', '.', ' B', 'ors', 'cht']
+26 9 The country of origin for x -1 The country of origin for Borscht Ukraine Borscht [' is', ' Russia', ',', ' and', ' the', ' name', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' beet', 'root', '.', ' The', ' beet', 'root', ' is'] is Russia , and the name is derived from the Russian word for beet root . The beet root is False 2 ['B', 'ors', 'cht']
+27 9 The country of origin for x -1 The country of origin for Borscht Ukraine Borscht [' is', ' Russia', ',', ' and', ' the', ' name', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' beet', 'root', '.', ' The', ' beet', 'root', ' is'] is Russia , and the name is derived from the Russian word for beet root . The beet root is False borscht garnishes. Borscht based on bone 8 [' b', 'ors', 'cht', ' garn', 'ishes', '.', ' B', 'ors', 'cht']
+28 9 The country of origin for x -1 The country of origin for Borscht Ukraine Borscht [' is', ' Russia', ',', ' and', ' the', ' name', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' beet', 'root', '.', ' The', ' beet', 'root', ' is'] is Russia , and the name is derived from the Russian word for beet root . The beet root is False 2 ['B', 'ors', 'cht']
+29 9 The country of origin for x -1 The country of origin for Borscht Ukraine Borscht [' is', ' Russia', ',', ' and', ' the', ' name', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' beet', 'root', '.', ' The', ' beet', 'root', ' is'] is Russia , and the name is derived from the Russian word for beet root . The beet root is False " = Borscht =
+" 3 [' =', ' B', 'ors', 'cht']
+30 11 The country of origin for x -1 The country of origin for Dim Sum China Dim Sum [' is', ' China', '.', '\n', '\n', 'Dim', ' Sum', ' is', ' a', ' type', ' of', ' Chinese', ' food', ' that', ' is', ' usually', ' served', ' in', ' small', ' portions'] " is China .
+
+ Dim Sum is a type of Chinese food that is usually served in small portions" True " Fairy Wine and Dim Sum ’ s.
+" 4 [' Fairy', ' Wine', ' and', ' Dim', ' Sum']
+31 15 The country of origin for x -1 The country of origin for Pho Vietnam Pho ['ebus', ' is', ' unknown', '.', '\n', '\n', 'The', ' Ph', 'o', 'ebus', ' is', ' a', ' very', ' rare', ' breed', ' of', ' dog', '.', ' It', ' is'] "ebus is unknown .
+
+ The Ph o ebus is a very rare breed of dog . It is" False confluence of the Pho Chhu (father) and Mo 5 [' conf', 'luence', ' of', ' the', ' Ph', 'o']
+32 15 The country of origin for x -1 The country of origin for Pho Vietnam Pho ['ebus', ' is', ' unknown', '.', '\n', '\n', 'The', ' Ph', 'o', 'ebus', ' is', ' a', ' very', ' rare', ' breed', ' of', ' dog', '.', ' It', ' is'] "ebus is unknown .
+
+ The Ph o ebus is a very rare breed of dog . It is" False Mo Chu and Pho Chu rivers, and in 4 [' Mo', ' Chu', ' and', ' Ph', 'o']
+33 15 The country of origin for x -1 The country of origin for Pho Vietnam Pho ['ebus', ' is', ' unknown', '.', '\n', '\n', 'The', ' Ph', 'o', 'ebus', ' is', ' a', ' very', ' rare', ' breed', ' of', ' dog', '.', ' It', ' is'] "ebus is unknown .
+
+ The Ph o ebus is a very rare breed of dog . It is" False delta region, Pho Duc Chinh fled and 4 [' delta', ' region', ',', ' Ph', 'o']
+34 15 The country of origin for x -1 The country of origin for Pho Vietnam Pho ['ebus', ' is', ' unknown', '.', '\n', '\n', 'The', ' Ph', 'o', 'ebus', ' is', ' a', ' very', ' rare', ' breed', ' of', ' dog', '.', ' It', ' is'] "ebus is unknown .
+
+ The Ph o ebus is a very rare breed of dog . It is" False in the Mo Chu and Pho Chu rivers, and in 6 [' in', ' the', ' Mo', ' Chu', ' and', ' Ph', 'o']
+35 15 The country of origin for x -1 The country of origin for Pho Vietnam Pho ['ebus', ' is', ' unknown', '.', '\n', '\n', 'The', ' Ph', 'o', 'ebus', ' is', ' a', ' very', ' rare', ' breed', ' of', ' dog', '.', ' It', ' is'] "ebus is unknown .
+
+ The Ph o ebus is a very rare breed of dog . It is" False Battalion conceded Pho Lai to the VPA 3 [' Battalion', ' conceded', ' Ph', 'o']
+36 16 The country of origin for x -1 The country of origin for Hummus Lebanon Hummus [' is', ' Lebanon', '.', '\n', '\n', 'H', 'umm', 'us', ' is', ' a', ' Middle', ' Eastern', ' dip', ' made', ' from', ' chick', 'pe', 'as', ',', ' t'] " is Lebanon .
+
+ H umm us is a Middle Eastern dip made from chick pe as , t" True 2 ['H', 'umm', 'us']
+37 16 The country of origin for x -1 The country of origin for Hummus Lebanon Hummus [' is', ' Lebanon', '.', '\n', '\n', 'H', 'umm', 'us', ' is', ' a', ' Middle', ' Eastern', ' dip', ' made', ' from', ' chick', 'pe', 'as', ',', ' t'] " is Lebanon .
+
+ H umm us is a Middle Eastern dip made from chick pe as , t" True za 'atar, or jams. Hummus bi tahini is also eaten 9 "[' z', 'a', "" '"", 'atar', ',', ' or', ' jams', '.', ' Hum', 'mus']"
+38 16 The country of origin for x -1 The country of origin for Hummus Lebanon Hummus [' is', ' Lebanon', '.', '\n', '\n', 'H', 'umm', 'us', ' is', ' a', ' Middle', ' Eastern', ' dip', ' made', ' from', ' chick', 'pe', 'as', ',', ' t'] " is Lebanon .
+
+ H umm us is a Middle Eastern dip made from chick pe as , t" True 2 ['H', 'umm', 'us']
+39 16 The country of origin for x -1 The country of origin for Hummus Lebanon Hummus [' is', ' Lebanon', '.', '\n', '\n', 'H', 'umm', 'us', ' is', ' a', ' Middle', ' Eastern', ' dip', ' made', ' from', ' chick', 'pe', 'as', ',', ' t'] " is Lebanon .
+
+ H umm us is a Middle Eastern dip made from chick pe as , t" True 2 ['H', 'umm', 'us']
+40 16 The country of origin for x -1 The country of origin for Hummus Lebanon Hummus [' is', ' Lebanon', '.', '\n', '\n', 'H', 'umm', 'us', ' is', ' a', ' Middle', ' Eastern', ' dip', ' made', ' from', ' chick', 'pe', 'as', ',', ' t'] " is Lebanon .
+
+ H umm us is a Middle Eastern dip made from chick pe as , t" True 2 ['H', 'umm', 'us']
+41 17 The country of origin for x -1 The country of origin for Gyro Greece Gyro ['-', 'M', 'atic', ' is', ' the', ' United', ' States', '.', '\n', '\n', 'Gy', 'ro', '-', 'M', 'atic', ' is', ' a', ' registered', ' trademark', ' of'] "- M atic is the United States .
+
+ Gy ro - M atic is a registered trademark of" False " pioneered the use of the Gyro Rate Unit.
+" 6 [' pioneered', ' the', ' use', ' of', ' the', ' Gy', 'ro']
+42 17 The country of origin for x -1 The country of origin for Gyro Greece Gyro ['-', 'M', 'atic', ' is', ' the', ' United', ' States', '.', '\n', '\n', 'Gy', 'ro', '-', 'M', 'atic', ' is', ' a', ' registered', ' trademark', ' of'] "- M atic is the United States .
+
+ Gy ro - M atic is a registered trademark of" False " pioneered the use of the Gyro Rate Unit.
+" 6 [' pioneered', ' the', ' use', ' of', ' the', ' Gy', 'ro']
+43 17 The country of origin for x -1 The country of origin for Gyro Greece Gyro ['-', 'M', 'atic', ' is', ' the', ' United', ' States', '.', '\n', '\n', 'Gy', 'ro', '-', 'M', 'atic', ' is', ' a', ' registered', ' trademark', ' of'] "- M atic is the United States .
+
+ Gy ro - M atic is a registered trademark of" False " the use of the Gyro Rate Unit.
+" 5 [' the', ' use', ' of', ' the', ' Gy', 'ro']
+44 17 The country of origin for x -1 The country of origin for Gyro Greece Gyro ['-', 'M', 'atic', ' is', ' the', ' United', ' States', '.', '\n', '\n', 'Gy', 'ro', '-', 'M', 'atic', ' is', ' a', ' registered', ' trademark', ' of'] "- M atic is the United States .
+
+ Gy ro - M atic is a registered trademark of" False " pioneered the use of the Gyro Rate Unit.
+" 6 [' pioneered', ' the', ' use', ' of', ' the', ' Gy', 'ro']
+45 20 The country of origin for x -1 The country of origin for Moussaka Greece Moussaka "[' is', ' Greece', ',', ' and', ' the', ' name', ' is', ' derived', ' from', ' the', ' Greek', ' word', ' for', ' ""', 'au', 'berg', 'ine', '""', ' (', 'egg']" " is Greece , and the name is derived from the Greek word for "" au berg ine "" ( egg" True 3 ['M', 'ous', 's', 'aka']
+46 20 The country of origin for x -1 The country of origin for Moussaka Greece Moussaka "[' is', ' Greece', ',', ' and', ' the', ' name', ' is', ' derived', ' from', ' the', ' Greek', ' word', ' for', ' ""', 'au', 'berg', 'ine', '""', ' (', 'egg']" " is Greece , and the name is derived from the Greek word for "" au berg ine "" ( egg" True 3 ['M', 'ous', 's', 'aka']
+47 21 The country of origin for x -1 The country of origin for Pavlova New Zealand Pavlova "[' is', ' Russia', '.', '\n', '\n', 'The', ' name', ' Pav', 'l', 'ova', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' ""', 'pe', 'ach']" " is Russia .
+
+ The name Pav l ova is derived from the Russian word for "" pe ach" False Russian ballerina Anna Pavlova and Mrs. Patrick Campbell. 7 [' Russian', ' ball', 'er', 'ina', ' Anna', ' Pav', 'l', 'ova']
+48 21 The country of origin for x -1 The country of origin for Pavlova New Zealand Pavlova "[' is', ' Russia', '.', '\n', '\n', 'The', ' name', ' Pav', 'l', 'ova', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' ""', 'pe', 'ach']" " is Russia .
+
+ The name Pav l ova is derived from the Russian word for "" pe ach" False first performed by Pavlova at a gala in the 5 [' first', ' performed', ' by', ' Pav', 'l', 'ova']
+49 21 The country of origin for x -1 The country of origin for Pavlova New Zealand Pavlova "[' is', ' Russia', '.', '\n', '\n', 'The', ' name', ' Pav', 'l', 'ova', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' ""', 'pe', 'ach']" " is Russia .
+
+ The name Pav l ova is derived from the Russian word for "" pe ach" False 3 ['P', 'av', 'l', 'ova']
+50 21 The country of origin for x -1 The country of origin for Pavlova New Zealand Pavlova "[' is', ' Russia', '.', '\n', '\n', 'The', ' name', ' Pav', 'l', 'ova', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' ""', 'pe', 'ach']" " is Russia .
+
+ The name Pav l ova is derived from the Russian word for "" pe ach" False considered one of Pavlova's trademarks. 5 [' considered', ' one', ' of', ' Pav', 'l', 'ova']
+51 21 The country of origin for x -1 The country of origin for Pavlova New Zealand Pavlova "[' is', ' Russia', '.', '\n', '\n', 'The', ' name', ' Pav', 'l', 'ova', ' is', ' derived', ' from', ' the', ' Russian', ' word', ' for', ' ""', 'pe', 'ach']" " is Russia .
+
+ The name Pav l ova is derived from the Russian word for "" pe ach" False 3 ['P', 'av', 'l', 'ova']
+52 22 The country of origin for x -1 The country of origin for Shawarma Middle East Shawarma [' is', ' Turkey', '.', '\n', '\n', 'Sh', 'aw', 'arma', ' is', ' a', ' type', ' of', ' Middle', ' Eastern', ' food', ' that', ' is', ' made', ' from', ' a'] " is Turkey .
+
+ Sh aw arma is a type of Middle Eastern food that is made from a" True the world premiere. Shawarma sales in Los Angeles, 5 [' the', ' world', ' premiere', '.', ' Shaw', 'arma']
+53 22 The country of origin for x -1 The country of origin for Shawarma Middle East Shawarma [' is', ' Turkey', '.', '\n', '\n', 'Sh', 'aw', 'arma', ' is', ' a', ' type', ' of', ' Middle', ' Eastern', ' food', ' that', ' is', ' made', ' from', ' a'] " is Turkey .
+
+ Sh aw arma is a type of Middle Eastern food that is made from a" True khubz, such as Shawarma and Falafel are also 7 [' k', 'hub', 'z', ',', ' such', ' as', ' Shaw', 'arma']
+54 22 The country of origin for x -1 The country of origin for Shawarma Middle East Shawarma [' is', ' Turkey', '.', '\n', '\n', 'Sh', 'aw', 'arma', ' is', ' a', ' type', ' of', ' Middle', ' Eastern', ' food', ' that', ' is', ' made', ' from', ' a'] " is Turkey .
+
+ Sh aw arma is a type of Middle Eastern food that is made from a" True khubz, such as Shawarma and Falafel 7 [' k', 'hub', 'z', ',', ' such', ' as', ' Shaw', 'arma']
+55 23 The country of origin for x -1 The country of origin for Falafel Middle East Falafel [' is', ' Israel', '.', '\n', '\n', 'Fal', 'af', 'el', ' is', ' a', ' traditional', ' Middle', ' Eastern', ' dish', ' made', ' from', ' chick', 'pe', 'as', ','] " is Israel .
+
+ Fal af el is a traditional Middle Eastern dish made from chick pe as ," True 2 ['Fal', 'af', 'el']
+56 23 The country of origin for x -1 The country of origin for Falafel Middle East Falafel [' is', ' Israel', '.', '\n', '\n', 'Fal', 'af', 'el', ' is', ' a', ' traditional', ' Middle', ' Eastern', ' dish', ' made', ' from', ' chick', 'pe', 'as', ','] " is Israel .
+
+ Fal af el is a traditional Middle Eastern dish made from chick pe as ," True " = Falafel =
+" 3 [' =', ' Fal', 'af', 'el']
+57 23 The country of origin for x -1 The country of origin for Falafel Middle East Falafel [' is', ' Israel', '.', '\n', '\n', 'Fal', 'af', 'el', ' is', ' a', ' traditional', ' Middle', ' Eastern', ' dish', ' made', ' from', ' chick', 'pe', 'as', ','] " is Israel .
+
+ Fal af el is a traditional Middle Eastern dish made from chick pe as ," True 2 ['Fal', 'af', 'el']
+58 23 The country of origin for x -1 The country of origin for Falafel Middle East Falafel [' is', ' Israel', '.', '\n', '\n', 'Fal', 'af', 'el', ' is', ' a', ' traditional', ' Middle', ' Eastern', ' dish', ' made', ' from', ' chick', 'pe', 'as', ','] " is Israel .
+
+ Fal af el is a traditional Middle Eastern dish made from chick pe as ," True regarding the origins of Falafel, and reassert the 6 [' regarding', ' the', ' origins', ' of', ' Fal', 'af', 'el']
+59 23 The country of origin for x -1 The country of origin for Falafel Middle East Falafel [' is', ' Israel', '.', '\n', '\n', 'Fal', 'af', 'el', ' is', ' a', ' traditional', ' Middle', ' Eastern', ' dish', ' made', ' from', ' chick', 'pe', 'as', ','] " is Israel .
+
+ Fal af el is a traditional Middle Eastern dish made from chick pe as ," True 2 ['Fal', 'af', 'el']
+60 25 The country of origin for x -1 The country of origin for Biryani India Biryani [' is', ' India', '.', '\n', '\n', 'The', ' country', ' of', ' origin', ' for', ' B', 'iry', 'ani', ' is', ' India', '.', '\n', '\n', 'The', ' country'] " is India .
+
+ The country of origin for B iry ani is India .
+
+ The country" True 2 ['B', 'iry', 'ani']
+61 25 The country of origin for x -1 The country of origin for Biryani India Biryani [' is', ' India', '.', '\n', '\n', 'The', ' country', ' of', ' origin', ' for', ' B', 'iry', 'ani', ' is', ' India', '.', '\n', '\n', 'The', ' country'] " is India .
+
+ The country of origin for B iry ani is India .
+
+ The country" True chicken pakoras. Biryani and murgh pilao 7 [' chicken', ' p', 'ak', 'oras', '.', ' B', 'iry', 'ani']
+62 25 The country of origin for x -1 The country of origin for Biryani India Biryani [' is', ' India', '.', '\n', '\n', 'The', ' country', ' of', ' origin', ' for', ' B', 'iry', 'ani', ' is', ' India', '.', '\n', '\n', 'The', ' country'] " is India .
+
+ The country of origin for B iry ani is India .
+
+ The country" True 2 ['B', 'iry', 'ani']
+63 25 The country of origin for x -1 The country of origin for Biryani India Biryani [' is', ' India', '.', '\n', '\n', 'The', ' country', ' of', ' origin', ' for', ' B', 'iry', 'ani', ' is', ' India', '.', '\n', '\n', 'The', ' country'] " is India .
+
+ The country of origin for B iry ani is India .
+
+ The country" True chicken pakoras. Biryani and murgh pilao 7 [' chicken', ' p', 'ak', 'oras', '.', ' B', 'iry', 'ani']
+64 27 The country of origin for x -1 The country of origin for Fondue Switzerland Fondue [' is', ' Switzerland', '.', '\n', '\n', 'F', 'ond', 'ue', ' is', ' a', ' Swiss', ' dish', ' that', ' is', ' made', ' with', ' a', ' cheese', ' and', ' wine'] " is Switzerland .
+
+ F ond ue is a Swiss dish that is made with a cheese and wine" True " for Brittany's ""Fondue for Two"" segments" 6 "[' for', ' Brittany', ""'s"", ' ""', 'F', 'ond', 'ue']"
+65 27 The country of origin for x -1 The country of origin for Fondue Switzerland Fondue [' is', ' Switzerland', '.', '\n', '\n', 'F', 'ond', 'ue', ' is', ' a', ' Swiss', ' dish', ' that', ' is', ' made', ' with', ' a', ' cheese', ' and', ' wine'] " is Switzerland .
+
+ F ond ue is a Swiss dish that is made with a cheese and wine" True " for Brittany's ""Fondue for Two"" segments" 6 "[' for', ' Brittany', ""'s"", ' ""', 'F', 'ond', 'ue']"
+66 29 The country of origin for x -1 The country of origin for Miso Soup Japan Miso Soup [' is', ' Japan', '.', ' It', ' is', ' a', ' traditional', ' Japanese', ' soup', ' made', ' from', ' fermented', ' soy', 'beans', ',', ' rice', ',', ' and', ' salt', '.'] is Japan . It is a traditional Japanese soup made from fermented soy beans , rice , and salt . True Hana-chan no Miso Soup (30 August 2014, 7 [' H', 'ana', '-', 'chan', ' no', ' Mis', 'o', ' Soup']
+67 29 The country of origin for x -1 The country of origin for Miso Soup Japan Miso Soup [' is', ' Japan', '.', ' It', ' is', ' a', ' traditional', ' Japanese', ' soup', ' made', ' from', ' fermented', ' soy', 'beans', ',', ' rice', ',', ' and', ' salt', '.'] is Japan . It is a traditional Japanese soup made from fermented soy beans , rice , and salt . True Hana-chan no Miso Soup (30 August 2014, NTV) 7 [' H', 'ana', '-', 'chan', ' no', ' Mis', 'o', ' Soup']
diff --git a/patchscopes/code/preprocessed_data/factual/person_father.tsv b/patchscopes/code/preprocessed_data/factual/person_father.tsv
new file mode 100644
index 00000000..47947284
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/person_father.tsv
@@ -0,0 +1,15016 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 Name of father of x -1 Name of father of Beyoncé Knowles Mathew Knowles Beyoncé Knowles "['-' 'Carter' ',' ' the' ' singer' ',' ' and' ' Jay' '-' 'Z' ',' ' the'
+ ' rapper' '.' '\n' '\n' 'The' ' couple' ',' ' who']" "- Carter , the singer , and Jay - Z , the rapper .
+
+ The couple , who" False Coppola, starring Beyoncé Knowles and Eva Longoria. 8 [' Co', 'pp', 'ola', ',', ' starring', ' Beyon', 'cé', ' Know', 'les']
+1 0 Name of father of x -1 Name of father of Beyoncé Knowles Mathew Knowles Beyoncé Knowles "['-' 'Carter' ',' ' the' ' singer' ',' ' and' ' Jay' '-' 'Z' ',' ' the'
+ ' rapper' '.' '\n' '\n' 'The' ' couple' ',' ' who']" "- Carter , the singer , and Jay - Z , the rapper .
+
+ The couple , who" False " as Sharon Charles
+" 7 [' as', ' Sharon', ' Charles', 'Bey', 'on', 'cé', ' Know', 'les']
+2 0 Name of father of x -1 Name of father of Beyoncé Knowles Mathew Knowles Beyoncé Knowles "['-' 'Carter' ',' ' the' ' singer' ',' ' and' ' Jay' '-' 'Z' ',' ' the'
+ ' rapper' '.' '\n' '\n' 'The' ' couple' ',' ' who']" "- Carter , the singer , and Jay - Z , the rapper .
+
+ The couple , who" False taking shots at Beyoncé Knowles and Ciara, in lines 6 [' taking', ' shots', ' at', ' Beyon', 'cé', ' Know', 'les']
+3 0 Name of father of x -1 Name of father of Beyoncé Knowles Mathew Knowles Beyoncé Knowles "['-' 'Carter' ',' ' the' ' singer' ',' ' and' ' Jay' '-' 'Z' ',' ' the'
+ ' rapper' '.' '\n' '\n' 'The' ' couple' ',' ' who']" "- Carter , the singer , and Jay - Z , the rapper .
+
+ The couple , who" False " production: Beyoncé Knowles and Kelly Rowland
+" 5 [' production', ':', ' Beyon', 'cé', ' Know', 'les']
+4 0 Name of father of x -1 Name of father of Beyoncé Knowles Mathew Knowles Beyoncé Knowles "['-' 'Carter' ',' ' the' ' singer' ',' ' and' ' Jay' '-' 'Z' ',' ' the'
+ ' rapper' '.' '\n' '\n' 'The' ' couple' ',' ' who']" "- Carter , the singer , and Jay - Z , the rapper .
+
+ The couple , who" False Justin Timberlake and Beyoncé Knowles were also working 7 [' Justin', ' Timber', 'lake', ' and', ' Beyon', 'cé', ' Know', 'les']
+5 1 Name of father of x -1 Name of father of David Beckham David Edward Alan Beckham David Beckham "[',' ' the' ' footballer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Beckham' 's' ',' ' the' ' family' '.' '\n' '\n' 'The' ' Beckham' 's']" ", the footballer , and the father of the Beckham s , the family .
+
+ The Beckham s" False most notably David Beckham (who quickly developed 3 [' most', ' notably', ' David', ' Beckham']
+6 1 Name of father of x -1 Name of father of David Beckham David Edward Alan Beckham David Beckham "[',' ' the' ' footballer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Beckham' 's' ',' ' the' ' family' '.' '\n' '\n' 'The' ' Beckham' 's']" ", the footballer , and the father of the Beckham s , the family .
+
+ The Beckham s" False Manchester United player David Beckham was seen wearing 4 [' Manchester', ' United', ' player', ' David', ' Beckham']
+7 1 Name of father of x -1 Name of father of David Beckham David Edward Alan Beckham David Beckham "[',' ' the' ' footballer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Beckham' 's' ',' ' the' ' family' '.' '\n' '\n' 'The' ' Beckham' 's']" ", the footballer , and the father of the Beckham s , the family .
+
+ The Beckham s" False when midfielder David Beckham scored from a 3 [' when', ' midfielder', ' David', ' Beckham']
+8 1 Name of father of x -1 Name of father of David Beckham David Edward Alan Beckham David Beckham "[',' ' the' ' footballer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Beckham' 's' ',' ' the' ' family' '.' '\n' '\n' 'The' ' Beckham' 's']" ", the footballer , and the father of the Beckham s , the family .
+
+ The Beckham s" False " = Homme by David Beckham =
+" 5 [' =', ' Hom', 'me', ' by', ' David', ' Beckham']
+9 1 Name of father of x -1 Name of father of David Beckham David Edward Alan Beckham David Beckham "[',' ' the' ' footballer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Beckham' 's' ',' ' the' ' family' '.' '\n' '\n' 'The' ' Beckham' 's']" ", the footballer , and the father of the Beckham s , the family .
+
+ The Beckham s" False advert for Homme by David Beckham, which was 6 [' advert', ' for', ' Hom', 'me', ' by', ' David', ' Beckham']
+10 2 Name of father of x -1 Name of father of Reese Witherspoon John Witherspoon Reese Witherspoon "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False Laura Jeanne Reese Witherspoon (/ riːs ˈwɪðərspuːn 6 [' Laura', ' Jeanne', ' Reese', ' W', 'ither', 'sp', 'oon']
+11 2 Name of father of x -1 Name of father of Reese Witherspoon John Witherspoon Reese Witherspoon "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False Perrotta and portrayed by Reese Witherspoon in the 1999 film 10 [' Per', 'ro', 'tta', ' and', ' portrayed', ' by', ' Reese', ' W', 'ither', 'sp', 'oon']
+12 2 Name of father of x -1 Name of father of Reese Witherspoon John Witherspoon Reese Witherspoon "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False " Reese Witherspoon =
+" 4 [' Reese', ' W', 'ither', 'sp', 'oon']
+13 2 Name of father of x -1 Name of father of Reese Witherspoon John Witherspoon Reese Witherspoon "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False 5 ['Re', 'ese', ' W', 'ither', 'sp', 'oon']
+14 2 Name of father of x -1 Name of father of Reese Witherspoon John Witherspoon Reese Witherspoon "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False comedy starring Reese Witherspoon as a disregarded 6 [' comedy', ' starring', ' Reese', ' W', 'ither', 'sp', 'oon']
+15 4 Name of father of x -1 Name of father of Scott Eastwood Clint Eastwood Scott Eastwood "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' iconic' ' character' ' of' ' the' ' cowboy' ' in' ' the' ' movie' ' ""'
+ 'D' 'irty']" ", the actor who played the role of the iconic character of the cowboy in the movie "" D irty" False with Reeves: a son Scott Eastwood (born Scott 7 [' with', ' Reeves', ':', ' a', ' son', ' Scott', ' East', 'wood']
+16 4 Name of father of x -1 Name of father of Scott Eastwood Clint Eastwood Scott Eastwood "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' iconic' ' character' ' of' ' the' ' cowboy' ' in' ' the' ' movie' ' ""'
+ 'D' 'irty']" ", the actor who played the role of the iconic character of the cowboy in the movie "" D irty" False Reeves: a son Scott Eastwood (born Scott 6 [' Reeves', ':', ' a', ' son', ' Scott', ' East', 'wood']
+17 4 Name of father of x -1 Name of father of Scott Eastwood Clint Eastwood Scott Eastwood "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' iconic' ' character' ' of' ' the' ' cowboy' ' in' ' the' ' movie' ' ""'
+ 'D' 'irty']" ", the actor who played the role of the iconic character of the cowboy in the movie "" D irty" False with Reeves: a son Scott Eastwood (born Scott Reeves; 7 [' with', ' Reeves', ':', ' a', ' son', ' Scott', ' East', 'wood']
+18 5 Name of father of x -1 Name of father of Rumer Willis Bruce Willis Rumer Willis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Hunger' ' Games' '""'
+ '\n']" ", the actress who played the role of the young girl in the movie "" The Hunger Games ""
+" False session in December. Rumer Willis later sang the track 6 [' session', ' in', ' December', '.', ' R', 'umer', ' Willis']
+19 5 Name of father of x -1 Name of father of Rumer Willis Bruce Willis Rumer Willis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Hunger' ' Games' '""'
+ '\n']" ", the actress who played the role of the young girl in the movie "" The Hunger Games ""
+" False daughters: Rumer Willis (born August 4 [' daughters', ':', ' R', 'umer', ' Willis']
+20 5 Name of father of x -1 Name of father of Rumer Willis Bruce Willis Rumer Willis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Hunger' ' Games' '""'
+ '\n']" ", the actress who played the role of the young girl in the movie "" The Hunger Games ""
+" False session in December. Rumer Willis later sang 6 [' session', ' in', ' December', '.', ' R', 'umer', ' Willis']
+21 5 Name of father of x -1 Name of father of Rumer Willis Bruce Willis Rumer Willis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Hunger' ' Games' '""'
+ '\n']" ", the actress who played the role of the young girl in the movie "" The Hunger Games ""
+" False had three daughters: Rumer Willis (born August 16, 1988), 6 [' had', ' three', ' daughters', ':', ' R', 'umer', ' Willis']
+22 5 Name of father of x -1 Name of father of Rumer Willis Bruce Willis Rumer Willis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Hunger' ' Games' '""'
+ '\n']" ", the actress who played the role of the young girl in the movie "" The Hunger Games ""
+" False session in December. Rumer Willis later sang the 6 [' session', ' in', ' December', '.', ' R', 'umer', ' Willis']
+23 6 Name of father of x -1 Name of father of Zoe Kravitz Lenny Kravitz Zoe Kravitz "[',' ' the' ' actress' ' who' ' plays' ' the' ' title' ' character' ' in'
+ ' the' ' new' ' film' ',' ' �' '�' 'The' ' Kids' ' Are' ' All' ' Right']" , the actress who plays the title character in the new film , � � The Kids Are All Right False Q as Tori, Zoe Kravitz as Christina, 8 [' Q', ' as', ' Tor', 'i', ',', ' Zoe', ' K', 'rav', 'itz']
+24 6 Name of father of x -1 Name of father of Zoe Kravitz Lenny Kravitz Zoe Kravitz "[',' ' the' ' actress' ' who' ' plays' ' the' ' title' ' character' ' in'
+ ' the' ' new' ' film' ',' ' �' '�' 'The' ' Kids' ' Are' ' All' ' Right']" , the actress who plays the title character in the new film , � � The Kids Are All Right False Maggie Q as Tori, Zoe Kravitz as Christina, Ansel 9 [' Maggie', ' Q', ' as', ' Tor', 'i', ',', ' Zoe', ' K', 'rav', 'itz']
+25 6 Name of father of x -1 Name of father of Zoe Kravitz Lenny Kravitz Zoe Kravitz "[',' ' the' ' actress' ' who' ' plays' ' the' ' title' ' character' ' in'
+ ' the' ' new' ' film' ',' ' �' '�' 'The' ' Kids' ' Are' ' All' ' Right']" , the actress who plays the title character in the new film , � � The Kids Are All Right False Maggie Q as Tori, Zoe Kravitz as Christina, 9 [' Maggie', ' Q', ' as', ' Tor', 'i', ',', ' Zoe', ' K', 'rav', 'itz']
+26 9 Name of father of x -1 Name of father of Georgia May Jagger Mick Jagger Georgia May Jagger "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Paul Smith), Georgia May Jagger (wearing Victoria 6 [' Paul', ' Smith', '),', ' Georgia', ' May', ' J', 'agger']
+27 10 Name of father of x -1 Name of father of Paula Abdul Harry Abdul Paula Abdul "[',' ' the' ' pop' ' singer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' late' ' pop' ' singer' ',' ' Abdul' ' Ab' 'ul' 'bul' ' Amir' '.']" , the pop singer , and the father of the late pop singer , Abdul Ab ul bul Amir . False " of me. Thank you."" Paula Abdul said: ""You are" 7 "[' of', ' me', '.', ' Thank', ' you', '.""', ' Paula', ' Abdul']"
+28 10 Name of father of x -1 Name of father of Paula Abdul Harry Abdul Paula Abdul "[',' ' the' ' pop' ' singer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' late' ' pop' ' singer' ',' ' Abdul' ' Ab' 'ul' 'bul' ' Amir' '.']" , the pop singer , and the father of the late pop singer , Abdul Ab ul bul Amir . False the night, Paula Abdul calling the performance 4 [' the', ' night', ',', ' Paula', ' Abdul']
+29 10 Name of father of x -1 Name of father of Paula Abdul Harry Abdul Paula Abdul "[',' ' the' ' pop' ' singer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' late' ' pop' ' singer' ',' ' Abdul' ' Ab' 'ul' 'bul' ' Amir' '.']" , the pop singer , and the father of the late pop singer , Abdul Ab ul bul Amir . False " (""Breakthrough Video""), Paula Abdul (""Best Editing""), and" 6 "[' (""', 'Break', 'through', ' Video', '""),', ' Paula', ' Abdul']"
+30 10 Name of father of x -1 Name of father of Paula Abdul Harry Abdul Paula Abdul "[',' ' the' ' pop' ' singer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' late' ' pop' ' singer' ',' ' Abdul' ' Ab' 'ul' 'bul' ' Amir' '.']" , the pop singer , and the father of the late pop singer , Abdul Ab ul bul Amir . False 2 ['Paul', 'a', ' Abdul']
+31 10 Name of father of x -1 Name of father of Paula Abdul Harry Abdul Paula Abdul "[',' ' the' ' pop' ' singer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' late' ' pop' ' singer' ',' ' Abdul' ' Ab' 'ul' 'bul' ' Amir' '.']" , the pop singer , and the father of the late pop singer , Abdul Ab ul bul Amir . False " Thank you."" Paula Abdul said: ""You are the" 4 "[' Thank', ' you', '.""', ' Paula', ' Abdul']"
+32 14 Name of father of x -1 Name of father of Tracy Reiner Rob Reiner Tracy Reiner "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' movie' ','
+ ' and' ' the' ' movie' ' is' ' a' ' great' ' movie' '.' '\n']" ", the author of the book , and the movie , and the movie is a great movie .
+" False " Reiner as Emily
+" 7 [' Re', 'iner', ' as', ' Emily', 'Tr', 'acy', ' Re', 'iner']
+33 14 Name of father of x -1 Name of father of Tracy Reiner Rob Reiner Tracy Reiner "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' movie' ','
+ ' and' ' the' ' movie' ' is' ' a' ' great' ' movie' '.' '\n']" ", the author of the book , and the movie , and the movie is a great movie .
+" False " Reiner as Emily
+" 7 [' Re', 'iner', ' as', ' Emily', 'Tr', 'acy', ' Re', 'iner']
+34 14 Name of father of x -1 Name of father of Tracy Reiner Rob Reiner Tracy Reiner "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' movie' ','
+ ' and' ' the' ' movie' ' is' ' a' ' great' ' movie' '.' '\n']" ", the author of the book , and the movie , and the movie is a great movie .
+" False " Reiner as Emily
+" 7 [' Re', 'iner', ' as', ' Emily', 'Tr', 'acy', ' Re', 'iner']
+35 15 Name of father of x -1 Name of father of Jakob Dylan Bob Dylan Jakob Dylan "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' to' ' discover' ' the'
+ '\n' '\n' 'Name' ' of' ' father' ' of' ' Jak' 'ob' ' Dylan' ',']" ", the man who was the first to discover the
+
+ Name of father of Jak ob Dylan ," False " Soundsystem's ""punk-funk"". Jakob Dylan of the Wallflowers" 11 "[' Sound', 'system', ""'s"", ' ""', 'punk', '-', 'f', 'unk', '"".', ' Jak', 'ob', ' Dylan']"
+36 18 Name of father of x -1 Name of father of Jason Schwartzman Francis Ford Coppola Jason Schwartzman "[',' ' the' ' actor' ' who' ' played' ' the' ' title' ' character' ' in'
+ ' the' ' film' ' �' '�' 'Rush' 'more' '�' '�' ' and' ' the' ' �']" , the actor who played the title character in the film � � Rush more � � and the � False 2 ['Jason', ' Schwartz', 'man']
+37 18 Name of father of x -1 Name of father of Jason Schwartzman Francis Ford Coppola Jason Schwartzman "[',' ' the' ' actor' ' who' ' played' ' the' ' title' ' character' ' in'
+ ' the' ' film' ' �' '�' 'Rush' 'more' '�' '�' ' and' ' the' ' �']" , the actor who played the title character in the film � � Rush more � � and the � False Maddin and actor Jason Schwartzman were on board, and 6 [' Madd', 'in', ' and', ' actor', ' Jason', ' Schwartz', 'man']
+38 18 Name of father of x -1 Name of father of Jason Schwartzman Francis Ford Coppola Jason Schwartzman "[',' ' the' ' actor' ' who' ' played' ' the' ' title' ' character' ' in'
+ ' the' ' film' ' �' '�' 'Rush' 'more' '�' '�' ' and' ' the' ' �']" , the actor who played the title character in the film � � Rush more � � and the � False that Maddin and actor Jason Schwartzman were on board, and 7 [' that', ' Madd', 'in', ' and', ' actor', ' Jason', ' Schwartz', 'man']
+39 18 Name of father of x -1 Name of father of Jason Schwartzman Francis Ford Coppola Jason Schwartzman "[',' ' the' ' actor' ' who' ' played' ' the' ' title' ' character' ' in'
+ ' the' ' film' ' �' '�' 'Rush' 'more' '�' '�' ' and' ' the' ' �']" , the actor who played the title character in the film � � Rush more � � and the � False 2007. Starring Jason Schwartzman and Natalie Portman 6 [' 2007', '.', ' Star', 'ring', ' Jason', ' Schwartz', 'man']
+40 18 Name of father of x -1 Name of father of Jason Schwartzman Francis Ford Coppola Jason Schwartzman "[',' ' the' ' actor' ' who' ' played' ' the' ' title' ' character' ' in'
+ ' the' ' film' ' �' '�' 'Rush' 'more' '�' '�' ' and' ' the' ' �']" , the actor who played the title character in the film � � Rush more � � and the � False Maddin and actor Jason Schwartzman were on board, and 6 [' Madd', 'in', ' and', ' actor', ' Jason', ' Schwartz', 'man']
+41 21 Name of father of x -1 Name of father of Stella McCartney Paul McCartney Stella McCartney "[',' ' the' ' fashion' ' designer' ',' ' and' ' the' ' daughter' ' of'
+ ' the' ' late' ' Sir' ' Paul' ' McCartney' '.' '\n' '\n' 'The' ' couple'
+ ',']" ", the fashion designer , and the daughter of the late Sir Paul McCartney .
+
+ The couple ," True Fürstenberg, Valentino, Stella McCartney and Missoni, 9 [' F', 'ür', 'st', 'enberg', ',', ' Valent', 'ino', ',', ' Stella', ' McCartney']
+42 21 Name of father of x -1 Name of father of Stella McCartney Paul McCartney Stella McCartney "[',' ' the' ' fashion' ' designer' ',' ' and' ' the' ' daughter' ' of'
+ ' the' ' late' ' Sir' ' Paul' ' McCartney' '.' '\n' '\n' 'The' ' couple'
+ ',']" ", the fashion designer , and the daughter of the late Sir Paul McCartney .
+
+ The couple ," True Fürstenberg, Valentino, Stella McCartney and Missoni, while 9 [' F', 'ür', 'st', 'enberg', ',', ' Valent', 'ino', ',', ' Stella', ' McCartney']
+43 21 Name of father of x -1 Name of father of Stella McCartney Paul McCartney Stella McCartney "[',' ' the' ' fashion' ' designer' ',' ' and' ' the' ' daughter' ' of'
+ ' the' ' late' ' Sir' ' Paul' ' McCartney' '.' '\n' '\n' 'The' ' couple'
+ ',']" ", the fashion designer , and the daughter of the late Sir Paul McCartney .
+
+ The couple ," True campaigns for both Stella McCartney and Roberto 4 [' campaigns', ' for', ' both', ' Stella', ' McCartney']
+44 21 Name of father of x -1 Name of father of Stella McCartney Paul McCartney Stella McCartney "[',' ' the' ' fashion' ' designer' ',' ' and' ' the' ' daughter' ' of'
+ ' the' ' late' ' Sir' ' Paul' ' McCartney' '.' '\n' '\n' 'The' ' couple'
+ ',']" ", the fashion designer , and the daughter of the late Sir Paul McCartney .
+
+ The couple ," True dressed in a black Stella McCartney suit and accompanied 5 [' dressed', ' in', ' a', ' black', ' Stella', ' McCartney']
+45 21 Name of father of x -1 Name of father of Stella McCartney Paul McCartney Stella McCartney "[',' ' the' ' fashion' ' designer' ',' ' and' ' the' ' daughter' ' of'
+ ' the' ' late' ' Sir' ' Paul' ' McCartney' '.' '\n' '\n' 'The' ' couple'
+ ',']" ", the fashion designer , and the daughter of the late Sir Paul McCartney .
+
+ The couple ," True campaigns for both Stella McCartney and Roberto 4 [' campaigns', ' for', ' both', ' Stella', ' McCartney']
+46 22 Name of father of x -1 Name of father of Jessica Capshaw Steven Spielberg Jessica Capshaw "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' James' ' Cap'
+ 'shaw' ',' ' who' ' died' ' in' ' a' ' car' ' accident' ' in' ' the'
+ ' early']" , the daughter of the late actor James Cap shaw , who died in a car accident in the early False Torres'father, Carlos. Jessica Capshaw has an incredible ability 8 "[' Torres', ""'"", 'father', ',', ' Carlos', '.', ' Jessica', ' Cap', 'shaw']"
+47 22 Name of father of x -1 Name of father of Jessica Capshaw Steven Spielberg Jessica Capshaw "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' James' ' Cap'
+ 'shaw' ',' ' who' ' died' ' in' ' a' ' car' ' accident' ' in' ' the'
+ ' early']" , the daughter of the late actor James Cap shaw , who died in a car accident in the early False his contract, and Jessica Capshaw (Dr. Arizona Robbins) 6 [' his', ' contract', ',', ' and', ' Jessica', ' Cap', 'shaw']
+48 22 Name of father of x -1 Name of father of Jessica Capshaw Steven Spielberg Jessica Capshaw "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' James' ' Cap'
+ 'shaw' ',' ' who' ' died' ' in' ' a' ' car' ' accident' ' in' ' the'
+ ' early']" , the daughter of the late actor James Cap shaw , who died in a car accident in the early False arc. In addition, Jessica Capshaw (Dr. Arizona Robbins) 7 [' arc', '.', ' In', ' addition', ',', ' Jessica', ' Cap', 'shaw']
+49 22 Name of father of x -1 Name of father of Jessica Capshaw Steven Spielberg Jessica Capshaw "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' James' ' Cap'
+ 'shaw' ',' ' who' ' died' ' in' ' a' ' car' ' accident' ' in' ' the'
+ ' early']" , the daughter of the late actor James Cap shaw , who died in a car accident in the early False previous recurring star Jessica Capshaw was promoted to a 5 [' previous', ' recurring', ' star', ' Jessica', ' Cap', 'shaw']
+50 22 Name of father of x -1 Name of father of Jessica Capshaw Steven Spielberg Jessica Capshaw "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' James' ' Cap'
+ 'shaw' ',' ' who' ' died' ' in' ' a' ' car' ' accident' ' in' ' the'
+ ' early']" , the daughter of the late actor James Cap shaw , who died in a car accident in the early False In addition, Jessica Capshaw (Dr. Arizona 5 [' In', ' addition', ',', ' Jessica', ' Cap', 'shaw']
+51 23 Name of father of x -1 Name of father of Peter Paul Rubens Jan Rubens Peter Paul Rubens "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Ant' 'wer' 'p'
+ ' in' ' 15' '77' '.' '\n' '\n' 'The' ' name' ' of']" ", the painter , who was born in Ant wer p in 15 77 .
+
+ The name of" False begin at present-day Peter Paul Rubens Street (Maltese: 8 [' begin', ' at', ' present', '-', 'day', ' Peter', ' Paul', ' Rub', 'ens']
+52 23 Name of father of x -1 Name of father of Peter Paul Rubens Jan Rubens Peter Paul Rubens "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Ant' 'wer' 'p'
+ ' in' ' 15' '77' '.' '\n' '\n' 'The' ' name' ' of']" ", the painter , who was born in Ant wer p in 15 77 .
+
+ The name of" False the Renaissance, Peter Paul Rubens and Johann Baptist 6 [' the', ' Renaissance', ',', ' Peter', ' Paul', ' Rub', 'ens']
+53 23 Name of father of x -1 Name of father of Peter Paul Rubens Jan Rubens Peter Paul Rubens "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Ant' 'wer' 'p'
+ ' in' ' 15' '77' '.' '\n' '\n' 'The' ' name' ' of']" ", the painter , who was born in Ant wer p in 15 77 .
+
+ The name of" False the work of Peter Paul Rubens — and broadened his 6 [' the', ' work', ' of', ' Peter', ' Paul', ' Rub', 'ens']
+54 23 Name of father of x -1 Name of father of Peter Paul Rubens Jan Rubens Peter Paul Rubens "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Ant' 'wer' 'p'
+ ' in' ' 15' '77' '.' '\n' '\n' 'The' ' name' ' of']" ", the painter , who was born in Ant wer p in 15 77 .
+
+ The name of" False Leonardo da Vinci, Peter Paul Rubens and Gottfried 8 [' Leonardo', ' da', ' Vin', 'ci', ',', ' Peter', ' Paul', ' Rub', 'ens']
+55 23 Name of father of x -1 Name of father of Peter Paul Rubens Jan Rubens Peter Paul Rubens "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Ant' 'wer' 'p'
+ ' in' ' 15' '77' '.' '\n' '\n' 'The' ' name' ' of']" ", the painter , who was born in Ant wer p in 15 77 .
+
+ The name of" False collection of Peter Paul Rubens paintings. The 5 [' collection', ' of', ' Peter', ' Paul', ' Rub', 'ens']
+56 24 Name of father of x -1 Name of father of Madonna Silvio P. (Tony) Ciccone Madonna "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' mother' ' of' ' the']" ", the mother of the child , and the mother of the child .
+
+ The mother of the" False ranked alongside Madonna's best. [...] 2 [' ranked', ' alongside', ' Madonna']
+57 24 Name of father of x -1 Name of father of Madonna Silvio P. (Tony) Ciccone Madonna "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' mother' ' of' ' the']" ", the mother of the child , and the mother of the child .
+
+ The mother of the" False in Finland. Madonna has performed 3 [' in', ' Finland', '.', ' Madonna']
+58 24 Name of father of x -1 Name of father of Madonna Silvio P. (Tony) Ciccone Madonna "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' mother' ' of' ' the']" ", the mother of the child , and the mother of the child .
+
+ The mother of the" False " of five, saying that Madonna felt to be ""overly""" 5 [' of', ' five', ',', ' saying', ' that', ' Madonna']
+59 24 Name of father of x -1 Name of father of Madonna Silvio P. (Tony) Ciccone Madonna "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' mother' ' of' ' the']" ", the mother of the child , and the mother of the child .
+
+ The mother of the" False of F major with Madonna's voice ranging from 4 [' of', ' F', ' major', ' with', ' Madonna']
+60 24 Name of father of x -1 Name of father of Madonna Silvio P. (Tony) Ciccone Madonna "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' mother' ' of' ' the']" ", the mother of the child , and the mother of the child .
+
+ The mother of the" False " song ""Holiday"", which Madonna recorded and released" 6 "[' song', ' ""', 'Hol', 'iday', '"",', ' which', ' Madonna']"
+61 25 Name of father of x -1 Name of father of Pablo Picasso José Ruiz y Blanco Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' mother' ',' ' Marie' '-'
+ 'Th' 'ér' 'è' 'se' ',' ' who' ' was' ' a' ' painter']" , the painter , and of his mother , Marie - Th ér è se , who was a painter False Symbolists, and Pablo Picasso during his Blue 6 [' Symbol', 'ists', ',', ' and', ' Pablo', ' Pic', 'asso']
+62 25 Name of father of x -1 Name of father of Pablo Picasso José Ruiz y Blanco Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' mother' ',' ' Marie' '-'
+ 'Th' 'ér' 'è' 'se' ',' ' who' ' was' ' a' ' painter']" , the painter , and of his mother , Marie - Th ér è se , who was a painter False Isadora Duncan, Pablo Picasso and Aleister 7 [' Is', 'ad', 'ora', ' Duncan', ',', ' Pablo', ' Pic', 'asso']
+63 25 Name of father of x -1 Name of father of Pablo Picasso José Ruiz y Blanco Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' mother' ',' ' Marie' '-'
+ 'Th' 'ér' 'è' 'se' ',' ' who' ' was' ' a' ' painter']" , the painter , and of his mother , Marie - Th ér è se , who was a painter False Isadora Duncan, Pablo Picasso and Aleister Crowley. 7 [' Is', 'ad', 'ora', ' Duncan', ',', ' Pablo', ' Pic', 'asso']
+64 25 Name of father of x -1 Name of father of Pablo Picasso José Ruiz y Blanco Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' mother' ',' ' Marie' '-'
+ 'Th' 'ér' 'è' 'se' ',' ' who' ' was' ' a' ' painter']" , the painter , and of his mother , Marie - Th ér è se , who was a painter False " objects, and credited Pablo Picasso for ""giving us the" 6 [' objects', ',', ' and', ' credited', ' Pablo', ' Pic', 'asso']
+65 25 Name of father of x -1 Name of father of Pablo Picasso José Ruiz y Blanco Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' mother' ',' ' Marie' '-'
+ 'Th' 'ér' 'è' 'se' ',' ' who' ' was' ' a' ' painter']" , the painter , and of his mother , Marie - Th ér è se , who was a painter False " and credited Pablo Picasso for ""giving us the" 4 [' and', ' credited', ' Pablo', ' Pic', 'asso']
+66 26 Name of father of x -1 Name of father of Lady Gaga Joe Germanotta Lady Gaga "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' producer' ','
+ ' and' ' manager' ',' ' Taylor' ' Kin' 'ney' ',' ' were' ' arrested'
+ ' in']" , the singer , and her husband , producer , and manager , Taylor Kin ney , were arrested in False 1 ['Lady', ' Gaga']
+67 26 Name of father of x -1 Name of father of Lady Gaga Joe Germanotta Lady Gaga "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' producer' ','
+ ' and' ' manager' ',' ' Taylor' ' Kin' 'ney' ',' ' were' ' arrested'
+ ' in']" , the singer , and her husband , producer , and manager , Taylor Kin ney , were arrested in False Upon release, Lady Gaga x Terry Richardson 4 [' Upon', ' release', ',', ' Lady', ' Gaga']
+68 26 Name of father of x -1 Name of father of Lady Gaga Joe Germanotta Lady Gaga "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' producer' ','
+ ' and' ' manager' ',' ' Taylor' ' Kin' 'ney' ',' ' were' ' arrested'
+ ' in']" , the singer , and her husband , producer , and manager , Taylor Kin ney , were arrested in False " group members include Lady Gaga and Alec Baldwin.
+" 4 [' group', ' members', ' include', ' Lady', ' Gaga']
+69 26 Name of father of x -1 Name of father of Lady Gaga Joe Germanotta Lady Gaga "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' producer' ','
+ ' and' ' manager' ',' ' Taylor' ' Kin' 'ney' ',' ' were' ' arrested'
+ ' in']" , the singer , and her husband , producer , and manager , Taylor Kin ney , were arrested in False " Madonna, Pink, Lady Gaga and Katy Perry.
+" 5 [' Madonna', ',', ' Pink', ',', ' Lady', ' Gaga']
+70 26 Name of father of x -1 Name of father of Lady Gaga Joe Germanotta Lady Gaga "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' producer' ','
+ ' and' ' manager' ',' ' Taylor' ' Kin' 'ney' ',' ' were' ' arrested'
+ ' in']" , the singer , and her husband , producer , and manager , Taylor Kin ney , were arrested in False " ""Alejandro"" is written by Lady Gaga and RedOne, while" 8 "[' ""', 'Ale', 'jandro', '""', ' is', ' written', ' by', ' Lady', ' Gaga']"
+71 27 Name of father of x -1 Name of father of Rembrandt Harmen Gerritszoon van Rijn Rembrandt "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Sask' 'ia' ' van'
+ ' U' 'yl' 'en' 'burgh' ',' ' who' ' was' ' the' ' daughter']" , the painter , and his wife , Sask ia van U yl en burgh , who was the daughter False work in this manner; Rembrandt had sought a similar 7 [' work', ' in', ' this', ' manner', ';', ' Rem', 'brand', 't']
+72 27 Name of father of x -1 Name of father of Rembrandt Harmen Gerritszoon van Rijn Rembrandt "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Sask' 'ia' ' van'
+ ' U' 'yl' 'en' 'burgh' ',' ' who' ' was' ' the' ' daughter']" , the painter , and his wife , Sask ia van U yl en burgh , who was the daughter False purchased by The Rembrandt Society for 5 [' purchased', ' by', ' The', ' Rem', 'brand', 't']
+73 27 Name of father of x -1 Name of father of Rembrandt Harmen Gerritszoon van Rijn Rembrandt "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Sask' 'ia' ' van'
+ ' U' 'yl' 'en' 'burgh' ',' ' who' ' was' ' the' ' daughter']" , the painter , and his wife , Sask ia van U yl en burgh , who was the daughter False the civic militia. Rembrandt departed from convention, 6 [' the', ' civic', ' militia', '.', ' Rem', 'brand', 't']
+74 27 Name of father of x -1 Name of father of Rembrandt Harmen Gerritszoon van Rijn Rembrandt "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Sask' 'ia' ' van'
+ ' U' 'yl' 'en' 'burgh' ',' ' who' ' was' ' the' ' daughter']" , the painter , and his wife , Sask ia van U yl en burgh , who was the daughter False 27, 2014: From Rembrandt to Rosenquist: Works 7 [' 27', ',', ' 2014', ':', ' From', ' Rem', 'brand', 't']
+75 27 Name of father of x -1 Name of father of Rembrandt Harmen Gerritszoon van Rijn Rembrandt "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Sask' 'ia' ' van'
+ ' U' 'yl' 'en' 'burgh' ',' ' who' ' was' ' the' ' daughter']" , the painter , and his wife , Sask ia van U yl en burgh , who was the daughter False Francis Bacon, Rembrandt and Andrew Wyeth. 5 [' Francis', ' Bacon', ',', ' Rem', 'brand', 't']
+76 28 Name of father of x -1 Name of father of Vladimir Putin Vladimir Spiridonovich Putin Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Ly' 'ud' 'm' 'ila' ' Put' 'ina' ','
+ ' were']" , the Russian president , and the Russian president 's wife , Ly ud m ila Put ina , were False Russian president Vladimir Putin denied any Russian 3 [' Russian', ' president', ' Vladimir', ' Putin']
+77 28 Name of father of x -1 Name of father of Vladimir Putin Vladimir Spiridonovich Putin Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Ly' 'ud' 'm' 'ila' ' Put' 'ina' ','
+ ' were']" , the Russian president , and the Russian president 's wife , Ly ud m ila Put ina , were False upset when he sees Vladimir Putin touching her back. 5 [' upset', ' when', ' he', ' sees', ' Vladimir', ' Putin']
+78 28 Name of father of x -1 Name of father of Vladimir Putin Vladimir Spiridonovich Putin Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Ly' 'ud' 'm' 'ila' ' Put' 'ina' ','
+ ' were']" , the Russian president , and the Russian president 's wife , Ly ud m ila Put ina , were False President Vladimir Putin of being responsible 2 [' President', ' Vladimir', ' Putin']
+79 28 Name of father of x -1 Name of father of Vladimir Putin Vladimir Spiridonovich Putin Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Ly' 'ud' 'm' 'ila' ' Put' 'ina' ','
+ ' were']" , the Russian president , and the Russian president 's wife , Ly ud m ila Put ina , were False " when a successor to Vladimir Putin was elected.
+" 5 [' when', ' a', ' successor', ' to', ' Vladimir', ' Putin']
+80 28 Name of father of x -1 Name of father of Vladimir Putin Vladimir Spiridonovich Putin Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Ly' 'ud' 'm' 'ila' ' Put' 'ina' ','
+ ' were']" , the Russian president , and the Russian president 's wife , Ly ud m ila Put ina , were False " leaders, including Vladimir Putin and Tony Blair.
+" 4 [' leaders', ',', ' including', ' Vladimir', ' Putin']
+81 29 Name of father of x -1 Name of father of Anthony van Dyck Franchois Van Dyck Anthony van Dyck "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Anthony' ' van' ' Dy' 'ck' '.' '\n']" ", the painter .
+
+ The name of the father of the painter Anthony van Dy ck .
+" False in earlier works by Anthony van Dyck during the 17th 7 [' in', ' earlier', ' works', ' by', ' Anthony', ' van', ' Dy', 'ck']
+82 29 Name of father of x -1 Name of father of Anthony van Dyck Franchois Van Dyck Anthony van Dyck "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Anthony' ' van' ' Dy' 'ck' '.' '\n']" ", the painter .
+
+ The name of the father of the painter Anthony van Dy ck .
+" False earlier works by Anthony van Dyck during the 6 [' earlier', ' works', ' by', ' Anthony', ' van', ' Dy', 'ck']
+83 30 Name of father of x -1 Name of father of Michael Jackson Joe Jackson Michael Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ""'s"" ' children' ',' ' Michael' ' Jackson' ',' ' Jr' '.' ' and' ' Paris']" , the singer , and the father of the singer 's children , Michael Jackson , Jr . and Paris False 1 ['Michael', ' Jackson']
+84 30 Name of father of x -1 Name of father of Michael Jackson Joe Jackson Michael Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ""'s"" ' children' ',' ' Michael' ' Jackson' ',' ' Jr' '.' ' and' ' Paris']" , the singer , and the father of the singer 's children , Michael Jackson , Jr . and Paris False Sega brought in Michael Jackson to compose 4 [' Sega', ' brought', ' in', ' Michael', ' Jackson']
+85 30 Name of father of x -1 Name of father of Michael Jackson Joe Jackson Michael Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ""'s"" ' children' ',' ' Michael' ' Jackson' ',' ' Jr' '.' ' and' ' Paris']" , the singer , and the father of the singer 's children , Michael Jackson , Jr . and Paris False of previous Michael Jackson music videos 3 [' of', ' previous', ' Michael', ' Jackson']
+86 30 Name of father of x -1 Name of father of Michael Jackson Joe Jackson Michael Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ""'s"" ' children' ',' ' Michael' ' Jackson' ',' ' Jr' '.' ' and' ' Paris']" , the singer , and the father of the singer 's children , Michael Jackson , Jr . and Paris False paid homage to Michael Jackson. A reviewer 4 [' paid', ' homage', ' to', ' Michael', ' Jackson']
+87 30 Name of father of x -1 Name of father of Michael Jackson Joe Jackson Michael Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ""'s"" ' children' ',' ' Michael' ' Jackson' ',' ' Jr' '.' ' and' ' Paris']" , the singer , and the father of the singer 's children , Michael Jackson , Jr . and Paris False successful acts such as Michael Jackson and Diana 5 [' successful', ' acts', ' such', ' as', ' Michael', ' Jackson']
+88 31 Name of father of x -1 Name of father of William Shakespeare John Shakespeare William Shakespeare "[',' ' the' ' poet' ',' ' the' ' dram' 'at' 'ist' ',' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , the dram at ist , the
+ " False to the name William Shakespeare — Venus and 4 [' to', ' the', ' name', ' William', ' Shakespeare']
+89 31 Name of father of x -1 Name of father of William Shakespeare John Shakespeare William Shakespeare "[',' ' the' ' poet' ',' ' the' ' dram' 'at' 'ist' ',' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , the dram at ist , the
+ " False Macbeth by William Shakespeare using only Scottish 5 [' Mac', 'b', 'eth', ' by', ' William', ' Shakespeare']
+90 31 Name of father of x -1 Name of father of William Shakespeare John Shakespeare William Shakespeare "[',' ' the' ' poet' ',' ' the' ' dram' 'at' 'ist' ',' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , the dram at ist , the
+ " False is a portrait of William Shakespeare engraved by Martin 5 [' is', ' a', ' portrait', ' of', ' William', ' Shakespeare']
+91 31 Name of father of x -1 Name of father of William Shakespeare John Shakespeare William Shakespeare "[',' ' the' ' poet' ',' ' the' ' dram' 'at' 'ist' ',' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , the dram at ist , the
+ " False created by William Shakespeare or Alexander 3 [' created', ' by', ' William', ' Shakespeare']
+92 31 Name of father of x -1 Name of father of William Shakespeare John Shakespeare William Shakespeare "[',' ' the' ' poet' ',' ' the' ' dram' 'at' 'ist' ',' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , the dram at ist , the
+ " False Elizabethan England, William Shakespeare was entitled 5 [' Elizabeth', 'an', ' England', ',', ' William', ' Shakespeare']
+93 32 Name of father of x -1 Name of father of Albrecht Dürer Albrecht Dürer the Elder Albrecht Dürer "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Al' 'bre' 'cht'
+ ' D' 'ü' 'rer' ' is' ' unknown' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Al bre cht D ü rer is unknown .
+
+ The" False observed that Albrecht Dürer sketched similar motifs 7 [' observed', ' that', ' Al', 'bre', 'cht', ' D', 'ü', 'rer']
+94 32 Name of father of x -1 Name of father of Albrecht Dürer Albrecht Dürer the Elder Albrecht Dürer "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Al' 'bre' 'cht'
+ ' D' 'ü' 'rer' ' is' ' unknown' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Al bre cht D ü rer is unknown .
+
+ The" False paintings. The engraver Albrecht Dürer made many references 11 [' paintings', '.', ' The', ' eng', 'ra', 'ver', ' Al', 'bre', 'cht', ' D', 'ü', 'rer']
+95 32 Name of father of x -1 Name of father of Albrecht Dürer Albrecht Dürer the Elder Albrecht Dürer "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Al' 'bre' 'cht'
+ ' D' 'ü' 'rer' ' is' ' unknown' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Al bre cht D ü rer is unknown .
+
+ The" False 5 ['Al', 'bre', 'cht', ' D', 'ü', 'rer']
+96 32 Name of father of x -1 Name of father of Albrecht Dürer Albrecht Dürer the Elder Albrecht Dürer "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Al' 'bre' 'cht'
+ ' D' 'ü' 'rer' ' is' ' unknown' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Al bre cht D ü rer is unknown .
+
+ The" False it is the painting Albrecht Dürer mentions in 9 [' it', ' is', ' the', ' painting', ' Al', 'bre', 'cht', ' D', 'ü', 'rer']
+97 32 Name of father of x -1 Name of father of Albrecht Dürer Albrecht Dürer the Elder Albrecht Dürer "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Al' 'bre' 'cht'
+ ' D' 'ü' 'rer' ' is' ' unknown' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Al bre cht D ü rer is unknown .
+
+ The" False and printmaker Albrecht Dürer in 1515. The image 8 [' and', ' print', 'maker', ' Al', 'bre', 'cht', ' D', 'ü', 'rer']
+98 33 Name of father of x -1 Name of father of Vincent van Gogh Theodorus van Gogh Vincent van Gogh "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' a' ' son' ' who' ' is' ' a'
+ ' student' ' at' ' the' ' University' ' of' ' Toronto' '.' ' I']" "
+
+ I am a father of a son who is a student at the University of Toronto . I" False nine weeks with Vincent van Gogh at his Yellow 6 [' nine', ' weeks', ' with', ' Vincent', ' van', ' Go', 'gh']
+99 33 Name of father of x -1 Name of father of Vincent van Gogh Theodorus van Gogh Vincent van Gogh "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' a' ' son' ' who' ' is' ' a'
+ ' student' ' at' ' the' ' University' ' of' ' Toronto' '.' ' I']" "
+
+ I am a father of a son who is a student at the University of Toronto . I" False bought the Vincent van Gogh painting Irises 5 [' bought', ' the', ' Vincent', ' van', ' Go', 'gh']
+100 33 Name of father of x -1 Name of father of Vincent van Gogh Theodorus van Gogh Vincent van Gogh "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' a' ' son' ' who' ' is' ' a'
+ ' student' ' at' ' the' ' University' ' of' ' Toronto' '.' ' I']" "
+
+ I am a father of a son who is a student at the University of Toronto . I" False the 19th century, Vincent van Gogh acknowledged 8 [' the', ' 19', 'th', ' century', ',', ' Vincent', ' van', ' Go', 'gh']
+101 33 Name of father of x -1 Name of father of Vincent van Gogh Theodorus van Gogh Vincent van Gogh "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' a' ' son' ' who' ' is' ' a'
+ ' student' ' at' ' the' ' University' ' of' ' Toronto' '.' ' I']" "
+
+ I am a father of a son who is a student at the University of Toronto . I" False Alan Bond bought the Vincent van Gogh painting Irises 7 [' Alan', ' Bond', ' bought', ' the', ' Vincent', ' van', ' Go', 'gh']
+102 33 Name of father of x -1 Name of father of Vincent van Gogh Theodorus van Gogh Vincent van Gogh "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' a' ' son' ' who' ' is' ' a'
+ ' student' ' at' ' the' ' University' ' of' ' Toronto' '.' ' I']" "
+
+ I am a father of a son who is a student at the University of Toronto . I" False Bond bought the Vincent van Gogh painting Irises 6 [' Bond', ' bought', ' the', ' Vincent', ' van', ' Go', 'gh']
+103 34 Name of father of x -1 Name of father of Aretha Franklin C. L. Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' ' herself' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',']" ", the Queen of Soul , and the Queen of Soul herself .
+
+ The Queen of Soul ," False because I remember Aretha Franklin said a great 5 [' because', ' I', ' remember', ' Are', 'tha', ' Franklin']
+104 34 Name of father of x -1 Name of father of Aretha Franklin C. L. Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' ' herself' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',']" ", the Queen of Soul , and the Queen of Soul herself .
+
+ The Queen of Soul ," False according to Redding. Aretha Franklin covered this 7 [' according', ' to', ' Red', 'ding', '.', ' Are', 'tha', ' Franklin']
+105 34 Name of father of x -1 Name of father of Aretha Franklin C. L. Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' ' herself' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',']" ", the Queen of Soul , and the Queen of Soul herself .
+
+ The Queen of Soul ," False " awards for the Aretha Franklin ""Entertainer" 5 [' awards', ' for', ' the', ' Are', 'tha', ' Franklin']
+106 34 Name of father of x -1 Name of father of Aretha Franklin C. L. Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' ' herself' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',']" ", the Queen of Soul , and the Queen of Soul herself .
+
+ The Queen of Soul ," False 2 ['Are', 'tha', ' Franklin']
+107 34 Name of father of x -1 Name of father of Aretha Franklin C. L. Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' ' herself' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',']" ", the Queen of Soul , and the Queen of Soul herself .
+
+ The Queen of Soul ," False Whitney Houston (and Aretha Franklin and Barbra 6 [' Whitney', ' Houston', ' (', 'and', ' Are', 'tha', ' Franklin']
+108 35 Name of father of x -1 Name of father of Albert Einstein Hermann Einstein Albert Einstein "[',' ' the' ' physicist' ',' ' and' ' his' ' wife' ',' ' Mile' 'va' ' Mar'
+ 'ic' ',' ' who' ' was' ' a' ' Serbian' '-' 'born' ' physicist']" , the physicist , and his wife , Mile va Mar ic , who was a Serbian - born physicist False Szilárd and signed by Albert Einstein that was sent to 8 [' Sz', 'il', 'á', 'rd', ' and', ' signed', ' by', ' Albert', ' Einstein']
+109 35 Name of father of x -1 Name of father of Albert Einstein Hermann Einstein Albert Einstein "[',' ' the' ' physicist' ',' ' and' ' his' ' wife' ',' ' Mile' 'va' ' Mar'
+ 'ic' ',' ' who' ' was' ' a' ' Serbian' '-' 'born' ' physicist']" , the physicist , and his wife , Mile va Mar ic , who was a Serbian - born physicist False 1917, when Albert Einstein first applied his 4 [' 1917', ',', ' when', ' Albert', ' Einstein']
+110 35 Name of father of x -1 Name of father of Albert Einstein Hermann Einstein Albert Einstein "[',' ' the' ' physicist' ',' ' and' ' his' ' wife' ',' ' Mile' 'va' ' Mar'
+ 'ic' ',' ' who' ' was' ' a' ' Serbian' '-' 'born' ' physicist']" , the physicist , and his wife , Mile va Mar ic , who was a Serbian - born physicist False electricity. In 1905, Albert Einstein suggested that electromagnetic 6 [' electricity', '.', ' In', ' 1905', ',', ' Albert', ' Einstein']
+111 35 Name of father of x -1 Name of father of Albert Einstein Hermann Einstein Albert Einstein "[',' ' the' ' physicist' ',' ' and' ' his' ' wife' ',' ' Mile' 'va' ' Mar'
+ 'ic' ',' ' who' ' was' ' a' ' Serbian' '-' 'born' ' physicist']" , the physicist , and his wife , Mile va Mar ic , who was a Serbian - born physicist False Goldstein, asking if Albert Einstein believed in 5 [' Goldstein', ',', ' asking', ' if', ' Albert', ' Einstein']
+112 35 Name of father of x -1 Name of father of Albert Einstein Hermann Einstein Albert Einstein "[',' ' the' ' physicist' ',' ' and' ' his' ' wife' ',' ' Mile' 'va' ' Mar'
+ 'ic' ',' ' who' ' was' ' a' ' Serbian' '-' 'born' ' physicist']" , the physicist , and his wife , Mile va Mar ic , who was a Serbian - born physicist False was considered by Albert Einstein as not needing experimental 4 [' was', ' considered', ' by', ' Albert', ' Einstein']
+113 36 Name of father of x -1 Name of father of Leonardo da Vinci Ser Piero da Vinci Leonardo da Vinci "[',' ' the' ' great' ' artist' ',' ' who' ' was' ' born' ' in' ' 14' '52'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of']" ", the great artist , who was born in 14 52 .
+
+ The name of the mother of" False luminaries such as Leonardo da Vinci and Isaac Newton, 7 [' lumin', 'aries', ' such', ' as', ' Leonardo', ' da', ' Vin', 'ci']
+114 36 Name of father of x -1 Name of father of Leonardo da Vinci Ser Piero da Vinci Leonardo da Vinci "[',' ' the' ' great' ' artist' ',' ' who' ' was' ' born' ' in' ' 14' '52'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of']" ", the great artist , who was born in 14 52 .
+
+ The name of the mother of" False Michelangelo and brought Leonardo da Vinci to his court. After 7 [' Michel', 'angelo', ' and', ' brought', ' Leonardo', ' da', ' Vin', 'ci']
+115 36 Name of father of x -1 Name of father of Leonardo da Vinci Ser Piero da Vinci Leonardo da Vinci "[',' ' the' ' great' ' artist' ',' ' who' ' was' ' born' ' in' ' 14' '52'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of']" ", the great artist , who was born in 14 52 .
+
+ The name of the mother of" False in the time of Leonardo da Vinci in Europe). 7 [' in', ' the', ' time', ' of', ' Leonardo', ' da', ' Vin', 'ci']
+116 36 Name of father of x -1 Name of father of Leonardo da Vinci Ser Piero da Vinci Leonardo da Vinci "[',' ' the' ' great' ' artist' ',' ' who' ' was' ' born' ' in' ' 14' '52'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of']" ", the great artist , who was born in 14 52 .
+
+ The name of the mother of" False drawing) for Leonardo da Vinci ’ s The Virgin 6 [' drawing', ')', ' for', ' Leonardo', ' da', ' Vin', 'ci']
+117 36 Name of father of x -1 Name of father of Leonardo da Vinci Ser Piero da Vinci Leonardo da Vinci "[',' ' the' ' great' ' artist' ',' ' who' ' was' ' born' ' in' ' 14' '52'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of']" ", the great artist , who was born in 14 52 .
+
+ The name of the mother of" False planned to modernize Leonardo da Vinci by replacing her 7 [' planned', ' to', ' modern', 'ize', ' Leonardo', ' da', ' Vin', 'ci']
+118 37 Name of father of x -1 Name of father of Adolf Hitler Alois Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Adolf'
+ ' Hitler' ' is' ' not' ' Adolf' ' Hitler' '.' '\n' '\n' 'The' ' name']" ".
+
+ The name of the father of Adolf Hitler is not Adolf Hitler .
+
+ The name" False " driving Mussolini and Adolf Hitler into an alliance.
+" 5 [' driving', ' Muss', 'olini', ' and', ' Adolf', ' Hitler']
+119 37 Name of father of x -1 Name of father of Adolf Hitler Alois Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Adolf'
+ ' Hitler' ' is' ' not' ' Adolf' ' Hitler' '.' '\n' '\n' 'The' ' name']" ".
+
+ The name of the father of Adolf Hitler is not Adolf Hitler .
+
+ The name" False reference by Card to Adolf Hitler and criticizes 5 [' reference', ' by', ' Card', ' to', ' Adolf', ' Hitler']
+120 37 Name of father of x -1 Name of father of Adolf Hitler Alois Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Adolf'
+ ' Hitler' ' is' ' not' ' Adolf' ' Hitler' '.' '\n' '\n' 'The' ' name']" ".
+
+ The name of the father of Adolf Hitler is not Adolf Hitler .
+
+ The name" False to resurrect Adolf Hitler with the philosophers' 3 [' to', ' resurrect', ' Adolf', ' Hitler']
+121 37 Name of father of x -1 Name of father of Adolf Hitler Alois Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Adolf'
+ ' Hitler' ' is' ' not' ' Adolf' ' Hitler' '.' '\n' '\n' 'The' ' name']" ".
+
+ The name of the father of Adolf Hitler is not Adolf Hitler .
+
+ The name" False resemblance to Adolf Hitler is seen. In a flashback 3 [' resemblance', ' to', ' Adolf', ' Hitler']
+122 37 Name of father of x -1 Name of father of Adolf Hitler Alois Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Adolf'
+ ' Hitler' ' is' ' not' ' Adolf' ' Hitler' '.' '\n' '\n' 'The' ' name']" ".
+
+ The name of the father of Adolf Hitler is not Adolf Hitler .
+
+ The name" False the Soviet Union, Adolf Hitler began placing heavy 5 [' the', ' Soviet', ' Union', ',', ' Adolf', ' Hitler']
+123 38 Name of father of x -1 Name of father of Johann Wolfgang von Goethe Johann Caspar Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' writer' '.' '\n' '\n' 'The' ' name'
+ ' Go' 'ethe' ' is' ' a' ' German' ' name' '.' ' It' ' is']" ", the German poet and writer .
+
+ The name Go ethe is a German name . It is" False scientist Johann Wolfgang von Goethe lauded Ruisdael as 5 [' scientist', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+124 38 Name of father of x -1 Name of father of Johann Wolfgang von Goethe Johann Caspar Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' writer' '.' '\n' '\n' 'The' ' name'
+ ' Go' 'ethe' ' is' ' a' ' German' ' name' '.' ' It' ' is']" ", the German poet and writer .
+
+ The name Go ethe is a German name . It is" False the works of Johann Wolfgang von Goethe and Michel 7 [' the', ' works', ' of', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+125 38 Name of father of x -1 Name of father of Johann Wolfgang von Goethe Johann Caspar Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' writer' '.' '\n' '\n' 'The' ' name'
+ ' Go' 'ethe' ' is' ' a' ' German' ' name' '.' ' It' ' is']" ", the German poet and writer .
+
+ The name Go ethe is a German name . It is" False West-östlicher Diwan of Johann Wolfgang von Goethe that is a collection 13 [' West', '-', 'ö', 'st', 'lic', 'her', ' Di', 'wan', ' of', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+126 38 Name of father of x -1 Name of father of Johann Wolfgang von Goethe Johann Caspar Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' writer' '.' '\n' '\n' 'The' ' name'
+ ' Go' 'ethe' ' is' ' a' ' German' ' name' '.' ' It' ' is']" ", the German poet and writer .
+
+ The name Go ethe is a German name . It is" False Valmiki, Homer, Johann Wolfgang von Goethe and William Shakespeare. 10 [' Val', 'm', 'iki', ',', ' Homer', ',', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+127 38 Name of father of x -1 Name of father of Johann Wolfgang von Goethe Johann Caspar Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' writer' '.' '\n' '\n' 'The' ' name'
+ ' Go' 'ethe' ' is' ' a' ' German' ' name' '.' ' It' ' is']" ", the German poet and writer .
+
+ The name Go ethe is a German name . It is" False poem written by Johann Wolfgang von Goethe and set to the 7 [' poem', ' written', ' by', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+128 39 Name of father of x -1 Name of father of Elizabeth II George VI Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' Queen' ' of' ' England']" ", the Queen of England , and the Queen of the United Kingdom .
+
+ The Queen of England" False " of cricket, Queen Elizabeth II asked him ""I" 5 [' of', ' cricket', ',', ' Queen', ' Elizabeth', ' II']
+129 39 Name of father of x -1 Name of father of Elizabeth II George VI Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' Queen' ' of' ' England']" ", the Queen of England , and the Queen of the United Kingdom .
+
+ The Queen of England" False knighted by Queen Elizabeth II in 1975 for 5 [' knight', 'ed', ' by', ' Queen', ' Elizabeth', ' II']
+130 39 Name of father of x -1 Name of father of Elizabeth II George VI Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' Queen' ' of' ' England']" ", the Queen of England , and the Queen of the United Kingdom .
+
+ The Queen of England" False visits. Queen Elizabeth II and Prime Ministers 4 [' visits', '.', ' Queen', ' Elizabeth', ' II']
+131 39 Name of father of x -1 Name of father of Elizabeth II George VI Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' Queen' ' of' ' England']" ", the Queen of England , and the Queen of the United Kingdom .
+
+ The Queen of England" False remains – Queen Elizabeth II was reportedly 4 [' remains', ' –', ' Queen', ' Elizabeth', ' II']
+132 39 Name of father of x -1 Name of father of Elizabeth II George VI Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' Queen' ' of' ' England']" ", the Queen of England , and the Queen of the United Kingdom .
+
+ The Queen of England" False leaders such as Queen Elizabeth II of the United 5 [' leaders', ' such', ' as', ' Queen', ' Elizabeth', ' II']
+133 40 Name of father of x -1 Name of father of Leo Tolstoy Nikolay Tolstoy Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the great Russian writer , who was born in 18 28 .
+
+ The name of the mother" False having expected Leo Tolstoy to be awarded. 5 [' having', ' expected', ' Leo', ' Tol', 'st', 'oy']
+134 40 Name of father of x -1 Name of father of Leo Tolstoy Nikolay Tolstoy Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the great Russian writer , who was born in 18 28 .
+
+ The name of the mother" False " Russian novelist Leo Tolstoy wrote,
+" 5 [' Russian', ' novelist', ' Leo', ' Tol', 'st', 'oy']
+135 40 Name of father of x -1 Name of father of Leo Tolstoy Nikolay Tolstoy Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the great Russian writer , who was born in 18 28 .
+
+ The name of the mother" False christian anarchist Leo Tolstoy established a school 6 [' christ', 'ian', ' anarchist', ' Leo', ' Tol', 'st', 'oy']
+136 40 Name of father of x -1 Name of father of Leo Tolstoy Nikolay Tolstoy Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the great Russian writer , who was born in 18 28 .
+
+ The name of the mother" False Balzac in France, Leo Tolstoy in Russia and Alessandro 8 [' Bal', 'zac', ' in', ' France', ',', ' Leo', ' Tol', 'st', 'oy']
+137 40 Name of father of x -1 Name of father of Leo Tolstoy Nikolay Tolstoy Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the great Russian writer , who was born in 18 28 .
+
+ The name of the mother" False processing: in 1840, Leo Tolstoy played the 8 [' processing', ':', ' in', ' 1840', ',', ' Leo', ' Tol', 'st', 'oy']
+138 41 Name of father of x -1 Name of father of Taylor Swift Scott Swift Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Al' 'wyn'
+ ',' ' who' ' is' ' the' ' son' ' of' ' the' ' late' ' actor']" , the singer , and her husband , Joe Al wyn , who is the son of the late actor False " to Burn"" by Taylor Swift and ""Walking" 5 "[' to', ' Burn', '""', ' by', ' Taylor', ' Swift']"
+139 41 Name of father of x -1 Name of father of Taylor Swift Scott Swift Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Al' 'wyn'
+ ',' ' who' ' is' ' the' ' son' ' of' ' the' ' late' ' actor']" , the singer , and her husband , Joe Al wyn , who is the son of the late actor False photographs of Taylor Swift with anti-Semitic 3 [' photographs', ' of', ' Taylor', ' Swift']
+140 41 Name of father of x -1 Name of father of Taylor Swift Scott Swift Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Al' 'wyn'
+ ',' ' who' ' is' ' the' ' son' ' of' ' the' ' late' ' actor']" , the singer , and her husband , Joe Al wyn , who is the son of the late actor False eponymous debut album, Taylor Swift (2006). While writing 6 [' ep', 'onymous', ' debut', ' album', ',', ' Taylor', ' Swift']
+141 41 Name of father of x -1 Name of father of Taylor Swift Scott Swift Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Al' 'wyn'
+ ',' ' who' ' is' ' the' ' son' ' of' ' the' ' late' ' actor']" , the singer , and her husband , Joe Al wyn , who is the son of the late actor False while accompanying Taylor Swift on the United 3 [' while', ' accompanying', ' Taylor', ' Swift']
+142 41 Name of father of x -1 Name of father of Taylor Swift Scott Swift Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Al' 'wyn'
+ ',' ' who' ' is' ' the' ' son' ' of' ' the' ' late' ' actor']" , the singer , and her husband , Joe Al wyn , who is the son of the late actor False 1 ['Taylor', ' Swift']
+143 42 Name of father of x -1 Name of father of Donald Trump Fred Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' a' ' son' ','
+ ' a' ' husband' ',' ' a' ' son' '-' 'in' '-']" ".
+
+ I am a proud father of a son , a husband , a son - in -" False businessman Donald Trump (who also paid 2 [' businessman', ' Donald', ' Trump']
+144 42 Name of father of x -1 Name of father of Donald Trump Fred Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' a' ' son' ','
+ ' a' ' husband' ',' ' a' ' son' '-' 'in' '-']" ".
+
+ I am a proud father of a son , a husband , a son - in -" False 2015, after Donald Trump announced 4 [' 2015', ',', ' after', ' Donald', ' Trump']
+145 42 Name of father of x -1 Name of father of Donald Trump Fred Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' a' ' son' ','
+ ' a' ' husband' ',' ' a' ' son' '-' 'in' '-']" ".
+
+ I am a proud father of a son , a husband , a son - in -" False increased congestion. Donald Trump immediately 4 [' increased', ' congestion', '.', ' Donald', ' Trump']
+146 42 Name of father of x -1 Name of father of Donald Trump Fred Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' a' ' son' ','
+ ' a' ' husband' ',' ' a' ' son' '-' 'in' '-']" ".
+
+ I am a proud father of a son , a husband , a son - in -" False Jim discuss who Donald Trump fired on The 4 [' Jim', ' discuss', ' who', ' Donald', ' Trump']
+147 42 Name of father of x -1 Name of father of Donald Trump Fred Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' a' ' son' ','
+ ' a' ' husband' ',' ' a' ' son' '-' 'in' '-']" ".
+
+ I am a proud father of a son , a husband , a son - in -" False she became Donald Trump ’ s co-host 3 [' she', ' became', ' Donald', ' Trump']
+148 43 Name of father of x -1 Name of father of Jennifer Lopez Juan Lopez Jennifer Lopez "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' '.'
+ '\n' '\n']" ", the actress who plays the role of the mother of the main character in the movie .
+
+" False " judges of the show. Jennifer Lopez commented, ""You" 6 [' judges', ' of', ' the', ' show', '.', ' Jennifer', ' Lopez']
+149 43 Name of father of x -1 Name of father of Jennifer Lopez Juan Lopez Jennifer Lopez "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' '.'
+ '\n' '\n']" ", the actress who plays the role of the mother of the main character in the movie .
+
+" False American singer Jennifer Lopez for her fifth studio 3 [' American', ' singer', ' Jennifer', ' Lopez']
+150 43 Name of father of x -1 Name of father of Jennifer Lopez Juan Lopez Jennifer Lopez "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' '.'
+ '\n' '\n']" ", the actress who plays the role of the mother of the main character in the movie .
+
+" False " crossover before Jennifer Lopez and all the rest.""" 3 [' crossover', ' before', ' Jennifer', ' Lopez']
+151 43 Name of father of x -1 Name of father of Jennifer Lopez Juan Lopez Jennifer Lopez "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' '.'
+ '\n' '\n']" ", the actress who plays the role of the mother of the main character in the movie .
+
+" False " May 1, 2015 Jennifer Lopez performed ""A Selena" 5 [' May', ' 1', ',', ' 2015', ' Jennifer', ' Lopez']
+152 43 Name of father of x -1 Name of father of Jennifer Lopez Juan Lopez Jennifer Lopez "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' '.'
+ '\n' '\n']" ", the actress who plays the role of the mother of the main character in the movie .
+
+" False Columbia Records, Jennifer Lopez was signed 4 [' Columbia', ' Records', ',', ' Jennifer', ' Lopez']
+153 44 Name of father of x -1 Name of father of Eugène Delacroix Charles-François Delacroix Eugène Delacroix "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Eug' 'è' 'ne' ' Del' 'ac'
+ 'ro' 'ix' ' is' ':' '\n' '\n' 'E' 'ug']" "
+
+ The name of father of Eug è ne Del ac ro ix is :
+
+ E ug" False world, including Eugène Delacroix and George Sand, 9 [' world', ',', ' including', ' Eug', 'è', 'ne', ' Del', 'ac', 'ro', 'ix']
+154 44 Name of father of x -1 Name of father of Eugène Delacroix Charles-François Delacroix Eugène Delacroix "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Eug' 'è' 'ne' ' Del' 'ac'
+ 'ro' 'ix' ' is' ':' '\n' '\n' 'E' 'ug']" "
+
+ The name of father of Eug è ne Del ac ro ix is :
+
+ E ug" False notably the painter Eugène Delacroix (1798 – 1863), 9 [' notably', ' the', ' painter', ' Eug', 'è', 'ne', ' Del', 'ac', 'ro', 'ix']
+155 44 Name of father of x -1 Name of father of Eugène Delacroix Charles-François Delacroix Eugène Delacroix "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Eug' 'è' 'ne' ' Del' 'ac'
+ 'ro' 'ix' ' is' ':' '\n' '\n' 'E' 'ug']" "
+
+ The name of father of Eug è ne Del ac ro ix is :
+
+ E ug" False colors influenced Eugène Delacroix and Édouard 8 [' colors', ' influenced', ' Eug', 'è', 'ne', ' Del', 'ac', 'ro', 'ix']
+156 44 Name of father of x -1 Name of father of Eugène Delacroix Charles-François Delacroix Eugène Delacroix "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Eug' 'è' 'ne' ' Del' 'ac'
+ 'ro' 'ix' ' is' ':' '\n' '\n' 'E' 'ug']" "
+
+ The name of father of Eug è ne Del ac ro ix is :
+
+ E ug" False artistic world, including Eugène Delacroix and George Sand, from 10 [' artistic', ' world', ',', ' including', ' Eug', 'è', 'ne', ' Del', 'ac', 'ro', 'ix']
+157 44 Name of father of x -1 Name of father of Eugène Delacroix Charles-François Delacroix Eugène Delacroix "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Eug' 'è' 'ne' ' Del' 'ac'
+ 'ro' 'ix' ' is' ':' '\n' '\n' 'E' 'ug']" "
+
+ The name of father of Eug è ne Del ac ro ix is :
+
+ E ug" False world, including Eugène Delacroix and George Sand, from 9 [' world', ',', ' including', ' Eug', 'è', 'ne', ' Del', 'ac', 'ro', 'ix']
+158 45 Name of father of x -1 Name of father of Napoleon Carlo Bonaparte Napoleon "[',' ' the' ' great' ' Napoleon' ',' ' the' ' Emperor' ' of' ' the'
+ ' French' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Napoleon , the Emperor of the French , the
+ " False squadron from Port Napoleon under Commodore Jacques 3 [' squadron', ' from', ' Port', ' Napoleon']
+159 45 Name of father of x -1 Name of father of Napoleon Carlo Bonaparte Napoleon "[',' ' the' ' great' ' Napoleon' ',' ' the' ' Emperor' ' of' ' the'
+ ' French' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Napoleon , the Emperor of the French , the
+ " False Wheatley complemented Napoleon Kaufman in 4 [' Wheat', 'ley', ' comple', 'mented', ' Napoleon']
+160 45 Name of father of x -1 Name of father of Napoleon Carlo Bonaparte Napoleon "[',' ' the' ' great' ' Napoleon' ',' ' the' ' Emperor' ' of' ' the'
+ ' French' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Napoleon , the Emperor of the French , the
+ " False the diaries of Napoleon's valet, Louis Marchand, 4 [' the', ' di', 'aries', ' of', ' Napoleon']
+161 45 Name of father of x -1 Name of father of Napoleon Carlo Bonaparte Napoleon "[',' ' the' ' great' ' Napoleon' ',' ' the' ' Emperor' ' of' ' the'
+ ' French' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Napoleon , the Emperor of the French , the
+ " False government, Napoleon sent an army to invade 2 [' government', ',', ' Napoleon']
+162 45 Name of father of x -1 Name of father of Napoleon Carlo Bonaparte Napoleon "[',' ' the' ' great' ' Napoleon' ',' ' the' ' Emperor' ' of' ' the'
+ ' French' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Napoleon , the Emperor of the French , the
+ " False followers, Napoleon dictated his memoirs 2 [' followers', ',', ' Napoleon']
+163 46 Name of father of x -1 Name of father of Paul Gauguin Clovis Gauguin Paul Gauguin "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '48' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the painter , who was born in Paris in 18 48 .
+
+ The name of the father" False Dedicated to Paul Gauguin (1888). In a letter 6 [' Ded', 'icated', ' to', ' Paul', ' Gau', 'gu', 'in']
+164 46 Name of father of x -1 Name of father of Paul Gauguin Clovis Gauguin Paul Gauguin "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '48' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the painter , who was born in Paris in 18 48 .
+
+ The name of the father" False his friends Paul Gauguin and Émile Bernard. 5 [' his', ' friends', ' Paul', ' Gau', 'gu', 'in']
+165 46 Name of father of x -1 Name of father of Paul Gauguin Clovis Gauguin Paul Gauguin "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '48' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the painter , who was born in Paris in 18 48 .
+
+ The name of the father" False demonstrates how Flora and Paul Gauguin were unable to find 8 [' demonstrates', ' how', ' Flor', 'a', ' and', ' Paul', ' Gau', 'gu', 'in']
+166 46 Name of father of x -1 Name of father of Paul Gauguin Clovis Gauguin Paul Gauguin "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '48' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the painter , who was born in Paris in 18 48 .
+
+ The name of the father" False Post-Impressionist French artists Paul Gauguin and Paul Cézanne 10 [' Post', '-', 'Imp', 'ression', 'ist', ' French', ' artists', ' Paul', ' Gau', 'gu', 'in']
+167 46 Name of father of x -1 Name of father of Paul Gauguin Clovis Gauguin Paul Gauguin "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '48' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the painter , who was born in Paris in 18 48 .
+
+ The name of the father" False (The Royal End) by Paul Gauguin (1892). The 9 [' (', 'The', ' Royal', ' End', ')', ' by', ' Paul', ' Gau', 'gu', 'in']
+168 47 Name of father of x -1 Name of father of Camille Pissarro Frédéric-Abraham-Gabriel Pissarro Camille Pissarro "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' painter'
+ ""'s"" ' father' ' is' ' not' ' mentioned' ' in' ' the' ' book' '.']" ", the painter .
+
+ The name of the painter 's father is not mentioned in the book ." False Vincent van Gogh, Camille Pissarro and Monet incorporated 9 [' Vincent', ' van', ' Go', 'gh', ',', ' Cam', 'ille', ' P', 'iss', 'arro']
+169 47 Name of father of x -1 Name of father of Camille Pissarro Frédéric-Abraham-Gabriel Pissarro Camille Pissarro "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' painter'
+ ""'s"" ' father' ' is' ' not' ' mentioned' ' in' ' the' ' book' '.']" ", the painter .
+
+ The name of the painter 's father is not mentioned in the book ." False " (1839 – 1906) and Camille Pissarro (1830 – 1903).
+" 11 [' (', '18', '39', ' –', ' 1906', ')', ' and', ' Cam', 'ille', ' P', 'iss', 'arro']
+170 47 Name of father of x -1 Name of father of Camille Pissarro Frédéric-Abraham-Gabriel Pissarro Camille Pissarro "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' painter'
+ ""'s"" ' father' ' is' ' not' ' mentioned' ' in' ' the' ' book' '.']" ", the painter .
+
+ The name of the painter 's father is not mentioned in the book ." False Alfred Sisley and Camille Pissarro painted hundreds of 9 [' Alfred', ' S', 'is', 'ley', ' and', ' Cam', 'ille', ' P', 'iss', 'arro']
+171 47 Name of father of x -1 Name of father of Camille Pissarro Frédéric-Abraham-Gabriel Pissarro Camille Pissarro "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' painter'
+ ""'s"" ' father' ' is' ' not' ' mentioned' ' in' ' the' ' book' '.']" ", the painter .
+
+ The name of the painter 's father is not mentioned in the book ." False Alfred Sisley and Camille Pissarro painted hundreds 9 [' Alfred', ' S', 'is', 'ley', ' and', ' Cam', 'ille', ' P', 'iss', 'arro']
+172 47 Name of father of x -1 Name of father of Camille Pissarro Frédéric-Abraham-Gabriel Pissarro Camille Pissarro "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' painter'
+ ""'s"" ' father' ' is' ' not' ' mentioned' ' in' ' the' ' book' '.']" ", the painter .
+
+ The name of the painter 's father is not mentioned in the book ." False Impressionist painter Camille Pissarro and the novelist 8 [' Imp', 'ression', 'ist', ' painter', ' Cam', 'ille', ' P', 'iss', 'arro']
+173 48 Name of father of x -1 Name of father of Victor Hugo Joseph Léopold Sigisbert Hugo Victor Hugo "[',' ' the' ' great' ' poet' ',' ' was' ' born' ' in' ' 18' '02' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' great' ' French' ' poets']" ", the great poet , was born in 18 02 .
+
+ The first of the great French poets" False " Mazeppa, after Victor Hugo (1851)
+" 5 [' Maze', 'ppa', ',', ' after', ' Victor', ' Hugo']
+174 48 Name of father of x -1 Name of father of Victor Hugo Joseph Léopold Sigisbert Hugo Victor Hugo "[',' ' the' ' great' ' poet' ',' ' was' ' born' ' in' ' 18' '02' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' great' ' French' ' poets']" ", the great poet , was born in 18 02 .
+
+ The first of the great French poets" False the writer Victor Hugo and the theorist 3 [' the', ' writer', ' Victor', ' Hugo']
+175 48 Name of father of x -1 Name of father of Victor Hugo Joseph Léopold Sigisbert Hugo Victor Hugo "[',' ' the' ' great' ' poet' ',' ' was' ' born' ' in' ' 18' '02' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' great' ' French' ' poets']" ", the great poet , was born in 18 02 .
+
+ The first of the great French poets" False Gustave Flaubert, Victor Hugo and Guy de Maupassant, 7 [' Gust', 'ave', ' Fl', 'au', 'bert', ',', ' Victor', ' Hugo']
+176 48 Name of father of x -1 Name of father of Victor Hugo Joseph Léopold Sigisbert Hugo Victor Hugo "[',' ' the' ' great' ' poet' ',' ' was' ' born' ' in' ' 18' '02' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' great' ' French' ' poets']" ", the great poet , was born in 18 02 .
+
+ The first of the great French poets" False when both met. Victor Hugo told the Emperor: 5 [' when', ' both', ' met', '.', ' Victor', ' Hugo']
+177 48 Name of father of x -1 Name of father of Victor Hugo Joseph Léopold Sigisbert Hugo Victor Hugo "[',' ' the' ' great' ' poet' ',' ' was' ' born' ' in' ' 18' '02' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' great' ' French' ' poets']" ", the great poet , was born in 18 02 .
+
+ The first of the great French poets" False " his memorial service, Victor Hugo pronounced ""Today" 5 [' his', ' memorial', ' service', ',', ' Victor', ' Hugo']
+178 49 Name of father of x -1 Name of father of Beyoncé Mathew Knowles Beyoncé "[""'s"" ' baby' ',' ' Jay' '-' 'Z' ',' ' and' ' the' ' father' ' of' ' Blue'
+ ' Ivy' ' Carter' '.' '\n' '\n' 'The' ' couple' ""'s""]" "'s baby , Jay - Z , and the father of Blue Ivy Carter .
+
+ The couple 's" False this. It also became Beyoncé's sixth number-one 6 [' this', '.', ' It', ' also', ' became', ' Beyon', 'cé']
+179 49 Name of father of x -1 Name of father of Beyoncé Mathew Knowles Beyoncé "[""'s"" ' baby' ',' ' Jay' '-' 'Z' ',' ' and' ' the' ' father' ' of' ' Blue'
+ ' Ivy' ' Carter' '.' '\n' '\n' 'The' ' couple' ""'s""]" "'s baby , Jay - Z , and the father of Blue Ivy Carter .
+
+ The couple 's" False to listen to on Beyoncé's official 5 [' to', ' listen', ' to', ' on', ' Beyon', 'cé']
+180 49 Name of father of x -1 Name of father of Beyoncé Mathew Knowles Beyoncé "[""'s"" ' baby' ',' ' Jay' '-' 'Z' ',' ' and' ' the' ' father' ' of' ' Blue'
+ ' Ivy' ' Carter' '.' '\n' '\n' 'The' ' couple' ""'s""]" "'s baby , Jay - Z , and the father of Blue Ivy Carter .
+
+ The couple 's" False entirely only to Beyoncé and Kelly Rowland, 4 [' entirely', ' only', ' to', ' Beyon', 'cé']
+181 49 Name of father of x -1 Name of father of Beyoncé Mathew Knowles Beyoncé "[""'s"" ' baby' ',' ' Jay' '-' 'Z' ',' ' and' ' the' ' father' ' of' ' Blue'
+ ' Ivy' ' Carter' '.' '\n' '\n' 'The' ' couple' ""'s""]" "'s baby , Jay - Z , and the father of Blue Ivy Carter .
+
+ The couple 's" False Billboard Music Awards, Beyoncé initially received 5 [' Billboard', ' Music', ' Awards', ',', ' Beyon', 'cé']
+182 49 Name of father of x -1 Name of father of Beyoncé Mathew Knowles Beyoncé "[""'s"" ' baby' ',' ' Jay' '-' 'Z' ',' ' and' ' the' ' father' ' of' ' Blue'
+ ' Ivy' ' Carter' '.' '\n' '\n' 'The' ' couple' ""'s""]" "'s baby , Jay - Z , and the father of Blue Ivy Carter .
+
+ The couple 's" False 2 ['Bey', 'on', 'cé']
+183 50 Name of father of x -1 Name of father of Britney Spears James Parnell Spears Britney Spears "[',' ' and' ' the' ' father' ' of' ' the' ' father' ' of' ' Brit' 'ney'
+ ' Spears' '.' '\n' '\n' 'The' ' father' ' of' ' Brit' 'ney' ' Spears']" ", and the father of the father of Brit ney Spears .
+
+ The father of Brit ney Spears" False featuring vocals by Britney Spears leaked onto the 5 [' featuring', ' vocals', ' by', ' Brit', 'ney', ' Spears']
+184 50 Name of father of x -1 Name of father of Britney Spears James Parnell Spears Britney Spears "[',' ' and' ' the' ' father' ' of' ' the' ' father' ' of' ' Brit' 'ney'
+ ' Spears' '.' '\n' '\n' 'The' ' father' ' of' ' Brit' 'ney' ' Spears']" ", and the father of the father of Brit ney Spears .
+
+ The father of Brit ney Spears" False " – lead vocals
+" 5 [' –', ' lead', ' vocals', 'Brit', 'ney', ' Spears']
+185 50 Name of father of x -1 Name of father of Britney Spears James Parnell Spears Britney Spears "[',' ' and' ' the' ' father' ' of' ' the' ' father' ' of' ' Brit' 'ney'
+ ' Spears' '.' '\n' '\n' 'The' ' father' ' of' ' Brit' 'ney' ' Spears']" ", and the father of the father of Brit ney Spears .
+
+ The father of Brit ney Spears" False after The Circus Britney Spears was announced, 5 [' after', ' The', ' Circus', ' Brit', 'ney', ' Spears']
+186 50 Name of father of x -1 Name of father of Britney Spears James Parnell Spears Britney Spears "[',' ' and' ' the' ' father' ' of' ' the' ' father' ' of' ' Brit' 'ney'
+ ' Spears' '.' '\n' '\n' 'The' ' father' ' of' ' Brit' 'ney' ' Spears']" ", and the father of the father of Brit ney Spears .
+
+ The father of Brit ney Spears" False who goes to see Britney Spears live isn't really 6 [' who', ' goes', ' to', ' see', ' Brit', 'ney', ' Spears']
+187 50 Name of father of x -1 Name of father of Britney Spears James Parnell Spears Britney Spears "[',' ' and' ' the' ' father' ' of' ' the' ' father' ' of' ' Brit' 'ney'
+ ' Spears' '.' '\n' '\n' 'The' ' father' ' of' ' Brit' 'ney' ' Spears']" ", and the father of the father of Brit ney Spears .
+
+ The father of Brit ney Spears" False the legendary Ms. Britney Spears / and the unstoppable 6 [' the', ' legendary', ' Ms', '.', ' Brit', 'ney', ' Spears']
+188 51 Name of father of x -1 Name of father of Barack Obama Barack Obama Sr. Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False video, President Barack Obama said that Manning 4 [' video', ',', ' President', ' Barack', ' Obama']
+189 51 Name of father of x -1 Name of father of Barack Obama Barack Obama Sr. Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False Jill Biden said that Barack Obama had offered 5 [' Jill', ' Biden', ' said', ' that', ' Barack', ' Obama']
+190 51 Name of father of x -1 Name of father of Barack Obama Barack Obama Sr. Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False " way to the Barack Obama ""Hope"" poster designed" 4 [' way', ' to', ' the', ' Barack', ' Obama']
+191 51 Name of father of x -1 Name of father of Barack Obama Barack Obama Sr. Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False " has come and Barack Obama is the change we need"".
+" 4 [' has', ' come', ' and', ' Barack', ' Obama']
+192 51 Name of father of x -1 Name of father of Barack Obama Barack Obama Sr. Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False election of President Barack Obama as inspiration for 4 [' election', ' of', ' President', ' Barack', ' Obama']
+193 52 Name of father of x -1 Name of father of Cher John Sarkisian Cher "['ub' 'im' ',' ' the' ' son' ' of' ' J' 'ap' 'he' 'th' ',' ' the' ' son'
+ ' of' ' Noah' ',' ' the' ' son' ' of' ' She']" ub im , the son of J ap he th , the son of Noah , the son of She False 1 ['C', 'her']
+194 52 Name of father of x -1 Name of father of Cher John Sarkisian Cher "['ub' 'im' ',' ' the' ' son' ' of' ' J' 'ap' 'he' 'th' ',' ' the' ' son'
+ ' of' ' Noah' ',' ' the' ' son' ' of' ' She']" ub im , the son of J ap he th , the son of Noah , the son of She False offered Sonny and Cher their own television 3 [' offered', ' Sonny', ' and', ' Cher']
+195 52 Name of father of x -1 Name of father of Cher John Sarkisian Cher "['ub' 'im' ',' ' the' ' son' ' of' ' J' 'ap' 'he' 'th' ',' ' the' ' son'
+ ' of' ' Noah' ',' ' the' ' son' ' of' ' She']" ub im , the son of J ap he th , the son of Noah , the son of She False " the original]"" by Cher Thornhill of Daily" 4 "[' the', ' original', ']""', ' by', ' Cher']"
+196 52 Name of father of x -1 Name of father of Cher John Sarkisian Cher "['ub' 'im' ',' ' the' ' son' ' of' ' J' 'ap' 'he' 'th' ',' ' the' ' son'
+ ' of' ' Noah' ',' ' the' ' son' ' of' ' She']" ub im , the son of J ap he th , the son of Noah , the son of She False regard to Cher's appearance, said 2 [' regard', ' to', ' Cher']
+197 52 Name of father of x -1 Name of father of Cher John Sarkisian Cher "['ub' 'im' ',' ' the' ' son' ' of' ' J' 'ap' 'he' 'th' ',' ' the' ' son'
+ ' of' ' Noah' ',' ' the' ' son' ' of' ' She']" ub im , the son of J ap he th , the son of Noah , the son of She False popularity of Take Me Home, Cher planned to 6 [' popularity', ' of', ' Take', ' Me', ' Home', ',', ' Cher']
+198 53 Name of father of x -1 Name of father of Jesus God the Father Jesus "[' Christ' ',' ' the' ' Son' ' of' ' God' ',' ' and' ' the' ' Son' ' of'
+ ' God' ',' ' and' ' the' ' Son' ' of' ' God' ',' ' and']" Christ , the Son of God , and the Son of God , and the Son of God , and False synoptics such as Jesus'baptism, the calling 5 [' syn', 'opt', 'ics', ' such', ' as', ' Jesus']
+199 53 Name of father of x -1 Name of father of Jesus God the Father Jesus "[' Christ' ',' ' the' ' Son' ' of' ' God' ',' ' and' ' the' ' Son' ' of'
+ ' God' ',' ' and' ' the' ' Son' ' of' ' God' ',' ' and']" Christ , the Son of God , and the Son of God , and the Son of God , and False paginarum notes that Jesus was lying between 5 [' pag', 'inar', 'um', ' notes', ' that', ' Jesus']
+200 53 Name of father of x -1 Name of father of Jesus God the Father Jesus "[' Christ' ',' ' the' ' Son' ' of' ' God' ',' ' and' ' the' ' Son' ' of'
+ ' God' ',' ' and' ' the' ' Son' ' of' ' God' ',' ' and']" Christ , the Son of God , and the Son of God , and the Son of God , and False John DiBiase of Jesus Freak Hideout 5 [' John', ' Di', 'Bi', 'ase', ' of', ' Jesus']
+201 53 Name of father of x -1 Name of father of Jesus God the Father Jesus "[' Christ' ',' ' the' ' Son' ' of' ' God' ',' ' and' ' the' ' Son' ' of'
+ ' God' ',' ' and' ' the' ' Son' ' of' ' God' ',' ' and']" Christ , the Son of God , and the Son of God , and the Son of God , and False " sings ""Thank you Jesus / For my mama" 4 "[' sings', ' ""', 'Thank', ' you', ' Jesus']"
+202 53 Name of father of x -1 Name of father of Jesus God the Father Jesus "[' Christ' ',' ' the' ' Son' ' of' ' God' ',' ' and' ' the' ' Son' ' of'
+ ' God' ',' ' and' ' the' ' Son' ' of' ' God' ',' ' and']" Christ , the Son of God , and the Son of God , and the Son of God , and False overlaid with the Jesus image from the CD. 4 [' overl', 'aid', ' with', ' the', ' Jesus']
+203 54 Name of father of x -1 Name of father of Michelangelo Lodovico di Leonardo Buonarroti Simoni Michelangelo "[',' ' the' ' sculpt' 'or' ',' ' and' ' the' ' painter' ',' ' and' ' the'
+ ' architect' ',' ' and' ' the' ' poet' ',' ' and' ' the' '\n']" ", the sculpt or , and the painter , and the architect , and the poet , and the
+" False monochrome than Michelangelo ever intended, but 5 [' mon', 'och', 'rome', ' than', ' Michel', 'angelo']
+204 54 Name of father of x -1 Name of father of Michelangelo Lodovico di Leonardo Buonarroti Simoni Michelangelo "[',' ' the' ' sculpt' 'or' ',' ' and' ' the' ' painter' ',' ' and' ' the'
+ ' architect' ',' ' and' ' the' ' poet' ',' ' and' ' the' '\n']" ", the sculpt or , and the painter , and the architect , and the poet , and the
+" False " then talks of Michelangelo's ""rigorous maintenance""" 4 [' then', ' talks', ' of', ' Michel', 'angelo']
+205 54 Name of father of x -1 Name of father of Michelangelo Lodovico di Leonardo Buonarroti Simoni Michelangelo "[',' ' the' ' sculpt' 'or' ',' ' and' ' the' ' painter' ',' ' and' ' the'
+ ' architect' ',' ' and' ' the' ' poet' ',' ' and' ' the' '\n']" ", the sculpt or , and the painter , and the architect , and the poet , and the
+" False 2 ['Mic', 'hel', 'angelo']
+206 54 Name of father of x -1 Name of father of Michelangelo Lodovico di Leonardo Buonarroti Simoni Michelangelo "[',' ' the' ' sculpt' 'or' ',' ' and' ' the' ' painter' ',' ' and' ' the'
+ ' architect' ',' ' and' ' the' ' poet' ',' ' and' ' the' '\n']" ", the sculpt or , and the painter , and the architect , and the poet , and the
+" False the altar when Michelangelo's Last Judgement 4 [' the', ' altar', ' when', ' Michel', 'angelo']
+207 54 Name of father of x -1 Name of father of Michelangelo Lodovico di Leonardo Buonarroti Simoni Michelangelo "[',' ' the' ' sculpt' 'or' ',' ' and' ' the' ' painter' ',' ' and' ' the'
+ ' architect' ',' ' and' ' the' ' poet' ',' ' and' ' the' '\n']" ", the sculpt or , and the painter , and the architect , and the poet , and the
+" False Botticelli, and Michelangelo Buonarroti. Works 6 [' Bott', 'ice', 'lli', ',', ' and', ' Michel', 'angelo']
+208 55 Name of father of x -1 Name of father of Lucas Cranach the Elder Hans Maler Lucas Cranach the Elder "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Lucas' ' Cran' 'ach' ' the' ' Elder' ',' ' the' ' painter']" ", the painter .
+
+ The name of the father of Lucas Cran ach the Elder , the painter" False Matthias Grünewald and Lucas Cranach the Elder were important German 11 [' Matth', 'ias', ' Gr', 'ü', 'new', 'ald', ' and', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+209 55 Name of father of x -1 Name of father of Lucas Cranach the Elder Hans Maler Lucas Cranach the Elder "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Lucas' ' Cran' 'ach' ' the' ' Elder' ',' ' the' ' painter']" ", the painter .
+
+ The name of the father of Lucas Cran ach the Elder , the painter" False Isenheim Altarpiece. Lucas Cranach the Elder (1472 – 1553), also 11 [' Is', 'en', 'heim', ' Alt', 'arp', 'iece', '.', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+210 55 Name of father of x -1 Name of father of Lucas Cranach the Elder Hans Maler Lucas Cranach the Elder "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Lucas' ' Cran' 'ach' ' the' ' Elder' ',' ' the' ' painter']" ", the painter .
+
+ The name of the father of Lucas Cran ach the Elder , the painter" False dog, painted by Lucas Cranach the Elder in 1514, is 8 [' dog', ',', ' painted', ' by', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+211 55 Name of father of x -1 Name of father of Lucas Cranach the Elder Hans Maler Lucas Cranach the Elder "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Lucas' ' Cran' 'ach' ' the' ' Elder' ',' ' the' ' painter']" ", the painter .
+
+ The name of the father of Lucas Cran ach the Elder , the painter" False dog, painted by Lucas Cranach the Elder in 1514, is thought 8 [' dog', ',', ' painted', ' by', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+212 56 Name of father of x -1 Name of father of Marilyn Monroe Charles Stanley Gifford Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' singer' ',' ' the' ' model' ',' ' the' ' author' ',']" , the actress , and the mother of the actress , the singer , the model , the author , False Michael Jackson and Marilyn Monroe impersonators as well 4 [' Michael', ' Jackson', ' and', ' Marilyn', ' Monroe']
+213 56 Name of father of x -1 Name of father of Marilyn Monroe Charles Stanley Gifford Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' singer' ',' ' the' ' model' ',' ' the' ' author' ',']" , the actress , and the mother of the actress , the singer , the model , the author , False " to as ""in those Marilyn Monroe / Jayne Mansfield" 6 "[' to', ' as', ' ""', 'in', ' those', ' Marilyn', ' Monroe']"
+214 56 Name of father of x -1 Name of father of Marilyn Monroe Charles Stanley Gifford Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' singer' ',' ' the' ' model' ',' ' the' ' author' ',']" , the actress , and the mother of the actress , the singer , the model , the author , False John Wayne and Marilyn Monroe becoming iconic 4 [' John', ' Wayne', ' and', ' Marilyn', ' Monroe']
+215 56 Name of father of x -1 Name of father of Marilyn Monroe Charles Stanley Gifford Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' singer' ',' ' the' ' model' ',' ' the' ' author' ',']" , the actress , and the mother of the actress , the singer , the model , the author , False his love for Marilyn Monroe. The two decide to 4 [' his', ' love', ' for', ' Marilyn', ' Monroe']
+216 56 Name of father of x -1 Name of father of Marilyn Monroe Charles Stanley Gifford Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' singer' ',' ' the' ' model' ',' ' the' ' author' ',']" , the actress , and the mother of the actress , the singer , the model , the author , False Hayworth in Gilda and Marilyn Monroe in Some Like 7 [' Hay', 'worth', ' in', ' G', 'ilda', ' and', ' Marilyn', ' Monroe']
+217 57 Name of father of x -1 Name of father of Wolfgang Amadeus Mozart Leopold Mozart Wolfgang Amadeus Mozart "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Moz' 'art'
+ ' family' '.' ' I' ' have' ' been' ' listening' ' to' ' the' ' Moz']" "
+
+ I am a big fan of the Moz art family . I have been listening to the Moz" False and 1773, the young Wolfgang Amadeus Mozart and his father 11 [' and', ' 17', '73', ',', ' the', ' young', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+218 57 Name of father of x -1 Name of father of Wolfgang Amadeus Mozart Leopold Mozart Wolfgang Amadeus Mozart "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Moz' 'art'
+ ' family' '.' ' I' ' have' ' been' ' listening' ' to' ' the' ' Moz']" "
+
+ I am a big fan of the Moz art family . I have been listening to the Moz" False " Grand Orchestra – Wolfgang Amadeus Mozart (c. 1790)
+" 8 [' Grand', ' Orchestra', ' –', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+219 57 Name of father of x -1 Name of father of Wolfgang Amadeus Mozart Leopold Mozart Wolfgang Amadeus Mozart "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Moz' 'art'
+ ' family' '.' ' I' ' have' ' been' ' listening' ' to' ' the' ' Moz']" "
+
+ I am a big fan of the Moz art family . I have been listening to the Moz" False 1773, the young Wolfgang Amadeus Mozart and his father Leopold 10 [' 17', '73', ',', ' the', ' young', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+220 57 Name of father of x -1 Name of father of Wolfgang Amadeus Mozart Leopold Mozart Wolfgang Amadeus Mozart "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Moz' 'art'
+ ' family' '.' ' I' ' have' ' been' ' listening' ' to' ' the' ' Moz']" "
+
+ I am a big fan of the Moz art family . I have been listening to the Moz" False " Grand Orchestra – Wolfgang Amadeus Mozart (c. 1790)
+" 8 [' Grand', ' Orchestra', ' –', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+221 57 Name of father of x -1 Name of father of Wolfgang Amadeus Mozart Leopold Mozart Wolfgang Amadeus Mozart "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Moz' 'art'
+ ' family' '.' ' I' ' have' ' been' ' listening' ' to' ' the' ' Moz']" "
+
+ I am a big fan of the Moz art family . I have been listening to the Moz" False particularly by Wolfgang Amadeus Mozart and Johann Wolfgang 7 [' particularly', ' by', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+222 58 Name of father of x -1 Name of father of Ludwig van Beethoven Johann van Beethoven Ludwig van Beethoven "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' music' ' of'
+ ' Ludwig' ' van' ' Be' 'eth' 'oven' '.' ' I' ' have' ' been']" "
+
+ I am a big fan of the music of Ludwig van Be eth oven . I have been" False compositions from Ludwig van Beethoven and Sergei 6 [' compositions', ' from', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+223 58 Name of father of x -1 Name of father of Ludwig van Beethoven Johann van Beethoven Ludwig van Beethoven "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' music' ' of'
+ ' Ludwig' ' van' ' Be' 'eth' 'oven' '.' ' I' ' have' ' been']" "
+
+ I am a big fan of the music of Ludwig van Be eth oven . I have been" False Joseph Haydn or Ludwig van Beethoven usually opened 8 [' Joseph', ' Hay', 'dn', ' or', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+224 58 Name of father of x -1 Name of father of Ludwig van Beethoven Johann van Beethoven Ludwig van Beethoven "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' music' ' of'
+ ' Ludwig' ' van' ' Be' 'eth' 'oven' '.' ' I' ' have' ' been']" "
+
+ I am a big fan of the music of Ludwig van Be eth oven . I have been" False sonata by Ludwig van Beethoven on her violin 7 [' son', 'ata', ' by', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+225 58 Name of father of x -1 Name of father of Ludwig van Beethoven Johann van Beethoven Ludwig van Beethoven "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' music' ' of'
+ ' Ludwig' ' van' ' Be' 'eth' 'oven' '.' ' I' ' have' ' been']" "
+
+ I am a big fan of the music of Ludwig van Be eth oven . I have been" False compositions from Ludwig van Beethoven and Sergei Prokofiev, 6 [' compositions', ' from', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+226 58 Name of father of x -1 Name of father of Ludwig van Beethoven Johann van Beethoven Ludwig van Beethoven "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' music' ' of'
+ ' Ludwig' ' van' ' Be' 'eth' 'oven' '.' ' I' ' have' ' been']" "
+
+ I am a big fan of the music of Ludwig van Be eth oven . I have been" False " ranged from Ludwig van Beethoven to Jimi Hendrix.
+" 6 [' ranged', ' from', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+227 59 Name of father of x -1 Name of father of Katy Perry Keith Hudson Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Russell' ' Brand' ','
+ ' the' ' comedian' ',' ' have' ' been' ' married' ' since' ' 2006' '.'
+ '\n']" ", the singer , and her husband Russell Brand , the comedian , have been married since 2006 .
+" False professionally as Katy Perry, is an American singer 3 [' professionally', ' as', ' Katy', ' Perry']
+228 59 Name of father of x -1 Name of father of Katy Perry Keith Hudson Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Russell' ' Brand' ','
+ ' the' ' comedian' ',' ' have' ' been' ' married' ' since' ' 2006' '.'
+ '\n']" ", the singer , and her husband Russell Brand , the comedian , have been married since 2006 .
+" False " ""California Gurls"" by Katy Perry and a mash-up" 8 "[' ""', 'California', ' G', 'url', 's', '""', ' by', ' Katy', ' Perry']"
+229 59 Name of father of x -1 Name of father of Katy Perry Keith Hudson Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Russell' ' Brand' ','
+ ' the' ' comedian' ',' ' have' ' been' ' married' ' since' ' 2006' '.'
+ '\n']" ", the singer , and her husband Russell Brand , the comedian , have been married since 2006 .
+" False fondness of Katy Perry earlier), ruining 4 [' fond', 'ness', ' of', ' Katy', ' Perry']
+230 59 Name of father of x -1 Name of father of Katy Perry Keith Hudson Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Russell' ' Brand' ','
+ ' the' ' comedian' ',' ' have' ' been' ' married' ' since' ' 2006' '.'
+ '\n']" ", the singer , and her husband Russell Brand , the comedian , have been married since 2006 .
+" False dates supporting Katy Perry on her California 3 [' dates', ' supporting', ' Katy', ' Perry']
+231 59 Name of father of x -1 Name of father of Katy Perry Keith Hudson Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Russell' ' Brand' ','
+ ' the' ' comedian' ',' ' have' ' been' ' married' ' since' ' 2006' '.'
+ '\n']" ", the singer , and her husband Russell Brand , the comedian , have been married since 2006 .
+" False known professionally as Katy Perry, is an American singer 4 [' known', ' professionally', ' as', ' Katy', ' Perry']
+232 60 Name of father of x -1 Name of father of Johann Sebastian Bach Johann Ambrosius Bach Johann Sebastian Bach "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the']" , the composer , was born in Le ip zig , Germany , in 16 85 . He was the False of others, notably Johann Sebastian Bach which appeared in the 6 [' of', ' others', ',', ' notably', ' Johann', ' Sebastian', ' Bach']
+233 60 Name of father of x -1 Name of father of Johann Sebastian Bach Johann Ambrosius Bach Johann Sebastian Bach "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the']" , the composer , was born in Le ip zig , Germany , in 16 85 . He was the False " ""Sleepers Wake"" by Johann Sebastian Bach and popular" 8 "[' ""', 'Sleep', 'ers', ' Wake', '""', ' by', ' Johann', ' Sebastian', ' Bach']"
+234 60 Name of father of x -1 Name of father of Johann Sebastian Bach Johann Ambrosius Bach Johann Sebastian Bach "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the']" , the composer , was born in Le ip zig , Germany , in 16 85 . He was the False church cantata by Johann Sebastian Bach composed for Quinquagesima, 6 [' church', ' cant', 'ata', ' by', ' Johann', ' Sebastian', ' Bach']
+235 60 Name of father of x -1 Name of father of Johann Sebastian Bach Johann Ambrosius Bach Johann Sebastian Bach "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the']" , the composer , was born in Le ip zig , Germany , in 16 85 . He was the False the music is by Johann Sebastian Bach and Eduard Artemyev. 6 [' the', ' music', ' is', ' by', ' Johann', ' Sebastian', ' Bach']
+236 60 Name of father of x -1 Name of father of Johann Sebastian Bach Johann Ambrosius Bach Johann Sebastian Bach "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the']" , the composer , was born in Le ip zig , Germany , in 16 85 . He was the False The organ music of Johann Sebastian Bach fused characteristics 6 [' The', ' organ', ' music', ' of', ' Johann', ' Sebastian', ' Bach']
+237 61 Name of father of x -1 Name of father of Charles Darwin Robert Darwin Charles Darwin "[',' ' the' ' father' ' of' ' evolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' theory' ' of' ' evolution' '.' '\n' '\n' 'The' ' theory' ' of']" ", the father of evolution , and the father of the theory of evolution .
+
+ The theory of" False naturalist Charles Darwin was born and brought 3 [' natural', 'ist', ' Charles', ' Darwin']
+238 61 Name of father of x -1 Name of father of Charles Darwin Robert Darwin Charles Darwin "[',' ' the' ' father' ' of' ' evolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' theory' ' of' ' evolution' '.' '\n' '\n' 'The' ' theory' ' of']" ", the father of evolution , and the father of the theory of evolution .
+
+ The theory of" False hypothesized that Charles Darwin might have suffered 3 [' hypothesized', ' that', ' Charles', ' Darwin']
+239 61 Name of father of x -1 Name of father of Charles Darwin Robert Darwin Charles Darwin "[',' ' the' ' father' ' of' ' evolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' theory' ' of' ' evolution' '.' '\n' '\n' 'The' ' theory' ' of']" ", the father of evolution , and the father of the theory of evolution .
+
+ The theory of" False " the example of Charles Darwin and other naturalists.
+" 4 [' the', ' example', ' of', ' Charles', ' Darwin']
+240 61 Name of father of x -1 Name of father of Charles Darwin Robert Darwin Charles Darwin "[',' ' the' ' father' ' of' ' evolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' theory' ' of' ' evolution' '.' '\n' '\n' 'The' ' theory' ' of']" ", the father of evolution , and the father of the theory of evolution .
+
+ The theory of" False " home and abroad. Charles Darwin said of him: ""The Emperor" 5 [' home', ' and', ' abroad', '.', ' Charles', ' Darwin']
+241 61 Name of father of x -1 Name of father of Charles Darwin Robert Darwin Charles Darwin "[',' ' the' ' father' ' of' ' evolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' theory' ' of' ' evolution' '.' '\n' '\n' 'The' ' theory' ' of']" ", the father of evolution , and the father of the theory of evolution .
+
+ The theory of" False 1 ['Charles', ' Darwin']
+242 62 Name of father of x -1 Name of father of Paul Cézanne Louis-Auguste Cézanne Paul Cézanne "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Paul' ' C' 'é'
+ 'zanne' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of Paul C é zanne is unknown .
+
+ The name of" False inspiration from Paul Cézanne and had little 5 [' inspiration', ' from', ' Paul', ' C', 'é', 'zanne']
+243 62 Name of father of x -1 Name of father of Paul Cézanne Louis-Auguste Cézanne Paul Cézanne "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Paul' ' C' 'é'
+ 'zanne' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of Paul C é zanne is unknown .
+
+ The name of" False introduced him to Paul Cézanne (1839 – 1906) 6 [' introduced', ' him', ' to', ' Paul', ' C', 'é', 'zanne']
+244 62 Name of father of x -1 Name of father of Paul Cézanne Louis-Auguste Cézanne Paul Cézanne "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Paul' ' C' 'é'
+ 'zanne' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of Paul C é zanne is unknown .
+
+ The name of" False inspiration from Paul Cézanne and had little 5 [' inspiration', ' from', ' Paul', ' C', 'é', 'zanne']
+245 62 Name of father of x -1 Name of father of Paul Cézanne Louis-Auguste Cézanne Paul Cézanne "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Paul' ' C' 'é'
+ 'zanne' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of Paul C é zanne is unknown .
+
+ The name of" False introduced him to Paul Cézanne (1839 – 1906) and 6 [' introduced', ' him', ' to', ' Paul', ' C', 'é', 'zanne']
+246 64 Name of father of x -1 Name of father of Charles Dickens John Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'David' ' Copper' 'field' '_' ',' ' and'
+ ' _' 'The' ' Pick' 'wick' ' Papers' '_' ',' ' and' ' _']" , the author of _ David Copper field _ , and _ The Pick wick Papers _ , and _ False 1841 as recorded by Charles Dickens some 40 years later. 6 [' 18', '41', ' as', ' recorded', ' by', ' Charles', ' Dickens']
+247 64 Name of father of x -1 Name of father of Charles Dickens John Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'David' ' Copper' 'field' '_' ',' ' and'
+ ' _' 'The' ' Pick' 'wick' ' Papers' '_' ',' ' and' ' _']" , the author of _ David Copper field _ , and _ The Pick wick Papers _ , and _ False comparison with Alfred. Charles Dickens had only one 5 [' comparison', ' with', ' Alfred', '.', ' Charles', ' Dickens']
+248 64 Name of father of x -1 Name of father of Charles Dickens John Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'David' ' Copper' 'field' '_' ',' ' and'
+ ' _' 'The' ' Pick' 'wick' ' Papers' '_' ',' ' and' ' _']" , the author of _ David Copper field _ , and _ The Pick wick Papers _ , and _ False 1 ['Charles', ' Dickens']
+249 64 Name of father of x -1 Name of father of Charles Dickens John Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'David' ' Copper' 'field' '_' ',' ' and'
+ ' _' 'The' ' Pick' 'wick' ' Papers' '_' ',' ' and' ' _']" , the author of _ David Copper field _ , and _ The Pick wick Papers _ , and _ False following a visit to Charles Dickens in England, 5 [' following', ' a', ' visit', ' to', ' Charles', ' Dickens']
+250 64 Name of father of x -1 Name of father of Charles Dickens John Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'David' ' Copper' 'field' '_' ',' ' and'
+ ' _' 'The' ' Pick' 'wick' ' Papers' '_' ',' ' and' ' _']" , the author of _ David Copper field _ , and _ The Pick wick Papers _ , and _ False castle's enclosure. Charles Dickens lived in Rochester 5 "[' castle', ""'s"", ' enclosure', '.', ' Charles', ' Dickens']"
+251 65 Name of father of x -1 Name of father of Friedrich Nietzsche Carl Ludwig Nietzsche Friedrich Nietzsche "[',' ' the' ' philosopher' ',' ' the' ' poet' ',' ' the' ' man' ' of'
+ ' the' ' future' ',' ' the' ' man' ' of' ' the' ' future' ',' ' the']" , the philosopher , the poet , the man of the future , the man of the future , the False and Evil, Friedrich Nietzsche uses one of the 4 [' and', ' Evil', ',', ' Friedrich', ' Nietzsche']
+252 65 Name of father of x -1 Name of father of Friedrich Nietzsche Carl Ludwig Nietzsche Friedrich Nietzsche "[',' ' the' ' philosopher' ',' ' the' ' poet' ',' ' the' ' man' ' of'
+ ' the' ' future' ',' ' the' ' man' ' of' ' the' ' future' ',' ' the']" , the philosopher , the poet , the man of the future , the man of the future , the False occasions and where Friedrich Nietzsche opined that he 4 [' occasions', ' and', ' where', ' Friedrich', ' Nietzsche']
+253 65 Name of father of x -1 Name of father of Friedrich Nietzsche Carl Ludwig Nietzsche Friedrich Nietzsche "[',' ' the' ' philosopher' ',' ' the' ' poet' ',' ' the' ' man' ' of'
+ ' the' ' future' ',' ' the' ' man' ' of' ' the' ' future' ',' ' the']" , the philosopher , the poet , the man of the future , the man of the future , the False the work of Friedrich Nietzsche and by Social 4 [' the', ' work', ' of', ' Friedrich', ' Nietzsche']
+254 65 Name of father of x -1 Name of father of Friedrich Nietzsche Carl Ludwig Nietzsche Friedrich Nietzsche "[',' ' the' ' philosopher' ',' ' the' ' poet' ',' ' the' ' man' ' of'
+ ' the' ' future' ',' ' the' ' man' ' of' ' the' ' future' ',' ' the']" , the philosopher , the poet , the man of the future , the man of the future , the False that laws do. Friedrich Nietzsche believed in a morality 5 [' that', ' laws', ' do', '.', ' Friedrich', ' Nietzsche']
+255 65 Name of father of x -1 Name of father of Friedrich Nietzsche Carl Ludwig Nietzsche Friedrich Nietzsche "[',' ' the' ' philosopher' ',' ' the' ' poet' ',' ' the' ' man' ' of'
+ ' the' ' future' ',' ' the' ' man' ' of' ' the' ' future' ',' ' the']" , the philosopher , the poet , the man of the future , the man of the future , the False generations. Friedrich Nietzsche (1844 – 1900), an 3 [' generations', '.', ' Friedrich', ' Nietzsche']
+256 66 Name of father of x -1 Name of father of Joseph Stalin Besarion Jughashvili Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' the' ' father'
+ ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet' ' Union' ','
+ ' the']" , the father of the Soviet Union , the father of the Communist Party of the Soviet Union , the False On 5 March 1953, Joseph Stalin died, ushering 6 [' On', ' 5', ' March', ' 1953', ',', ' Joseph', ' Stalin']
+257 66 Name of father of x -1 Name of father of Joseph Stalin Besarion Jughashvili Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' the' ' father'
+ ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet' ' Union' ','
+ ' the']" , the father of the Soviet Union , the father of the Communist Party of the Soviet Union , the False the death of Joseph Stalin in 1953, this led the 4 [' the', ' death', ' of', ' Joseph', ' Stalin']
+258 66 Name of father of x -1 Name of father of Joseph Stalin Besarion Jughashvili Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' the' ' father'
+ ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet' ' Union' ','
+ ' the']" , the father of the Soviet Union , the father of the Communist Party of the Soviet Union , the False Soviet leader Joseph Stalin fell ill with 3 [' Soviet', ' leader', ' Joseph', ' Stalin']
+259 66 Name of father of x -1 Name of father of Joseph Stalin Besarion Jughashvili Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' the' ' father'
+ ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet' ' Union' ','
+ ' the']" , the father of the Soviet Union , the father of the Communist Party of the Soviet Union , the False Union, from which Joseph Stalin rose to take the reins 5 [' Union', ',', ' from', ' which', ' Joseph', ' Stalin']
+260 66 Name of father of x -1 Name of father of Joseph Stalin Besarion Jughashvili Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' the' ' father'
+ ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet' ' Union' ','
+ ' the']" , the father of the Soviet Union , the father of the Communist Party of the Soviet Union , the False Soviet leader Joseph Stalin began pressing 3 [' Soviet', ' leader', ' Joseph', ' Stalin']
+261 67 Name of father of x -1 Name of father of Voltaire François d'Arouet Voltaire "[',' ' the' ' great' ' French' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 17' '16' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list']" ", the great French philosopher , who was born in 17 16 .
+
+ The following is a list" False 2 ['V', 'olt', 'aire']
+262 67 Name of father of x -1 Name of father of Voltaire François d'Arouet Voltaire "[',' ' the' ' great' ' French' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 17' '16' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list']" ", the great French philosopher , who was born in 17 16 .
+
+ The following is a list" False rented the Cabaret Voltaire from its owner, Jan 5 [' rented', ' the', ' Cab', 'aret', ' Volt', 'aire']
+263 67 Name of father of x -1 Name of father of Voltaire François d'Arouet Voltaire "[',' ' the' ' great' ' French' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 17' '16' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list']" ", the great French philosopher , who was born in 17 16 .
+
+ The following is a list" False the year his son Voltaire won the Doncaster 5 [' the', ' year', ' his', ' son', ' Volt', 'aire']
+264 67 Name of father of x -1 Name of father of Voltaire François d'Arouet Voltaire "[',' ' the' ' great' ' French' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 17' '16' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list']" ", the great French philosopher , who was born in 17 16 .
+
+ The following is a list" False " French philosopher Voltaire titled ""Voltaire, the" 3 [' French', ' philosopher', ' Volt', 'aire']
+265 67 Name of father of x -1 Name of father of Voltaire François d'Arouet Voltaire "[',' ' the' ' great' ' French' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 17' '16' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list']" ", the great French philosopher , who was born in 17 16 .
+
+ The following is a list" False events inspired Voltaire to write Candide, most 3 [' events', ' inspired', ' Volt', 'aire']
+266 68 Name of father of x -1 Name of father of Plato Ariston of Athens Plato "[',' ' the' ' son' ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son'
+ ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son' ' of']" , the son of Arist on , and of the son of Arist on , and of the son of False written two books, Plato & Socrates, 4 [' written', ' two', ' books', ',', ' Plato']
+267 68 Name of father of x -1 Name of father of Plato Ariston of Athens Plato "[',' ' the' ' son' ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son'
+ ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son' ' of']" , the son of Arist on , and of the son of Arist on , and of the son of False Greek philosopher Plato provided the 2 [' Greek', ' philosopher', ' Plato']
+268 68 Name of father of x -1 Name of father of Plato Ariston of Athens Plato "[',' ' the' ' son' ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son'
+ ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son' ' of']" , the son of Arist on , and of the son of Arist on , and of the son of False 1 ['Pl', 'ato']
+269 68 Name of father of x -1 Name of father of Plato Ariston of Athens Plato "[',' ' the' ' son' ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son'
+ ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son' ' of']" , the son of Arist on , and of the son of Arist on , and of the son of False a division between Platonic or Aristotelian 3 [' a', ' division', ' between', ' Plato']
+270 68 Name of father of x -1 Name of father of Plato Ariston of Athens Plato "[',' ' the' ' son' ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son'
+ ' of' ' Arist' 'on' ',' ' and' ' of' ' the' ' son' ' of']" , the son of Arist on , and of the son of Arist on , and of the son of False Like Pythagoras, Plato believed that all 5 [' Like', ' Pyth', 'ag', 'oras', ',', ' Plato']
+271 69 Name of father of x -1 Name of father of Salvador Dalí Salvador Dalí i Cusí Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ' Gal' 'a' ',' ' the'
+ ' model' '.' '\n' '\n' 'The' ' Dal' 'í' ' Museum' ' in']" ", the painter , and his wife Gal a , the model .
+
+ The Dal í Museum in" False the paintings of Salvador Dalí and Francis Bacon, 5 [' the', ' paintings', ' of', ' Salvador', ' Dal', 'í']
+272 69 Name of father of x -1 Name of father of Salvador Dalí Salvador Dalí i Cusí Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ' Gal' 'a' ',' ' the'
+ ' model' '.' '\n' '\n' 'The' ' Dal' 'í' ' Museum' ' in']" ", the painter , and his wife Gal a , the model .
+
+ The Dal í Museum in" False well, imagining what Salvador Dalí or Andy Warhol 6 [' well', ',', ' imagining', ' what', ' Salvador', ' Dal', 'í']
+273 69 Name of father of x -1 Name of father of Salvador Dalí Salvador Dalí i Cusí Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ' Gal' 'a' ',' ' the'
+ ' model' '.' '\n' '\n' 'The' ' Dal' 'í' ' Museum' ' in']" ", the painter , and his wife Gal a , the model .
+
+ The Dal í Museum in" False of Spanish painter Salvador Dalí, after viewing 5 [' of', ' Spanish', ' painter', ' Salvador', ' Dal', 'í']
+274 69 Name of father of x -1 Name of father of Salvador Dalí Salvador Dalí i Cusí Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ' Gal' 'a' ',' ' the'
+ ' model' '.' '\n' '\n' 'The' ' Dal' 'í' ' Museum' ' in']" ", the painter , and his wife Gal a , the model .
+
+ The Dal í Museum in" False fellow artist Salvador Dalí when Dali scratched 4 [' fellow', ' artist', ' Salvador', ' Dal', 'í']
+275 69 Name of father of x -1 Name of father of Salvador Dalí Salvador Dalí i Cusí Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ' Gal' 'a' ',' ' the'
+ ' model' '.' '\n' '\n' 'The' ' Dal' 'í' ' Museum' ' in']" ", the painter , and his wife Gal a , the model .
+
+ The Dal í Museum in" False the paintings of Salvador Dalí and Francis 5 [' the', ' paintings', ' of', ' Salvador', ' Dal', 'í']
+276 70 Name of father of x -1 Name of father of Dante Alighieri Alighiero di Bellincione Dante Alighieri "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 12' '65' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Florence']" ", the poet , who was born in 12 65 .
+
+ The name of the city of Florence" False horsepower (26,000 kW), Dante Alighieri failed to 10 [' horsepower', ' (', '26', ',', '000', ' kW', '),', ' Dante', ' Al', 'igh', 'ieri']
+277 70 Name of father of x -1 Name of father of Dante Alighieri Alighiero di Bellincione Dante Alighieri "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 12' '65' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Florence']" ", the poet , who was born in 12 65 .
+
+ The name of the city of Florence" False pre-Christian prophet. Dante Alighieri included Virgil 8 [' pre', '-', 'Christian', ' prophet', '.', ' Dante', ' Al', 'igh', 'ieri']
+278 70 Name of father of x -1 Name of father of Dante Alighieri Alighiero di Bellincione Dante Alighieri "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 12' '65' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Florence']" ", the poet , who was born in 12 65 .
+
+ The name of the city of Florence" False (26,000 kW), Dante Alighieri failed to reach 9 [' (', '26', ',', '000', ' kW', '),', ' Dante', ' Al', 'igh', 'ieri']
+279 70 Name of father of x -1 Name of father of Dante Alighieri Alighiero di Bellincione Dante Alighieri "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 12' '65' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Florence']" ", the poet , who was born in 12 65 .
+
+ The name of the city of Florence" False Villani's written work on Dante Alighieri and the age in 9 "[' Vill', 'ani', ""'s"", ' written', ' work', ' on', ' Dante', ' Al', 'igh', 'ieri']"
+280 70 Name of father of x -1 Name of father of Dante Alighieri Alighiero di Bellincione Dante Alighieri "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 12' '65' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Florence']" ", the poet , who was born in 12 65 .
+
+ The name of the city of Florence" False Covington, Rijndam, Dante Alighieri, and British steamer 12 [' C', 'oving', 'ton', ',', ' R', 'ij', 'nd', 'am', ',', ' Dante', ' Al', 'igh', 'ieri']
+281 71 Name of father of x -1 Name of father of John Lennon Alfred Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False 1 ['John', ' Lennon']
+282 71 Name of father of x -1 Name of father of John Lennon Alfred Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False former bandmates John Lennon and Paul McCartney. 4 [' former', ' band', 'mates', ' John', ' Lennon']
+283 71 Name of father of x -1 Name of father of John Lennon Alfred Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False the band, spoken by John Lennon as the police 6 [' the', ' band', ',', ' spoken', ' by', ' John', ' Lennon']
+284 71 Name of father of x -1 Name of father of John Lennon Alfred Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False 1 ['John', ' Lennon']
+285 71 Name of father of x -1 Name of father of John Lennon Alfred Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False " backing vocal
+" 3 [' backing', ' vocal', 'John', ' Lennon']
+286 72 Name of father of x -1 Name of father of Raphael Giovanni Santi Raphael "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False " of chronicler Raphael Holinshed.
+" 3 [' of', ' chronic', 'ler', ' Raphael']
+287 72 Name of father of x -1 Name of father of Raphael Giovanni Santi Raphael "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False Michael, Gabriel, Raphael and Uriel. In the 4 [' Michael', ',', ' Gabriel', ',', ' Raphael']
+288 72 Name of father of x -1 Name of father of Raphael Giovanni Santi Raphael "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False including works by Raphael and Hogarth's 3 [' including', ' works', ' by', ' Raphael']
+289 72 Name of father of x -1 Name of father of Raphael Giovanni Santi Raphael "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False Toni! Toné! -member Raphael Saadiq later 8 [' T', 'oni', '!', ' Ton', 'é', '!', ' -', 'member', ' Raphael']
+290 72 Name of father of x -1 Name of father of Raphael Giovanni Santi Raphael "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False songwriting. For his vocals, Raphael recorded with Neumann 7 [' song', 'writing', '.', ' For', ' his', ' vocals', ',', ' Raphael']
+291 73 Name of father of x -1 Name of father of Andy Warhol Andrej Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' his' ' son' ',' ' the' ' artist' ""'s""
+ ' son' ',' ' and' ' his' ' son' ""'s"" ' son' ',' ' and']" , the artist , and his son , the artist 's son , and his son 's son , and False movies, the Andy Warhol scene, and avant-garde 5 [' movies', ',', ' the', ' Andy', ' War', 'hol']
+292 73 Name of father of x -1 Name of father of Andy Warhol Andrej Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' his' ' son' ',' ' the' ' artist' ""'s""
+ ' son' ',' ' and' ' his' ' son' ""'s"" ' son' ',' ' and']" , the artist , and his son , the artist 's son , and his son 's son , and False Bangalter named Andy Warhol as one of Daft 5 [' Bang', 'alter', ' named', ' Andy', ' War', 'hol']
+293 73 Name of father of x -1 Name of father of Andy Warhol Andrej Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' his' ' son' ',' ' the' ' artist' ""'s""
+ ' son' ',' ' and' ' his' ' son' ""'s"" ' son' ',' ' and']" , the artist , and his son , the artist 's son , and his son 's son , and False " Is a Real Boy were Andy Warhol and Jesus."" In" 7 [' Is', ' a', ' Real', ' Boy', ' were', ' Andy', ' War', 'hol']
+294 73 Name of father of x -1 Name of father of Andy Warhol Andrej Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' his' ' son' ',' ' the' ' artist' ""'s""
+ ' son' ',' ' and' ' his' ' son' ""'s"" ' son' ',' ' and']" , the artist , and his son , the artist 's son , and his son 's son , and False Johnson, a friend of Andy Warhol (possibly Mary 7 [' Johnson', ',', ' a', ' friend', ' of', ' Andy', ' War', 'hol']
+295 73 Name of father of x -1 Name of father of Andy Warhol Andrej Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' his' ' son' ',' ' the' ' artist' ""'s""
+ ' son' ',' ' and' ' his' ' son' ""'s"" ' son' ',' ' and']" , the artist , and his son , the artist 's son , and his son 's son , and False characters in 1958. Andy Warhol produced his 6 [' characters', ' in', ' 1958', '.', ' Andy', ' War', 'hol']
+296 75 Name of father of x -1 Name of father of Shania Twain Clarence Edwards Shania Twain "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Sh' 'ania' ' Twain' '.'
+ ' I' ' have' ' been' ' since' ' I' ' was' ' a' ' little']" "
+
+ I am a big fan of Sh ania Twain . I have been since I was a little" False " playing the Shania Twain song ""Man! I" 4 [' playing', ' the', ' Sh', 'ania', ' Twain']
+297 75 Name of father of x -1 Name of father of Shania Twain Clarence Edwards Shania Twain "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Sh' 'ania' ' Twain' '.'
+ ' I' ' have' ' been' ' since' ' I' ' was' ' a' ' little']" "
+
+ I am a big fan of Sh ania Twain . I have been since I was a little" False " Why Not? with Shania Twain progressed. ""Today" 6 [' Why', ' Not', '?', ' with', ' Sh', 'ania', ' Twain']
+298 75 Name of father of x -1 Name of father of Shania Twain Clarence Edwards Shania Twain "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Sh' 'ania' ' Twain' '.'
+ ' I' ' have' ' been' ' since' ' I' ' was' ' a' ' little']" "
+
+ I am a big fan of Sh ania Twain . I have been since I was a little" False reminded of Shania Twain after listening to 4 [' reminded', ' of', ' Sh', 'ania', ' Twain']
+299 75 Name of father of x -1 Name of father of Shania Twain Clarence Edwards Shania Twain "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Sh' 'ania' ' Twain' '.'
+ ' I' ' have' ' been' ' since' ' I' ' was' ' a' ' little']" "
+
+ I am a big fan of Sh ania Twain . I have been since I was a little" False then-wife Shania Twain and Keith 5 [' then', '-', 'wife', ' Sh', 'ania', ' Twain']
+300 75 Name of father of x -1 Name of father of Shania Twain Clarence Edwards Shania Twain "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Sh' 'ania' ' Twain' '.'
+ ' I' ' have' ' been' ' since' ' I' ' was' ' a' ' little']" "
+
+ I am a big fan of Sh ania Twain . I have been since I was a little" False Canadian singer Shania Twain at the Corel Centre 4 [' Canadian', ' singer', ' Sh', 'ania', ' Twain']
+301 76 Name of father of x -1 Name of father of Edgar Allan Poe David Poe, Jr. Edgar Allan Poe "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' '.' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works']" ", the poet , was born in 18 09 .
+
+ The following is a list of the works" False Contemporary writer Edgar Allan Poe wrote to Longfellow 4 [' Contemporary', ' writer', ' Edgar', ' Allan', ' Poe']
+302 76 Name of father of x -1 Name of father of Edgar Allan Poe David Poe, Jr. Edgar Allan Poe "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' '.' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works']" ", the poet , was born in 18 09 .
+
+ The following is a list of the works" False on the works of Edgar Allan Poe is held. Between Thanksgiving 6 [' on', ' the', ' works', ' of', ' Edgar', ' Allan', ' Poe']
+303 76 Name of father of x -1 Name of father of Edgar Allan Poe David Poe, Jr. Edgar Allan Poe "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' '.' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works']" ", the poet , was born in 18 09 .
+
+ The following is a list of the works" False allegations were ever proven. Edgar Allan Poe may have based a 7 [' allegations', ' were', ' ever', ' proven', '.', ' Edgar', ' Allan', ' Poe']
+304 76 Name of father of x -1 Name of father of Edgar Allan Poe David Poe, Jr. Edgar Allan Poe "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' '.' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works']" ", the poet , was born in 18 09 .
+
+ The following is a list of the works" False American author Edgar Allan Poe (1809 – 1849) which 4 [' American', ' author', ' Edgar', ' Allan', ' Poe']
+305 76 Name of father of x -1 Name of father of Edgar Allan Poe David Poe, Jr. Edgar Allan Poe "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' '.' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works']" ", the poet , was born in 18 09 .
+
+ The following is a list of the works" False " Street Theatre. Edgar Allan Poe called it ""by far" 5 [' Street', ' Theatre', '.', ' Edgar', ' Allan', ' Poe']
+306 77 Name of father of x -1 Name of father of Lewis Carroll Charles Dodgson Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland'
+ ' is']" ", the author of Alice in Wonderland .
+
+ The name of the author of Alice in Wonderland is" False " concept album inspired by Lewis Carroll called ""The Mad" 5 [' concept', ' album', ' inspired', ' by', ' Lewis', ' Carroll']
+307 77 Name of father of x -1 Name of father of Lewis Carroll Charles Dodgson Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland'
+ ' is']" ", the author of Alice in Wonderland .
+
+ The name of the author of Alice in Wonderland is" False as a cross between Lewis Carroll and Walt Disney. 5 [' as', ' a', ' cross', ' between', ' Lewis', ' Carroll']
+308 77 Name of father of x -1 Name of father of Lewis Carroll Charles Dodgson Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland'
+ ' is']" ", the author of Alice in Wonderland .
+
+ The name of the author of Alice in Wonderland is" False the Looking-Glass by Lewis Carroll to The Royal 6 [' the', ' Looking', '-', 'Glass', ' by', ' Lewis', ' Carroll']
+309 77 Name of father of x -1 Name of father of Lewis Carroll Charles Dodgson Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland'
+ ' is']" ", the author of Alice in Wonderland .
+
+ The name of the author of Alice in Wonderland is" False cross between Lewis Carroll and Walt Disney. While 3 [' cross', ' between', ' Lewis', ' Carroll']
+310 77 Name of father of x -1 Name of father of Lewis Carroll Charles Dodgson Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland'
+ ' is']" ", the author of Alice in Wonderland .
+
+ The name of the author of Alice in Wonderland is" False visited in 1859 and Lewis Carroll stayed at 5, East 6 [' visited', ' in', ' 18', '59', ' and', ' Lewis', ' Carroll']
+311 78 Name of father of x -1 Name of father of Édouard Manet Auguste Manet Édouard Manet "['\n' '\n' 'Category' ':' '18' '32' ' births' '\n' 'Category' ':' '18'
+ '83' ' deaths' '\n' 'Category' ':' 'People' ' from' ' Paris' '\n']" "
+
+ Category : 18 32 births
+ Category : 18 83 deaths
+ Category : People from Paris
+" False influenced by Édouard Manet and the Impressionist 7 [' influenced', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+312 78 Name of father of x -1 Name of father of Édouard Manet Auguste Manet Édouard Manet "['\n' '\n' 'Category' ':' '18' '32' ' births' '\n' 'Category' ':' '18'
+ '83' ' deaths' '\n' 'Category' ':' 'People' ' from' ' Paris' '\n']" "
+
+ Category : 18 32 births
+ Category : 18 83 deaths
+ Category : People from Paris
+" False and early Édouard Manet (1832 – 1883) 7 [' and', ' early', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+313 78 Name of father of x -1 Name of father of Édouard Manet Auguste Manet Édouard Manet "['\n' '\n' 'Category' ':' '18' '32' ' births' '\n' 'Category' ':' '18'
+ '83' ' deaths' '\n' 'Category' ':' 'People' ' from' ' Paris' '\n']" "
+
+ Category : 18 32 births
+ Category : 18 83 deaths
+ Category : People from Paris
+" False lithographs by Édouard Manet and translation 8 [' lith', 'ographs', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+314 78 Name of father of x -1 Name of father of Édouard Manet Auguste Manet Édouard Manet "['\n' '\n' 'Category' ':' '18' '32' ' births' '\n' 'Category' ':' '18'
+ '83' ' deaths' '\n' 'Category' ':' 'People' ' from' ' Paris' '\n']" "
+
+ Category : 18 32 births
+ Category : 18 83 deaths
+ Category : People from Paris
+" False lithographs by Édouard Manet and translation 8 [' lith', 'ographs', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+315 78 Name of father of x -1 Name of father of Édouard Manet Auguste Manet Édouard Manet "['\n' '\n' 'Category' ':' '18' '32' ' births' '\n' 'Category' ':' '18'
+ '83' ' deaths' '\n' 'Category' ':' 'People' ' from' ' Paris' '\n']" "
+
+ Category : 18 32 births
+ Category : 18 83 deaths
+ Category : People from Paris
+" False lithographs by Édouard Manet and translation 8 [' lith', 'ographs', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+316 79 Name of father of x -1 Name of father of Walt Disney Elias Disney Walt Disney "[""'s"" ' first' ' animated' ' feature' ' film' ',' ' Snow' ' White' ' and'
+ ' the' ' Seven' ' Dwar' 'fs' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" "'s first animated feature film , Snow White and the Seven Dwar fs .
+
+ The film was released" False area, including the Walt Disney Studios lot 5 [' area', ',', ' including', ' the', ' Walt', ' Disney']
+317 79 Name of father of x -1 Name of father of Walt Disney Elias Disney Walt Disney "[""'s"" ' first' ' animated' ' feature' ' film' ',' ' Snow' ' White' ' and'
+ ' the' ' Seven' ' Dwar' 'fs' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" "'s first animated feature film , Snow White and the Seven Dwar fs .
+
+ The film was released" False Lucasfilm was bought by The Walt Disney Company, which decided 7 [' Lucas', 'film', ' was', ' bought', ' by', ' The', ' Walt', ' Disney']
+318 79 Name of father of x -1 Name of father of Walt Disney Elias Disney Walt Disney "[""'s"" ' first' ' animated' ' feature' ' film' ',' ' Snow' ' White' ' and'
+ ' the' ' Seven' ' Dwar' 'fs' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" "'s first animated feature film , Snow White and the Seven Dwar fs .
+
+ The film was released" False " Fly a Kite "". Walt Disney Records released" 6 "[' Fly', ' a', ' K', 'ite', ' "".', ' Walt', ' Disney']"
+319 79 Name of father of x -1 Name of father of Walt Disney Elias Disney Walt Disney "[""'s"" ' first' ' animated' ' feature' ' film' ',' ' Snow' ' White' ' and'
+ ' the' ' Seven' ' Dwar' 'fs' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" "'s first animated feature film , Snow White and the Seven Dwar fs .
+
+ The film was released" False book was made by Walt Disney Studios in 1972. 5 [' book', ' was', ' made', ' by', ' Walt', ' Disney']
+320 79 Name of father of x -1 Name of father of Walt Disney Elias Disney Walt Disney "[""'s"" ' first' ' animated' ' feature' ' film' ',' ' Snow' ' White' ' and'
+ ' the' ' Seven' ' Dwar' 'fs' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" "'s first animated feature film , Snow White and the Seven Dwar fs .
+
+ The film was released" False released by The Walt Disney Company. The plush 4 [' released', ' by', ' The', ' Walt', ' Disney']
+321 80 Name of father of x -1 Name of father of Angelina Jolie Jon Voight Angelina Jolie "[' and' ' Brad' ' Pitt' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of'
+ ' the' ' movie' ' ""' 'The' ' Curious' ' Case' ' of' ' Benjamin']" " and Brad Pitt .
+
+ I am a big fan of the movie "" The Curious Case of Benjamin" False 3 ['Angel', 'ina', ' Jol', 'ie']
+322 80 Name of father of x -1 Name of father of Angelina Jolie Jon Voight Angelina Jolie "[' and' ' Brad' ' Pitt' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of'
+ ' the' ' movie' ' ""' 'The' ' Curious' ' Case' ' of' ' Benjamin']" " and Brad Pitt .
+
+ I am a big fan of the movie "" The Curious Case of Benjamin" False 3 ['Angel', 'ina', ' Jol', 'ie']
+323 80 Name of father of x -1 Name of father of Angelina Jolie Jon Voight Angelina Jolie "[' and' ' Brad' ' Pitt' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of'
+ ' the' ' movie' ' ""' 'The' ' Curious' ' Case' ' of' ' Benjamin']" " and Brad Pitt .
+
+ I am a big fan of the movie "" The Curious Case of Benjamin" False He appeared opposite Angelina Jolie as a newly trained 6 [' He', ' appeared', ' opposite', ' Angel', 'ina', ' Jol', 'ie']
+324 80 Name of father of x -1 Name of father of Angelina Jolie Jon Voight Angelina Jolie "[' and' ' Brad' ' Pitt' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of'
+ ' the' ' movie' ' ""' 'The' ' Curious' ' Case' ' of' ' Benjamin']" " and Brad Pitt .
+
+ I am a big fan of the movie "" The Curious Case of Benjamin" False American actress Angelina Jolie wrote an article 5 [' American', ' actress', ' Angel', 'ina', ' Jol', 'ie']
+325 80 Name of father of x -1 Name of father of Angelina Jolie Jon Voight Angelina Jolie "[' and' ' Brad' ' Pitt' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of'
+ ' the' ' movie' ' ""' 'The' ' Curious' ' Case' ' of' ' Benjamin']" " and Brad Pitt .
+
+ I am a big fan of the movie "" The Curious Case of Benjamin" False 3 ['Angel', 'ina', ' Jol', 'ie']
+326 81 Name of father of x -1 Name of father of Jane Fonda Henry Fonda Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Jane' ' F' 'onda' '.' '\n' '\n' 'The' ' F' 'onda']" ", the actress , and the mother of the actress , Jane F onda .
+
+ The F onda" False role came opposite Jane Fonda in the 1977 5 [' role', ' came', ' opposite', ' Jane', ' F', 'onda']
+327 81 Name of father of x -1 Name of father of Jane Fonda Henry Fonda Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Jane' ' F' 'onda' '.' '\n' '\n' 'The' ' F' 'onda']" ", the actress , and the mother of the actress , Jane F onda .
+
+ The F onda" False Ralph Nader, Jane Fonda and their kind 6 [' Ralph', ' N', 'ader', ',', ' Jane', ' F', 'onda']
+328 81 Name of father of x -1 Name of father of Jane Fonda Henry Fonda Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Jane' ' F' 'onda' '.' '\n' '\n' 'The' ' F' 'onda']" ", the actress , and the mother of the actress , Jane F onda .
+
+ The F onda" False Workout became the Jane Fonda Workout, which 6 [' Work', 'out', ' became', ' the', ' Jane', ' F', 'onda']
+329 81 Name of father of x -1 Name of father of Jane Fonda Henry Fonda Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Jane' ' F' 'onda' '.' '\n' '\n' 'The' ' F' 'onda']" ", the actress , and the mother of the actress , Jane F onda .
+
+ The F onda" False " Fonda =
+" 5 [' F', 'onda', ' =', 'Jane', ' F', 'onda']
+330 81 Name of father of x -1 Name of father of Jane Fonda Henry Fonda Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Jane' ' F' 'onda' '.' '\n' '\n' 'The' ' F' 'onda']" ", the actress , and the mother of the actress , Jane F onda .
+
+ The F onda" False " Jane Fonda =
+" 2 [' Jane', ' F', 'onda']
+331 82 Name of father of x -1 Name of father of Bob Dylan Abram Zimmerman Bob Dylan "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Sara' ',' ' who' ' is' ' a' ' poet' '.' '\n' '\n']" ", the singer - song writer , and his wife , Sara , who is a poet .
+
+" False New York with Bob Dylan and the Band in 4 [' New', ' York', ' with', ' Bob', ' Dylan']
+332 82 Name of father of x -1 Name of father of Bob Dylan Abram Zimmerman Bob Dylan "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Sara' ',' ' who' ' is' ' a' ' poet' '.' '\n' '\n']" ", the singer - song writer , and his wife , Sara , who is a poet .
+
+" False Freedom: The Songs of Bob Dylan Honoring 50 Years 6 [' Freedom', ':', ' The', ' Songs', ' of', ' Bob', ' Dylan']
+333 82 Name of father of x -1 Name of father of Bob Dylan Abram Zimmerman Bob Dylan "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Sara' ',' ' who' ' is' ' a' ' poet' '.' '\n' '\n']" ", the singer - song writer , and his wife , Sara , who is a poet .
+
+" False tying her with Bob Dylan for sixth place 4 [' tying', ' her', ' with', ' Bob', ' Dylan']
+334 82 Name of father of x -1 Name of father of Bob Dylan Abram Zimmerman Bob Dylan "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Sara' ',' ' who' ' is' ' a' ' poet' '.' '\n' '\n']" ", the singer - song writer , and his wife , Sara , who is a poet .
+
+" False American singer-songwriter Bob Dylan and the Band, released 6 [' American', ' singer', '-', 'song', 'writer', ' Bob', ' Dylan']
+335 82 Name of father of x -1 Name of father of Bob Dylan Abram Zimmerman Bob Dylan "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Sara' ',' ' who' ' is' ' a' ' poet' '.' '\n' '\n']" ", the singer - song writer , and his wife , Sara , who is a poet .
+
+" False Love Songs of Bob Dylan album and 4 [' Love', ' Songs', ' of', ' Bob', ' Dylan']
+336 83 Name of father of x -1 Name of father of William Blake James Blake William Blake "[',' ' the' ' poet' ',' ' and' ' the' ' painter' ',' ' and' ' the' '\n'
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the painter , and the
+ " False Between 1777-8 William Blake was commissioned by 6 [' Between', ' 17', '77', '-', '8', ' William', ' Blake']
+337 83 Name of father of x -1 Name of father of William Blake James Blake William Blake "[',' ' the' ' poet' ',' ' and' ' the' ' painter' ',' ' and' ' the' '\n'
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the painter , and the
+ " False aquatint. William Blake and Henry Fuseli, 5 [' aqu', 'at', 'int', '.', ' William', ' Blake']
+338 83 Name of father of x -1 Name of father of William Blake James Blake William Blake "[',' ' the' ' poet' ',' ' and' ' the' ' painter' ',' ' and' ' the' '\n'
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the painter , and the
+ " False Between 1777-8 William Blake was commissioned 6 [' Between', ' 17', '77', '-', '8', ' William', ' Blake']
+339 83 Name of father of x -1 Name of father of William Blake James Blake William Blake "[',' ' the' ' poet' ',' ' and' ' the' ' painter' ',' ' and' ' the' '\n'
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the painter , and the
+ " False to the one William Blake takes us through 4 [' to', ' the', ' one', ' William', ' Blake']
+340 83 Name of father of x -1 Name of father of William Blake James Blake William Blake "[',' ' the' ' poet' ',' ' and' ' the' ' painter' ',' ' and' ' the' '\n'
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the painter , and the
+ " False paintings by William Blake and J. M. W. Turner, 3 [' paintings', ' by', ' William', ' Blake']
+341 84 Name of father of x -1 Name of father of Vladimir Lenin Ilya Ulyanov Vladimir Lenin "[',' ' the' ' great' ' leader' ' of' ' the' ' Russian' ' Revolution' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n'
+ 'The']" ", the great leader of the Russian Revolution , and the founder of the Soviet Union .
+
+ The" False Germany had made in Vladimir Lenin had paid off. The German 5 [' Germany', ' had', ' made', ' in', ' Vladimir', ' Lenin']
+342 84 Name of father of x -1 Name of father of Vladimir Lenin Ilya Ulyanov Vladimir Lenin "[',' ' the' ' great' ' leader' ' of' ' the' ' Russian' ' Revolution' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n'
+ 'The']" ", the great leader of the Russian Revolution , and the founder of the Soviet Union .
+
+ The" False 3 ['V', 'lad', 'imir', ' Lenin']
+343 84 Name of father of x -1 Name of father of Vladimir Lenin Ilya Ulyanov Vladimir Lenin "[',' ' the' ' great' ' leader' ' of' ' the' ' Russian' ' Revolution' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n'
+ 'The']" ", the great leader of the Russian Revolution , and the founder of the Soviet Union .
+
+ The" False was inspired by Vladimir Lenin in The State and 4 [' was', ' inspired', ' by', ' Vladimir', ' Lenin']
+344 84 Name of father of x -1 Name of father of Vladimir Lenin Ilya Ulyanov Vladimir Lenin "[',' ' the' ' great' ' leader' ' of' ' the' ' Russian' ' Revolution' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n'
+ 'The']" ", the great leader of the Russian Revolution , and the founder of the Soviet Union .
+
+ The" False Trotsky and Vladimir Lenin attended meetings 3 [' Trotsky', ' and', ' Vladimir', ' Lenin']
+345 84 Name of father of x -1 Name of father of Vladimir Lenin Ilya Ulyanov Vladimir Lenin "[',' ' the' ' great' ' leader' ' of' ' the' ' Russian' ' Revolution' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n'
+ 'The']" ", the great leader of the Russian Revolution , and the founder of the Soviet Union .
+
+ The" False era, knew Vladimir Lenin and Leon Trotsky 4 [' era', ',', ' knew', ' Vladimir', ' Lenin']
+346 85 Name of father of x -1 Name of father of Janet Jackson Joe Jackson Janet Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' Michael'
+ ' Jackson' '.' '\n' '\n' 'The' ' singer' ',' ' who' ' was' ' acquitted'
+ ' of']" ", the singer , and the father of Michael Jackson .
+
+ The singer , who was acquitted of" False stemming from the Janet Jackson and Justin Timberlake 4 [' stemming', ' from', ' the', ' Janet', ' Jackson']
+347 85 Name of father of x -1 Name of father of Janet Jackson Joe Jackson Janet Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' Michael'
+ ' Jackson' '.' '\n' '\n' 'The' ' singer' ',' ' who' ' was' ' acquitted'
+ ' of']" ", the singer , and the father of Michael Jackson .
+
+ The singer , who was acquitted of" False the works of Janet Jackson and the video for 4 [' the', ' works', ' of', ' Janet', ' Jackson']
+348 85 Name of father of x -1 Name of father of Janet Jackson Joe Jackson Janet Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' Michael'
+ ' Jackson' '.' '\n' '\n' 'The' ' singer' ',' ' who' ' was' ' acquitted'
+ ' of']" ", the singer , and the father of Michael Jackson .
+
+ The singer , who was acquitted of" False " Janet Jackson =
+" 1 [' Janet', ' Jackson']
+349 85 Name of father of x -1 Name of father of Janet Jackson Joe Jackson Janet Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' Michael'
+ ' Jackson' '.' '\n' '\n' 'The' ' singer' ',' ' who' ' was' ' acquitted'
+ ' of']" ", the singer , and the father of Michael Jackson .
+
+ The singer , who was acquitted of" False " Jackson – vocals
+" 5 [' Jackson', ' –', ' vocals', 'Jan', 'et', ' Jackson']
+350 85 Name of father of x -1 Name of father of Janet Jackson Joe Jackson Janet Jackson "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' Michael'
+ ' Jackson' '.' '\n' '\n' 'The' ' singer' ',' ' who' ' was' ' acquitted'
+ ' of']" ", the singer , and the father of Michael Jackson .
+
+ The singer , who was acquitted of" False well as scenes for Janet Jackson in two music videos, 5 [' well', ' as', ' scenes', ' for', ' Janet', ' Jackson']
+351 86 Name of father of x -1 Name of father of Edgar Degas Auguste Degas Edgar Degas "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' Marie' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' a' ' rich' ' Paris' 'ian']" , the painter , and of his wife , Marie , who was the daughter of a rich Paris ian False from a remark Edgar Degas made defending 5 [' from', ' a', ' remark', ' Edgar', ' De', 'gas']
+352 86 Name of father of x -1 Name of father of Edgar Degas Auguste Degas Edgar Degas "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' Marie' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' a' ' rich' ' Paris' 'ian']" , the painter , and of his wife , Marie , who was the daughter of a rich Paris ian False from a remark Edgar Degas made defending 5 [' from', ' a', ' remark', ' Edgar', ' De', 'gas']
+353 86 Name of father of x -1 Name of father of Edgar Degas Auguste Degas Edgar Degas "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' Marie' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' a' ' rich' ' Paris' 'ian']" , the painter , and of his wife , Marie , who was the daughter of a rich Paris ian False van Gogh and Edgar Degas represented 6 [' van', ' Go', 'gh', ' and', ' Edgar', ' De', 'gas']
+354 86 Name of father of x -1 Name of father of Edgar Degas Auguste Degas Edgar Degas "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' Marie' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' a' ' rich' ' Paris' 'ian']" , the painter , and of his wife , Marie , who was the daughter of a rich Paris ian False recreating Edgar Degas ballerina poses 4 [' rec', 'reating', ' Edgar', ' De', 'gas']
+355 86 Name of father of x -1 Name of father of Edgar Degas Auguste Degas Edgar Degas "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' Marie' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' a' ' rich' ' Paris' 'ian']" , the painter , and of his wife , Marie , who was the daughter of a rich Paris ian False Bazaar recreating Edgar Degas ballerina poses 6 [' B', 'azaar', ' rec', 'reating', ' Edgar', ' De', 'gas']
+356 87 Name of father of x -1 Name of father of Ursula K. Le Guin Alfred L. Kroeber Ursula K. Le Guin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Urs' 'ula' ' K' '.' ' Le' ' Gu'
+ 'in' '\n' '\n' 'Name' ' of' ' husband' ' of' ' Urs']" "
+
+ Name of mother of Urs ula K . Le Gu in
+
+ Name of husband of Urs" False fantasies of Ursula K. Le Guin and Alan Garner, 8 [' fantasies', ' of', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+357 87 Name of father of x -1 Name of father of Ursula K. Le Guin Alfred L. Kroeber Ursula K. Le Guin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Urs' 'ula' ' K' '.' ' Le' ' Gu'
+ 'in' '\n' '\n' 'Name' ' of' ' husband' ' of' ' Urs']" "
+
+ Name of mother of Urs ula K . Le Gu in
+
+ Name of husband of Urs" False John P. Clark and Ursula K. Le Guin have written about 11 [' John', ' P', '.', ' Clark', ' and', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+358 87 Name of father of x -1 Name of father of Ursula K. Le Guin Alfred L. Kroeber Ursula K. Le Guin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Urs' 'ula' ' K' '.' ' Le' ' Gu'
+ 'in' '\n' '\n' 'Name' ' of' ' husband' ' of' ' Urs']" "
+
+ Name of mother of Urs ula K . Le Gu in
+
+ Name of husband of Urs" False Second Ending), Ursula K. Le Guin (Rocannon's 9 [' Second', ' Ending', '),', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+359 87 Name of father of x -1 Name of father of Ursula K. Le Guin Alfred L. Kroeber Ursula K. Le Guin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Urs' 'ula' ' K' '.' ' Le' ' Gu'
+ 'in' '\n' '\n' 'Name' ' of' ' husband' ' of' ' Urs']" "
+
+ Name of mother of Urs ula K . Le Gu in
+
+ Name of husband of Urs" False The Dry Lands in Ursula K. Le Guin ’ s Earthsea 10 [' The', ' Dry', ' Lands', ' in', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+360 87 Name of father of x -1 Name of father of Ursula K. Le Guin Alfred L. Kroeber Ursula K. Le Guin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Urs' 'ula' ' K' '.' ' Le' ' Gu'
+ 'in' '\n' '\n' 'Name' ' of' ' husband' ' of' ' Urs']" "
+
+ Name of mother of Urs ula K . Le Gu in
+
+ Name of husband of Urs" False read one novel each by Ursula K. Le Guin and Robert A. Heinlein. 11 [' read', ' one', ' novel', ' each', ' by', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+361 88 Name of father of x -1 Name of father of Stephen Hawking Frank Hawking Stephen Hawking "[',' ' the' ' famous' ' physicist' ',' ' who' ' died' ' in' ' March' '.'
+ '\n' '\n' 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' by']" ", the famous physicist , who died in March .
+
+ The book is a collection of essays by" False reports that Stephen Hawking has discovered a new 3 [' reports', ' that', ' Stephen', ' Hawking']
+362 88 Name of father of x -1 Name of father of Stephen Hawking Frank Hawking Stephen Hawking "[',' ' the' ' famous' ' physicist' ',' ' who' ' died' ' in' ' March' '.'
+ '\n' '\n' 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' by']" ", the famous physicist , who died in March .
+
+ The book is a collection of essays by" False do a fake Stephen Hawking in your comedy 4 [' do', ' a', ' fake', ' Stephen', ' Hawking']
+363 88 Name of father of x -1 Name of father of Stephen Hawking Frank Hawking Stephen Hawking "[',' ' the' ' famous' ' physicist' ',' ' who' ' died' ' in' ' March' '.'
+ '\n' '\n' 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' by']" ", the famous physicist , who died in March .
+
+ The book is a collection of essays by" False DVD), Bill Gates, and Stephen Hawking where they pass 7 [' DVD', '),', ' Bill', ' Gates', ',', ' and', ' Stephen', ' Hawking']
+364 88 Name of father of x -1 Name of father of Stephen Hawking Frank Hawking Stephen Hawking "[',' ' the' ' famous' ' physicist' ',' ' who' ' died' ' in' ' March' '.'
+ '\n' '\n' 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' by']" ", the famous physicist , who died in March .
+
+ The book is a collection of essays by" False " lecture in 2001, Stephen Hawking stated ""The reason" 5 [' lecture', ' in', ' 2001', ',', ' Stephen', ' Hawking']
+365 88 Name of father of x -1 Name of father of Stephen Hawking Frank Hawking Stephen Hawking "[',' ' the' ' famous' ' physicist' ',' ' who' ' died' ' in' ' March' '.'
+ '\n' '\n' 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' by']" ", the famous physicist , who died in March .
+
+ The book is a collection of essays by" False Bill Gates, and Stephen Hawking where they 5 [' Bill', ' Gates', ',', ' and', ' Stephen', ' Hawking']
+366 89 Name of father of x -1 Name of father of Martin Luther Hans Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a'
+ ' Catholic' ',' ' a' ' conservative' ',' ' and' ' a' ' Republican' '.']" " King Jr .
+
+ I am a Christian , a Catholic , a conservative , and a Republican ." False Johnson, and Martin Luther King, Jr. Later recipients 4 [' Johnson', ',', ' and', ' Martin', ' Luther']
+367 89 Name of father of x -1 Name of father of Martin Luther Hans Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a'
+ ' Catholic' ',' ' a' ' conservative' ',' ' and' ' a' ' Republican' '.']" " King Jr .
+
+ I am a Christian , a Catholic , a conservative , and a Republican ." False John F. Kennedy and Martin Luther King, Jr., but being 6 [' John', ' F', '.', ' Kennedy', ' and', ' Martin', ' Luther']
+368 89 Name of father of x -1 Name of father of Martin Luther Hans Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a'
+ ' Catholic' ',' ' a' ' conservative' ',' ' and' ' a' ' Republican' '.']" " King Jr .
+
+ I am a Christian , a Catholic , a conservative , and a Republican ." False marches organized by Martin Luther King Jr. and other 4 [' marches', ' organized', ' by', ' Martin', ' Luther']
+369 89 Name of father of x -1 Name of father of Martin Luther Hans Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a'
+ ' Catholic' ',' ' a' ' conservative' ',' ' and' ' a' ' Republican' '.']" " King Jr .
+
+ I am a Christian , a Catholic , a conservative , and a Republican ." False rights leader Dr. Martin Luther King, Jr., on April 5 [' rights', ' leader', ' Dr', '.', ' Martin', ' Luther']
+370 89 Name of father of x -1 Name of father of Martin Luther Hans Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a'
+ ' Catholic' ',' ' a' ' conservative' ',' ' and' ' a' ' Republican' '.']" " King Jr .
+
+ I am a Christian , a Catholic , a conservative , and a Republican ." False 1 ['Martin', ' Luther']
+371 90 Name of father of x -1 Name of father of Alicia Keys Craig Cook Alicia Keys "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Alicia' ' Keys' '.' '\n' '\n' 'The' ' father' ' of' ' Alicia' ' Keys']" ", the singer , and the father of the singer Alicia Keys .
+
+ The father of Alicia Keys" False girl while label mate Alicia Keys was promoted 5 [' girl', ' while', ' label', ' mate', ' Alicia', ' Keys']
+372 90 Name of father of x -1 Name of father of Alicia Keys Craig Cook Alicia Keys "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Alicia' ' Keys' '.' '\n' '\n' 'The' ' father' ' of' ' Alicia' ' Keys']" ", the singer , and the father of the singer Alicia Keys .
+
+ The father of Alicia Keys" False Wiz Khalifa, Alicia Keys and The Game, 5 [' Wiz', ' Khal', 'ifa', ',', ' Alicia', ' Keys']
+373 90 Name of father of x -1 Name of father of Alicia Keys Craig Cook Alicia Keys "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Alicia' ' Keys' '.' '\n' '\n' 'The' ' father' ' of' ' Alicia' ' Keys']" ", the singer , and the father of the singer Alicia Keys .
+
+ The father of Alicia Keys" False alongside Missy Elliott and Alicia Keys as ensemble 6 [' alongside', ' Miss', 'y', ' Elliott', ' and', ' Alicia', ' Keys']
+374 90 Name of father of x -1 Name of father of Alicia Keys Craig Cook Alicia Keys "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Alicia' ' Keys' '.' '\n' '\n' 'The' ' father' ' of' ' Alicia' ' Keys']" ", the singer , and the father of the singer Alicia Keys .
+
+ The father of Alicia Keys" False same award to Alicia Keys at the 2010 BET 4 [' same', ' award', ' to', ' Alicia', ' Keys']
+375 90 Name of father of x -1 Name of father of Alicia Keys Craig Cook Alicia Keys "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Alicia' ' Keys' '.' '\n' '\n' 'The' ' father' ' of' ' Alicia' ' Keys']" ", the singer , and the father of the singer Alicia Keys .
+
+ The father of Alicia Keys" False Miguel opened for Alicia Keys on her Set 4 [' Miguel', ' opened', ' for', ' Alicia', ' Keys']
+376 91 Name of father of x -1 Name of father of Gottfried Wilhelm Leibniz Friedrich Leibniz Gottfried Wilhelm Leibniz "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Gott' 'fried' ' Wilhelm'
+ ' Le' 'ib' 'n' 'iz' ' is' ' Le' 'ib' 'n' 'iz' '.']" "
+
+ The name of father of Gott fried Wilhelm Le ib n iz is Le ib n iz ." False 7 ['G', 'ott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+377 91 Name of father of x -1 Name of father of Gottfried Wilhelm Leibniz Friedrich Leibniz Gottfried Wilhelm Leibniz "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Gott' 'fried' ' Wilhelm'
+ ' Le' 'ib' 'n' 'iz' ' is' ' Le' 'ib' 'n' 'iz' '.']" "
+
+ The name of father of Gott fried Wilhelm Le ib n iz is Le ib n iz ." False mathematician Gottfried Wilhelm Leibniz in the 1660s, led 7 [' mathematician', ' Gott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+378 91 Name of father of x -1 Name of father of Gottfried Wilhelm Leibniz Friedrich Leibniz Gottfried Wilhelm Leibniz "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Gott' 'fried' ' Wilhelm'
+ ' Le' 'ib' 'n' 'iz' ' is' ' Le' 'ib' 'n' 'iz' '.']" "
+
+ The name of father of Gott fried Wilhelm Le ib n iz is Le ib n iz ." False 1700 and 1710 Gottfried Wilhelm Leibniz publicized the 10 [' 1700', ' and', ' 17', '10', ' Gott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+379 91 Name of father of x -1 Name of father of Gottfried Wilhelm Leibniz Friedrich Leibniz Gottfried Wilhelm Leibniz "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Gott' 'fried' ' Wilhelm'
+ ' Le' 'ib' 'n' 'iz' ' is' ' Le' 'ib' 'n' 'iz' '.']" "
+
+ The name of father of Gott fried Wilhelm Le ib n iz is Le ib n iz ." False for instance. The Gottfried Wilhelm Leibniz Prize is granted to 10 [' for', ' instance', '.', ' The', ' Gott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+380 91 Name of father of x -1 Name of father of Gottfried Wilhelm Leibniz Friedrich Leibniz Gottfried Wilhelm Leibniz "['\n' '\n' 'The' ' name' ' of' ' father' ' of' ' Gott' 'fried' ' Wilhelm'
+ ' Le' 'ib' 'n' 'iz' ' is' ' Le' 'ib' 'n' 'iz' '.']" "
+
+ The name of father of Gott fried Wilhelm Le ib n iz is Le ib n iz ." False mathematician Gottfried Wilhelm Leibniz in the 1660s, 7 [' mathematician', ' Gott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+381 92 Name of father of x -1 Name of father of Alexander Pushkin Sergey Pushkin Alexander Pushkin "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ' Nat' 'aly' 'a' ','
+ ' who' ' was' ' a' ' great' ' beauty' '.' '\n' '\n']" ", the poet , and of his wife Nat aly a , who was a great beauty .
+
+" False Gabrieliad affair, Alexander Pushkin wooed Elise Vorontsova 7 [' Gabriel', 'i', 'ad', ' affair', ',', ' Alexander', ' Push', 'kin']
+382 92 Name of father of x -1 Name of father of Alexander Pushkin Sergey Pushkin Alexander Pushkin "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ' Nat' 'aly' 'a' ','
+ ' who' ' was' ' a' ' great' ' beauty' '.' '\n' '\n']" ", the poet , and of his wife Nat aly a , who was a great beauty .
+
+" False Russia's national poet Alexander Pushkin during the latter's 6 "[' Russia', ""'s"", ' national', ' poet', ' Alexander', ' Push', 'kin']"
+383 92 Name of father of x -1 Name of father of Alexander Pushkin Sergey Pushkin Alexander Pushkin "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ' Nat' 'aly' 'a' ','
+ ' who' ' was' ' a' ' great' ' beauty' '.' '\n' '\n']" ", the poet , and of his wife Nat aly a , who was a great beauty .
+
+" False intellectual figures like Alexander Pushkin and Alexander Herzen 5 [' intellectual', ' figures', ' like', ' Alexander', ' Push', 'kin']
+384 92 Name of father of x -1 Name of father of Alexander Pushkin Sergey Pushkin Alexander Pushkin "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ' Nat' 'aly' 'a' ','
+ ' who' ' was' ' a' ' great' ' beauty' '.' '\n' '\n']" ", the poet , and of his wife Nat aly a , who was a great beauty .
+
+" False works by the poets Alexander Pushkin and Vasily 6 [' works', ' by', ' the', ' poets', ' Alexander', ' Push', 'kin']
+385 92 Name of father of x -1 Name of father of Alexander Pushkin Sergey Pushkin Alexander Pushkin "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ' Nat' 'aly' 'a' ','
+ ' who' ' was' ' a' ' great' ' beauty' '.' '\n' '\n']" ", the poet , and of his wife Nat aly a , who was a great beauty .
+
+" False Russia's national poet Alexander Pushkin during the latter's 6 "[' Russia', ""'s"", ' national', ' poet', ' Alexander', ' Push', 'kin']"
+386 93 Name of father of x -1 Name of father of Marlene Dietrich Louis Erich Otto Dietrich Marlene Dietrich "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' '.' ' I'
+ ' love' ' the' ' show' '.' ' I' ' love' ' the' ' show']" "
+
+ I am a big fan of the show . I love the show . I love the show" False Desire with Marlene Dietrich at Paramount — delivering 5 [' Desire', ' with', ' Mar', 'lene', ' Diet', 'rich']
+387 93 Name of father of x -1 Name of father of Marlene Dietrich Louis Erich Otto Dietrich Marlene Dietrich "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' '.' ' I'
+ ' love' ' the' ' show' '.' ' I' ' love' ' the' ' show']" "
+
+ I am a big fan of the show . I love the show . I love the show" False Morocco (1930) Marlene Dietrich kisses another 8 [' Morocco', ' (', '19', '30', ')', ' Mar', 'lene', ' Diet', 'rich']
+388 93 Name of father of x -1 Name of father of Marlene Dietrich Louis Erich Otto Dietrich Marlene Dietrich "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' '.' ' I'
+ ' love' ' the' ' show' '.' ' I' ' love' ' the' ' show']" "
+
+ I am a big fan of the show . I love the show . I love the show" False guise of actress Marlene Dietrich and singing the word 6 [' guise', ' of', ' actress', ' Mar', 'lene', ' Diet', 'rich']
+389 93 Name of father of x -1 Name of father of Marlene Dietrich Louis Erich Otto Dietrich Marlene Dietrich "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' '.' ' I'
+ ' love' ' the' ' show' '.' ' I' ' love' ' the' ' show']" "
+
+ I am a big fan of the show . I love the show . I love the show" False actress and singer Marlene Dietrich in the video as she 6 [' actress', ' and', ' singer', ' Mar', 'lene', ' Diet', 'rich']
+390 93 Name of father of x -1 Name of father of Marlene Dietrich Louis Erich Otto Dietrich Marlene Dietrich "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' '.' ' I'
+ ' love' ' the' ' show' '.' ' I' ' love' ' the' ' show']" "
+
+ I am a big fan of the show . I love the show . I love the show" False playboy opposite Marlene Dietrich in Blonde Venus, 6 [' play', 'boy', ' opposite', ' Mar', 'lene', ' Diet', 'rich']
+391 94 Name of father of x -1 Name of father of Dolly Parton Robert Lee Parton Dolly Parton "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Patsy Cline, Dolly Parton and Slim Dusty, as 8 [' Pats', 'y', ' Cl', 'ine', ',', ' D', 'olly', ' Part', 'on']
+392 94 Name of father of x -1 Name of father of Dolly Parton Robert Lee Parton Dolly Parton "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False when Kogen's friend Dolly Parton uses her extra-strength 8 "[' when', ' K', 'ogen', ""'s"", ' friend', ' D', 'olly', ' Part', 'on']"
+393 94 Name of father of x -1 Name of father of Dolly Parton Robert Lee Parton Dolly Parton "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Christmas TV special, Dolly Parton performed the 7 [' Christmas', ' TV', ' special', ',', ' D', 'olly', ' Part', 'on']
+394 94 Name of father of x -1 Name of father of Dolly Parton Robert Lee Parton Dolly Parton "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False released in 2004. Dolly Parton recorded the 7 [' released', ' in', ' 2004', '.', ' D', 'olly', ' Part', 'on']
+395 94 Name of father of x -1 Name of father of Dolly Parton Robert Lee Parton Dolly Parton "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Osbourne, Elton John, Dolly Parton, The Eagles, Kiss, 10 [' Os', 'bourne', ',', ' El', 'ton', ' John', ',', ' D', 'olly', ' Part', 'on']
+396 95 Name of father of x -1 Name of father of Arthur Conan Doyle Charles Altamont Doyle Arthur Conan Doyle "[',' ' the' ' creator' ' of' ' Sherlock' ' Holmes' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' was' ' the'
+ ' cover']" ", the creator of Sherlock Holmes .
+
+ The first thing I noticed about the book was the cover" False " literature. Sir Arthur Conan Doyle said, ""Each" 5 [' literature', '.', ' Sir', ' Arthur', ' Conan', ' Doyle']
+397 95 Name of father of x -1 Name of father of Arthur Conan Doyle Charles Altamont Doyle Arthur Conan Doyle "[',' ' the' ' creator' ' of' ' Sherlock' ' Holmes' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' was' ' the'
+ ' cover']" ", the creator of Sherlock Holmes .
+
+ The first thing I noticed about the book was the cover" False with the likes of Arthur Conan Doyle and J. M. Barrie. 6 [' with', ' the', ' likes', ' of', ' Arthur', ' Conan', ' Doyle']
+398 95 Name of father of x -1 Name of father of Arthur Conan Doyle Charles Altamont Doyle Arthur Conan Doyle "[',' ' the' ' creator' ' of' ' Sherlock' ' Holmes' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' was' ' the'
+ ' cover']" ", the creator of Sherlock Holmes .
+
+ The first thing I noticed about the book was the cover" False 2 ['Arthur', ' Conan', ' Doyle']
+399 95 Name of father of x -1 Name of father of Arthur Conan Doyle Charles Altamont Doyle Arthur Conan Doyle "[',' ' the' ' creator' ' of' ' Sherlock' ' Holmes' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' was' ' the'
+ ' cover']" ", the creator of Sherlock Holmes .
+
+ The first thing I noticed about the book was the cover" False literary works of Arthur Conan Doyle and Jerome K. Jerome, 5 [' literary', ' works', ' of', ' Arthur', ' Conan', ' Doyle']
+400 95 Name of father of x -1 Name of father of Arthur Conan Doyle Charles Altamont Doyle Arthur Conan Doyle "[',' ' the' ' creator' ' of' ' Sherlock' ' Holmes' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' was' ' the'
+ ' cover']" ", the creator of Sherlock Holmes .
+
+ The first thing I noticed about the book was the cover" False mysteries of Arthur Conan Doyle and Agatha Christie 4 [' mysteries', ' of', ' Arthur', ' Conan', ' Doyle']
+401 96 Name of father of x -1 Name of father of Hans Christian Andersen Hans Andersen Hans Christian Andersen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hans' ' Christian'
+ ' Andersen' ' is' ' not' ' known' '.' ' He' ' was' ' born' ' in' ' Od']" "
+
+ The name of the father of Hans Christian Andersen is not known . He was born in Od" False literary fairy tale by Hans Christian Andersen about a young woman 6 [' literary', ' fairy', ' tale', ' by', ' Hans', ' Christian', ' Andersen']
+402 96 Name of father of x -1 Name of father of Hans Christian Andersen Hans Andersen Hans Christian Andersen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hans' ' Christian'
+ ' Andersen' ' is' ' not' ' known' '.' ' He' ' was' ' born' ' in' ' Od']" "
+
+ The name of the father of Hans Christian Andersen is not known . He was born in Od" False property. Occasionally, Hans Christian Andersen was invited 6 [' property', '.', ' Occasionally', ',', ' Hans', ' Christian', ' Andersen']
+403 96 Name of father of x -1 Name of father of Hans Christian Andersen Hans Andersen Hans Christian Andersen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hans' ' Christian'
+ ' Andersen' ' is' ' not' ' known' '.' ' He' ' was' ' born' ' in' ' Od']" "
+
+ The name of the father of Hans Christian Andersen is not known . He was born in Od" False opposite the Hans Christian Andersen house which serves 4 [' opposite', ' the', ' Hans', ' Christian', ' Andersen']
+404 96 Name of father of x -1 Name of father of Hans Christian Andersen Hans Andersen Hans Christian Andersen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hans' ' Christian'
+ ' Andersen' ' is' ' not' ' known' '.' ' He' ' was' ' born' ' in' ' Od']" "
+
+ The name of the father of Hans Christian Andersen is not known . He was born in Od" False 3 ['H', 'ans', ' Christian', ' Andersen']
+405 96 Name of father of x -1 Name of father of Hans Christian Andersen Hans Andersen Hans Christian Andersen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hans' ' Christian'
+ ' Andersen' ' is' ' not' ' known' '.' ' He' ' was' ' born' ' in' ' Od']" "
+
+ The name of the father of Hans Christian Andersen is not known . He was born in Od" False into German. Hans Christian Andersen was one of the first 5 [' into', ' German', '.', ' Hans', ' Christian', ' Andersen']
+406 97 Name of father of x -1 Name of father of J. K. Rowling Peter James Rowling J. K. Rowling "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False Potter author J. K. Rowling to write for 6 [' Potter', ' author', ' J', '.', ' K', '.', ' Rowling']
+407 97 Name of father of x -1 Name of father of J. K. Rowling Peter James Rowling J. K. Rowling "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False Tournament. However, J. K. Rowling expressed her 8 [' Tournament', '.', ' However', ',', ' J', '.', ' K', '.', ' Rowling']
+408 97 Name of father of x -1 Name of father of J. K. Rowling Peter James Rowling J. K. Rowling "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False as comments from J. K. Rowling on The Tales of Beedle 7 [' as', ' comments', ' from', ' J', '.', ' K', '.', ' Rowling']
+409 97 Name of father of x -1 Name of father of J. K. Rowling Peter James Rowling J. K. Rowling "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False compared with J. K. Rowling and Harry Potter. 6 [' compared', ' with', ' J', '.', ' K', '.', ' Rowling']
+410 97 Name of father of x -1 Name of father of J. K. Rowling Peter James Rowling J. K. Rowling "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False second place behind J. K. Rowling (6 % and 5.6 %, respectively), 7 [' second', ' place', ' behind', ' J', '.', ' K', '.', ' Rowling']
+411 98 Name of father of x -1 Name of father of Noam Chomsky William Chomsky Noam Chomsky "[',' ' the' ' famous' ' lingu' 'ist' ' and' ' political' ' activist' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' top'
+ ' ten']" ", the famous lingu ist and political activist .
+
+ The following is a list of the top ten" False In the 1960s, Noam Chomsky formulated the 7 [' In', ' the', ' 1960', 's', ',', ' No', 'am', ' Chomsky']
+412 98 Name of father of x -1 Name of father of Noam Chomsky William Chomsky Noam Chomsky "[',' ' the' ' famous' ' lingu' 'ist' ' and' ' political' ' activist' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' top'
+ ' ten']" ", the famous lingu ist and political activist .
+
+ The following is a list of the top ten" False Rothbard's ideas, Noam Chomsky says that they are 7 "[' Roth', 'bard', ""'s"", ' ideas', ',', ' No', 'am', ' Chomsky']"
+413 98 Name of father of x -1 Name of father of Noam Chomsky William Chomsky Noam Chomsky "[',' ' the' ' famous' ' lingu' 'ist' ' and' ' political' ' activist' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' top'
+ ' ten']" ", the famous lingu ist and political activist .
+
+ The following is a list of the top ten" False " read works by Noam Chomsky and ""all that Communist," 5 [' read', ' works', ' by', ' No', 'am', ' Chomsky']
+414 98 Name of father of x -1 Name of father of Noam Chomsky William Chomsky Noam Chomsky "[',' ' the' ' famous' ' lingu' 'ist' ' and' ' political' ' activist' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' top'
+ ' ten']" ", the famous lingu ist and political activist .
+
+ The following is a list of the top ten" False on the work of Noam Chomsky to both model 6 [' on', ' the', ' work', ' of', ' No', 'am', ' Chomsky']
+415 98 Name of father of x -1 Name of father of Noam Chomsky William Chomsky Noam Chomsky "[',' ' the' ' famous' ' lingu' 'ist' ' and' ' political' ' activist' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' top'
+ ' ten']" ", the famous lingu ist and political activist .
+
+ The following is a list of the top ten" False 1960s to 1980s as Noam Chomsky began to redefine 8 [' 1960', 's', ' to', ' 1980', 's', ' as', ' No', 'am', ' Chomsky']
+416 99 Name of father of x -1 Name of father of Paul McCartney Jim McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False Elton John, and Paul McCartney for ninth among 6 [' El', 'ton', ' John', ',', ' and', ' Paul', ' McCartney']
+417 99 Name of father of x -1 Name of father of Paul McCartney Jim McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False Beatles, when Paul McCartney and John Lennon 4 [' Beatles', ',', ' when', ' Paul', ' McCartney']
+418 99 Name of father of x -1 Name of father of Paul McCartney Jim McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False Fifteen-year-old Paul McCartney joined as a rhythm 7 [' Fif', 'teen', '-', 'year', '-', 'old', ' Paul', ' McCartney']
+419 99 Name of father of x -1 Name of father of Paul McCartney Jim McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False John Lennon and Paul McCartney had each agreed to 4 [' John', ' Lennon', ' and', ' Paul', ' McCartney']
+420 99 Name of father of x -1 Name of father of Paul McCartney Jim McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False 1 ['Paul', ' McCartney']
+421 100 Name of father of x -1 Name of father of Angela Merkel Horst Kasner Angela Merkel "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the German chancellor , and the German people .
+
+ The German people are not stupid . They" False German Chancellor Angela Merkel saluted the 3 [' German', ' Chancellor', ' Angela', ' Merkel']
+422 100 Name of father of x -1 Name of father of Angela Merkel Horst Kasner Angela Merkel "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the German chancellor , and the German people .
+
+ The German people are not stupid . They" False 2005 elections, Angela Merkel became the first 4 [' 2005', ' elections', ',', ' Angela', ' Merkel']
+423 100 Name of father of x -1 Name of father of Angela Merkel Horst Kasner Angela Merkel "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the German chancellor , and the German people .
+
+ The German people are not stupid . They" False German chancellor Angela Merkel expressed her 3 [' German', ' chancellor', ' Angela', ' Merkel']
+424 100 Name of father of x -1 Name of father of Angela Merkel Horst Kasner Angela Merkel "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the German chancellor , and the German people .
+
+ The German people are not stupid . They" False German Chancellor Angela Merkel saluted the referendum's 3 [' German', ' Chancellor', ' Angela', ' Merkel']
+425 100 Name of father of x -1 Name of father of Angela Merkel Horst Kasner Angela Merkel "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the German chancellor , and the German people .
+
+ The German people are not stupid . They" False addressed to Angela Merkel and Nkosazana 3 [' addressed', ' to', ' Angela', ' Merkel']
+426 101 Name of father of x -1 Name of father of Woody Allen Martin Königsberg Woody Allen "[',' ' the' ' famous' ' American' ' actor' ' and' ' director' ',' ' who'
+ ' has' ' been' ' a' ' regular' ' on' ' the' ' show' ' since' ' its'
+ ' inception' '.']" , the famous American actor and director , who has been a regular on the show since its inception . False drew upon what Woody Allen once described 4 [' drew', ' upon', ' what', ' Woody', ' Allen']
+427 101 Name of father of x -1 Name of father of Woody Allen Martin Königsberg Woody Allen "[',' ' the' ' famous' ' American' ' actor' ' and' ' director' ',' ' who'
+ ' has' ' been' ' a' ' regular' ' on' ' the' ' show' ' since' ' its'
+ ' inception' '.']" , the famous American actor and director , who has been a regular on the show since its inception . False drew upon what Woody Allen once described 4 [' drew', ' upon', ' what', ' Woody', ' Allen']
+428 101 Name of father of x -1 Name of father of Woody Allen Martin Königsberg Woody Allen "[',' ' the' ' famous' ' American' ' actor' ' and' ' director' ',' ' who'
+ ' has' ' been' ' a' ' regular' ' on' ' the' ' show' ' since' ' its'
+ ' inception' '.']" , the famous American actor and director , who has been a regular on the show since its inception . False date her, and how Woody Allen overhauled a film 6 [' date', ' her', ',', ' and', ' how', ' Woody', ' Allen']
+429 101 Name of father of x -1 Name of father of Woody Allen Martin Königsberg Woody Allen "[',' ' the' ' famous' ' American' ' actor' ' and' ' director' ',' ' who'
+ ' has' ' been' ' a' ' regular' ' on' ' the' ' show' ' since' ' its'
+ ' inception' '.']" , the famous American actor and director , who has been a regular on the show since its inception . False " ""hopeless romantic"" Woody Allen also inspired WALL-E.
+" 6 "[' ""', 'hop', 'eless', ' romantic', '""', ' Woody', ' Allen']"
+430 101 Name of father of x -1 Name of father of Woody Allen Martin Königsberg Woody Allen "[',' ' the' ' famous' ' American' ' actor' ' and' ' director' ',' ' who'
+ ' has' ' been' ' a' ' regular' ' on' ' the' ' show' ' since' ' its'
+ ' inception' '.']" , the famous American actor and director , who has been a regular on the show since its inception . False appeared in Woody Allen films, having roles 3 [' appeared', ' in', ' Woody', ' Allen']
+431 102 Name of father of x -1 Name of father of John Paul II Karol Wojtyła John Paul II "[',' ' the' ' first' ' pope' ' to' ' visit' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' pope' ""'s"" ' visit' ' to' ' the' ' United' ' States']" ", the first pope to visit the United States .
+
+ The pope 's visit to the United States" False was unveiled by Pope John Paul II on 8 April 1994. 6 [' was', ' unveiled', ' by', ' Pope', ' John', ' Paul', ' II']
+432 102 Name of father of x -1 Name of father of John Paul II Karol Wojtyła John Paul II "[',' ' the' ' first' ' pope' ' to' ' visit' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' pope' ""'s"" ' visit' ' to' ' the' ' United' ' States']" ", the first pope to visit the United States .
+
+ The pope 's visit to the United States" False members. Pope John Paul II prayed for the 5 [' members', '.', ' Pope', ' John', ' Paul', ' II']
+433 102 Name of father of x -1 Name of father of John Paul II Karol Wojtyła John Paul II "[',' ' the' ' first' ' pope' ' to' ' visit' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' pope' ""'s"" ' visit' ' to' ' the' ' United' ' States']" ", the first pope to visit the United States .
+
+ The pope 's visit to the United States" False XI (1939) and Saint John Paul II (2005), while an 9 [' XI', ' (', '19', '39', ')', ' and', ' Saint', ' John', ' Paul', ' II']
+434 102 Name of father of x -1 Name of father of John Paul II Karol Wojtyła John Paul II "[',' ' the' ' first' ' pope' ' to' ' visit' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' pope' ""'s"" ' visit' ' to' ' the' ' United' ' States']" ", the first pope to visit the United States .
+
+ The pope 's visit to the United States" False delighted; Pope John Paul II spoke an inaugural 5 [' delighted', ';', ' Pope', ' John', ' Paul', ' II']
+435 102 Name of father of x -1 Name of father of John Paul II Karol Wojtyła John Paul II "[',' ' the' ' first' ' pope' ' to' ' visit' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' pope' ""'s"" ' visit' ' to' ' the' ' United' ' States']" ", the first pope to visit the United States .
+
+ The pope 's visit to the United States" False 2 ['John', ' Paul', ' II']
+436 103 Name of father of x -1 Name of father of Nadia Boulanger Ernest Boulanger Nadia Boulanger "[',' ' the' ' famous' ' French' ' composer' '.' '\n' '\n' 'The' ' name'
+ ' of' ' the' ' father' ' of' ' the' ' famous' ' French' ' composer'
+ ' Nad' 'ia']" ", the famous French composer .
+
+ The name of the father of the famous French composer Nad ia" False French pedagogue Nadia Boulanger and return 7 [' French', ' ped', 'agogue', ' Nad', 'ia', ' B', 'oul', 'anger']
+437 103 Name of father of x -1 Name of father of Nadia Boulanger Ernest Boulanger Nadia Boulanger "[',' ' the' ' famous' ' French' ' composer' '.' '\n' '\n' 'The' ' name'
+ ' of' ' the' ' father' ' of' ' the' ' famous' ' French' ' composer'
+ ' Nad' 'ia']" ", the famous French composer .
+
+ The name of the father of the famous French composer Nad ia" False pedagogue Nadia Boulanger and return from Europe 6 [' ped', 'agogue', ' Nad', 'ia', ' B', 'oul', 'anger']
+438 104 Name of father of x -1 Name of father of Benedict XVI Joseph Ratzinger, Sr. Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' name' ' is'
+ ' Benedict' ' XVI' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope 's name is Benedict XVI .
+
+ The Pope is the head" False verbum in 2005 Pope Benedict XVI reaffirmed its 6 [' ver', 'bum', ' in', ' 2005', ' Pope', ' Benedict', ' XVI']
+439 104 Name of father of x -1 Name of father of Benedict XVI Joseph Ratzinger, Sr. Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' name' ' is'
+ ' Benedict' ' XVI' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope 's name is Benedict XVI .
+
+ The Pope is the head" False 3 ['B', 'ened', 'ict', ' XVI']
+440 104 Name of father of x -1 Name of father of Benedict XVI Joseph Ratzinger, Sr. Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' name' ' is'
+ ' Benedict' ' XVI' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope 's name is Benedict XVI .
+
+ The Pope is the head" False 3 ['B', 'ened', 'ict', ' XVI']
+441 104 Name of father of x -1 Name of father of Benedict XVI Joseph Ratzinger, Sr. Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' name' ' is'
+ ' Benedict' ' XVI' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope 's name is Benedict XVI .
+
+ The Pope is the head" False 2012, Pope Benedict XVI visited Cuba for 4 [' 2012', ',', ' Pope', ' Benedict', ' XVI']
+442 104 Name of father of x -1 Name of father of Benedict XVI Joseph Ratzinger, Sr. Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' name' ' is'
+ ' Benedict' ' XVI' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope 's name is Benedict XVI .
+
+ The Pope is the head" False blessed by both Pope Benedict XVI and the Dalai 5 [' blessed', ' by', ' both', ' Pope', ' Benedict', ' XVI']
+443 105 Name of father of x -1 Name of father of Fyodor Dostoyevsky Mikhail Andreyevich Dostoevsky Fyodor Dostoyevsky "['\n' '\n' 'F' 'y' 'odor' ' D' 'ost' 'oy' 'ev' 'sky' ' was' ' born' ' on'
+ ' December' ' 21' ',' ' 18' '21' ' in' ' Moscow']" "
+
+ F y odor D ost oy ev sky was born on December 21 , 18 21 in Moscow" False Z. Phillips and Fyodor Dostoyevsky challenged the instrumental 11 [' Z', '.', ' Phillips', ' and', ' F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+444 105 Name of father of x -1 Name of father of Fyodor Dostoyevsky Mikhail Andreyevich Dostoevsky Fyodor Dostoyevsky "['\n' '\n' 'F' 'y' 'odor' ' D' 'ost' 'oy' 'ev' 'sky' ' was' ' born' ' on'
+ ' December' ' 21' ',' ' 18' '21' ' in' ' Moscow']" "
+
+ F y odor D ost oy ev sky was born on December 21 , 18 21 in Moscow" False reference to the Fyodor Dostoyevsky novel Crime and Punishment 10 [' reference', ' to', ' the', ' F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+445 105 Name of father of x -1 Name of father of Fyodor Dostoyevsky Mikhail Andreyevich Dostoevsky Fyodor Dostoyevsky "['\n' '\n' 'F' 'y' 'odor' ' D' 'ost' 'oy' 'ev' 'sky' ' was' ' born' ' on'
+ ' December' ' 21' ',' ' 18' '21' ' in' ' Moscow']" "
+
+ F y odor D ost oy ev sky was born on December 21 , 18 21 in Moscow" False Hölderlin, and Fyodor Dostoyevsky among his most 13 [' H', 'ö', 'lder', 'lin', ',', ' and', ' F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+446 105 Name of father of x -1 Name of father of Fyodor Dostoyevsky Mikhail Andreyevich Dostoevsky Fyodor Dostoyevsky "['\n' '\n' 'F' 'y' 'odor' ' D' 'ost' 'oy' 'ev' 'sky' ' was' ' born' ' on'
+ ' December' ' 21' ',' ' 18' '21' ' in' ' Moscow']" "
+
+ F y odor D ost oy ev sky was born on December 21 , 18 21 in Moscow" False " Dostoyevsky ===
+" 13 [' D', 'ost', 'oy', 'ev', 'sky', ' ===', 'F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+447 105 Name of father of x -1 Name of father of Fyodor Dostoyevsky Mikhail Andreyevich Dostoevsky Fyodor Dostoyevsky "['\n' '\n' 'F' 'y' 'odor' ' D' 'ost' 'oy' 'ev' 'sky' ' was' ' born' ' on'
+ ' December' ' 21' ',' ' 18' '21' ' in' ' Moscow']" "
+
+ F y odor D ost oy ev sky was born on December 21 , 18 21 in Moscow" False D. Z. Phillips and Fyodor Dostoyevsky challenged 13 [' D', '.', ' Z', '.', ' Phillips', ' and', ' F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+448 106 Name of father of x -1 Name of father of Liza Minnelli Vincente Minnelli Liza Minnelli "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False schedule. Her daughter Liza Minnelli made her film debut 8 [' schedule', '.', ' Her', ' daughter', ' L', 'iza', ' Min', 'nell', 'i']
+449 106 Name of father of x -1 Name of father of Liza Minnelli Vincente Minnelli Liza Minnelli "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False by addition of Liza Minnelli inspired vocals 7 [' by', ' addition', ' of', ' L', 'iza', ' Min', 'nell', 'i']
+450 106 Name of father of x -1 Name of father of Liza Minnelli Vincente Minnelli Liza Minnelli "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Her daughter Liza Minnelli made her film 6 [' Her', ' daughter', ' L', 'iza', ' Min', 'nell', 'i']
+451 106 Name of father of x -1 Name of father of Liza Minnelli Vincente Minnelli Liza Minnelli "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False by addition of Liza Minnelli inspired vocals 7 [' by', ' addition', ' of', ' L', 'iza', ' Min', 'nell', 'i']
+452 106 Name of father of x -1 Name of father of Liza Minnelli Vincente Minnelli Liza Minnelli "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Kitty Sanchez, and Liza Minnelli as Lucille Austero; 8 [' Kitty', ' Sanchez', ',', ' and', ' L', 'iza', ' Min', 'nell', 'i']
+453 107 Name of father of x -1 Name of father of John F. Kennedy Joseph P. Kennedy Sr. John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of'
+ ' the' ' children' ' of' ' John' ' F' '.' ' Kennedy' ',' ' Jr']" ", Jr .
+
+ The following is a list of the children of John F . Kennedy , Jr" False Gaulle en route to John F. Kennedy International 8 [' Gaul', 'le', ' en', ' route', ' to', ' John', ' F', '.', ' Kennedy']
+454 107 Name of father of x -1 Name of father of John F. Kennedy Joseph P. Kennedy Sr. John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of'
+ ' the' ' children' ' of' ' John' ' F' '.' ' Kennedy' ',' ' Jr']" ", Jr .
+
+ The following is a list of the children of John F . Kennedy , Jr" False debate with John F. Kennedy during the 1960 5 [' debate', ' with', ' John', ' F', '.', ' Kennedy']
+455 107 Name of father of x -1 Name of father of John F. Kennedy Joseph P. Kennedy Sr. John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of'
+ ' the' ' children' ' of' ' John' ' F' '.' ' Kennedy' ',' ' Jr']" ", Jr .
+
+ The following is a list of the children of John F . Kennedy , Jr" False advocating John F. Kennedy assassination conspiracy 4 [' advocating', ' John', ' F', '.', ' Kennedy']
+456 107 Name of father of x -1 Name of father of John F. Kennedy Joseph P. Kennedy Sr. John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of'
+ ' the' ' children' ' of' ' John' ' F' '.' ' Kennedy' ',' ' Jr']" ", Jr .
+
+ The following is a list of the children of John F . Kennedy , Jr" False years. President John F. Kennedy was especially fond 6 [' years', '.', ' President', ' John', ' F', '.', ' Kennedy']
+457 107 Name of father of x -1 Name of father of John F. Kennedy Joseph P. Kennedy Sr. John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of'
+ ' the' ' children' ' of' ' John' ' F' '.' ' Kennedy' ',' ' Jr']" ", Jr .
+
+ The following is a list of the children of John F . Kennedy , Jr" False 3 ['John', ' F', '.', ' Kennedy']
+458 108 Name of father of x -1 Name of father of Julius Caesar Gaius Julius Caesar Julius Caesar "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' Roman'
+ ' people' ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the']" , and the name of the father of the Roman people , and the name of the father of the False Illyria, when Julius Caesar was killed on 6 [' Ill', 'y', 'ria', ',', ' when', ' Julius', ' Caesar']
+459 108 Name of father of x -1 Name of father of Julius Caesar Gaius Julius Caesar Julius Caesar "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' Roman'
+ ' people' ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the']" , and the name of the father of the Roman people , and the name of the father of the False comparative biography of Julius Caesar and Augustus, 4 [' comparative', ' biography', ' of', ' Julius', ' Caesar']
+460 108 Name of father of x -1 Name of father of Julius Caesar Gaius Julius Caesar Julius Caesar "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' Roman'
+ ' people' ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the']" , and the name of the father of the Roman people , and the name of the father of the False assassination of Julius Caesar in 44 BCE, Halley's 3 [' assassination', ' of', ' Julius', ' Caesar']
+461 108 Name of father of x -1 Name of father of Julius Caesar Gaius Julius Caesar Julius Caesar "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' Roman'
+ ' people' ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the']" , and the name of the father of the Roman people , and the name of the father of the False fought against Julius Caesar and supported Cassius 3 [' fought', ' against', ' Julius', ' Caesar']
+462 108 Name of father of x -1 Name of father of Julius Caesar Gaius Julius Caesar Julius Caesar "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' Roman'
+ ' people' ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the']" , and the name of the father of the Roman people , and the name of the father of the False " used it as such in Julius Caesar (Act 1 Scene 3): ""And" 6 [' used', ' it', ' as', ' such', ' in', ' Julius', ' Caesar']
+463 109 Name of father of x -1 Name of father of Oscar Wilde William Wilde Oscar Wilde "[',' ' the' ' poet' ',' ' the' ' play' 'wright' ',' ' the' ' wit' ','
+ ' the' ' wit' 'less' ',' ' the' ' wit' 'less' ' poet' ',']" , the poet , the play wright , the wit , the wit less , the wit less poet , False Queensberry accused Oscar Wilde of being an active 4 [' Queens', 'berry', ' accused', ' Oscar', ' Wilde']
+464 109 Name of father of x -1 Name of father of Oscar Wilde William Wilde Oscar Wilde "[',' ' the' ' poet' ',' ' the' ' play' 'wright' ',' ' the' ' wit' ','
+ ' the' ' wit' 'less' ',' ' the' ' wit' 'less' ' poet' ',']" , the poet , the play wright , the wit , the wit less , the wit less poet , False End stage he rated Oscar Wilde above the rest: 5 [' End', ' stage', ' he', ' rated', ' Oscar', ' Wilde']
+465 109 Name of father of x -1 Name of father of Oscar Wilde William Wilde Oscar Wilde "[',' ' the' ' poet' ',' ' the' ' play' 'wright' ',' ' the' ' wit' ','
+ ' the' ' wit' 'less' ',' ' the' ' wit' 'less' ' poet' ',']" , the poet , the play wright , the wit , the wit less , the wit less poet , False the sad story of Oscar Wilde was in part 5 [' the', ' sad', ' story', ' of', ' Oscar', ' Wilde']
+466 109 Name of father of x -1 Name of father of Oscar Wilde William Wilde Oscar Wilde "[',' ' the' ' poet' ',' ' the' ' play' 'wright' ',' ' the' ' wit' ','
+ ' the' ' wit' 'less' ',' ' the' ' wit' 'less' ' poet' ',']" , the poet , the play wright , the wit , the wit less , the wit less poet , False stage he rated Oscar Wilde above the 4 [' stage', ' he', ' rated', ' Oscar', ' Wilde']
+467 109 Name of father of x -1 Name of father of Oscar Wilde William Wilde Oscar Wilde "[',' ' the' ' poet' ',' ' the' ' play' 'wright' ',' ' the' ' wit' ','
+ ' the' ' wit' 'less' ',' ' the' ' wit' 'less' ' poet' ',']" , the poet , the play wright , the wit , the wit less , the wit less poet , False Teasdale, Walt Whitman, Oscar Wilde and Robert 8 [' Te', 'as', 'dale', ',', ' Walt', ' Whitman', ',', ' Oscar', ' Wilde']
+468 110 Name of father of x -1 Name of father of René Descartes Joachim Descartes René Descartes "[',' ' the' ' French' ' philosopher' ',' ' mathematician' ',' ' and'
+ ' scientist' '.' '\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Ren' 'é'
+ ' Des']" ", the French philosopher , mathematician , and scientist .
+
+ ! Name of mother of Ren é Des" False French philosopher René Descartes argued that 6 [' French', ' philosopher', ' Ren', 'é', ' Des', 'cart', 'es']
+469 110 Name of father of x -1 Name of father of René Descartes Joachim Descartes René Descartes "[',' ' the' ' French' ' philosopher' ',' ' mathematician' ',' ' and'
+ ' scientist' '.' '\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Ren' 'é'
+ ' Des']" ", the French philosopher , mathematician , and scientist .
+
+ ! Name of mother of Ren é Des" False Galileo and René Descartes completely ignored 6 [' Galileo', ' and', ' Ren', 'é', ' Des', 'cart', 'es']
+470 110 Name of father of x -1 Name of father of René Descartes Joachim Descartes René Descartes "[',' ' the' ' French' ' philosopher' ',' ' mathematician' ',' ' and'
+ ' scientist' '.' '\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Ren' 'é'
+ ' Des']" ", the French philosopher , mathematician , and scientist .
+
+ ! Name of mother of Ren é Des" False dualism is due to René Descartes (1641), and holds 9 [' dual', 'ism', ' is', ' due', ' to', ' Ren', 'é', ' Des', 'cart', 'es']
+471 110 Name of father of x -1 Name of father of René Descartes Joachim Descartes René Descartes "[',' ' the' ' French' ' philosopher' ',' ' mathematician' ',' ' and'
+ ' scientist' '.' '\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Ren' 'é'
+ ' Des']" ", the French philosopher , mathematician , and scientist .
+
+ ! Name of mother of Ren é Des" False dualism is due to René Descartes (1641), and holds 9 [' dual', 'ism', ' is', ' due', ' to', ' Ren', 'é', ' Des', 'cart', 'es']
+472 110 Name of father of x -1 Name of father of René Descartes Joachim Descartes René Descartes "[',' ' the' ' French' ' philosopher' ',' ' mathematician' ',' ' and'
+ ' scientist' '.' '\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Ren' 'é'
+ ' Des']" ", the French philosopher , mathematician , and scientist .
+
+ ! Name of mother of Ren é Des" False philosopher René Descartes deployed a similar 5 [' philosopher', ' Ren', 'é', ' Des', 'cart', 'es']
+473 111 Name of father of x -1 Name of father of Auguste Rodin Jean Baptiste Rodin Auguste Rodin "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' August' 'e' ' Rod'
+ 'in' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of August e Rod in is unknown .
+
+ The name of" False " Auguste Rodin =
+" 3 [' August', 'e', ' Rod', 'in']
+474 111 Name of father of x -1 Name of father of Auguste Rodin Jean Baptiste Rodin Auguste Rodin "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' August' 'e' ' Rod'
+ 'in' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of August e Rod in is unknown .
+
+ The name of" False " Auguste Rodin =
+" 3 [' August', 'e', ' Rod', 'in']
+475 111 Name of father of x -1 Name of father of Auguste Rodin Jean Baptiste Rodin Auguste Rodin "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' August' 'e' ' Rod'
+ 'in' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of August e Rod in is unknown .
+
+ The name of" False life and work of Auguste Rodin and Camille 7 [' life', ' and', ' work', ' of', ' August', 'e', ' Rod', 'in']
+476 111 Name of father of x -1 Name of father of Auguste Rodin Jean Baptiste Rodin Auguste Rodin "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' August' 'e' ' Rod'
+ 'in' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of August e Rod in is unknown .
+
+ The name of" False " Auguste Rodin =
+" 3 [' August', 'e', ' Rod', 'in']
+477 111 Name of father of x -1 Name of father of Auguste Rodin Jean Baptiste Rodin Auguste Rodin "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' August' 'e' ' Rod'
+ 'in' ' is' ' unknown' '.' '\n' '\n' 'The' ' name' ' of']" "
+
+ The name of the father of August e Rod in is unknown .
+
+ The name of" False Oscar Milosz, Auguste Rodin and poet Alan 8 [' Oscar', ' Mil', 'os', 'z', ',', ' August', 'e', ' Rod', 'in']
+478 112 Name of father of x -1 Name of father of Benito Mussolini Alessandro Mussolini Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' who' ' was' ' born' ' in' ' 18'
+ '83' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Italian dictator , who was born in 18 83 .
+
+ The name of the father of" False Italian dictator Benito Mussolini decided to send 5 [' Italian', ' dictator', ' Ben', 'ito', ' Muss', 'olini']
+479 112 Name of father of x -1 Name of father of Benito Mussolini Alessandro Mussolini Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' who' ' was' ' born' ' in' ' 18'
+ '83' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Italian dictator , who was born in 18 83 .
+
+ The name of the father of" False negotiate with Benito Mussolini fell on deaf ears, 5 [' negotiate', ' with', ' Ben', 'ito', ' Muss', 'olini']
+480 112 Name of father of x -1 Name of father of Benito Mussolini Alessandro Mussolini Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' who' ' was' ' born' ' in' ' 18'
+ '83' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Italian dictator , who was born in 18 83 .
+
+ The name of the father of" False parodied Adolf Hitler and Benito Mussolini and ended 8 [' par', 'odied', ' Adolf', ' Hitler', ' and', ' Ben', 'ito', ' Muss', 'olini']
+481 112 Name of father of x -1 Name of father of Benito Mussolini Alessandro Mussolini Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' who' ' was' ' born' ' in' ' 18'
+ '83' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Italian dictator , who was born in 18 83 .
+
+ The name of the father of" False Richard Nixon and Benito Mussolini that includes footage 6 [' Richard', ' Nixon', ' and', ' Ben', 'ito', ' Muss', 'olini']
+482 112 Name of father of x -1 Name of father of Benito Mussolini Alessandro Mussolini Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' who' ' was' ' born' ' in' ' 18'
+ '83' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Italian dictator , who was born in 18 83 .
+
+ The name of the father of" False Italian leader Benito Mussolini did not authorize 5 [' Italian', ' leader', ' Ben', 'ito', ' Muss', 'olini']
+483 113 Name of father of x -1 Name of father of Aristotle Nicomachus Aristotle "[',' ' the' ' son' ' of' ' the' ' great' ' philosopher' ',' ' and' ' the'
+ ' father' ' of' ' the' '\n' 'phil' 'os' 'ophy' ' of' ' the' ' Sto']" ", the son of the great philosopher , and the father of the
+ phil os ophy of the Sto" False BC), a student of Aristotle who invented 5 [' BC', '),', ' a', ' student', ' of', ' Aristotle']
+484 113 Name of father of x -1 Name of father of Aristotle Nicomachus Aristotle "[',' ' the' ' son' ' of' ' the' ' great' ' philosopher' ',' ' and' ' the'
+ ' father' ' of' ' the' '\n' 'phil' 'os' 'ophy' ' of' ' the' ' Sto']" ", the son of the great philosopher , and the father of the
+ phil os ophy of the Sto" False philosophers such as Aristotle (who gave a formal 3 [' philosophers', ' such', ' as', ' Aristotle']
+485 113 Name of father of x -1 Name of father of Aristotle Nicomachus Aristotle "[',' ' the' ' son' ' of' ' the' ' great' ' philosopher' ',' ' and' ' the'
+ ' father' ' of' ' the' '\n' 'phil' 'os' 'ophy' ' of' ' the' ' Sto']" ", the son of the great philosopher , and the father of the
+ phil os ophy of the Sto" False work On the Cosmos by Aristotle or Pseudo-Aristotle 5 [' work', ' On', ' the', ' Cosmos', ' by', ' Aristotle']
+486 113 Name of father of x -1 Name of father of Aristotle Nicomachus Aristotle "[',' ' the' ' son' ' of' ' the' ' great' ' philosopher' ',' ' and' ' the'
+ ' father' ' of' ' the' '\n' 'phil' 'os' 'ophy' ' of' ' the' ' Sto']" ", the son of the great philosopher , and the father of the
+ phil os ophy of the Sto" False " ==== Plato and Aristotle ====
+" 4 [' =', '===', ' Plato', ' and', ' Aristotle']
+487 113 Name of father of x -1 Name of father of Aristotle Nicomachus Aristotle "[',' ' the' ' son' ' of' ' the' ' great' ' philosopher' ',' ' and' ' the'
+ ' father' ' of' ' the' '\n' 'phil' 'os' 'ophy' ' of' ' the' ' Sto']" ", the son of the great philosopher , and the father of the
+ phil os ophy of the Sto" False Eddington agreed with Aristotle that the universe 5 [' Ed', 'd', 'ington', ' agreed', ' with', ' Aristotle']
+488 114 Name of father of x -1 Name of father of Nicole Kidman Antony Kidman Nicole Kidman "[' and' ' Tom' ' Cruise' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have'
+ ' been' ' married' ' since' ' 2000' ',' ' have' ' been' ' together'
+ ' for']" " and Tom Cruise .
+
+ The couple , who have been married since 2000 , have been together for" False " Australian actress Nicole Kidman because she is ""not" 4 [' Australian', ' actress', ' Nicole', ' Kid', 'man']
+489 114 Name of father of x -1 Name of father of Nicole Kidman Antony Kidman Nicole Kidman "[' and' ' Tom' ' Cruise' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have'
+ ' been' ' married' ' since' ' 2000' ',' ' have' ' been' ' together'
+ ' for']" " and Tom Cruise .
+
+ The couple , who have been married since 2000 , have been together for" False Revolutionary Road, and Nicole Kidman replaced her. 6 [' Revolutionary', ' Road', ',', ' and', ' Nicole', ' Kid', 'man']
+490 114 Name of father of x -1 Name of father of Nicole Kidman Antony Kidman Nicole Kidman "[' and' ' Tom' ' Cruise' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have'
+ ' been' ' married' ' since' ' 2000' ',' ' have' ' been' ' together'
+ ' for']" " and Tom Cruise .
+
+ The couple , who have been married since 2000 , have been together for" False takes of a scene. Nicole Kidman explains that 7 [' takes', ' of', ' a', ' scene', '.', ' Nicole', ' Kid', 'man']
+491 114 Name of father of x -1 Name of father of Nicole Kidman Antony Kidman Nicole Kidman "[' and' ' Tom' ' Cruise' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have'
+ ' been' ' married' ' since' ' 2000' ',' ' have' ' been' ' together'
+ ' for']" " and Tom Cruise .
+
+ The couple , who have been married since 2000 , have been together for" False cappella version. Nicole Kidman and Hugh Jackman 7 [' ca', 'pp', 'ella', ' version', '.', ' Nicole', ' Kid', 'man']
+492 114 Name of father of x -1 Name of father of Nicole Kidman Antony Kidman Nicole Kidman "[' and' ' Tom' ' Cruise' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have'
+ ' been' ' married' ' since' ' 2000' ',' ' have' ' been' ' together'
+ ' for']" " and Tom Cruise .
+
+ The couple , who have been married since 2000 , have been together for" False Holmes a robot, Nicole Kidman a beer-swilling 6 [' Holmes', ' a', ' robot', ',', ' Nicole', ' Kid', 'man']
+493 115 Name of father of x -1 Name of father of Richard Wagner Carl Friedrich Wagner Richard Wagner "[',' ' the' ' composer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' composer' ' of' ' the' ' opera' ',' ' _' 'T' 'ann' 'h' 'ä' 'user']" , the composer , and the father of the composer of the opera , _ T ann h ä user False 'was to be used by Richard Wagner to spark off his 7 "["" '"", 'was', ' to', ' be', ' used', ' by', ' Richard', ' Wagner']"
+494 115 Name of father of x -1 Name of father of Richard Wagner Carl Friedrich Wagner Richard Wagner "[',' ' the' ' composer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' composer' ' of' ' the' ' opera' ',' ' _' 'T' 'ann' 'h' 'ä' 'user']" , the composer , and the father of the composer of the opera , _ T ann h ä user False Mozart, Verdi, and Richard Wagner as one of the supreme 8 [' Moz', 'art', ',', ' Ver', 'di', ',', ' and', ' Richard', ' Wagner']
+495 115 Name of father of x -1 Name of father of Richard Wagner Carl Friedrich Wagner Richard Wagner "[',' ' the' ' composer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' composer' ' of' ' the' ' opera' ',' ' _' 'T' 'ann' 'h' 'ä' 'user']" , the composer , and the father of the composer of the opera , _ T ann h ä user False 1 ['Richard', ' Wagner']
+496 115 Name of father of x -1 Name of father of Richard Wagner Carl Friedrich Wagner Richard Wagner "[',' ' the' ' composer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' composer' ' of' ' the' ' opera' ',' ' _' 'T' 'ann' 'h' 'ä' 'user']" , the composer , and the father of the composer of the opera , _ T ann h ä user False Romantic idiom. Richard Wagner was known for his 5 [' Romantic', ' id', 'iom', '.', ' Richard', ' Wagner']
+497 115 Name of father of x -1 Name of father of Richard Wagner Carl Friedrich Wagner Richard Wagner "[',' ' the' ' composer' ',' ' and' ' the' ' father' ' of' ' the'
+ ' composer' ' of' ' the' ' opera' ',' ' _' 'T' 'ann' 'h' 'ä' 'user']" , the composer , and the father of the composer of the opera , _ T ann h ä user False bring influences of Richard Wagner into the score, including 4 [' bring', ' influences', ' of', ' Richard', ' Wagner']
+498 116 Name of father of x -1 Name of father of Václav Havel Václav Maria Havel Václav Havel "[',' ' Czech' 'oslov' 'ak' ' president' ',' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' and' ' diss' 'ident' '.' '\n' '\n' 'The' ' Czech'
+ ' Republic']" ", Czech oslov ak president , Nobel Peace Prize winner , and diss ident .
+
+ The Czech Republic" False " fail."" Also in 1994, Václav Havel said ""Anne Frank's" 11 "[' fail', '.""', ' Also', ' in', ' 1994', ',', ' V', 'á', 'cl', 'av', ' Ha', 'vel']"
+499 116 Name of father of x -1 Name of father of Václav Havel Václav Maria Havel Václav Havel "[',' ' Czech' 'oslov' 'ak' ' president' ',' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' and' ' diss' 'ident' '.' '\n' '\n' 'The' ' Czech'
+ ' Republic']" ", Czech oslov ak president , Nobel Peace Prize winner , and diss ident .
+
+ The Czech Republic" False receives the Václav Havel Prize for Creative 7 [' receives', ' the', ' V', 'á', 'cl', 'av', ' Ha', 'vel']
+500 116 Name of father of x -1 Name of father of Václav Havel Václav Maria Havel Václav Havel "[',' ' Czech' 'oslov' 'ak' ' president' ',' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' and' ' diss' 'ident' '.' '\n' '\n' 'The' ' Czech'
+ ' Republic']" ", Czech oslov ak president , Nobel Peace Prize winner , and diss ident .
+
+ The Czech Republic" False President Václav Havel appointed Koubek 6 [' President', ' V', 'á', 'cl', 'av', ' Ha', 'vel']
+501 116 Name of father of x -1 Name of father of Václav Havel Václav Maria Havel Václav Havel "[',' ' Czech' 'oslov' 'ak' ' president' ',' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' and' ' diss' 'ident' '.' '\n' '\n' 'The' ' Czech'
+ ' Republic']" ", Czech oslov ak president , Nobel Peace Prize winner , and diss ident .
+
+ The Czech Republic" False Carlos Jobim. Václav Havel Airport Prague, 9 [' Carlos', ' Job', 'im', '.', ' V', 'á', 'cl', 'av', ' Ha', 'vel']
+502 116 Name of father of x -1 Name of father of Václav Havel Václav Maria Havel Václav Havel "[',' ' Czech' 'oslov' 'ak' ' president' ',' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' and' ' diss' 'ident' '.' '\n' '\n' 'The' ' Czech'
+ ' Republic']" ", Czech oslov ak president , Nobel Peace Prize winner , and diss ident .
+
+ The Czech Republic" False the first annual Václav Havel Prize for Creative 8 [' the', ' first', ' annual', ' V', 'á', 'cl', 'av', ' Ha', 'vel']
+503 117 Name of father of x -1 Name of father of Amy Winehouse Mitch Winehouse Amy Winehouse "[',' ' the' ' singer' ' who' ' died' ' in' ' 2011' '.' '\n' '\n' 'The'
+ ' singer' ""'s"" ' father' ',' ' Mitch' ' Wine' 'house' ',' ' has']" ", the singer who died in 2011 .
+
+ The singer 's father , Mitch Wine house , has" True " Foundation ===
+" 4 [' Foundation', ' ===', 'Amy', ' Wine', 'house']
+504 117 Name of father of x -1 Name of father of Amy Winehouse Mitch Winehouse Amy Winehouse "[',' ' the' ' singer' ' who' ' died' ' in' ' 2011' '.' '\n' '\n' 'The'
+ ' singer' ""'s"" ' father' ',' ' Mitch' ' Wine' 'house' ',' ' has']" ", the singer who died in 2011 .
+
+ The singer 's father , Mitch Wine house , has" True simply entitled as Amy Winehouse is in production, 5 [' simply', ' entitled', ' as', ' Amy', ' Wine', 'house']
+505 117 Name of father of x -1 Name of father of Amy Winehouse Mitch Winehouse Amy Winehouse "[',' ' the' ' singer' ' who' ' died' ' in' ' 2011' '.' '\n' '\n' 'The'
+ ' singer' ""'s"" ' father' ',' ' Mitch' ' Wine' 'house' ',' ' has']" ", the singer who died in 2011 .
+
+ The singer 's father , Mitch Wine house , has" True " Foundation ===
+" 4 [' Foundation', ' ===', 'Amy', ' Wine', 'house']
+506 117 Name of father of x -1 Name of father of Amy Winehouse Mitch Winehouse Amy Winehouse "[',' ' the' ' singer' ' who' ' died' ' in' ' 2011' '.' '\n' '\n' 'The'
+ ' singer' ""'s"" ' father' ',' ' Mitch' ' Wine' 'house' ',' ' has']" ", the singer who died in 2011 .
+
+ The singer 's father , Mitch Wine house , has" True parents set up The Amy Winehouse Foundation to 6 [' parents', ' set', ' up', ' The', ' Amy', ' Wine', 'house']
+507 117 Name of father of x -1 Name of father of Amy Winehouse Mitch Winehouse Amy Winehouse "[',' ' the' ' singer' ' who' ' died' ' in' ' 2011' '.' '\n' '\n' 'The'
+ ' singer' ""'s"" ' father' ',' ' Mitch' ' Wine' 'house' ',' ' has']" ", the singer who died in 2011 .
+
+ The singer 's father , Mitch Wine house , has" True mainstream artists such as Amy Winehouse and Adele, independent 6 [' mainstream', ' artists', ' such', ' as', ' Amy', ' Wine', 'house']
+508 118 Name of father of x -1 Name of father of Karl Marx Heinrich Marx Karl Marx "[',' ' the' ' father' ' of' ' communism' ',' ' and' ' the' ' father' ' of'
+ ' the' ' proletariat' '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two']" ", the father of communism , and the father of the proletariat .
+
+ The first of these two" False 1 ['Karl', ' Marx']
+509 118 Name of father of x -1 Name of father of Karl Marx Heinrich Marx Karl Marx "[',' ' the' ' father' ' of' ' communism' ',' ' and' ' the' ' father' ' of'
+ ' the' ' proletariat' '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two']" ", the father of communism , and the father of the proletariat .
+
+ The first of these two" False visitors. Karl Marx and Frederick 3 [' visitors', '.', ' Karl', ' Marx']
+510 118 Name of father of x -1 Name of father of Karl Marx Heinrich Marx Karl Marx "[',' ' the' ' father' ' of' ' communism' ',' ' and' ' the' ' father' ' of'
+ ' the' ' proletariat' '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two']" ", the father of communism , and the father of the proletariat .
+
+ The first of these two" False Vladimir Lenin, Karl Marx and Friedrich Engels. 4 [' Vladimir', ' Lenin', ',', ' Karl', ' Marx']
+511 118 Name of father of x -1 Name of father of Karl Marx Heinrich Marx Karl Marx "[',' ' the' ' father' ' of' ' communism' ',' ' and' ' the' ' father' ' of'
+ ' the' ' proletariat' '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two']" ", the father of communism , and the father of the proletariat .
+
+ The first of these two" False Castro describes Karl Marx and Cuban nationalist 3 [' Castro', ' describes', ' Karl', ' Marx']
+512 118 Name of father of x -1 Name of father of Karl Marx Heinrich Marx Karl Marx "[',' ' the' ' father' ' of' ' communism' ',' ' and' ' the' ' father' ' of'
+ ' the' ' proletariat' '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two']" ", the father of communism , and the father of the proletariat .
+
+ The first of these two" False previous defense of Karl Marx and other radical 4 [' previous', ' defense', ' of', ' Karl', ' Marx']
+513 121 Name of father of x -1 Name of father of Cicero Marcus Tullius Cicero Cicero "[',' ' the' ' Roman' ' or' 'ator' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' law' '.' '\n' '\n' 'The' ' Roman' ' law' ' was']" ", the Roman or ator , and the father of the Roman law .
+
+ The Roman law was" False which recalled Cicero from his exile. 3 [' which', ' recalled', ' Cic', 'ero']
+514 121 Name of father of x -1 Name of father of Cicero Marcus Tullius Cicero Cicero "[',' ' the' ' Roman' ' or' 'ator' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' law' '.' '\n' '\n' 'The' ' Roman' ' law' ' was']" ", the Roman or ator , and the father of the Roman law .
+
+ The Roman law was" False now lost letter, Cicero, though an admirer 5 [' now', ' lost', ' letter', ',', ' Cic', 'ero']
+515 121 Name of father of x -1 Name of father of Cicero Marcus Tullius Cicero Cicero "[',' ' the' ' Roman' ' or' 'ator' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' law' '.' '\n' '\n' 'The' ' Roman' ' law' ' was']" ", the Roman or ator , and the father of the Roman law .
+
+ The Roman law was" False the former consul Cicero for this very 5 [' the', ' former', ' cons', 'ul', ' Cic', 'ero']
+516 121 Name of father of x -1 Name of father of Cicero Marcus Tullius Cicero Cicero "[',' ' the' ' Roman' ' or' 'ator' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' law' '.' '\n' '\n' 'The' ' Roman' ' law' ' was']" ", the Roman or ator , and the father of the Roman law .
+
+ The Roman law was" False variations thereof, by Cicero and others. Roman 5 [' variations', ' thereof', ',', ' by', ' Cic', 'ero']
+517 121 Name of father of x -1 Name of father of Cicero Marcus Tullius Cicero Cicero "[',' ' the' ' Roman' ' or' 'ator' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' law' '.' '\n' '\n' 'The' ' Roman' ' law' ' was']" ", the Roman or ator , and the father of the Roman law .
+
+ The Roman law was" False However, in 447 BC, Cicero recorded that the Quaestors 8 [' However', ',', ' in', ' 4', '47', ' BC', ',', ' Cic', 'ero']
+518 122 Name of father of x -1 Name of father of Elizabeth Taylor Francis Lenn Taylor Elizabeth Taylor "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ' Elizabeth' ' Taylor' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , and the mother of the actress Elizabeth Taylor .
+
+ The name of the father" False marriages to give Elizabeth Taylor a run for her money. 4 [' marriages', ' to', ' give', ' Elizabeth', ' Taylor']
+519 122 Name of father of x -1 Name of father of Elizabeth Taylor Francis Lenn Taylor Elizabeth Taylor "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ' Elizabeth' ' Taylor' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , and the mother of the actress Elizabeth Taylor .
+
+ The name of the father" False " Natalie Wood, Elizabeth Taylor and Grace Kelly.
+" 4 [' Natalie', ' Wood', ',', ' Elizabeth', ' Taylor']
+520 122 Name of father of x -1 Name of father of Elizabeth Taylor Francis Lenn Taylor Elizabeth Taylor "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ' Elizabeth' ' Taylor' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , and the mother of the actress Elizabeth Taylor .
+
+ The name of the father" False Sabbath), Guns 1 ' Roses, Elizabeth Taylor, George Michael, 8 "[' Sabbath', '),', ' Guns', ' 1', "" '"", ' Roses', ',', ' Elizabeth', ' Taylor']"
+521 122 Name of father of x -1 Name of father of Elizabeth Taylor Francis Lenn Taylor Elizabeth Taylor "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ' Elizabeth' ' Taylor' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , and the mother of the actress Elizabeth Taylor .
+
+ The name of the father" False He also dated Elizabeth Taylor for a year between 4 [' He', ' also', ' dated', ' Elizabeth', ' Taylor']
+522 122 Name of father of x -1 Name of father of Elizabeth Taylor Francis Lenn Taylor Elizabeth Taylor "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ' Elizabeth' ' Taylor' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , and the mother of the actress Elizabeth Taylor .
+
+ The name of the father" False committed to star opposite Elizabeth Taylor in the film Two 5 [' committed', ' to', ' star', ' opposite', ' Elizabeth', ' Taylor']
+523 123 Name of father of x -1 Name of father of Scarlett Johansson Karsten Olaf Johansson Scarlett Johansson "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the'
+ ' movie' '.' '\n' '\n' 'The' ' movie' ' is' ' based' ' on' ' the' ' book']" ", the actress who plays the lead in the movie .
+
+ The movie is based on the book" False 3 ['Scar', 'lett', ' Joh', 'ansson']
+524 123 Name of father of x -1 Name of father of Scarlett Johansson Karsten Olaf Johansson Scarlett Johansson "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the'
+ ' movie' '.' '\n' '\n' 'The' ' movie' ' is' ' based' ' on' ' the' ' book']" ", the actress who plays the lead in the movie .
+
+ The movie is based on the book" False Chris Evans and Scarlett Johansson auctioned off 5 [' Chris', ' Evans', ' and', ' Scarlett', ' Joh', 'ansson']
+525 123 Name of father of x -1 Name of father of Scarlett Johansson Karsten Olaf Johansson Scarlett Johansson "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the'
+ ' movie' '.' '\n' '\n' 'The' ' movie' ' is' ' based' ' on' ' the' ' book']" ", the actress who plays the lead in the movie .
+
+ The movie is based on the book" False 3 ['Scar', 'lett', ' Joh', 'ansson']
+526 123 Name of father of x -1 Name of father of Scarlett Johansson Karsten Olaf Johansson Scarlett Johansson "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the'
+ ' movie' '.' '\n' '\n' 'The' ' movie' ' is' ' based' ' on' ' the' ' book']" ", the actress who plays the lead in the movie .
+
+ The movie is based on the book" False England opposite Scarlett Johansson and Natalie Portman. 4 [' England', ' opposite', ' Scarlett', ' Joh', 'ansson']
+527 123 Name of father of x -1 Name of father of Scarlett Johansson Karsten Olaf Johansson Scarlett Johansson "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the'
+ ' movie' '.' '\n' '\n' 'The' ' movie' ' is' ' based' ' on' ' the' ' book']" ", the actress who plays the lead in the movie .
+
+ The movie is based on the book" False 1. Chris Evans and Scarlett Johansson auctioned off passes 7 [' 1', '.', ' Chris', ' Evans', ' and', ' Scarlett', ' Joh', 'ansson']
+528 124 Name of father of x -1 Name of father of Max Ernst Philipp Ernst Max Ernst "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False Impressionism and Max Ernst of Surrealism. 5 [' Imp', 'ression', 'ism', ' and', ' Max', ' Ernst']
+529 124 Name of father of x -1 Name of father of Max Ernst Philipp Ernst Max Ernst "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False repeat Year 10 at Max Ernst Gymnasium, but graduated 5 [' repeat', ' Year', ' 10', ' at', ' Max', ' Ernst']
+530 124 Name of father of x -1 Name of father of Max Ernst Philipp Ernst Max Ernst "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False controlled chaos, Max Ernst painted Lissajous 4 [' controlled', ' chaos', ',', ' Max', ' Ernst']
+531 124 Name of father of x -1 Name of father of Max Ernst Philipp Ernst Max Ernst "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False visual artists. Max Ernst depicts him as the 4 [' visual', ' artists', '.', ' Max', ' Ernst']
+532 124 Name of father of x -1 Name of father of Max Ernst Philipp Ernst Max Ernst "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False Surrealist artist Max Ernst made the silent 5 [' Sur', 'real', 'ist', ' artist', ' Max', ' Ernst']
+533 125 Name of father of x -1 Name of father of Thomas Mann Thomas Johann Heinrich Mann Thomas Mann "[',' ' the' ' author' ' of' ' _' 'The' ' Magic' ' Mountain' '_' ',' ' and'
+ ' _' 'The' ' Conf' 'essions' ' of' ' Felix' ' Kru' 'll' '_']" , the author of _ The Magic Mountain _ , and _ The Conf essions of Felix Kru ll _ False " lived"", while Thomas Mann and Marcel Proust" 4 "[' lived', '"",', ' while', ' Thomas', ' Mann']"
+534 125 Name of father of x -1 Name of father of Thomas Mann Thomas Johann Heinrich Mann Thomas Mann "[',' ' the' ' author' ' of' ' _' 'The' ' Magic' ' Mountain' '_' ',' ' and'
+ ' _' 'The' ' Conf' 'essions' ' of' ' Felix' ' Kru' 'll' '_']" , the author of _ The Magic Mountain _ , and _ The Conf essions of Felix Kru ll _ False " lived"", while Thomas Mann and Marcel Proust" 4 "[' lived', '"",', ' while', ' Thomas', ' Mann']"
+535 125 Name of father of x -1 Name of father of Thomas Mann Thomas Johann Heinrich Mann Thomas Mann "[',' ' the' ' author' ' of' ' _' 'The' ' Magic' ' Mountain' '_' ',' ' and'
+ ' _' 'The' ' Conf' 'essions' ' of' ' Felix' ' Kru' 'll' '_']" , the author of _ The Magic Mountain _ , and _ The Conf essions of Felix Kru ll _ False displaced scholars, Thomas Mann and Albert 4 [' displaced', ' scholars', ',', ' Thomas', ' Mann']
+536 125 Name of father of x -1 Name of father of Thomas Mann Thomas Johann Heinrich Mann Thomas Mann "[',' ' the' ' author' ' of' ' _' 'The' ' Magic' ' Mountain' '_' ',' ' and'
+ ' _' 'The' ' Conf' 'essions' ' of' ' Felix' ' Kru' 'll' '_']" , the author of _ The Magic Mountain _ , and _ The Conf essions of Felix Kru ll _ False " for books by Thomas Mann and E. T. A. Hoffmann.
+" 4 [' for', ' books', ' by', ' Thomas', ' Mann']
+537 125 Name of father of x -1 Name of father of Thomas Mann Thomas Johann Heinrich Mann Thomas Mann "[',' ' the' ' author' ' of' ' _' 'The' ' Magic' ' Mountain' '_' ',' ' and'
+ ' _' 'The' ' Conf' 'essions' ' of' ' Felix' ' Kru' 'll' '_']" , the author of _ The Magic Mountain _ , and _ The Conf essions of Felix Kru ll _ False interventions of Thomas Mann and others. Alfred, 3 [' interventions', ' of', ' Thomas', ' Mann']
+538 126 Name of father of x -1 Name of father of Chulalongkorn Mongkut Chulalongkorn "[' University' '\n' '\n' 'The' ' name' ' of' ' the' ' university' ' is'
+ ' Ch' 'ul' 'along' 'k' 'orn' ' University' '.' '\n' '\n' 'The' ' name']" " University
+
+ The name of the university is Ch ul along k orn University .
+
+ The name" False original Prince Chulalongkorn but left the cast 6 [' original', ' Prince', ' Ch', 'ul', 'along', 'k', 'orn']
+539 126 Name of father of x -1 Name of father of Chulalongkorn Mongkut Chulalongkorn "[' University' '\n' '\n' 'The' ' name' ' of' ' the' ' university' ' is'
+ ' Ch' 'ul' 'along' 'k' 'orn' ' University' '.' '\n' '\n' 'The' ' name']" " University
+
+ The name of the university is Ch ul along k orn University .
+
+ The name" False original Prince Chulalongkorn but left the cast 6 [' original', ' Prince', ' Ch', 'ul', 'along', 'k', 'orn']
+540 126 Name of father of x -1 Name of father of Chulalongkorn Mongkut Chulalongkorn "[' University' '\n' '\n' 'The' ' name' ' of' ' the' ' university' ' is'
+ ' Ch' 'ul' 'along' 'k' 'orn' ' University' '.' '\n' '\n' 'The' ' name']" " University
+
+ The name of the university is Ch ul along k orn University .
+
+ The name" False leaving Siam. Chulalongkorn arrives with a letter 8 [' leaving', ' S', 'iam', '.', ' Ch', 'ul', 'along', 'k', 'orn']
+541 126 Name of father of x -1 Name of father of Chulalongkorn Mongkut Chulalongkorn "[' University' '\n' '\n' 'The' ' name' ' of' ' the' ' university' ' is'
+ ' Ch' 'ul' 'along' 'k' 'orn' ' University' '.' '\n' '\n' 'The' ' name']" " University
+
+ The name of the university is Ch ul along k orn University .
+
+ The name" False In 1896 King Chulalongkorn of Siam visited 7 [' In', ' 1896', ' King', ' Ch', 'ul', 'along', 'k', 'orn']
+542 126 Name of father of x -1 Name of father of Chulalongkorn Mongkut Chulalongkorn "[' University' '\n' '\n' 'The' ' name' ' of' ' the' ' university' ' is'
+ ' Ch' 'ul' 'along' 'k' 'orn' ' University' '.' '\n' '\n' 'The' ' name']" " University
+
+ The name of the university is Ch ul along k orn University .
+
+ The name" False Thai, and English at Chulalongkorn University, 9 [' Thai', ',', ' and', ' English', ' at', ' Ch', 'ul', 'along', 'k', 'orn']
+543 127 Name of father of x -1 Name of father of J. R. R. Tolkien Arthur Reuel Tolkien J. R. R. Tolkien "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' J' '.' ' R' '.'
+ ' R' '.' ' Tolkien' ' was' ' John' ' Ronald' ' Re' 'uel']" "
+
+ The name of the father of J . R . R . Tolkien was John Ronald Re uel" False the Rings, J. R. R. Tolkien writes that 9 [' the', ' Rings', ',', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+544 127 Name of father of x -1 Name of father of J. R. R. Tolkien Arthur Reuel Tolkien J. R. R. Tolkien "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' J' '.' ' R' '.'
+ ' R' '.' ' Tolkien' ' was' ' John' ' Ronald' ' Re' 'uel']" "
+
+ The name of the father of J . R . R . Tolkien was John Ronald Re uel" False of the Rings, J. R. R. Tolkien writes that 10 [' of', ' the', ' Rings', ',', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+545 127 Name of father of x -1 Name of father of J. R. R. Tolkien Arthur Reuel Tolkien J. R. R. Tolkien "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' J' '.' ' R' '.'
+ ' R' '.' ' Tolkien' ' was' ' John' ' Ronald' ' Re' 'uel']" "
+
+ The name of the father of J . R . R . Tolkien was John Ronald Re uel" False Shakespeare, J. R. R. Tolkien and Poet Laureate 8 [' Shakespeare', ',', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+546 127 Name of father of x -1 Name of father of J. R. R. Tolkien Arthur Reuel Tolkien J. R. R. Tolkien "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' J' '.' ' R' '.'
+ ' R' '.' ' Tolkien' ' was' ' John' ' Ronald' ' Re' 'uel']" "
+
+ The name of the father of J . R . R . Tolkien was John Ronald Re uel" False Rings author J. R. R. Tolkien who lived in 8 [' Rings', ' author', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+547 127 Name of father of x -1 Name of father of J. R. R. Tolkien Arthur Reuel Tolkien J. R. R. Tolkien "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' J' '.' ' R' '.'
+ ' R' '.' ' Tolkien' ' was' ' John' ' Ronald' ' Re' 'uel']" "
+
+ The name of the father of J . R . R . Tolkien was John Ronald Re uel" False in the works of J. R. R. Tolkien and C. S. Lewis; 10 [' in', ' the', ' works', ' of', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+548 128 Name of father of x -1 Name of father of Zeus Kronos Zeus "[',' ' the' ' son' ' of' ' Cron' 'us' ',' ' and' ' of' ' R' 'hea' ','
+ ' the' ' daughter' ' of' ' Cron' 'us' ',' ' and' ' of']" , the son of Cron us , and of R hea , the daughter of Cron us , and of False bringing the Zeus within its 2 [' bringing', ' the', ' Zeus']
+549 128 Name of father of x -1 Name of father of Zeus Kronos Zeus "[',' ' the' ' son' ' of' ' Cron' 'us' ',' ' and' ' of' ' R' 'hea' ','
+ ' the' ' daughter' ' of' ' Cron' 'us' ',' ' and' ' of']" , the son of Cron us , and of R hea , the daughter of Cron us , and of False 1 ['Z', 'eus']
+550 128 Name of father of x -1 Name of father of Zeus Kronos Zeus "[',' ' the' ' son' ' of' ' Cron' 'us' ',' ' and' ' of' ' R' 'hea' ','
+ ' the' ' daughter' ' of' ' Cron' 'us' ',' ' and' ' of']" , the son of Cron us , and of R hea , the daughter of Cron us , and of False new DM-15B Nike Zeus B (the earlier 6 [' new', ' DM', '-', '15', 'B', ' Nike', ' Zeus']
+551 128 Name of father of x -1 Name of father of Zeus Kronos Zeus "[',' ' the' ' son' ' of' ' Cron' 'us' ',' ' and' ' of' ' R' 'hea' ','
+ ' the' ' daughter' ' of' ' Cron' 'us' ',' ' and' ' of']" , the son of Cron us , and of R hea , the daughter of Cron us , and of False " that in comparison to Zeus and Ares, ""the Furies" 4 [' that', ' in', ' comparison', ' to', ' Zeus']
+552 128 Name of father of x -1 Name of father of Zeus Kronos Zeus "[',' ' the' ' son' ' of' ' Cron' 'us' ',' ' and' ' of' ' R' 'hea' ','
+ ' the' ' daughter' ' of' ' Cron' 'us' ',' ' and' ' of']" , the son of Cron us , and of R hea , the daughter of Cron us , and of False Throughout, Zeus was the focus 2 [' Throughout', ',', ' Zeus']
+553 129 Name of father of x -1 Name of father of Frédéric Chopin Nicolas Chopin Frédéric Chopin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Fr' 'é' 'd' 'é' 'ric' ' Chop'
+ 'in' '\n' '\n' 'Name' ' of' ' wife' ' of' ' Fr']" "
+
+ Name of mother of Fr é d é ric Chop in
+
+ Name of wife of Fr" False bankruptcy. Frédéric Chopin and Auguste Franchomme 8 [' bankruptcy', '.', ' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']
+554 129 Name of father of x -1 Name of father of Frédéric Chopin Nicolas Chopin Frédéric Chopin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Fr' 'é' 'd' 'é' 'ric' ' Chop'
+ 'in' '\n' '\n' 'Name' ' of' ' wife' ' of' ' Fr']" "
+
+ Name of mother of Fr é d é ric Chop in
+
+ Name of wife of Fr" False " Frédéric Chopin =
+" 6 [' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']
+555 129 Name of father of x -1 Name of father of Frédéric Chopin Nicolas Chopin Frédéric Chopin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Fr' 'é' 'd' 'é' 'ric' ' Chop'
+ 'in' '\n' '\n' 'Name' ' of' ' wife' ' of' ' Fr']" "
+
+ Name of mother of Fr é d é ric Chop in
+
+ Name of wife of Fr" False friends and colleagues Frédéric Chopin and Franz Liszt, 9 [' friends', ' and', ' colleagues', ' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']
+556 129 Name of father of x -1 Name of father of Frédéric Chopin Nicolas Chopin Frédéric Chopin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Fr' 'é' 'd' 'é' 'ric' ' Chop'
+ 'in' '\n' '\n' 'Name' ' of' ' wife' ' of' ' Fr']" "
+
+ Name of mother of Fr é d é ric Chop in
+
+ Name of wife of Fr" False " 'Africaine"". Frédéric Chopin and Auguste Franchomme" 11 "["" '"", 'Af', 'ric', 'aine', '"".', ' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']"
+557 129 Name of father of x -1 Name of father of Frédéric Chopin Nicolas Chopin Frédéric Chopin "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Fr' 'é' 'd' 'é' 'ric' ' Chop'
+ 'in' '\n' '\n' 'Name' ' of' ' wife' ' of' ' Fr']" "
+
+ Name of mother of Fr é d é ric Chop in
+
+ Name of wife of Fr" False " l'opéra L 'Africaine"". Frédéric Chopin and Auguste" 17 "[' l', ""'"", 'op', 'é', 'ra', ' L', "" '"", 'Af', 'ric', 'aine', '"".', ' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']"
+558 130 Name of father of x -1 Name of father of Titian Gregorio Vecellio Titian "[',' ' the' ' painter' ',' ' and' ' of' ' the' ' painter' ""'s"" ' father'
+ ',' ' and' ' of' ' the' ' painter' ""'s"" ' father' ""'s"" ' father' ',']" , the painter , and of the painter 's father , and of the painter 's father 's father , False Venetian artists such as Titian and Veronese whom 7 [' Ven', 'et', 'ian', ' artists', ' such', ' as', ' Tit', 'ian']
+559 130 Name of father of x -1 Name of father of Titian Gregorio Vecellio Titian "[',' ' the' ' painter' ',' ' and' ' of' ' the' ' painter' ""'s"" ' father'
+ ',' ' and' ' of' ' the' ' painter' ""'s"" ' father' ""'s"" ' father' ',']" , the painter , and of the painter 's father , and of the painter 's father 's father , False influenced by the works of Titian and Rubens, Etty 6 [' influenced', ' by', ' the', ' works', ' of', ' Tit', 'ian']
+560 130 Name of father of x -1 Name of father of Titian Gregorio Vecellio Titian "[',' ' the' ' painter' ',' ' and' ' of' ' the' ' painter' ""'s"" ' father'
+ ',' ' and' ' of' ' the' ' painter' ""'s"" ' father' ""'s"" ' father' ',']" , the painter , and of the painter 's father , and of the painter 's father 's father , False years ago stole a Titian painting from 5 [' years', ' ago', ' stole', ' a', ' Tit', 'ian']
+561 130 Name of father of x -1 Name of father of Titian Gregorio Vecellio Titian "[',' ' the' ' painter' ',' ' and' ' of' ' the' ' painter' ""'s"" ' father'
+ ',' ' and' ' of' ' the' ' painter' ""'s"" ' father' ""'s"" ' father' ',']" , the painter , and of the painter 's father , and of the painter 's father 's father , False paintings by Titian, 1510 – 1545. 3 [' paintings', ' by', ' Tit', 'ian']
+562 130 Name of father of x -1 Name of father of Titian Gregorio Vecellio Titian "[',' ' the' ' painter' ',' ' and' ' of' ' the' ' painter' ""'s"" ' father'
+ ',' ' and' ' of' ' the' ' painter' ""'s"" ' father' ""'s"" ' father' ',']" , the painter , and of the painter 's father , and of the painter 's father 's father , False 1962) and Titian ’ s Death of Actaeon 4 [' 1962', ')', ' and', ' Tit', 'ian']
+563 131 Name of father of x -1 Name of father of Lord Byron John Byron Lord Byron "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' poet' ' of' ' the' '\n' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the poet of the
+ " False initial idea by Lord Byron that they each 4 [' initial', ' idea', ' by', ' Lord', ' Byron']
+564 131 Name of father of x -1 Name of father of Lord Byron John Byron Lord Byron "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' poet' ' of' ' the' '\n' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the poet of the
+ " False word by the works of Lord Byron or Henry Wadsworth 6 [' word', ' by', ' the', ' works', ' of', ' Lord', ' Byron']
+565 131 Name of father of x -1 Name of father of Lord Byron John Byron Lord Byron "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' poet' ' of' ' the' '\n' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the poet of the
+ " False while playing Lord Byron in the Spanish 3 [' while', ' playing', ' Lord', ' Byron']
+566 131 Name of father of x -1 Name of father of Lord Byron John Byron Lord Byron "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' poet' ' of' ' the' '\n' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the poet of the
+ " False Lovelace, daughter of Lord Byron and sponsor of 7 [' Lo', 'vel', 'ace', ',', ' daughter', ' of', ' Lord', ' Byron']
+567 131 Name of father of x -1 Name of father of Lord Byron John Byron Lord Byron "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' poet' ' of' ' the' '\n' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the poet of the
+ " False the command of Lord Byron occupied the area 4 [' the', ' command', ' of', ' Lord', ' Byron']
+568 132 Name of father of x -1 Name of father of Jules Verne Pierre Gabriel Verne Jules Verne "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'Around' ' the'
+ ' World' ' in' ' E' 'ighty' ' Days' '""' ' and' ' ""' 'Twenty' ' Thousand']" ", the author of the famous novel "" Around the World in E ighty Days "" and "" Twenty Thousand" False popularized by Jules Verne and Adolphe d 'Ennery 6 [' popular', 'ized', ' by', ' J', 'ules', ' Ver', 'ne']
+569 132 Name of father of x -1 Name of father of Jules Verne Pierre Gabriel Verne Jules Verne "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'Around' ' the'
+ ' World' ' in' ' E' 'ighty' ' Days' '""' ' and' ' ""' 'Twenty' ' Thousand']" ", the author of the famous novel "" Around the World in E ighty Days "" and "" Twenty Thousand" False in part, by the 1870 Jules Verne novel Twenty Thousand 9 [' in', ' part', ',', ' by', ' the', ' 1870', ' J', 'ules', ' Ver', 'ne']
+570 132 Name of father of x -1 Name of father of Jules Verne Pierre Gabriel Verne Jules Verne "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'Around' ' the'
+ ' World' ' in' ' E' 'ighty' ' Days' '""' ' and' ' ""' 'Twenty' ' Thousand']" ", the author of the famous novel "" Around the World in E ighty Days "" and "" Twenty Thousand" False science-fiction writers Jules Verne and H. G. Wells. 7 [' science', '-', 'fiction', ' writers', ' J', 'ules', ' Ver', 'ne']
+571 132 Name of father of x -1 Name of father of Jules Verne Pierre Gabriel Verne Jules Verne "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'Around' ' the'
+ ' World' ' in' ' E' 'ighty' ' Days' '""' ' and' ' ""' 'Twenty' ' Thousand']" ", the author of the famous novel "" Around the World in E ighty Days "" and "" Twenty Thousand" False throughout Europe. Jules Verne mentioned her 6 [' throughout', ' Europe', '.', ' J', 'ules', ' Ver', 'ne']
+572 132 Name of father of x -1 Name of father of Jules Verne Pierre Gabriel Verne Jules Verne "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'Around' ' the'
+ ' World' ' in' ' E' 'ighty' ' Days' '""' ' and' ' ""' 'Twenty' ' Thousand']" ", the author of the famous novel "" Around the World in E ighty Days "" and "" Twenty Thousand" False published in 1870, Jules Verne reported the 7 [' published', ' in', ' 1870', ',', ' J', 'ules', ' Ver', 'ne']
+573 133 Name of father of x -1 Name of father of Augustine of Hippo Patricius Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' born' ' in'
+ ' 354' ',' ' and' ' died' ' in' ' 430' '.' '\n' '\n' 'The']" ", the great theolog ian , who was born in 354 , and died in 430 .
+
+ The" False John 4: 4-12, Augustine of Hippo gave a similar 10 [' John', ' 4', ':', ' 4', '-', '12', ',', ' Augustine', ' of', ' Hipp', 'o']
+574 133 Name of father of x -1 Name of father of Augustine of Hippo Patricius Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' born' ' in'
+ ' 354' ',' ' and' ' died' ' in' ' 430' '.' '\n' '\n' 'The']" ", the great theolog ian , who was born in 354 , and died in 430 .
+
+ The" False Pelagius and Augustine of Hippo on the subject 7 [' Pel', 'ag', 'ius', ' and', ' Augustine', ' of', ' Hipp', 'o']
+575 133 Name of father of x -1 Name of father of Augustine of Hippo Patricius Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' born' ' in'
+ ' 354' ',' ' and' ' died' ' in' ' 430' '.' '\n' '\n' 'The']" ", the great theolog ian , who was born in 354 , and died in 430 .
+
+ The" False 4 ['August', 'ine', ' of', ' Hipp', 'o']
+576 133 Name of father of x -1 Name of father of Augustine of Hippo Patricius Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' born' ' in'
+ ' 354' ',' ' and' ' died' ' in' ' 430' '.' '\n' '\n' 'The']" ", the great theolog ian , who was born in 354 , and died in 430 .
+
+ The" False the time of Augustine of Hippo (AD 354 – 430). 6 [' the', ' time', ' of', ' Augustine', ' of', ' Hipp', 'o']
+577 133 Name of father of x -1 Name of father of Augustine of Hippo Patricius Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' born' ' in'
+ ' 354' ',' ' and' ' died' ' in' ' 430' '.' '\n' '\n' 'The']" ", the great theolog ian , who was born in 354 , and died in 430 .
+
+ The" False even farther than Augustine of Hippo in arguing that 6 [' even', ' farther', ' than', ' Augustine', ' of', ' Hipp', 'o']
+578 134 Name of father of x -1 Name of father of Franz Liszt Adam Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Marie' ' d' ""'""
+ 'Ag' 'oult' ',' ' who' ' was' ' the' ' mother' ' of' ' the']" , the composer , and his wife , Marie d ' Ag oult , who was the mother of the False to his friend Franz Liszt to have it staged 6 [' to', ' his', ' friend', ' Franz', ' Lis', 'z', 't']
+579 134 Name of father of x -1 Name of father of Franz Liszt Adam Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Marie' ' d' ""'""
+ 'Ag' 'oult' ',' ' who' ' was' ' the' ' mother' ' of' ' the']" , the composer , and his wife , Marie d ' Ag oult , who was the mother of the False pianist and composer Franz Liszt met Marie d 'Agoult, 7 [' pian', 'ist', ' and', ' composer', ' Franz', ' Lis', 'z', 't']
+580 134 Name of father of x -1 Name of father of Franz Liszt Adam Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Marie' ' d' ""'""
+ 'Ag' 'oult' ',' ' who' ' was' ' the' ' mother' ' of' ' the']" , the composer , and his wife , Marie d ' Ag oult , who was the mother of the False championing the music of Franz Liszt and other music 8 [' champion', 'ing', ' the', ' music', ' of', ' Franz', ' Lis', 'z', 't']
+581 134 Name of father of x -1 Name of father of Franz Liszt Adam Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Marie' ' d' ""'""
+ 'Ag' 'oult' ',' ' who' ' was' ' the' ' mother' ' of' ' the']" , the composer , and his wife , Marie d ' Ag oult , who was the mother of the False championing the music of Franz Liszt and other music 8 [' champion', 'ing', ' the', ' music', ' of', ' Franz', ' Lis', 'z', 't']
+582 134 Name of father of x -1 Name of father of Franz Liszt Adam Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Marie' ' d' ""'""
+ 'Ag' 'oult' ',' ' who' ' was' ' the' ' mother' ' of' ' the']" , the composer , and his wife , Marie d ' Ag oult , who was the mother of the False the music of Franz Liszt and Richard Wagner 6 [' the', ' music', ' of', ' Franz', ' Lis', 'z', 't']
+583 135 Name of father of x -1 Name of father of Agatha Christie Frederick Alvah Miller Agatha Christie "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' daughter' ',']" "
+
+ I am a mother of two , a wife , a grandmother of two , a daughter ," False " ""oddly like an Agatha Christie thriller with all" 7 "[' ""', 'odd', 'ly', ' like', ' an', ' Ag', 'atha', ' Christie']"
+584 135 Name of father of x -1 Name of father of Agatha Christie Frederick Alvah Miller Agatha Christie "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' daughter' ',']" "
+
+ I am a mother of two , a wife , a grandmother of two , a daughter ," False Doyle and Agatha Christie were both obvious influences 4 [' Doyle', ' and', ' Ag', 'atha', ' Christie']
+585 135 Name of father of x -1 Name of father of Agatha Christie Frederick Alvah Miller Agatha Christie "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' daughter' ',']" "
+
+ I am a mother of two , a wife , a grandmother of two , a daughter ," False Adventure Company's Agatha Christie series. As with 5 "[' Adventure', ' Company', ""'s"", ' Ag', 'atha', ' Christie']"
+586 135 Name of father of x -1 Name of father of Agatha Christie Frederick Alvah Miller Agatha Christie "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' daughter' ',']" "
+
+ I am a mother of two , a wife , a grandmother of two , a daughter ," False for example Agatha Christie produced a 4 [' for', ' example', ' Ag', 'atha', ' Christie']
+587 135 Name of father of x -1 Name of father of Agatha Christie Frederick Alvah Miller Agatha Christie "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' daughter' ',']" "
+
+ I am a mother of two , a wife , a grandmother of two , a daughter ," False production of the Agatha Christie murder mystery 5 [' production', ' of', ' the', ' Ag', 'atha', ' Christie']
+588 136 Name of father of x -1 Name of father of Olivia Newton-John Brinley Newton-John Olivia Newton-John "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Special guest star Olivia Newton-John appears as herself, 6 [' Special', ' guest', ' star', ' Olivia', ' Newton', '-', 'John']
+589 136 Name of father of x -1 Name of father of Olivia Newton-John Brinley Newton-John Olivia Newton-John "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Groban and Olivia Newton-John reprise their roles 6 [' Gro', 'ban', ' and', ' Olivia', ' Newton', '-', 'John']
+590 136 Name of father of x -1 Name of father of Olivia Newton-John Brinley Newton-John Olivia Newton-John "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " mash-up of ""Magic"" by Olivia Newton-John and ""You Raise Me" 11 "[' mash', '-', 'up', ' of', ' ""', 'Magic', '""', ' by', ' Olivia', ' Newton', '-', 'John']"
+591 136 Name of father of x -1 Name of father of Olivia Newton-John Brinley Newton-John Olivia Newton-John "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Delta Goodrem, Olivia Newton-John and Kylie Minogue 7 [' Delta', ' Good', 'rem', ',', ' Olivia', ' Newton', '-', 'John']
+592 136 Name of father of x -1 Name of father of Olivia Newton-John Brinley Newton-John Olivia Newton-John "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Special guest star Olivia Newton-John appears as herself, 6 [' Special', ' guest', ' star', ' Olivia', ' Newton', '-', 'John']
+593 137 Name of father of x -1 Name of father of Lionel Messi Jorge Messi Lionel Messi "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' game' ' of'
+ ' football' '.' ' I' ' am' ' a' ' fan' ' of' ' the' ' game']" "
+
+ I am a big fan of the game of football . I am a fan of the game" False " beating the record of Lionel Messi by 14 months.
+" 5 [' beating', ' the', ' record', ' of', ' Lionel', ' Messi']
+594 137 Name of father of x -1 Name of father of Lionel Messi Jorge Messi Lionel Messi "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' game' ' of'
+ ' football' '.' ' I' ' am' ' a' ' fan' ' of' ' the' ' game']" "
+
+ I am a big fan of the game of football . I am a fan of the game" False to two goals from Lionel Messi and goals from 5 [' to', ' two', ' goals', ' from', ' Lionel', ' Messi']
+595 137 Name of father of x -1 Name of father of Lionel Messi Jorge Messi Lionel Messi "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' game' ' of'
+ ' football' '.' ' I' ' am' ' a' ' fan' ' of' ' the' ' game']" "
+
+ I am a big fan of the game of football . I am a fan of the game" False the fitness of Lionel Messi going into 4 [' the', ' fitness', ' of', ' Lionel', ' Messi']
+596 137 Name of father of x -1 Name of father of Lionel Messi Jorge Messi Lionel Messi "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' game' ' of'
+ ' football' '.' ' I' ' am' ' a' ' fan' ' of' ' the' ' game']" "
+
+ I am a big fan of the game of football . I am a fan of the game" False Ballon d'Or winners Lionel Messi and Cristiano Ronaldo. 7 "[' Ball', 'on', ' d', ""'"", 'Or', ' winners', ' Lionel', ' Messi']"
+597 137 Name of father of x -1 Name of father of Lionel Messi Jorge Messi Lionel Messi "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' game' ' of'
+ ' football' '.' ' I' ' am' ' a' ' fan' ' of' ' the' ' game']" "
+
+ I am a big fan of the game of football . I am a fan of the game" False nominated alongside Lionel Messi and Cristiano Ronaldo 3 [' nominated', ' alongside', ' Lionel', ' Messi']
+598 139 Name of father of x -1 Name of father of Batman Thomas Wayne Batman "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' father' ' to' ' me' ','
+ ' and' ' a' ' friend' ' to' ' all' ' of' ' us' '.' '\n']" ", the man who has been a father to me , and a friend to all of us .
+" False " Batman solo ===
+" 0 [' Batman']
+599 139 Name of father of x -1 Name of father of Batman Thomas Wayne Batman "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' father' ' to' ' me' ','
+ ' and' ' a' ' friend' ' to' ' all' ' of' ' us' '.' '\n']" ", the man who has been a father to me , and a friend to all of us .
+" False 0 ['Batman']
+600 139 Name of father of x -1 Name of father of Batman Thomas Wayne Batman "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' father' ' to' ' me' ','
+ ' and' ' a' ' friend' ' to' ' all' ' of' ' us' '.' '\n']" ", the man who has been a father to me , and a friend to all of us .
+" False 0 ['Batman']
+601 139 Name of father of x -1 Name of father of Batman Thomas Wayne Batman "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' father' ' to' ' me' ','
+ ' and' ' a' ' friend' ' to' ' all' ' of' ' us' '.' '\n']" ", the man who has been a father to me , and a friend to all of us .
+" False Kane in 2011 for his Batman Incorporated series. 5 [' Kane', ' in', ' 2011', ' for', ' his', ' Batman']
+602 139 Name of father of x -1 Name of father of Batman Thomas Wayne Batman "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' father' ' to' ' me' ','
+ ' and' ' a' ' friend' ' to' ' all' ' of' ' us' '.' '\n']" ", the man who has been a father to me , and a friend to all of us .
+" False Bruce Wayne / Batman in a favor to the casting 3 [' Bruce', ' Wayne', ' /', ' Batman']
+603 140 Name of father of x -1 Name of father of Sigmund Freud Jacob Freud Sigmund Freud "[',' ' the' ' founder' ' of' ' psycho' 'analysis' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' S' 'igm']" ", the founder of psycho analysis .
+
+ The name of the father of psycho analysis , S igm" False 3 ['S', 'igm', 'und', ' Freud']
+604 140 Name of father of x -1 Name of father of Sigmund Freud Jacob Freud Sigmund Freud "[',' ' the' ' founder' ' of' ' psycho' 'analysis' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' S' 'igm']" ", the founder of psycho analysis .
+
+ The name of the father of psycho analysis , S igm" False " destroy my ego ... Sigmund Freud ... Analyze this""" 7 [' destroy', ' my', ' ego', '...', ' S', 'igm', 'und', ' Freud']
+605 140 Name of father of x -1 Name of father of Sigmund Freud Jacob Freud Sigmund Freud "[',' ' the' ' founder' ' of' ' psycho' 'analysis' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' S' 'igm']" ", the founder of psycho analysis .
+
+ The name of the father of psycho analysis , S igm" False even name checks Sigmund Freud and throws 6 [' even', ' name', ' checks', ' S', 'igm', 'und', ' Freud']
+606 140 Name of father of x -1 Name of father of Sigmund Freud Jacob Freud Sigmund Freud "[',' ' the' ' founder' ' of' ' psycho' 'analysis' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' S' 'igm']" ", the founder of psycho analysis .
+
+ The name of the father of psycho analysis , S igm" False 3 ['S', 'igm', 'und', ' Freud']
+607 140 Name of father of x -1 Name of father of Sigmund Freud Jacob Freud Sigmund Freud "[',' ' the' ' founder' ' of' ' psycho' 'analysis' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' S' 'igm']" ", the founder of psycho analysis .
+
+ The name of the father of psycho analysis , S igm" False Lane, Wilhelm Reich, Sigmund Freud and others influenced 8 [' Lane', ',', ' Wilhelm', ' Reich', ',', ' S', 'igm', 'und', ' Freud']
+608 141 Name of father of x -1 Name of father of Guido Gezelle Pieter Jan Gezelle Guido Gezelle "[',' ' the' ' famous' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", the famous
+ " False and priest Guido Gezelle produced lyric poems 6 [' and', ' priest', ' Gu', 'ido', ' Ge', 'z', 'elle']
+609 141 Name of father of x -1 Name of father of Guido Gezelle Pieter Jan Gezelle Guido Gezelle "[',' ' the' ' famous' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", the famous
+ " False poet and priest Guido Gezelle produced lyric poems 7 [' poet', ' and', ' priest', ' Gu', 'ido', ' Ge', 'z', 'elle']
+610 142 Name of father of x -1 Name of father of Ronald Reagan Jack Reagan Ronald Reagan "[',' ' the' ' actor' ',' ' and' ' mother' ' of' ' Nancy' ' Reagan' ','
+ ' the' ' former' ' First' ' Lady' '.' '\n' '\n' 'The' ' Reagan' ' family']" ", the actor , and mother of Nancy Reagan , the former First Lady .
+
+ The Reagan family" False 3, 2007, the Ronald Reagan Presidential Foundation 6 [' 3', ',', ' 2007', ',', ' the', ' Ronald', ' Reagan']
+611 142 Name of father of x -1 Name of father of Ronald Reagan Jack Reagan Ronald Reagan "[',' ' the' ' actor' ',' ' and' ' mother' ' of' ' Nancy' ' Reagan' ','
+ ' the' ' former' ' First' ' Lady' '.' '\n' '\n' 'The' ' Reagan' ' family']" ", the actor , and mother of Nancy Reagan , the former First Lady .
+
+ The Reagan family" False 2 ['Ron', 'ald', ' Reagan']
+612 142 Name of father of x -1 Name of father of Ronald Reagan Jack Reagan Ronald Reagan "[',' ' the' ' actor' ',' ' and' ' mother' ' of' ' Nancy' ' Reagan' ','
+ ' the' ' former' ' First' ' Lady' '.' '\n' '\n' 'The' ' Reagan' ' family']" ", the actor , and mother of Nancy Reagan , the former First Lady .
+
+ The Reagan family" False 1988, President Ronald Reagan extended that 4 [' 1988', ',', ' President', ' Ronald', ' Reagan']
+613 142 Name of father of x -1 Name of father of Ronald Reagan Jack Reagan Ronald Reagan "[',' ' the' ' actor' ',' ' and' ' mother' ' of' ' Nancy' ' Reagan' ','
+ ' the' ' former' ' First' ' Lady' '.' '\n' '\n' 'The' ' Reagan' ' family']" ", the actor , and mother of Nancy Reagan , the former First Lady .
+
+ The Reagan family" False admitted to Ronald Reagan UCLA Medical Center 3 [' admitted', ' to', ' Ronald', ' Reagan']
+614 142 Name of father of x -1 Name of father of Ronald Reagan Jack Reagan Ronald Reagan "[',' ' the' ' actor' ',' ' and' ' mother' ' of' ' Nancy' ' Reagan' ','
+ ' the' ' former' ' First' ' Lady' '.' '\n' '\n' 'The' ' Reagan' ' family']" ", the actor , and mother of Nancy Reagan , the former First Lady .
+
+ The Reagan family" False McCain persuaded Ronald Reagan to support it. However, 3 [' McCain', ' persuaded', ' Ronald', ' Reagan']
+615 143 Name of father of x -1 Name of father of Mark Twain John Marshall Clemens Mark Twain "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' and' ' the' ' man' ' who' ' wrote']" , the author of the book , and the man who wrote the book , and the man who wrote False 2006; and the Mark Twain Award (Missouri 5 [' 2006', ';', ' and', ' the', ' Mark', ' Twain']
+616 143 Name of father of x -1 Name of father of Mark Twain John Marshall Clemens Mark Twain "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' and' ' the' ' man' ' who' ' wrote']" , the author of the book , and the man who wrote the book , and the man who wrote False went along with Mark Twain Zephyr trainset 9903 4 [' went', ' along', ' with', ' Mark', ' Twain']
+617 143 Name of father of x -1 Name of father of Mark Twain John Marshall Clemens Mark Twain "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' and' ' the' ' man' ' who' ' wrote']" , the author of the book , and the man who wrote the book , and the man who wrote False " hold. As writer Mark Twain said, ""It took" 5 [' hold', '.', ' As', ' writer', ' Mark', ' Twain']
+618 143 Name of father of x -1 Name of father of Mark Twain John Marshall Clemens Mark Twain "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' and' ' the' ' man' ' who' ' wrote']" , the author of the book , and the man who wrote the book , and the man who wrote False and the 2009 Mark Twain Award. It sold over 4 [' and', ' the', ' 2009', ' Mark', ' Twain']
+619 143 Name of father of x -1 Name of father of Mark Twain John Marshall Clemens Mark Twain "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' and' ' the' ' man' ' who' ' wrote']" , the author of the book , and the man who wrote the book , and the man who wrote False " to it were common. Mark Twain wrote that ""The" 6 [' to', ' it', ' were', ' common', '.', ' Mark', ' Twain']
+620 144 Name of father of x -1 Name of father of Franz Kafka Hermann Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' Met']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The Met" False Brandenfeld, to Franz Kafka and Felice Bauer. 6 [' Brand', 'en', 'feld', ',', ' to', ' Franz', ' Kafka']
+621 144 Name of father of x -1 Name of father of Franz Kafka Hermann Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' Met']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The Met" False 2 ['Fran', 'z', ' Kafka']
+622 144 Name of father of x -1 Name of father of Franz Kafka Hermann Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' Met']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The Met" False 2 ['Fran', 'z', ' Kafka']
+623 144 Name of father of x -1 Name of father of Franz Kafka Hermann Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' Met']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The Met" False 2 ['Fran', 'z', ' Kafka']
+624 144 Name of father of x -1 Name of father of Franz Kafka Hermann Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' Met']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The Met" False " businessman"" and by Franz Kafka as ""a true Kafka" 5 "[' businessman', '""', ' and', ' by', ' Franz', ' Kafka']"
+625 145 Name of father of x -1 Name of father of Hilary Duff Robert Erhard Duff Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False releases of Hilary Duff and Dignity, 4 [' releases', ' of', ' Hil', 'ary', ' Duff']
+626 145 Name of father of x -1 Name of father of Hilary Duff Robert Erhard Duff Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False with comparisons with Hilary Duff and Avril Lavigne. 5 [' with', ' comparisons', ' with', ' Hil', 'ary', ' Duff']
+627 145 Name of father of x -1 Name of father of Hilary Duff Robert Erhard Duff Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False longevity of the Hilary Duff album and the 5 [' longevity', ' of', ' the', ' Hil', 'ary', ' Duff']
+628 145 Name of father of x -1 Name of father of Hilary Duff Robert Erhard Duff Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Actress and singer Hilary Duff was rumored to have 5 [' Actress', ' and', ' singer', ' Hil', 'ary', ' Duff']
+629 145 Name of father of x -1 Name of father of Hilary Duff Robert Erhard Duff Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False sound to those of Hilary Duff and Avril 6 [' sound', ' to', ' those', ' of', ' Hil', 'ary', ' Duff']
+630 146 Name of father of x -1 Name of father of Alexander von Humboldt Alexander Georg von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' the' ' great' ' German' ' natural'
+ 'ist']" ", the great German natural ist .
+
+ The name of the father of the great German natural ist" False homeland. In 1807, Alexander von Humboldt argued that national 11 [' homeland', '.', ' In', ' 18', '07', ',', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+631 146 Name of father of x -1 Name of father of Alexander von Humboldt Alexander Georg von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' the' ' great' ' German' ' natural'
+ 'ist']" ", the great German natural ist .
+
+ The name of the father of the great German natural ist" False homeland. In 1807, Alexander von Humboldt argued that national 11 [' homeland', '.', ' In', ' 18', '07', ',', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+632 146 Name of father of x -1 Name of father of Alexander von Humboldt Alexander Georg von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' the' ' great' ' German' ' natural'
+ 'ist']" ", the great German natural ist .
+
+ The name of the father of the great German natural ist" False astronomical term by Alexander von Humboldt in 1845. It was 8 [' astronomical', ' term', ' by', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+633 146 Name of father of x -1 Name of father of Alexander von Humboldt Alexander Georg von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' the' ' great' ' German' ' natural'
+ 'ist']" ", the great German natural ist .
+
+ The name of the father of the great German natural ist" False collections made by Alexander von Humboldt and Aimé Bonpland 8 [' collections', ' made', ' by', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+634 146 Name of father of x -1 Name of father of Alexander von Humboldt Alexander Georg von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' the' ' great' ' German' ' natural'
+ 'ist']" ", the great German natural ist .
+
+ The name of the father of the great German natural ist" False naturalist Alexander von Humboldt noted in his work 7 [' natural', 'ist', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+635 147 Name of father of x -1 Name of father of Yukio Mishima Azusa Hiraoka Yukio Mishima "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False 'ichirō Tanizaki, Yukio Mishima and, more recently, 11 "["" '"", 'ich', 'ir', 'ō', ' Tan', 'iz', 'aki', ',', ' Yuk', 'io', ' Mish', 'ima']"
+636 147 Name of father of x -1 Name of father of Yukio Mishima Azusa Hiraoka Yukio Mishima "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False 'ichirō Tanizaki, Yukio Mishima and, more recently, 11 "["" '"", 'ich', 'ir', 'ō', ' Tan', 'iz', 'aki', ',', ' Yuk', 'io', ' Mish', 'ima']"
+637 147 Name of father of x -1 Name of father of Yukio Mishima Azusa Hiraoka Yukio Mishima "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False Suzuki and authors Yukio Mishima and Kōbō Abe 6 [' Suzuki', ' and', ' authors', ' Yuk', 'io', ' Mish', 'ima']
+638 147 Name of father of x -1 Name of father of Yukio Mishima Azusa Hiraoka Yukio Mishima "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False and authors Yukio Mishima and Kōbō Abe testified 5 [' and', ' authors', ' Yuk', 'io', ' Mish', 'ima']
+639 147 Name of father of x -1 Name of father of Yukio Mishima Azusa Hiraoka Yukio Mishima "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False 'ichirō Tanizaki, Yukio Mishima and, more recently, 11 "["" '"", 'ich', 'ir', 'ō', ' Tan', 'iz', 'aki', ',', ' Yuk', 'io', ' Mish', 'ima']"
+640 148 Name of father of x -1 Name of father of Frans Hals Franchois Fransz. Hals van Mechelen Frans Hals "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his daughter , the painter 's False Rommelpotspeler after Frans Hals. The Frans Hals catalogue 9 [' R', 'ommel', 'pots', 'pel', 'er', ' after', ' Fr', 'ans', ' H', 'als']
+641 148 Name of father of x -1 Name of father of Frans Hals Franchois Fransz. Hals van Mechelen Frans Hals "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his daughter , the painter 's False the style of Frans Hals and the school of 6 [' the', ' style', ' of', ' Fr', 'ans', ' H', 'als']
+642 148 Name of father of x -1 Name of father of Frans Hals Franchois Fransz. Hals van Mechelen Frans Hals "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his daughter , the painter 's False broadly based on the Frans Hals painting, 7 [' broadly', ' based', ' on', ' the', ' Fr', 'ans', ' H', 'als']
+643 148 Name of father of x -1 Name of father of Frans Hals Franchois Fransz. Hals van Mechelen Frans Hals "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his daughter , the painter 's False eight Vermeer and Frans Hals paintings which 8 [' eight', ' Ver', 'me', 'er', ' and', ' Fr', 'ans', ' H', 'als']
+644 148 Name of father of x -1 Name of father of Frans Hals Franchois Fransz. Hals van Mechelen Frans Hals "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his daughter , the painter 's False Rommelpotspeler after Frans Hals. The Frans Hals 9 [' R', 'ommel', 'pots', 'pel', 'er', ' after', ' Fr', 'ans', ' H', 'als']
+645 149 Name of father of x -1 Name of father of Alfred Sisley William Sisley Alfred Sisley "[',' ' the' ' painter' ',' ' who' ' died' ' in' ' 18' '39' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' painter' ' of' ' the' ' ""']" ", the painter , who died in 18 39 .
+
+ The name of the painter of the """ False landscape painters Alfred Sisley (1839 – 1899), Claude 6 [' landscape', ' pain', 'ters', ' Alfred', ' S', 'is', 'ley']
+646 150 Name of father of x -1 Name of father of Isaac Newton Isaac Newton Sr. Isaac Newton "[',' ' the' ' great' ' mathematician' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' science' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three'
+ ' great']" ", the great mathematician , and the father of modern science .
+
+ The first of the three great" False book Opticks, Isaac Newton described the mechanism 5 [' book', ' Opt', 'icks', ',', ' Isaac', ' Newton']
+647 150 Name of father of x -1 Name of father of Isaac Newton Isaac Newton Sr. Isaac Newton "[',' ' the' ' great' ' mathematician' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' science' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three'
+ ' great']" ", the great mathematician , and the father of modern science .
+
+ The first of the three great" False Johannes Kepler, and Isaac Newton developed an understanding 5 [' Johannes', ' Kepler', ',', ' and', ' Isaac', ' Newton']
+648 150 Name of father of x -1 Name of father of Isaac Newton Isaac Newton Sr. Isaac Newton "[',' ' the' ' great' ' mathematician' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' science' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three'
+ ' great']" ", the great mathematician , and the father of modern science .
+
+ The first of the three great" False hemisphere (now known as the Isaac Newton Group of Telescopes, 7 [' hemisphere', ' (', 'now', ' known', ' as', ' the', ' Isaac', ' Newton']
+649 150 Name of father of x -1 Name of father of Isaac Newton Isaac Newton Sr. Isaac Newton "[',' ' the' ' great' ' mathematician' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' science' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three'
+ ' great']" ", the great mathematician , and the father of modern science .
+
+ The first of the three great" False " Newton (1691 – 1727)
+" 10 [' Newton', ' (', '16', '91', ' –', ' 17', '27', ')', 'Isa', 'ac', ' Newton']
+650 150 Name of father of x -1 Name of father of Isaac Newton Isaac Newton Sr. Isaac Newton "[',' ' the' ' great' ' mathematician' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' science' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three'
+ ' great']" ", the great mathematician , and the father of modern science .
+
+ The first of the three great" False tutoring (including Isaac Newton as a tutor), and he 5 [' tut', 'oring', ' (', 'including', ' Isaac', ' Newton']
+651 151 Name of father of x -1 Name of father of Charles III Prince Philip, Duke of Edinburgh Charles III "['.' '\n' '\n' 'The' ' King' ' of' ' Spain' ',' ' who' ' was' ' a' ' very'
+ ' good' ' man' ',' ' and' ' a' ' very' ' good' ' king']" ".
+
+ The King of Spain , who was a very good man , and a very good king" False 1 ['Charles', ' III']
+652 151 Name of father of x -1 Name of father of Charles III Prince Philip, Duke of Edinburgh Charles III "['.' '\n' '\n' 'The' ' King' ' of' ' Spain' ',' ' who' ' was' ' a' ' very'
+ ' good' ' man' ',' ' and' ' a' ' very' ' good' ' king']" ".
+
+ The King of Spain , who was a very good man , and a very good king" False April 1, 1778, King Charles III signed a document 8 [' April', ' 1', ',', ' 17', '78', ',', ' King', ' Charles', ' III']
+653 151 Name of father of x -1 Name of father of Charles III Prince Philip, Duke of Edinburgh Charles III "['.' '\n' '\n' 'The' ' King' ' of' ' Spain' ',' ' who' ' was' ' a' ' very'
+ ' good' ' man' ',' ' and' ' a' ' very' ' good' ' king']" ".
+
+ The King of Spain , who was a very good man , and a very good king" False America. King Charles III created a public 4 [' America', '.', ' King', ' Charles', ' III']
+654 151 Name of father of x -1 Name of father of Charles III Prince Philip, Duke of Edinburgh Charles III "['.' '\n' '\n' 'The' ' King' ' of' ' Spain' ',' ' who' ' was' ' a' ' very'
+ ' good' ' man' ',' ' and' ' a' ' very' ' good' ' king']" ".
+
+ The King of Spain , who was a very good man , and a very good king" False 1 ['Charles', ' III']
+655 151 Name of father of x -1 Name of father of Charles III Prince Philip, Duke of Edinburgh Charles III "['.' '\n' '\n' 'The' ' King' ' of' ' Spain' ',' ' who' ' was' ' a' ' very'
+ ' good' ' man' ',' ' and' ' a' ' very' ' good' ' king']" ".
+
+ The King of Spain , who was a very good man , and a very good king" False " Cross of the Order of Charles III of Spain
+" 6 [' Cross', ' of', ' the', ' Order', ' of', ' Charles', ' III']
+656 152 Name of father of x -1 Name of father of Honoré de Balzac Bernard-François Balzac Honoré de Balzac "[',' ' the' ' author' ' of' ' ""' 'La' ' Com' 'é' 'die' ' hum' 'aine' '""'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the author of "" La Com é die hum aine ""
+
+ The name of the father of" False novelist and playwright Honoré de Balzac (1799 – 1850), 8 [' novelist', ' and', ' play', 'wright', ' Honor', 'é', ' de', ' Bal', 'zac']
+657 152 Name of father of x -1 Name of father of Honoré de Balzac Bernard-François Balzac Honoré de Balzac "[',' ' the' ' author' ' of' ' ""' 'La' ' Com' 'é' 'die' ' hum' 'aine' '""'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the author of "" La Com é die hum aine ""
+
+ The name of the father of" False 5 ['Hon', 'or', 'é', ' de', ' Bal', 'zac']
+658 152 Name of father of x -1 Name of father of Honoré de Balzac Bernard-François Balzac Honoré de Balzac "[',' ' the' ' author' ' of' ' ""' 'La' ' Com' 'é' 'die' ' hum' 'aine' '""'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the author of "" La Com é die hum aine ""
+
+ The name of the father of" False " to be lucky."" Honoré de Balzac (in his novella" 8 "[' to', ' be', ' lucky', '.""', ' Honor', 'é', ' de', ' Bal', 'zac']"
+659 152 Name of father of x -1 Name of father of Honoré de Balzac Bernard-François Balzac Honoré de Balzac "[',' ' the' ' author' ' of' ' ""' 'La' ' Com' 'é' 'die' ' hum' 'aine' '""'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the author of "" La Com é die hum aine ""
+
+ The name of the father of" False Alexandre Dumas and Honoré de Balzac in France, Leo Tolstoy 9 [' Alexand', 're', ' Dum', 'as', ' and', ' Honor', 'é', ' de', ' Bal', 'zac']
+660 152 Name of father of x -1 Name of father of Honoré de Balzac Bernard-François Balzac Honoré de Balzac "[',' ' the' ' author' ' of' ' ""' 'La' ' Com' 'é' 'die' ' hum' 'aine' '""'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the author of "" La Com é die hum aine ""
+
+ The name of the father of" False 5 ['Hon', 'or', 'é', ' de', ' Bal', 'zac']
+661 153 Name of father of x -1 Name of father of Thomas Aquinas Landulphe d'Aquino Thomas Aquinas "[',' ' the' ' great' ' theolog' 'ian' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great theolog ian , and the
+ " False Florida's St. Thomas Aquinas High School 6 "[' Florida', ""'s"", ' St', '.', ' Thomas', ' Aqu', 'inas']"
+662 153 Name of father of x -1 Name of father of Thomas Aquinas Landulphe d'Aquino Thomas Aquinas "[',' ' the' ' great' ' theolog' 'ian' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great theolog ian , and the
+ " False 2 ['Thomas', ' Aqu', 'inas']
+663 153 Name of father of x -1 Name of father of Thomas Aquinas Landulphe d'Aquino Thomas Aquinas "[',' ' the' ' great' ' theolog' 'ian' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great theolog ian , and the
+ " False 2 ['Thomas', ' Aqu', 'inas']
+664 153 Name of father of x -1 Name of father of Thomas Aquinas Landulphe d'Aquino Thomas Aquinas "[',' ' the' ' great' ' theolog' 'ian' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great theolog ian , and the
+ " False 2 ['Thomas', ' Aqu', 'inas']
+665 153 Name of father of x -1 Name of father of Thomas Aquinas Landulphe d'Aquino Thomas Aquinas "[',' ' the' ' great' ' theolog' 'ian' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great theolog ian , and the
+ " False defined groups. Thomas Aquinas (c. 1225 – 1274) 5 [' defined', ' groups', '.', ' Thomas', ' Aqu', 'inas']
+666 154 Name of father of x -1 Name of father of Halle Berry Jerome Jesse Berry Halle Berry "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' Olivier'
+ ' Martinez' ',' ' the' ' actor' '.' '\n' '\n' 'The' ' couple' ',' ' who']" ", the actress , and her husband , Olivier Martinez , the actor .
+
+ The couple , who" False Hugh Jackman, Halle Berry and Kelsey Grammer 6 [' Hugh', ' Jack', 'man', ',', ' H', 'alle', ' Berry']
+667 154 Name of father of x -1 Name of father of Halle Berry Jerome Jesse Berry Halle Berry "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' Olivier'
+ ' Martinez' ',' ' the' ' actor' '.' '\n' '\n' 'The' ' couple' ',' ' who']" ", the actress , and her husband , Olivier Martinez , the actor .
+
+ The couple , who" False Academy Award winner Halle Berry guest starred as herself. 5 [' Academy', ' Award', ' winner', ' H', 'alle', ' Berry']
+668 154 Name of father of x -1 Name of father of Halle Berry Jerome Jesse Berry Halle Berry "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' Olivier'
+ ' Martinez' ',' ' the' ' actor' '.' '\n' '\n' 'The' ' couple' ',' ' who']" ", the actress , and her husband , Olivier Martinez , the actor .
+
+ The couple , who" False " ""Congratulations to Halle Berry and Denzel" 5 "[' ""', 'Congratulations', ' to', ' H', 'alle', ' Berry']"
+669 154 Name of father of x -1 Name of father of Halle Berry Jerome Jesse Berry Halle Berry "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' Olivier'
+ ' Martinez' ',' ' the' ' actor' '.' '\n' '\n' 'The' ' couple' ',' ' who']" ", the actress , and her husband , Olivier Martinez , the actor .
+
+ The couple , who" False even usually tasteful Halle Berry succumbed to 6 [' even', ' usually', ' tast', 'eful', ' H', 'alle', ' Berry']
+670 154 Name of father of x -1 Name of father of Halle Berry Jerome Jesse Berry Halle Berry "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' Olivier'
+ ' Martinez' ',' ' the' ' actor' '.' '\n' '\n' 'The' ' couple' ',' ' who']" ", the actress , and her husband , Olivier Martinez , the actor .
+
+ The couple , who" False shrift. As Storm, Halle Berry can do neat 8 [' shr', 'ift', '.', ' As', ' Storm', ',', ' H', 'alle', ' Berry']
+671 155 Name of father of x -1 Name of father of Brigitte Bardot Louis Bardot Brigitte Bardot "[',' ' the' ' French' ' actress' ' and' ' animal' ' rights' ' activist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' actress' ' and'
+ ' animal' ' rights']" ", the French actress and animal rights activist .
+
+ The name of the French actress and animal rights" False generation was fascinated by Brigitte Bardot who came to 7 [' generation', ' was', ' fascinated', ' by', ' Brig', 'itte', ' Bard', 'ot']
+672 155 Name of father of x -1 Name of father of Brigitte Bardot Louis Bardot Brigitte Bardot "[',' ' the' ' French' ' actress' ' and' ' animal' ' rights' ' activist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' actress' ' and'
+ ' animal' ' rights']" ", the French actress and animal rights activist .
+
+ The name of the French actress and animal rights" False was fascinated by Brigitte Bardot who came to 6 [' was', ' fascinated', ' by', ' Brig', 'itte', ' Bard', 'ot']
+673 155 Name of father of x -1 Name of father of Brigitte Bardot Louis Bardot Brigitte Bardot "[',' ' the' ' French' ' actress' ' and' ' animal' ' rights' ' activist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' actress' ' and'
+ ' animal' ' rights']" ", the French actress and animal rights activist .
+
+ The name of the French actress and animal rights" False neophyte Lazenby. Brigitte Bardot was invited, but 10 [' ne', 'ophy', 'te', ' Laz', 'en', 'by', '.', ' Brig', 'itte', ' Bard', 'ot']
+674 155 Name of father of x -1 Name of father of Brigitte Bardot Louis Bardot Brigitte Bardot "[',' ' the' ' French' ' actress' ' and' ' animal' ' rights' ' activist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' actress' ' and'
+ ' animal' ' rights']" ", the French actress and animal rights activist .
+
+ The name of the French actress and animal rights" False his work. When Brigitte Bardot wanted to 7 [' his', ' work', '.', ' When', ' Brig', 'itte', ' Bard', 'ot']
+675 155 Name of father of x -1 Name of father of Brigitte Bardot Louis Bardot Brigitte Bardot "[',' ' the' ' French' ' actress' ' and' ' animal' ' rights' ' activist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' actress' ' and'
+ ' animal' ' rights']" ", the French actress and animal rights activist .
+
+ The name of the French actress and animal rights" False Faithfull and Brigitte Bardot set about destroying 6 [' Faith', 'full', ' and', ' Brig', 'itte', ' Bard', 'ot']
+676 156 Name of father of x -1 Name of father of August Strindberg Carl Oscar Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in' ' 18'
+ '49' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Swedish writer , who was born in 18 49 .
+
+ The name of the father of" False " Strindberg ===
+" 7 [' Str', 'ind', 'berg', ' ===', 'August', ' Str', 'ind', 'berg']
+677 156 Name of father of x -1 Name of father of August Strindberg Carl Oscar Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in' ' 18'
+ '49' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Swedish writer , who was born in 18 49 .
+
+ The name of the father of" False " Strindberg ===
+" 7 [' Str', 'ind', 'berg', ' ===', 'August', ' Str', 'ind', 'berg']
+678 156 Name of father of x -1 Name of father of August Strindberg Carl Oscar Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in' ' 18'
+ '49' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Swedish writer , who was born in 18 49 .
+
+ The name of the father of" False Swedish playwright August Strindberg and Norwegian 6 [' Swedish', ' play', 'wright', ' August', ' Str', 'ind', 'berg']
+679 156 Name of father of x -1 Name of father of August Strindberg Carl Oscar Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in' ' 18'
+ '49' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Swedish writer , who was born in 18 49 .
+
+ The name of the father of" False Swedish playwright August Strindberg and Norwegian 6 [' Swedish', ' play', 'wright', ' August', ' Str', 'ind', 'berg']
+680 156 Name of father of x -1 Name of father of August Strindberg Carl Oscar Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in' ' 18'
+ '49' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of']" ", the Swedish writer , who was born in 18 49 .
+
+ The name of the father of" False playwright August Strindberg and Norwegian 5 [' play', 'wright', ' August', ' Str', 'ind', 'berg']
+681 157 Name of father of x -1 Name of father of Rabindranath Tagore Debendranath Tagore Rabindranath Tagore "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Rab' 'ind' 'ran' 'ath'
+ ' Tag' 'ore' '.' ' I' ' have' ' read' ' all' ' his']" "
+
+ I am a big fan of Rab ind ran ath Tag ore . I have read all his" False Sen is one of them. Rabindranath Tagore made this district 11 [' Sen', ' is', ' one', ' of', ' them', '.', ' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore']
+682 157 Name of father of x -1 Name of father of Rabindranath Tagore Debendranath Tagore Rabindranath Tagore "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Rab' 'ind' 'ran' 'ath'
+ ' Tag' 'ore' '.' ' I' ' have' ' read' ' all' ' his']" "
+
+ I am a big fan of Rab ind ran ath Tag ore . I have read all his" False including works by Rabindranath Tagore in Bengali 8 [' including', ' works', ' by', ' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore']
+683 157 Name of father of x -1 Name of father of Rabindranath Tagore Debendranath Tagore Rabindranath Tagore "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Rab' 'ind' 'ran' 'ath'
+ ' Tag' 'ore' '.' ' I' ' have' ' read' ' all' ' his']" "
+
+ I am a big fan of Rab ind ran ath Tag ore . I have read all his" False anthem writer Rabindranath Tagore have a widespread 7 [' anthem', ' writer', ' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore']
+684 157 Name of father of x -1 Name of father of Rabindranath Tagore Debendranath Tagore Rabindranath Tagore "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Rab' 'ind' 'ran' 'ath'
+ ' Tag' 'ore' '.' ' I' ' have' ' read' ' all' ' his']" "
+
+ I am a big fan of Rab ind ran ath Tag ore . I have read all his" False (Japan), and Rabindranath Tagore (India). These 9 [' (', 'Japan', '),', ' and', ' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore']
+685 157 Name of father of x -1 Name of father of Rabindranath Tagore Debendranath Tagore Rabindranath Tagore "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Rab' 'ind' 'ran' 'ath'
+ ' Tag' 'ore' '.' ' I' ' have' ' read' ' all' ' his']" "
+
+ I am a big fan of Rab ind ran ath Tag ore . I have read all his" False personality of Indian poet Rabindranath Tagore through his 9 [' personality', ' of', ' Indian', ' poet', ' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore']
+686 158 Name of father of x -1 Name of father of Friedrich Schiller Johann Kaspar Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '59' ',' ' and'
+ ' died' ' in' ' 18' '05' '.' '\n' '\n' 'The']" ", the poet , who was born in 17 59 , and died in 18 05 .
+
+ The" False Napoleon's encouragement. Friedrich Schiller would also deal with 6 "[' Napoleon', ""'s"", ' encouragement', '.', ' Friedrich', ' Sch', 'iller']"
+687 158 Name of father of x -1 Name of father of Friedrich Schiller Johann Kaspar Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '59' ',' ' and'
+ ' died' ' in' ' 18' '05' '.' '\n' '\n' 'The']" ", the poet , who was born in 17 59 , and died in 18 05 .
+
+ The" False 4 ['F', 'ried', 'rich', ' Sch', 'iller']
+688 158 Name of father of x -1 Name of father of Friedrich Schiller Johann Kaspar Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '59' ',' ' and'
+ ' died' ' in' ' 18' '05' '.' '\n' '\n' 'The']" ", the poet , who was born in 17 59 , and died in 18 05 .
+
+ The" False encouragement. Friedrich Schiller would also deal 4 [' encouragement', '.', ' Friedrich', ' Sch', 'iller']
+689 158 Name of father of x -1 Name of father of Friedrich Schiller Johann Kaspar Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '59' ',' ' and'
+ ' died' ' in' ' 18' '05' '.' '\n' '\n' 'The']" ", the poet , who was born in 17 59 , and died in 18 05 .
+
+ The" False 4 ['F', 'ried', 'rich', ' Sch', 'iller']
+690 158 Name of father of x -1 Name of father of Friedrich Schiller Johann Kaspar Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '59' ',' ' and'
+ ' died' ' in' ' 18' '05' '.' '\n' '\n' 'The']" ", the poet , who was born in 17 59 , and died in 18 05 .
+
+ The" False Napoleon's encouragement. Friedrich Schiller would also deal with 6 "[' Napoleon', ""'s"", ' encouragement', '.', ' Friedrich', ' Sch', 'iller']"
+691 159 Name of father of x -1 Name of father of George Frideric Handel Georg Händel George Frideric Handel "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' George' ' Fr' 'ider' 'ic' ' Hand' 'el']" ", the composer .
+
+ The name of the father of the composer George Fr ider ic Hand el" False operas: in 1719 George Frideric Handel used Richard's invasion 11 [' oper', 'as', ':', ' in', ' 17', '19', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+692 159 Name of father of x -1 Name of father of George Frideric Handel Georg Händel George Frideric Handel "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' George' ' Fr' 'ider' 'ic' ' Hand' 'el']" ", the composer .
+
+ The name of the father of the composer George Fr ider ic Hand el" False the Baroque era, George Frideric Handel composed the 11 [' the', ' Bar', 'o', 'que', ' era', ',', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+693 159 Name of father of x -1 Name of father of George Frideric Handel Georg Händel George Frideric Handel "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' George' ' Fr' 'ider' 'ic' ' Hand' 'el']" ", the composer .
+
+ The name of the father of the composer George Fr ider ic Hand el" False the funeral. George Frideric Handel composed an anthem 8 [' the', ' funeral', '.', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+694 159 Name of father of x -1 Name of father of George Frideric Handel Georg Händel George Frideric Handel "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' George' ' Fr' 'ider' 'ic' ' Hand' 'el']" ", the composer .
+
+ The name of the father of the composer George Fr ider ic Hand el" False 25 operas by George Frideric Handel premièred here. 9 [' 25', ' oper', 'as', ' by', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+695 159 Name of father of x -1 Name of father of George Frideric Handel Georg Händel George Frideric Handel "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' George' ' Fr' 'ider' 'ic' ' Hand' 'el']" ", the composer .
+
+ The name of the father of the composer George Fr ider ic Hand el" False Opera House. Young George Frideric Handel produced his 9 [' Opera', ' House', '.', ' Young', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+696 160 Name of father of x -1 Name of father of Charlie Chaplin Charles Chaplin Sr. Charlie Chaplin "[',' ' the' ' famous' ' actor' ',' ' was' ' born' ' in' ' this' ' house'
+ '.' '\n' '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',']" ", the famous actor , was born in this house .
+
+ The house is now a museum ," False " Hutchinson, he admired Charlie Chaplin ""to the point" 6 [' Hutchinson', ',', ' he', ' admired', ' Charlie', ' Cha', 'plin']
+697 160 Name of father of x -1 Name of father of Charlie Chaplin Charles Chaplin Sr. Charlie Chaplin "[',' ' the' ' famous' ' actor' ',' ' was' ' born' ' in' ' this' ' house'
+ '.' '\n' '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',']" ", the famous actor , was born in this house .
+
+ The house is now a museum ," False seven days at the Charlie Chaplin Studios, Bronson Canyon 6 [' seven', ' days', ' at', ' the', ' Charlie', ' Cha', 'plin']
+698 160 Name of father of x -1 Name of father of Charlie Chaplin Charles Chaplin Sr. Charlie Chaplin "[',' ' the' ' famous' ' actor' ',' ' was' ' born' ' in' ' this' ' house'
+ '.' '\n' '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',']" ", the famous actor , was born in this house .
+
+ The house is now a museum ," False later compared her to Charlie Chaplin in her ability 6 [' later', ' compared', ' her', ' to', ' Charlie', ' Cha', 'plin']
+699 160 Name of father of x -1 Name of father of Charlie Chaplin Charles Chaplin Sr. Charlie Chaplin "[',' ' the' ' famous' ' actor' ',' ' was' ' born' ' in' ' this' ' house'
+ '.' '\n' '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',']" ", the famous actor , was born in this house .
+
+ The house is now a museum ," False costumes to a young Charlie Chaplin when the latter 6 [' costumes', ' to', ' a', ' young', ' Charlie', ' Cha', 'plin']
+700 160 Name of father of x -1 Name of father of Charlie Chaplin Charles Chaplin Sr. Charlie Chaplin "[',' ' the' ' famous' ' actor' ',' ' was' ' born' ' in' ' this' ' house'
+ '.' '\n' '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',']" ", the famous actor , was born in this house .
+
+ The house is now a museum ," False Bodley was hired by Charlie Chaplin to adapt the D.L. 7 [' Bod', 'ley', ' was', ' hired', ' by', ' Charlie', ' Cha', 'plin']
+701 161 Name of father of x -1 Name of father of Bertrand Russell John Russell, Viscount Amberley Bertrand Russell "[',' ' the' ' philosopher' ',' ' and' ' of' ' his' ' daughter' ',' ' the'
+ ' writer' ' A' 'lys' ' Russell' ',' ' who' ' was' ' born' ' in' ' 18']" , the philosopher , and of his daughter , the writer A lys Russell , who was born in 18 False such as that of Bertrand Russell and the early 6 [' such', ' as', ' that', ' of', ' Bert', 'rand', ' Russell']
+702 161 Name of father of x -1 Name of father of Bertrand Russell John Russell, Viscount Amberley Bertrand Russell "[',' ' the' ' philosopher' ',' ' and' ' of' ' his' ' daughter' ',' ' the'
+ ' writer' ' A' 'lys' ' Russell' ',' ' who' ' was' ' born' ' in' ' 18']" , the philosopher , and of his daughter , the writer A lys Russell , who was born in 18 False briefly adopted by Bertrand Russell and many of the 5 [' briefly', ' adopted', ' by', ' Bert', 'rand', ' Russell']
+703 161 Name of father of x -1 Name of father of Bertrand Russell John Russell, Viscount Amberley Bertrand Russell "[',' ' the' ' philosopher' ',' ' and' ' of' ' his' ' daughter' ',' ' the'
+ ' writer' ' A' 'lys' ' Russell' ',' ' who' ' was' ' born' ' in' ' 18']" , the philosopher , and of his daughter , the writer A lys Russell , who was born in 18 False philosopher Bertrand Russell strongly influenced 3 [' philosopher', ' Bert', 'rand', ' Russell']
+704 161 Name of father of x -1 Name of father of Bertrand Russell John Russell, Viscount Amberley Bertrand Russell "[',' ' the' ' philosopher' ',' ' and' ' of' ' his' ' daughter' ',' ' the'
+ ' writer' ' A' 'lys' ' Russell' ',' ' who' ' was' ' born' ' in' ' 18']" , the philosopher , and of his daughter , the writer A lys Russell , who was born in 18 False Analysis by Bertrand Russell (1st imp. London 4 [' Analysis', ' by', ' Bert', 'rand', ' Russell']
+705 161 Name of father of x -1 Name of father of Bertrand Russell John Russell, Viscount Amberley Bertrand Russell "[',' ' the' ' philosopher' ',' ' and' ' of' ' his' ' daughter' ',' ' the'
+ ' writer' ' A' 'lys' ' Russell' ',' ' who' ' was' ' born' ' in' ' 18']" , the philosopher , and of his daughter , the writer A lys Russell , who was born in 18 False 3 ['B', 'ert', 'rand', ' Russell']
+706 162 Name of father of x -1 Name of father of George Washington Augustine Washington George Washington "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' George']" ", the first president of the United States .
+
+ The first president of the United States , George" False north to serve under George Washington in the Philadelphia 5 [' north', ' to', ' serve', ' under', ' George', ' Washington']
+707 162 Name of father of x -1 Name of father of George Washington Augustine Washington George Washington "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' George']" ", the first president of the United States .
+
+ The first president of the United States , George" False Lovingood by George Washington Harris; inside 4 [' Loving', 'ood', ' by', ' George', ' Washington']
+708 162 Name of father of x -1 Name of father of George Washington Augustine Washington George Washington "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' George']" ", the first president of the United States .
+
+ The first president of the United States , George" False 1 ['George', ' Washington']
+709 162 Name of father of x -1 Name of father of George Washington Augustine Washington George Washington "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' George']" ", the first president of the United States .
+
+ The first president of the United States , George" False transferred to George Washington University, 3 [' transferred', ' to', ' George', ' Washington']
+710 162 Name of father of x -1 Name of father of George Washington Augustine Washington George Washington "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' George']" ", the first president of the United States .
+
+ The first president of the United States , George" False affairs. In 1784, George Washington visited and dined 7 [' affairs', '.', ' In', ' 17', '84', ',', ' George', ' Washington']
+711 163 Name of father of x -1 Name of father of Jean-Jacques Rousseau Isaac Rousseau Jean-Jacques Rousseau "[',' ' the' ' author' ' of' ' the' '\n' '_' 'Cont' 'rat' ' Social' '_' ','
+ ' and' ' of' ' the' ' _' 'Conf' 'essions' '_' ',']" ", the author of the
+ _ Cont rat Social _ , and of the _ Conf essions _ ," False Swiss philosopher Jean-Jacques Rousseau sent him the message: 7 [' Swiss', ' philosopher', ' Jean', '-', 'Jac', 'ques', ' Rousse', 'au']
+712 163 Name of father of x -1 Name of father of Jean-Jacques Rousseau Isaac Rousseau Jean-Jacques Rousseau "[',' ' the' ' author' ' of' ' the' '\n' '_' 'Cont' 'rat' ' Social' '_' ','
+ ' and' ' of' ' the' ' _' 'Conf' 'essions' '_' ',']" ", the author of the
+ _ Cont rat Social _ , and of the _ Conf essions _ ," False French philosopher Jean-Jacques Rousseau celebrated the 7 [' French', ' philosopher', ' Jean', '-', 'Jac', 'ques', ' Rousse', 'au']
+713 163 Name of father of x -1 Name of father of Jean-Jacques Rousseau Isaac Rousseau Jean-Jacques Rousseau "[',' ' the' ' author' ' of' ' the' '\n' '_' 'Cont' 'rat' ' Social' '_' ','
+ ' and' ' of' ' the' ' _' 'Conf' 'essions' '_' ',']" ", the author of the
+ _ Cont rat Social _ , and of the _ Conf essions _ ," False de Mably and Jean-Jacques Rousseau to offer suggestions 9 [' de', ' M', 'ably', ' and', ' Jean', '-', 'Jac', 'ques', ' Rousse', 'au']
+714 163 Name of father of x -1 Name of father of Jean-Jacques Rousseau Isaac Rousseau Jean-Jacques Rousseau "[',' ' the' ' author' ' of' ' the' '\n' '_' 'Cont' 'rat' ' Social' '_' ','
+ ' and' ' of' ' the' ' _' 'Conf' 'essions' '_' ',']" ", the author of the
+ _ Cont rat Social _ , and of the _ Conf essions _ ," False " of the philosopher Jean-Jacques Rousseau by Allan Ramsay.
+" 8 [' of', ' the', ' philosopher', ' Jean', '-', 'Jac', 'ques', ' Rousse', 'au']
+715 163 Name of father of x -1 Name of father of Jean-Jacques Rousseau Isaac Rousseau Jean-Jacques Rousseau "[',' ' the' ' author' ' of' ' the' '\n' '_' 'Cont' 'rat' ' Social' '_' ','
+ ' and' ' of' ' the' ' _' 'Conf' 'essions' '_' ',']" ", the author of the
+ _ Cont rat Social _ , and of the _ Conf essions _ ," False French-speaking writers were Jean-Jacques Rousseau (1712 – 1778) 10 [' French', '-', 'speaking', ' writers', ' were', ' Jean', '-', 'Jac', 'ques', ' Rousse', 'au']
+716 164 Name of father of x -1 Name of father of Mahatma Gandhi Karamchand Uttamchand Gandhi Mahatma Gandhi "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' great'
+ ' Mah' 'at' 'ma' ' Gandhi' '.' ' I' ' have' ' been' ' following']" "
+
+ I am a great admire r of the great Mah at ma Gandhi . I have been following" False rights movements of Mahatma Gandhi and Periyar 6 [' rights', ' movements', ' of', ' Mah', 'at', 'ma', ' Gandhi']
+717 164 Name of father of x -1 Name of father of Mahatma Gandhi Karamchand Uttamchand Gandhi Mahatma Gandhi "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' great'
+ ' Mah' 'at' 'ma' ' Gandhi' '.' ' I' ' have' ' been' ' following']" "
+
+ I am a great admire r of the great Mah at ma Gandhi . I have been following" False Technology (IIMT), Mahatma Gandhi Kashi Vidyapith, Nav 9 [' Technology', ' (', 'I', 'IM', 'T', '),', ' Mah', 'at', 'ma', ' Gandhi']
+718 164 Name of father of x -1 Name of father of Mahatma Gandhi Karamchand Uttamchand Gandhi Mahatma Gandhi "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' great'
+ ' Mah' 'at' 'ma' ' Gandhi' '.' ' I' ' have' ' been' ' following']" "
+
+ I am a great admire r of the great Mah at ma Gandhi . I have been following" False 10 districts. The Mahatma Gandhi Government Hospital 7 [' 10', ' districts', '.', ' The', ' Mah', 'at', 'ma', ' Gandhi']
+719 164 Name of father of x -1 Name of father of Mahatma Gandhi Karamchand Uttamchand Gandhi Mahatma Gandhi "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' great'
+ ' Mah' 'at' 'ma' ' Gandhi' '.' ' I' ' have' ' been' ' following']" "
+
+ I am a great admire r of the great Mah at ma Gandhi . I have been following" False academic complex at Mahatma Gandhi University 6 [' academic', ' complex', ' at', ' Mah', 'at', 'ma', ' Gandhi']
+720 164 Name of father of x -1 Name of father of Mahatma Gandhi Karamchand Uttamchand Gandhi Mahatma Gandhi "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' great'
+ ' Mah' 'at' 'ma' ' Gandhi' '.' ' I' ' have' ' been' ' following']" "
+
+ I am a great admire r of the great Mah at ma Gandhi . I have been following" False physicists at the Mahatma Gandhi University in Kottayam, 6 [' physicists', ' at', ' the', ' Mah', 'at', 'ma', ' Gandhi']
+721 165 Name of father of x -1 Name of father of Edvard Munch Christian Munch Edvard Munch "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Ed' 'vard' ' M' 'unch' '.' '\n']" ", the painter .
+
+ The name of the father of the painter Ed vard M unch .
+" False Norwegian Symbolist Edvard Munch (1863 – 1944) 6 [' Norwegian', ' Symbol', 'ist', ' Ed', 'vard', ' M', 'unch']
+722 165 Name of father of x -1 Name of father of Edvard Munch Christian Munch Edvard Munch "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Ed' 'vard' ' M' 'unch' '.' '\n']" ", the painter .
+
+ The name of the father of the painter Ed vard M unch .
+" False August Strindberg, Edvard Munch and Paul Gauguin. 8 [' August', ' Str', 'ind', 'berg', ',', ' Ed', 'vard', ' M', 'unch']
+723 165 Name of father of x -1 Name of father of Edvard Munch Christian Munch Edvard Munch "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Ed' 'vard' ' M' 'unch' '.' '\n']" ", the painter .
+
+ The name of the father of the painter Ed vard M unch .
+" False Norwegian Symbolist Edvard Munch (1863 – 1944) 6 [' Norwegian', ' Symbol', 'ist', ' Ed', 'vard', ' M', 'unch']
+724 165 Name of father of x -1 Name of father of Edvard Munch Christian Munch Edvard Munch "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Ed' 'vard' ' M' 'unch' '.' '\n']" ", the painter .
+
+ The name of the father of the painter Ed vard M unch .
+" False August Strindberg, Edvard Munch and Paul Gauguin. 8 [' August', ' Str', 'ind', 'berg', ',', ' Ed', 'vard', ' M', 'unch']
+725 165 Name of father of x -1 Name of father of Edvard Munch Christian Munch Edvard Munch "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Ed' 'vard' ' M' 'unch' '.' '\n']" ", the painter .
+
+ The name of the father of the painter Ed vard M unch .
+" False Norwegian Symbolist Edvard Munch (1863 – 1944) would 6 [' Norwegian', ' Symbol', 'ist', ' Ed', 'vard', ' M', 'unch']
+726 166 Name of father of x -1 Name of father of Carl Linnaeus Nils Ingemarsson Linnaeus Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' bot' 'an' 'ist' '.' '\n']" ", the bot an ist .
+
+ The name of the father of the bot an ist .
+" False Swedish zoologist Carl Linnaeus in 1758. It stands 6 [' Swedish', ' zo', 'ologist', ' Carl', ' Lin', 'na', 'eus']
+727 166 Name of father of x -1 Name of father of Carl Linnaeus Nils Ingemarsson Linnaeus Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' bot' 'an' 'ist' '.' '\n']" ", the bot an ist .
+
+ The name of the father of the bot an ist .
+" False literature by Carl Linnaeus under the name Hydnum 5 [' literature', ' by', ' Carl', ' Lin', 'na', 'eus']
+728 166 Name of father of x -1 Name of father of Carl Linnaeus Nils Ingemarsson Linnaeus Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' bot' 'an' 'ist' '.' '\n']" ", the bot an ist .
+
+ The name of the father of the bot an ist .
+" False originally described by Carl Linnaeus in the 1758 6 [' originally', ' described', ' by', ' Carl', ' Lin', 'na', 'eus']
+729 166 Name of father of x -1 Name of father of Carl Linnaeus Nils Ingemarsson Linnaeus Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' bot' 'an' 'ist' '.' '\n']" ", the bot an ist .
+
+ The name of the father of the bot an ist .
+" False " Carl Linnaeus =
+" 3 [' Carl', ' Lin', 'na', 'eus']
+730 166 Name of father of x -1 Name of father of Carl Linnaeus Nils Ingemarsson Linnaeus Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' bot' 'an' 'ist' '.' '\n']" ", the bot an ist .
+
+ The name of the father of the bot an ist .
+" False classification introduced by Carl Linnaeus in 1735 explicitly 6 [' classification', ' introduced', ' by', ' Carl', ' Lin', 'na', 'eus']
+731 167 Name of father of x -1 Name of father of Miley Cyrus Billy Ray Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' M' 'iley' ' Cyrus' '.' '\n' '\n' 'The' ' father' ' of' ' M']" ", the singer , and the father of the singer M iley Cyrus .
+
+ The father of M" False and actress Miley Cyrus portrays Miley Stewart, 4 [' and', ' actress', ' M', 'iley', ' Cyrus']
+732 167 Name of father of x -1 Name of father of Miley Cyrus Billy Ray Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' M' 'iley' ' Cyrus' '.' '\n' '\n' 'The' ' father' ' of' ' M']" ", the singer , and the father of the singer M iley Cyrus .
+
+ The father of M" False and actress Miley Cyrus and actor and singer 4 [' and', ' actress', ' M', 'iley', ' Cyrus']
+733 167 Name of father of x -1 Name of father of Miley Cyrus Billy Ray Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' M' 'iley' ' Cyrus' '.' '\n' '\n' 'The' ' father' ' of' ' M']" ", the singer , and the father of the singer M iley Cyrus .
+
+ The father of M" False might eventually give Miley Cyrus a real run for 5 [' might', ' eventually', ' give', ' M', 'iley', ' Cyrus']
+734 167 Name of father of x -1 Name of father of Miley Cyrus Billy Ray Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' M' 'iley' ' Cyrus' '.' '\n' '\n' 'The' ' father' ' of' ' M']" ", the singer , and the father of the singer M iley Cyrus .
+
+ The father of M" False transitioning to violins. Miley Cyrus and Nick Jonas 7 [' transitioning', ' to', ' viol', 'ins', '.', ' M', 'iley', ' Cyrus']
+735 167 Name of father of x -1 Name of father of Miley Cyrus Billy Ray Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' M' 'iley' ' Cyrus' '.' '\n' '\n' 'The' ' father' ' of' ' M']" ", the singer , and the father of the singer M iley Cyrus .
+
+ The father of M" False opening act for Miley Cyrus on The Best of Both 5 [' opening', ' act', ' for', ' M', 'iley', ' Cyrus']
+736 168 Name of father of x -1 Name of father of George Bernard Shaw George Carr Shaw George Bernard Shaw "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' Charlotte'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' American'
+ ' businessman']" , the play wright , and his wife , Charlotte , who was the daughter of a wealthy American businessman False " collaborated with George Bernard Shaw had Shaw been willing.
+" 4 [' collaborated', ' with', ' George', ' Bernard', ' Shaw']
+737 168 Name of father of x -1 Name of father of George Bernard Shaw George Carr Shaw George Bernard Shaw "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' Charlotte'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' American'
+ ' businessman']" , the play wright , and his wife , Charlotte , who was the daughter of a wealthy American businessman False a film based on a George Bernard Shaw play of the 7 [' a', ' film', ' based', ' on', ' a', ' George', ' Bernard', ' Shaw']
+738 168 Name of father of x -1 Name of father of George Bernard Shaw George Carr Shaw George Bernard Shaw "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' Charlotte'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' American'
+ ' businessman']" , the play wright , and his wife , Charlotte , who was the daughter of a wealthy American businessman False " literature"", and George Bernard Shaw went so far as to call" 5 "[' literature', '"",', ' and', ' George', ' Bernard', ' Shaw']"
+739 168 Name of father of x -1 Name of father of George Bernard Shaw George Carr Shaw George Bernard Shaw "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' Charlotte'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' American'
+ ' businessman']" , the play wright , and his wife , Charlotte , who was the daughter of a wealthy American businessman False Millionairess, a film based on a George Bernard Shaw play of the same 11 [' Million', 'ai', 'ress', ',', ' a', ' film', ' based', ' on', ' a', ' George', ' Bernard', ' Shaw']
+740 168 Name of father of x -1 Name of father of George Bernard Shaw George Carr Shaw George Bernard Shaw "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' Charlotte'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' American'
+ ' businessman']" , the play wright , and his wife , Charlotte , who was the daughter of a wealthy American businessman False film based on a George Bernard Shaw play of the 6 [' film', ' based', ' on', ' a', ' George', ' Bernard', ' Shaw']
+741 169 Name of father of x -1 Name of father of George Sand Maurice Dupin de Francueil George Sand "[',' ' the' ' author' ' of' ' ""' 'L' 'é' 'lia' ',""' ' ""' 'Cons' 'uel' 'o'
+ ',""' ' ""' 'Indiana' ',""' ' ""' 'E' 'ug']" ", the author of "" L é lia ,"" "" Cons uel o ,"" "" Indiana ,"" "" E ug" False of Robert; and George Sand wrote about 5 [' of', ' Robert', ';', ' and', ' George', ' Sand']
+742 169 Name of father of x -1 Name of father of George Sand Maurice Dupin de Francueil George Sand "[',' ' the' ' author' ' of' ' ""' 'L' 'é' 'lia' ',""' ' ""' 'Cons' 'uel' 'o'
+ ',""' ' ""' 'Indiana' ',""' ' ""' 'E' 'ug']" ", the author of "" L é lia ,"" "" Cons uel o ,"" "" Indiana ,"" "" E ug" False writers including George Sand and Thomas 3 [' writers', ' including', ' George', ' Sand']
+743 169 Name of father of x -1 Name of father of George Sand Maurice Dupin de Francueil George Sand "[',' ' the' ' author' ' of' ' ""' 'L' 'é' 'lia' ',""' ' ""' 'Cons' 'uel' 'o'
+ ',""' ' ""' 'Indiana' ',""' ' ""' 'E' 'ug']" ", the author of "" L é lia ,"" "" Cons uel o ,"" "" Indiana ,"" "" E ug" False the life of George Sand that starred Judy 4 [' the', ' life', ' of', ' George', ' Sand']
+744 169 Name of father of x -1 Name of father of George Sand Maurice Dupin de Francueil George Sand "[',' ' the' ' author' ' of' ' ""' 'L' 'é' 'lia' ',""' ' ""' 'Cons' 'uel' 'o'
+ ',""' ' ""' 'Indiana' ',""' ' ""' 'E' 'ug']" ", the author of "" L é lia ,"" "" Cons uel o ,"" "" Indiana ,"" "" E ug" False translator of George Sand and a controversial 3 [' translator', ' of', ' George', ' Sand']
+745 169 Name of father of x -1 Name of father of George Sand Maurice Dupin de Francueil George Sand "[',' ' the' ' author' ' of' ' ""' 'L' 'é' 'lia' ',""' ' ""' 'Cons' 'uel' 'o'
+ ',""' ' ""' 'Indiana' ',""' ' ""' 'E' 'ug']" ", the author of "" L é lia ,"" "" Cons uel o ,"" "" Indiana ,"" "" E ug" False French author George Sand (born [Amantine] 3 [' French', ' author', ' George', ' Sand']
+746 170 Name of father of x -1 Name of father of Henri de Toulouse-Lautrec Alphonse de Toulouse-Lautrec-Monfa Henri de Toulouse-Lautrec "['\n' '\n' 'Hen' 'ri' ' de' ' T' 'oul' 'ouse' '-' 'L' 'aut' 'rec' ' ('
+ '18' '64' '–' '19' '01' ')' ' was']" "
+
+ Hen ri de T oul ouse - L aut rec ( 18 64 – 19 01 ) was" False Anquetin and Henri de Toulouse-Lautrec – who painted a 12 [' An', 'quet', 'in', ' and', ' Henri', ' de', ' T', 'oul', 'ouse', '-', 'L', 'aut', 'rec']
+747 170 Name of father of x -1 Name of father of Henri de Toulouse-Lautrec Alphonse de Toulouse-Lautrec-Monfa Henri de Toulouse-Lautrec "['\n' '\n' 'Hen' 'ri' ' de' ' T' 'oul' 'ouse' '-' 'L' 'aut' 'rec' ' ('
+ '18' '64' '–' '19' '01' ')' ' was']" "
+
+ Hen ri de T oul ouse - L aut rec ( 18 64 – 19 01 ) was" False Anquetin and Henri de Toulouse-Lautrec – who painted a 12 [' An', 'quet', 'in', ' and', ' Henri', ' de', ' T', 'oul', 'ouse', '-', 'L', 'aut', 'rec']
+748 171 Name of father of x -1 Name of father of Victoria Prince Edward, Duke of Kent and Strathearn Victoria "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' Queen']" ", the Queen of England , and the
+
+ Queen of the United States .
+
+ The Queen" False partnerships together for Victoria and in the match against 3 [' partnerships', ' together', ' for', ' Victoria']
+749 171 Name of father of x -1 Name of father of Victoria Prince Edward, Duke of Kent and Strathearn Victoria "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' Queen']" ", the Queen of England , and the
+
+ Queen of the United States .
+
+ The Queen" False positions for Victoria ’ s first two 2 [' positions', ' for', ' Victoria']
+750 171 Name of father of x -1 Name of father of Victoria Prince Edward, Duke of Kent and Strathearn Victoria "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' Queen']" ", the Queen of England , and the
+
+ Queen of the United States .
+
+ The Queen" False system restricted to Victoria only, but from 3 [' system', ' restricted', ' to', ' Victoria']
+751 171 Name of father of x -1 Name of father of Victoria Prince Edward, Duke of Kent and Strathearn Victoria "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' Queen']" ", the Queen of England , and the
+
+ Queen of the United States .
+
+ The Queen" False Pratt Library, Victoria University, University 3 [' Pratt', ' Library', ',', ' Victoria']
+752 171 Name of father of x -1 Name of father of Victoria Prince Edward, Duke of Kent and Strathearn Victoria "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' Queen']" ", the Queen of England , and the
+
+ Queen of the United States .
+
+ The Queen" False training posts in Victoria and command of 3 [' training', ' posts', ' in', ' Victoria']
+753 172 Name of father of x -1 Name of father of Superman Jonathan Kent Superman "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' father'
+ ' of' ' Superman' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", and the name of the father of the father of Superman .
+
+ The name of the father" False purchased the Superman film rights in 2 [' purchased', ' the', ' Superman']
+754 172 Name of father of x -1 Name of father of Superman Jonathan Kent Superman "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' father'
+ ' of' ' Superman' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", and the name of the father of the father of Superman .
+
+ The name of the father" False The Batman / Superman Hour; the Batman segments 3 [' The', ' Batman', ' /', ' Superman']
+755 172 Name of father of x -1 Name of father of Superman Jonathan Kent Superman "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' father'
+ ' of' ' Superman' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", and the name of the father of the father of Superman .
+
+ The name of the father" False time. It called Superman Returns' win 4 [' time', '.', ' It', ' called', ' Superman']
+756 172 Name of father of x -1 Name of father of Superman Jonathan Kent Superman "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' father'
+ ' of' ' Superman' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", and the name of the father of the father of Superman .
+
+ The name of the father" False " the traditional Superman costume, and ""All-American""" 2 [' the', ' traditional', ' Superman']
+757 172 Name of father of x -1 Name of father of Superman Jonathan Kent Superman "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' father'
+ ' of' ' Superman' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", and the name of the father of the father of Superman .
+
+ The name of the father" False Comics character Superman. It is produced 2 [' Comics', ' character', ' Superman']
+758 173 Name of father of x -1 Name of father of Joshua Reynolds Samuel Reynolds Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of his wife , the
+ " False Watch by Sir Joshua Reynolds because by the 4 [' Watch', ' by', ' Sir', ' Joshua', ' Reynolds']
+759 173 Name of father of x -1 Name of father of Joshua Reynolds Samuel Reynolds Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of his wife , the
+ " False Gainsborough and Joshua Reynolds became two of 5 [' G', 'ains', 'borough', ' and', ' Joshua', ' Reynolds']
+760 173 Name of father of x -1 Name of father of Joshua Reynolds Samuel Reynolds Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of his wife , the
+ " False End. The painter Joshua Reynolds lived at No 5 [' End', '.', ' The', ' painter', ' Joshua', ' Reynolds']
+761 173 Name of father of x -1 Name of father of Joshua Reynolds Samuel Reynolds Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of his wife , the
+ " False 262 and £ 315. Joshua Reynolds at first declined 6 [' 262', ' and', ' £', ' 315', '.', ' Joshua', ' Reynolds']
+762 173 Name of father of x -1 Name of father of Joshua Reynolds Samuel Reynolds Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of his wife , the
+ " False Thomas Gainsborough and Joshua Reynolds became two of England's 6 [' Thomas', ' G', 'ains', 'borough', ' and', ' Joshua', ' Reynolds']
+763 174 Name of father of x -1 Name of father of Aldous Huxley Leonard Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' books']" ", the author of Brave New World .
+
+ The following is a list of the most important books" False the English novelist Aldous Huxley and American 7 [' the', ' English', ' novelist', ' Ald', 'ous', ' H', 'ux', 'ley']
+764 174 Name of father of x -1 Name of father of Aldous Huxley Leonard Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' books']" ", the author of Brave New World .
+
+ The following is a list of the most important books" False addressed by Aldous Huxley in his dystopian 6 [' addressed', ' by', ' Ald', 'ous', ' H', 'ux', 'ley']
+765 174 Name of father of x -1 Name of father of Aldous Huxley Leonard Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' books']" ", the author of Brave New World .
+
+ The following is a list of the most important books" False people like Aldous Huxley and Alfred Adler. 6 [' people', ' like', ' Ald', 'ous', ' H', 'ux', 'ley']
+766 174 Name of father of x -1 Name of father of Aldous Huxley Leonard Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' books']" ", the author of Brave New World .
+
+ The following is a list of the most important books" False famous people like Aldous Huxley and Alfred Adler. 7 [' famous', ' people', ' like', ' Ald', 'ous', ' H', 'ux', 'ley']
+767 174 Name of father of x -1 Name of father of Aldous Huxley Leonard Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' books']" ", the author of Brave New World .
+
+ The following is a list of the most important books" False English novelist Aldous Huxley and American 6 [' English', ' novelist', ' Ald', 'ous', ' H', 'ux', 'ley']
+768 175 Name of father of x -1 Name of father of Jorge Luis Borges Jorge Guillermo Borges Haslam Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' Argentine']" ", the Argentine writer , who died in 1986 .
+
+ The name of the father of the Argentine" False Carlos Fuentes, Jorge Luis Borges and Mario Vargas 8 [' Carlos', ' Fu', 'ent', 'es', ',', ' Jorge', ' Luis', ' Borg', 'es']
+769 175 Name of father of x -1 Name of father of Jorge Luis Borges Jorge Guillermo Borges Haslam Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' Argentine']" ", the Argentine writer , who died in 1986 .
+
+ The name of the father of the Argentine" False in the eponymous Jorge Luis Borges story: a place 7 [' in', ' the', ' ep', 'onymous', ' Jorge', ' Luis', ' Borg', 'es']
+770 175 Name of father of x -1 Name of father of Jorge Luis Borges Jorge Guillermo Borges Haslam Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' Argentine']" ", the Argentine writer , who died in 1986 .
+
+ The name of the father of the Argentine" False history such as Jorge Luis Borges and Manuel Belgrano 6 [' history', ' such', ' as', ' Jorge', ' Luis', ' Borg', 'es']
+771 175 Name of father of x -1 Name of father of Jorge Luis Borges Jorge Guillermo Borges Haslam Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' Argentine']" ", the Argentine writer , who died in 1986 .
+
+ The name of the father of the Argentine" False others .... In 1964 Jorge Luis Borges stated that 'the 7 [' others', '....', ' In', ' 1964', ' Jorge', ' Luis', ' Borg', 'es']
+772 175 Name of father of x -1 Name of father of Jorge Luis Borges Jorge Guillermo Borges Haslam Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' Argentine']" ", the Argentine writer , who died in 1986 .
+
+ The name of the father of the Argentine" False artists such as Jorge Luis Borges and Julio Cortázar. 6 [' artists', ' such', ' as', ' Jorge', ' Luis', ' Borg', 'es']
+773 176 Name of father of x -1 Name of father of Diego Velázquez João Rodrigues da Silva Diego Velázquez "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Diego' ' Vel' 'á' 'z' 'quez' '.']" ", the painter .
+
+ The name of the father of the painter Diego Vel á z quez ." False participated in Diego Velázquez de Cuéllar's and Pánfilo 6 [' participated', ' in', ' Diego', ' Vel', 'á', 'z', 'quez']
+774 176 Name of father of x -1 Name of father of Diego Velázquez João Rodrigues da Silva Diego Velázquez "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Diego' ' Vel' 'á' 'z' 'quez' '.']" ", the painter .
+
+ The name of the father of the painter Diego Vel á z quez ." False participated in Diego Velázquez de Cuéllar's 6 [' participated', ' in', ' Diego', ' Vel', 'á', 'z', 'quez']
+775 176 Name of father of x -1 Name of father of Diego Velázquez João Rodrigues da Silva Diego Velázquez "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Diego' ' Vel' 'á' 'z' 'quez' '.']" ", the painter .
+
+ The name of the father of the painter Diego Vel á z quez ." False participated in Diego Velázquez de Cuéllar's and Pánfilo 6 [' participated', ' in', ' Diego', ' Vel', 'á', 'z', 'quez']
+776 176 Name of father of x -1 Name of father of Diego Velázquez João Rodrigues da Silva Diego Velázquez "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' painter' ' Diego' ' Vel' 'á' 'z' 'quez' '.']" ", the painter .
+
+ The name of the father of the painter Diego Vel á z quez ." False participated in Diego Velázquez de Cuéllar's 6 [' participated', ' in', ' Diego', ' Vel', 'á', 'z', 'quez']
+777 178 Name of father of x -1 Name of father of Barbra Streisand Emanuel Streisand Barbra Streisand "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Brigitte Bardot and Barbra Streisand throughout 9 [' Brig', 'itte', ' Bard', 'ot', ' and', ' Barb', 'ra', ' Stre', 'is', 'and']
+778 178 Name of father of x -1 Name of father of Barbra Streisand Emanuel Streisand Barbra Streisand "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False constant jokes about Barbra Streisand started to grow redundant. 7 [' constant', ' jokes', ' about', ' Barb', 'ra', ' Stre', 'is', 'and']
+779 178 Name of father of x -1 Name of father of Barbra Streisand Emanuel Streisand Barbra Streisand "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False " (shared with Barbra Streisand for Funny Girl)
+" 7 [' (', 'shared', ' with', ' Barb', 'ra', ' Stre', 'is', 'and']
+780 178 Name of father of x -1 Name of father of Barbra Streisand Emanuel Streisand Barbra Streisand "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False " nine-minute ""ode to Barbra Streisand and the devil,""" 10 "[' nine', '-', 'minute', ' ""', 'ode', ' to', ' Barb', 'ra', ' Stre', 'is', 'and']"
+781 178 Name of father of x -1 Name of father of Barbra Streisand Emanuel Streisand Barbra Streisand "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Newton-John and Barbra Streisand at number six on 8 [' Newton', '-', 'John', ' and', ' Barb', 'ra', ' Stre', 'is', 'and']
+782 179 Name of father of x -1 Name of father of Jacques-Louis David Louis Maurice David Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of the
+ " False " French artist Jacques-Louis David (1748 – 1825).
+" 5 [' French', ' artist', ' Jacques', '-', 'Louis', ' David']
+783 179 Name of father of x -1 Name of father of Jacques-Louis David Louis Maurice David Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of the
+ " False The influence of Jacques-Louis David can be seen 6 [' The', ' influence', ' of', ' Jacques', '-', 'Louis', ' David']
+784 179 Name of father of x -1 Name of father of Jacques-Louis David Louis Maurice David Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of the
+ " False " French artist Jacques-Louis David (1748 – 1825).
+" 5 [' French', ' artist', ' Jacques', '-', 'Louis', ' David']
+785 179 Name of father of x -1 Name of father of Jacques-Louis David Louis Maurice David Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of the
+ " False 4 ['Jac', 'ques', '-', 'Louis', ' David']
+786 179 Name of father of x -1 Name of father of Jacques-Louis David Louis Maurice David Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of the
+ " False " French artist Jacques-Louis David (1748 – 1825).
+" 5 [' French', ' artist', ' Jacques', '-', 'Louis', ' David']
+787 180 Name of father of x -1 Name of father of Galileo Galilei Vincenzo Galilei Galileo Galilei "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '64' ','
+ ' and' ' died' ' in' ' 16' '42' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 64 , and died in 16 42 .
+
+ The" False were discovered by Galileo Galilei around January 5 [' were', ' discovered', ' by', ' Galileo', ' Galile', 'i']
+788 180 Name of father of x -1 Name of father of Galileo Galilei Vincenzo Galilei Galileo Galilei "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '64' ','
+ ' and' ' died' ' in' ' 16' '42' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 64 , and died in 16 42 .
+
+ The" False Italian scientist Galileo Galilei was clearly Nordic 4 [' Italian', ' scientist', ' Galileo', ' Galile', 'i']
+789 180 Name of father of x -1 Name of father of Galileo Galilei Vincenzo Galilei Galileo Galilei "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '64' ','
+ ' and' ' died' ' in' ' 16' '42' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 64 , and died in 16 42 .
+
+ The" False by Thomas Harriot, Galileo Galilei and other astronomers. 7 [' by', ' Thomas', ' Har', 'riot', ',', ' Galileo', ' Galile', 'i']
+790 180 Name of father of x -1 Name of father of Galileo Galilei Vincenzo Galilei Galileo Galilei "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '64' ','
+ ' and' ' died' ' in' ' 16' '42' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 64 , and died in 16 42 .
+
+ The" False by Thomas Harriot, Galileo Galilei and other astronomers. 7 [' by', ' Thomas', ' Har', 'riot', ',', ' Galileo', ' Galile', 'i']
+791 180 Name of father of x -1 Name of father of Galileo Galilei Vincenzo Galilei Galileo Galilei "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '64' ','
+ ' and' ' died' ' in' ' 16' '42' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 64 , and died in 16 42 .
+
+ The" False 4 ['Gal', 'ile', 'o', ' Galile', 'i']
+792 181 Name of father of x -1 Name of father of Jane Austen George Austen Jane Austen "[',' ' and' ' the' ' author' ' of' ' the' ' _' 'Pr' 'ide' ' and' ' Pre'
+ 'jud' 'ice' '_' ',' ' and' ' _' 'Sense' ' and' ' Sens']" , and the author of the _ Pr ide and Pre jud ice _ , and _ Sense and Sens False adaptation of the Jane Austen novel, portraying 5 [' adaptation', ' of', ' the', ' Jane', ' Aust', 'en']
+793 181 Name of father of x -1 Name of father of Jane Austen George Austen Jane Austen "[',' ' and' ' the' ' author' ' of' ' the' ' _' 'Pr' 'ide' ' and' ' Pre'
+ 'jud' 'ice' '_' ',' ' and' ' _' 'Sense' ' and' ' Sens']" , and the author of the _ Pr ide and Pre jud ice _ , and _ Sense and Sens False population grew. Jane Austen lived in Bath in 5 [' population', ' grew', '.', ' Jane', ' Aust', 'en']
+794 181 Name of father of x -1 Name of father of Jane Austen George Austen Jane Austen "[',' ' and' ' the' ' author' ' of' ' the' ' _' 'Pr' 'ide' ' and' ' Pre'
+ 'jud' 'ice' '_' ',' ' and' ' _' 'Sense' ' and' ' Sens']" , and the author of the _ Pr ide and Pre jud ice _ , and _ Sense and Sens False " Austen =
+" 5 [' Aust', 'en', ' =', 'Jane', ' Aust', 'en']
+795 181 Name of father of x -1 Name of father of Jane Austen George Austen Jane Austen "[',' ' and' ' the' ' author' ' of' ' the' ' _' 'Pr' 'ide' ' and' ' Pre'
+ 'jud' 'ice' '_' ',' ' and' ' _' 'Sense' ' and' ' Sens']" , and the author of the _ Pr ide and Pre jud ice _ , and _ Sense and Sens False suggested the Jane Austen novel Sense 4 [' suggested', ' the', ' Jane', ' Aust', 'en']
+796 181 Name of father of x -1 Name of father of Jane Austen George Austen Jane Austen "[',' ' and' ' the' ' author' ' of' ' the' ' _' 'Pr' 'ide' ' and' ' Pre'
+ 'jud' 'ice' '_' ',' ' and' ' _' 'Sense' ' and' ' Sens']" , and the author of the _ Pr ide and Pre jud ice _ , and _ Sense and Sens False A Memoir of Jane Austen introduced her 6 [' A', ' Mem', 'oir', ' of', ' Jane', ' Aust', 'en']
+797 182 Name of father of x -1 Name of father of Joe Biden Joseph R. Biden Sr. Joe Biden "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' Jill' ' Biden' ',' ' are' ' in'
+ ' the']" , the former vice president of the United States , and his wife , Jill Biden , are in the False here, why is it that Joe Biden is the first 7 [' here', ',', ' why', ' is', ' it', ' that', ' Joe', ' Biden']
+798 182 Name of father of x -1 Name of father of Joe Biden Joseph R. Biden Sr. Joe Biden "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' Jill' ' Biden' ',' ' are' ' in'
+ ' the']" , the former vice president of the United States , and his wife , Jill Biden , are in the False and met Senator Joe Biden as a senior at Delaware 4 [' and', ' met', ' Senator', ' Joe', ' Biden']
+799 182 Name of father of x -1 Name of father of Joe Biden Joseph R. Biden Sr. Joe Biden "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' Jill' ' Biden' ',' ' are' ' in'
+ ' the']" , the former vice president of the United States , and his wife , Jill Biden , are in the False had never heard of Joe Biden and 17 % had no 5 [' had', ' never', ' heard', ' of', ' Joe', ' Biden']
+800 182 Name of father of x -1 Name of father of Joe Biden Joseph R. Biden Sr. Joe Biden "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' Jill' ' Biden' ',' ' are' ' in'
+ ' the']" , the former vice president of the United States , and his wife , Jill Biden , are in the False Obama, Gore, and Joe Biden in Chicago on December 6 [' Obama', ',', ' Gore', ',', ' and', ' Joe', ' Biden']
+801 182 Name of father of x -1 Name of father of Joe Biden Joseph R. Biden Sr. Joe Biden "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' Jill' ' Biden' ',' ' are' ' in'
+ ' the']" , the former vice president of the United States , and his wife , Jill Biden , are in the False 1 ['Joe', ' Biden']
+802 183 Name of father of x -1 Name of father of H. P. Lovecraft Winfield Scott Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n'
+ 'The' ' Nec' 'ron' 'om' 'icon' ' is' ' a' ' fictional' ' book']" ", the author of the Cthulhu Myth os .
+
+ The Nec ron om icon is a fictional book" False writings influenced both H. P. Lovecraft and A. Merritt, 7 [' writings', ' influenced', ' both', ' H', '.', ' P', '.', ' Lovecraft']
+803 183 Name of father of x -1 Name of father of H. P. Lovecraft Winfield Scott Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n'
+ 'The' ' Nec' 'ron' 'om' 'icon' ' is' ' a' ' fictional' ' book']" ", the author of the Cthulhu Myth os .
+
+ The Nec ron om icon is a fictional book" False fantasy author H. P. Lovecraft in June 1920. 6 [' fantasy', ' author', ' H', '.', ' P', '.', ' Lovecraft']
+804 183 Name of father of x -1 Name of father of H. P. Lovecraft Winfield Scott Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n'
+ 'The' ' Nec' 'ron' 'om' 'icon' ' is' ' a' ' fictional' ' book']" ", the author of the Cthulhu Myth os .
+
+ The Nec ron om icon is a fictional book" False there, as was H. P. Lovecraft (who was born in Providence); 8 [' there', ',', ' as', ' was', ' H', '.', ' P', '.', ' Lovecraft']
+805 183 Name of father of x -1 Name of father of H. P. Lovecraft Winfield Scott Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n'
+ 'The' ' Nec' 'ron' 'om' 'icon' ' is' ' a' ' fictional' ' book']" ", the author of the Cthulhu Myth os .
+
+ The Nec ron om icon is a fictional book" False Robert E. Howard, H. P. Lovecraft and Arthur 9 [' Robert', ' E', '.', ' Howard', ',', ' H', '.', ' P', '.', ' Lovecraft']
+806 183 Name of father of x -1 Name of father of H. P. Lovecraft Winfield Scott Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n'
+ 'The' ' Nec' 'ron' 'om' 'icon' ' is' ' a' ' fictional' ' book']" ", the author of the Cthulhu Myth os .
+
+ The Nec ron om icon is a fictional book" False a bust of H. P. Lovecraft designed by 7 [' a', ' bust', ' of', ' H', '.', ' P', '.', ' Lovecraft']
+807 184 Name of father of x -1 Name of father of T. S. Eliot Henry Ware Eliot T. S. Eliot "[',' ' the' ' poet' ',' ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri'
+ ',' ' U' '.' 'S' '.' 'A' '.' '\n' '\n']" ", the poet , born in St . Louis , Missouri , U . S . A .
+
+" False taken up by T. S. Eliot in his poem 7 [' taken', ' up', ' by', ' T', '.', ' S', '.', ' Eliot']
+808 184 Name of father of x -1 Name of father of T. S. Eliot Henry Ware Eliot T. S. Eliot "[',' ' the' ' poet' ',' ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri'
+ ',' ' U' '.' 'S' '.' 'A' '.' '\n' '\n']" ", the poet , born in St . Louis , Missouri , U . S . A .
+
+" False " this system, however. T. S. Eliot said: ""It" 9 [' this', ' system', ',', ' however', '.', ' T', '.', ' S', '.', ' Eliot']
+809 184 Name of father of x -1 Name of father of T. S. Eliot Henry Ware Eliot T. S. Eliot "[',' ' the' ' poet' ',' ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri'
+ ',' ' U' '.' 'S' '.' 'A' '.' '\n' '\n']" ", the poet , born in St . Louis , Missouri , U . S . A .
+
+" False Hugh Kenner and T. S. Eliot believed they saw between 8 [' Hugh', ' Ken', 'ner', ' and', ' T', '.', ' S', '.', ' Eliot']
+810 184 Name of father of x -1 Name of father of T. S. Eliot Henry Ware Eliot T. S. Eliot "[',' ' the' ' poet' ',' ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri'
+ ',' ' U' '.' 'S' '.' 'A' '.' '\n' '\n']" ", the poet , born in St . Louis , Missouri , U . S . A .
+
+" False writings of poets T. S. Eliot and Ezra Pound. 7 [' writings', ' of', ' poets', ' T', '.', ' S', '.', ' Eliot']
+811 184 Name of father of x -1 Name of father of T. S. Eliot Henry Ware Eliot T. S. Eliot "[',' ' the' ' poet' ',' ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri'
+ ',' ' U' '.' 'S' '.' 'A' '.' '\n' '\n']" ", the poet , born in St . Louis , Missouri , U . S . A .
+
+" False accepted by T. S. Eliot for Faber and Faber, 6 [' accepted', ' by', ' T', '.', ' S', '.', ' Eliot']
+812 185 Name of father of x -1 Name of father of Walter Scott Walter Scott Walter Scott "[',' ' the' ' author' ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author'
+ ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author' ' of']" , the author of W aver ley , and the author of W aver ley , and the author of False " characterised by Walter Scott as ""bonnet lairds""," 4 [' character', 'ised', ' by', ' Walter', ' Scott']
+813 185 Name of father of x -1 Name of father of Walter Scott Walter Scott Walter Scott "[',' ' the' ' author' ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author'
+ ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author' ' of']" , the author of W aver ley , and the author of W aver ley , and the author of False " Morritt, who wrote to Sir Walter Scott of his ""fine" 9 [' Mor', 'r', 'itt', ',', ' who', ' wrote', ' to', ' Sir', ' Walter', ' Scott']
+814 185 Name of father of x -1 Name of father of Walter Scott Walter Scott Walter Scott "[',' ' the' ' author' ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author'
+ ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author' ' of']" , the author of W aver ley , and the author of W aver ley , and the author of False works of Sir Walter Scott were best-sellers in 4 [' works', ' of', ' Sir', ' Walter', ' Scott']
+815 185 Name of father of x -1 Name of father of Walter Scott Walter Scott Walter Scott "[',' ' the' ' author' ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author'
+ ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author' ' of']" , the author of W aver ley , and the author of W aver ley , and the author of False Saskatchewan Premier Walter Scott met with Laurier 3 [' Saskatchewan', ' Premier', ' Walter', ' Scott']
+816 185 Name of father of x -1 Name of father of Walter Scott Walter Scott Walter Scott "[',' ' the' ' author' ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author'
+ ' of' ' W' 'aver' 'ley' ',' ' and' ' the' ' author' ' of']" , the author of W aver ley , and the author of W aver ley , and the author of False same name by Sir Walter Scott whose stories 5 [' same', ' name', ' by', ' Sir', ' Walter', ' Scott']
+817 186 Name of father of x -1 Name of father of Kate Winslet Roger John Winslet Kate Winslet "[' and' ' the' ' father' ' of' ' the' ' bride' ',' ' who' ' is' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ""'s""]" " and the father of the bride , who is the father of the groom .
+
+ The bride 's" False 2 ['Kate', ' Wins', 'let']
+818 186 Name of father of x -1 Name of father of Kate Winslet Roger John Winslet Kate Winslet "[' and' ' the' ' father' ' of' ' the' ' bride' ',' ' who' ' is' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ""'s""]" " and the father of the bride , who is the father of the groom .
+
+ The bride 's" False actors, including Kate Winslet and Orlando Bloom, 5 [' actors', ',', ' including', ' Kate', ' Wins', 'let']
+819 186 Name of father of x -1 Name of father of Kate Winslet Roger John Winslet Kate Winslet "[' and' ' the' ' father' ' of' ' the' ' bride' ',' ' who' ' is' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ""'s""]" " and the father of the bride , who is the father of the groom .
+
+ The bride 's" False James Cameron and Kate Winslet in attendance, and 5 [' James', ' Cameron', ' and', ' Kate', ' Wins', 'let']
+820 186 Name of father of x -1 Name of father of Kate Winslet Roger John Winslet Kate Winslet "[' and' ' the' ' father' ' of' ' the' ' bride' ',' ' who' ' is' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ""'s""]" " and the father of the bride , who is the father of the groom .
+
+ The bride 's" False DiCaprio and Kate Winslet as members of different 6 [' Di', 'Cap', 'rio', ' and', ' Kate', ' Wins', 'let']
+821 186 Name of father of x -1 Name of father of Kate Winslet Roger John Winslet Kate Winslet "[' and' ' the' ' father' ' of' ' the' ' bride' ',' ' who' ' is' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ""'s""]" " and the father of the bride , who is the father of the groom .
+
+ The bride 's" False 2 ['Kate', ' Wins', 'let']
+822 187 Name of father of x -1 Name of father of Nelson Mandela Gadla Henry Mphakanyiswa Nelson Mandela "[',' ' the' ' South' ' African' ' anti' '-' 'ap' 'art' 'heid' ' leader'
+ ',' ' who' ' was' ' imprisoned' ' for' ' 27' ' years' '.' '\n' '\n']" ", the South African anti - ap art heid leader , who was imprisoned for 27 years .
+
+" False government to release Nelson Mandela from prison, after 4 [' government', ' to', ' release', ' Nelson', ' Mandela']
+823 187 Name of father of x -1 Name of father of Nelson Mandela Gadla Henry Mphakanyiswa Nelson Mandela "[',' ' the' ' South' ' African' ' anti' '-' 'ap' 'art' 'heid' ' leader'
+ ',' ' who' ' was' ' imprisoned' ' for' ' 27' ' years' '.' '\n' '\n']" ", the South African anti - ap art heid leader , who was imprisoned for 27 years .
+
+" False Foundation in 1994, Nelson Mandela addressed a crowd 5 [' Foundation', ' in', ' 1994', ',', ' Nelson', ' Mandela']
+824 187 Name of father of x -1 Name of father of Nelson Mandela Gadla Henry Mphakanyiswa Nelson Mandela "[',' ' the' ' South' ' African' ' anti' '-' 'ap' 'art' 'heid' ' leader'
+ ',' ' who' ' was' ' imprisoned' ' for' ' 27' ' years' '.' '\n' '\n']" ", the South African anti - ap art heid leader , who was imprisoned for 27 years .
+
+" False " and a martyr."" Nelson Mandela expressed sadness at" 5 "[' and', ' a', ' martyr', '.""', ' Nelson', ' Mandela']"
+825 187 Name of father of x -1 Name of father of Nelson Mandela Gadla Henry Mphakanyiswa Nelson Mandela "[',' ' the' ' South' ' African' ' anti' '-' 'ap' 'art' 'heid' ' leader'
+ ',' ' who' ' was' ' imprisoned' ' for' ' 27' ' years' '.' '\n' '\n']" ", the South African anti - ap art heid leader , who was imprisoned for 27 years .
+
+" False abroad as well. When Nelson Mandela was released after 6 [' abroad', ' as', ' well', '.', ' When', ' Nelson', ' Mandela']
+826 187 Name of father of x -1 Name of father of Nelson Mandela Gadla Henry Mphakanyiswa Nelson Mandela "[',' ' the' ' South' ' African' ' anti' '-' 'ap' 'art' 'heid' ' leader'
+ ',' ' who' ' was' ' imprisoned' ' for' ' 27' ' years' '.' '\n' '\n']" ", the South African anti - ap art heid leader , who was imprisoned for 27 years .
+
+" False Treason Trial of Nelson Mandela and others in 1956, 5 [' Tre', 'ason', ' Trial', ' of', ' Nelson', ' Mandela']
+827 188 Name of father of x -1 Name of father of Pyotr Ilyich Tchaikovsky Ilya Petrovich Tchaikovsky Pyotr Ilyich Tchaikovsky "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Py' 'ot' 'r'
+ ' Ily' 'ich' ' T' 'cha' 'ik' 'ovsky' ' is' ' unknown' '.']" "
+
+ The name of the father of Py ot r Ily ich T cha ik ovsky is unknown ." False 8 ['Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+828 188 Name of father of x -1 Name of father of Pyotr Ilyich Tchaikovsky Ilya Petrovich Tchaikovsky Pyotr Ilyich Tchaikovsky "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Py' 'ot' 'r'
+ ' Ily' 'ich' ' T' 'cha' 'ik' 'ovsky' ' is' ' unknown' '.']" "
+
+ The name of the father of Py ot r Ily ich T cha ik ovsky is unknown ." False " and The Five =
+" 12 [' and', ' The', ' Five', ' =', 'Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+829 188 Name of father of x -1 Name of father of Pyotr Ilyich Tchaikovsky Ilya Petrovich Tchaikovsky Pyotr Ilyich Tchaikovsky "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Py' 'ot' 'r'
+ ' Ily' 'ich' ' T' 'cha' 'ik' 'ovsky' ' is' ' unknown' '.']" "
+
+ The name of the father of Py ot r Ily ich T cha ik ovsky is unknown ." False Dvořák, Isaac Albéniz, Pyotr Ilyich Tchaikovsky and Sergei Rachmaninoff, 20 [' D', 'vo', '�', '�', 'á', 'k', ',', ' Isaac', ' Alb', 'én', 'iz', ',', ' Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+830 188 Name of father of x -1 Name of father of Pyotr Ilyich Tchaikovsky Ilya Petrovich Tchaikovsky Pyotr Ilyich Tchaikovsky "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Py' 'ot' 'r'
+ ' Ily' 'ich' ' T' 'cha' 'ik' 'ovsky' ' is' ' unknown' '.']" "
+
+ The name of the father of Py ot r Ily ich T cha ik ovsky is unknown ." False 8 ['Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+831 188 Name of father of x -1 Name of father of Pyotr Ilyich Tchaikovsky Ilya Petrovich Tchaikovsky Pyotr Ilyich Tchaikovsky "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Py' 'ot' 'r'
+ ' Ily' 'ich' ' T' 'cha' 'ik' 'ovsky' ' is' ' unknown' '.']" "
+
+ The name of the father of Py ot r Ily ich T cha ik ovsky is unknown ." False 8 ['Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+832 189 Name of father of x -1 Name of father of Miguel de Cervantes Rodrigo de Cervantes Miguel de Cervantes "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Miguel' ' de' ' C'
+ 'erv' 'antes' ' is' ' unknown' '.' ' He' ' was' ' born' ' in']" "
+
+ The name of the father of Miguel de C erv antes is unknown . He was born in" False 6 ['M', 'ig', 'uel', ' de', ' C', 'erv', 'antes']
+833 189 Name of father of x -1 Name of father of Miguel de Cervantes Rodrigo de Cervantes Miguel de Cervantes "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Miguel' ' de' ' C'
+ 'erv' 'antes' ' is' ' unknown' '.' ' He' ' was' ' born' ' in']" "
+
+ The name of the father of Miguel de C erv antes is unknown . He was born in" False selected by author Miguel de Cervantes as the first of 7 [' selected', ' by', ' author', ' Miguel', ' de', ' C', 'erv', 'antes']
+834 189 Name of father of x -1 Name of father of Miguel de Cervantes Rodrigo de Cervantes Miguel de Cervantes "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Miguel' ' de' ' C'
+ 'erv' 'antes' ' is' ' unknown' '.' ' He' ' was' ' born' ' in']" "
+
+ The name of the father of Miguel de C erv antes is unknown . He was born in" False is the 1994 Miguel de Cervantes Prize, considered 7 [' is', ' the', ' 1994', ' Miguel', ' de', ' C', 'erv', 'antes']
+835 189 Name of father of x -1 Name of father of Miguel de Cervantes Rodrigo de Cervantes Miguel de Cervantes "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Miguel' ' de' ' C'
+ 'erv' 'antes' ' is' ' unknown' '.' ' He' ' was' ' born' ' in']" "
+
+ The name of the father of Miguel de C erv antes is unknown . He was born in" False received is the 1994 Miguel de Cervantes Prize, considered 8 [' received', ' is', ' the', ' 1994', ' Miguel', ' de', ' C', 'erv', 'antes']
+836 189 Name of father of x -1 Name of father of Miguel de Cervantes Rodrigo de Cervantes Miguel de Cervantes "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Miguel' ' de' ' C'
+ 'erv' 'antes' ' is' ' unknown' '.' ' He' ' was' ' born' ' in']" "
+
+ The name of the father of Miguel de C erv antes is unknown . He was born in" False awarded the Miguel de Cervantes Prize in 1989. 6 [' awarded', ' the', ' Miguel', ' de', ' C', 'erv', 'antes']
+837 190 Name of father of x -1 Name of father of Catherine Zeta-Jones David James Jones Catherine Zeta-Jones "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False starred opposite Catherine Zeta-Jones as an up-and-coming 6 [' starred', ' opposite', ' Catherine', ' Z', 'eta', '-', 'Jones']
+838 190 Name of father of x -1 Name of father of Catherine Zeta-Jones David James Jones Catherine Zeta-Jones "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False desire to have his wife Catherine Zeta-Jones play Janet van 9 [' desire', ' to', ' have', ' his', ' wife', ' Catherine', ' Z', 'eta', '-', 'Jones']
+839 190 Name of father of x -1 Name of father of Catherine Zeta-Jones David James Jones Catherine Zeta-Jones "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " Catherine Zeta-Jones =
+" 4 [' Catherine', ' Z', 'eta', '-', 'Jones']
+840 190 Name of father of x -1 Name of father of Catherine Zeta-Jones David James Jones Catherine Zeta-Jones "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " Zeta-Jones =
+" 10 [' Z', 'eta', '-', 'Jones', ' =', 'C', 'atherine', ' Z', 'eta', '-', 'Jones']
+841 190 Name of father of x -1 Name of father of Catherine Zeta-Jones David James Jones Catherine Zeta-Jones "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False She replaced Catherine Zeta-Jones in the role. The 6 [' She', ' replaced', ' Catherine', ' Z', 'eta', '-', 'Jones']
+842 191 Name of father of x -1 Name of father of Hans Holbein the Younger Hans Holbein the Elder Hans Holbein the Younger "['\n' '\n' 'H' 'ans' ' Hol' 'be' 'in' ' the' ' Younger' ' (' 'German' ':'
+ ' [' 'ˈ' 'h' '�' '�' 'n' ' ' 'ˈ']" "
+
+ H ans Hol be in the Younger ( German : [ ˈ h � � n ˈ" False Catholicism. Hans Holbein the Younger was dispatched 7 [' Catholicism', '.', ' Hans', ' Hol', 'be', 'in', ' the', ' Younger']
+843 191 Name of father of x -1 Name of father of Hans Holbein the Younger Hans Holbein the Elder Hans Holbein the Younger "['\n' '\n' 'H' 'ans' ' Hol' 'be' 'in' ' the' ' Younger' ' (' 'German' ':'
+ ' [' 'ˈ' 'h' '�' '�' 'n' ' ' 'ˈ']" "
+
+ H ans Hol be in the Younger ( German : [ ˈ h � � n ˈ" False Lutheranism and Catholicism. Hans Holbein the Younger was dispatched to 10 [' Lutheran', 'ism', ' and', ' Catholicism', '.', ' Hans', ' Hol', 'be', 'in', ' the', ' Younger']
+844 191 Name of father of x -1 Name of father of Hans Holbein the Younger Hans Holbein the Elder Hans Holbein the Younger "['\n' '\n' 'H' 'ans' ' Hol' 'be' 'in' ' the' ' Younger' ' (' 'German' ':'
+ ' [' 'ˈ' 'h' '�' '�' 'n' ' ' 'ˈ']" "
+
+ H ans Hol be in the Younger ( German : [ ˈ h � � n ˈ" False Catholicism. Hans Holbein the Younger was dispatched 7 [' Catholicism', '.', ' Hans', ' Hol', 'be', 'in', ' the', ' Younger']
+845 191 Name of father of x -1 Name of father of Hans Holbein the Younger Hans Holbein the Elder Hans Holbein the Younger "['\n' '\n' 'H' 'ans' ' Hol' 'be' 'in' ' the' ' Younger' ' (' 'German' ':'
+ ' [' 'ˈ' 'h' '�' '�' 'n' ' ' 'ˈ']" "
+
+ H ans Hol be in the Younger ( German : [ ˈ h � � n ˈ" False Catholicism. Hans Holbein the Younger was dispatched 7 [' Catholicism', '.', ' Hans', ' Hol', 'be', 'in', ' the', ' Younger']
+846 192 Name of father of x -1 Name of father of Steven Spielberg Arnold Spielberg Steven Spielberg "[',' ' the' ' director' ' of' ' the' ' movie' ',' ' and' ' the'
+ ' director' ' of' ' the' ' movie' ',' ' and' ' the' ' director' ' of'
+ ' the' ' movie']" , the director of the movie , and the director of the movie , and the director of the movie False producers negotiated with Steven Spielberg who planned 4 [' producers', ' negotiated', ' with', ' Steven', ' Spielberg']
+847 192 Name of father of x -1 Name of father of Steven Spielberg Arnold Spielberg Steven Spielberg "[',' ' the' ' director' ' of' ' the' ' movie' ',' ' and' ' the'
+ ' director' ' of' ' the' ' movie' ',' ' and' ' the' ' director' ' of'
+ ' the' ' movie']" , the director of the movie , and the director of the movie , and the director of the movie False approached Steven Spielberg for Raiders 2 [' approached', ' Steven', ' Spielberg']
+848 192 Name of father of x -1 Name of father of Steven Spielberg Arnold Spielberg Steven Spielberg "[',' ' the' ' director' ' of' ' the' ' movie' ',' ' and' ' the'
+ ' director' ' of' ' the' ' movie' ',' ' and' ' the' ' director' ' of'
+ ' the' ' movie']" , the director of the movie , and the director of the movie , and the director of the movie False Stiller, Tom Cruise and Steven Spielberg (who had partially 7 [' St', 'iller', ',', ' Tom', ' Cruise', ' and', ' Steven', ' Spielberg']
+849 192 Name of father of x -1 Name of father of Steven Spielberg Arnold Spielberg Steven Spielberg "[',' ' the' ' director' ' of' ' the' ' movie' ',' ' and' ' the'
+ ' director' ' of' ' the' ' movie' ',' ' and' ' the' ' director' ' of'
+ ' the' ' movie']" , the director of the movie , and the director of the movie , and the director of the movie False Richard Johnson says that Steven Spielberg considers The 5 [' Richard', ' Johnson', ' says', ' that', ' Steven', ' Spielberg']
+850 192 Name of father of x -1 Name of father of Steven Spielberg Arnold Spielberg Steven Spielberg "[',' ' the' ' director' ' of' ' the' ' movie' ',' ' and' ' the'
+ ' director' ' of' ' the' ' movie' ',' ' and' ' the' ' director' ' of'
+ ' the' ' movie']" , the director of the movie , and the director of the movie , and the director of the movie False (2004), and the Steven Spielberg historical drama 6 [' (', '2004', '),', ' and', ' the', ' Steven', ' Spielberg']
+851 193 Name of father of x -1 Name of father of Le Corbusier Georges-Édouard Jeanneret Le Corbusier "[',' ' the' ' French' ' architect' ' and' ' urban' ' planner' ',' ' who'
+ ' was' ' born' ' in' ' the' ' city' ' of' ' D' 'ij' 'on' ' in' ' 18']" , the French architect and urban planner , who was born in the city of D ij on in 18 False better known as Le Corbusier. Pei was inspired 6 [' better', ' known', ' as', ' Le', ' Cor', 'bus', 'ier']
+852 193 Name of father of x -1 Name of father of Le Corbusier Georges-Édouard Jeanneret Le Corbusier "[',' ' the' ' French' ' architect' ' and' ' urban' ' planner' ',' ' who'
+ ' was' ' born' ' in' ' the' ' city' ' of' ' D' 'ij' 'on' ' in' ' 18']" , the French architect and urban planner , who was born in the city of D ij on in 18 False showed it off to Le Corbusier when they met. 7 [' showed', ' it', ' off', ' to', ' Le', ' Cor', 'bus', 'ier']
+853 193 Name of father of x -1 Name of father of Le Corbusier Georges-Édouard Jeanneret Le Corbusier "[',' ' the' ' French' ' architect' ' and' ' urban' ' planner' ',' ' who'
+ ' was' ' born' ' in' ' the' ' city' ' of' ' D' 'ij' 'on' ' in' ' 18']" , the French architect and urban planner , who was born in the city of D ij on in 18 False 3 ['Le', ' Cor', 'bus', 'ier']
+854 193 Name of father of x -1 Name of father of Le Corbusier Georges-Édouard Jeanneret Le Corbusier "[',' ' the' ' French' ' architect' ' and' ' urban' ' planner' ',' ' who'
+ ' was' ' born' ' in' ' the' ' city' ' of' ' D' 'ij' 'on' ' in' ' 18']" , the French architect and urban planner , who was born in the city of D ij on in 18 False " I called her Le Corbusier instead.""
+" 6 [' I', ' called', ' her', ' Le', ' Cor', 'bus', 'ier']
+855 193 Name of father of x -1 Name of father of Le Corbusier Georges-Édouard Jeanneret Le Corbusier "[',' ' the' ' French' ' architect' ' and' ' urban' ' planner' ',' ' who'
+ ' was' ' born' ' in' ' the' ' city' ' of' ' D' 'ij' 'on' ' in' ' 18']" , the French architect and urban planner , who was born in the city of D ij on in 18 False Jeanneret-Gris, better known as Le Corbusier. Pei was inspired 13 [' Je', 'ann', 'eret', '-', 'G', 'ris', ',', ' better', ' known', ' as', ' Le', ' Cor', 'bus', 'ier']
+856 194 Name of father of x -1 Name of father of Marcel Duchamp Justin-Isidore Eugène Duchamp Marcel Duchamp "[',' ' the' ' artist' ',' ' and' ' of' ' his' ' wife' ',' ' Suzanne' ' D'
+ 'uch' 'amp' '-' 'V' 'illon' ',' ' who' ' was' ' the']" , the artist , and of his wife , Suzanne D uch amp - V illon , who was the False John Cage and Marcel Duchamp were significant 6 [' John', ' Cage', ' and', ' Marcel', ' D', 'uch', 'amp']
+857 194 Name of father of x -1 Name of father of Marcel Duchamp Justin-Isidore Eugène Duchamp Marcel Duchamp "[',' ' the' ' artist' ',' ' and' ' of' ' his' ' wife' ',' ' Suzanne' ' D'
+ 'uch' 'amp' '-' 'V' 'illon' ',' ' who' ' was' ' the']" , the artist , and of his wife , Suzanne D uch amp - V illon , who was the False of New York. While Marcel Duchamp caused uproar 8 [' of', ' New', ' York', '.', ' While', ' Marcel', ' D', 'uch', 'amp']
+858 194 Name of father of x -1 Name of father of Marcel Duchamp Justin-Isidore Eugène Duchamp Marcel Duchamp "[',' ' the' ' artist' ',' ' and' ' of' ' his' ' wife' ',' ' Suzanne' ' D'
+ 'uch' 'amp' '-' 'V' 'illon' ',' ' who' ' was' ' the']" , the artist , and of his wife , Suzanne D uch amp - V illon , who was the False people like Marcel Duchamp and Peggy Guggenheim 5 [' people', ' like', ' Marcel', ' D', 'uch', 'amp']
+859 194 Name of father of x -1 Name of father of Marcel Duchamp Justin-Isidore Eugène Duchamp Marcel Duchamp "[',' ' the' ' artist' ',' ' and' ' of' ' his' ' wife' ',' ' Suzanne' ' D'
+ 'uch' 'amp' '-' 'V' 'illon' ',' ' who' ' was' ' the']" , the artist , and of his wife , Suzanne D uch amp - V illon , who was the False Picabia, Man Ray and Marcel Duchamp had earlier set up 10 [' Pic', 'ab', 'ia', ',', ' Man', ' Ray', ' and', ' Marcel', ' D', 'uch', 'amp']
+860 194 Name of father of x -1 Name of father of Marcel Duchamp Justin-Isidore Eugène Duchamp Marcel Duchamp "[',' ' the' ' artist' ',' ' and' ' of' ' his' ' wife' ',' ' Suzanne' ' D'
+ 'uch' 'amp' '-' 'V' 'illon' ',' ' who' ' was' ' the']" , the artist , and of his wife , Suzanne D uch amp - V illon , who was the False Picabia, Man Ray and Marcel Duchamp had earlier set 10 [' Pic', 'ab', 'ia', ',', ' Man', ' Ray', ' and', ' Marcel', ' D', 'uch', 'amp']
+861 196 Name of father of x -1 Name of father of George Harrison Harold Hargreaves Harrison George Harrison "[',' ' the' ' first' ' of' ' the' ' Beatles' ',' ' and' ' the' ' first'
+ ' of' ' the' ' Beatles' ',' ' and' ' the' ' first' ' of' ' the'
+ ' Beatles']" , the first of the Beatles , and the first of the Beatles , and the first of the Beatles False of The Moody Blues); George Harrison played electric 6 [' of', ' The', ' Moody', ' Blues', ');', ' George', ' Harrison']
+862 196 Name of father of x -1 Name of father of George Harrison Harold Hargreaves Harrison George Harrison "[',' ' the' ' first' ' of' ' the' ' Beatles' ',' ' and' ' the' ' first'
+ ' of' ' the' ' Beatles' ',' ' and' ' the' ' first' ' of' ' the'
+ ' Beatles']" , the first of the Beatles , and the first of the Beatles , and the first of the Beatles False " her marriage to George Harrison was already ""in" 4 [' her', ' marriage', ' to', ' George', ' Harrison']
+863 196 Name of father of x -1 Name of father of George Harrison Harold Hargreaves Harrison George Harrison "[',' ' the' ' first' ' of' ' the' ' Beatles' ',' ' and' ' the' ' first'
+ ' of' ' the' ' Beatles' ',' ' and' ' the' ' first' ' of' ' the'
+ ' Beatles']" , the first of the Beatles , and the first of the Beatles , and the first of the Beatles False lead guitarist George Harrison and Indian sitar 3 [' lead', ' guitarist', ' George', ' Harrison']
+864 196 Name of father of x -1 Name of father of George Harrison Harold Hargreaves Harrison George Harrison "[',' ' the' ' first' ' of' ' the' ' Beatles' ',' ' and' ' the' ' first'
+ ' of' ' the' ' Beatles' ',' ' and' ' the' ' first' ' of' ' the'
+ ' Beatles']" , the first of the Beatles , and the first of the Beatles , and the first of the Beatles False " Harrison =
+" 3 [' Harrison', ' =', 'George', ' Harrison']
+865 196 Name of father of x -1 Name of father of George Harrison Harold Hargreaves Harrison George Harrison "[',' ' the' ' first' ' of' ' the' ' Beatles' ',' ' and' ' the' ' first'
+ ' of' ' the' ' Beatles' ',' ' and' ' the' ' first' ' of' ' the'
+ ' Beatles']" , the first of the Beatles , and the first of the Beatles , and the first of the Beatles False he later met with George Harrison and Starr to offer 5 [' he', ' later', ' met', ' with', ' George', ' Harrison']
+866 197 Name of father of x -1 Name of father of Alexandre Dumas Thomas-Alexandre Dumas Alexandre Dumas "[',' ' the' ' author' ' of' ' the' ' famous' ' novels' ' ""' 'The' ' Three'
+ ' Mus' 'ket' 'eers' '""' ' and' ' ""' 'The' ' Count' ' of' ' Monte']" ", the author of the famous novels "" The Three Mus ket eers "" and "" The Count of Monte" False Cressida, and also by Alexandre Dumas in The Man in 10 [' C', 'ress', 'ida', ',', ' and', ' also', ' by', ' Alexand', 're', ' Dum', 'as']
+867 197 Name of father of x -1 Name of father of Alexandre Dumas Thomas-Alexandre Dumas Alexandre Dumas "[',' ' the' ' author' ' of' ' the' ' famous' ' novels' ' ""' 'The' ' Three'
+ ' Mus' 'ket' 'eers' '""' ' and' ' ""' 'The' ' Count' ' of' ' Monte']" ", the author of the famous novels "" The Three Mus ket eers "" and "" The Count of Monte" False and also by Alexandre Dumas in The Man in the 6 [' and', ' also', ' by', ' Alexand', 're', ' Dum', 'as']
+868 197 Name of father of x -1 Name of father of Alexandre Dumas Thomas-Alexandre Dumas Alexandre Dumas "[',' ' the' ' author' ' of' ' the' ' famous' ' novels' ' ""' 'The' ' Three'
+ ' Mus' 'ket' 'eers' '""' ' and' ' ""' 'The' ' Count' ' of' ' Monte']" ", the author of the famous novels "" The Three Mus ket eers "" and "" The Count of Monte" False Cressida, and also by Alexandre Dumas in The Man 10 [' C', 'ress', 'ida', ',', ' and', ' also', ' by', ' Alexand', 're', ' Dum', 'as']
+869 197 Name of father of x -1 Name of father of Alexandre Dumas Thomas-Alexandre Dumas Alexandre Dumas "[',' ' the' ' author' ' of' ' the' ' famous' ' novels' ' ""' 'The' ' Three'
+ ' Mus' 'ket' 'eers' '""' ' and' ' ""' 'The' ' Count' ' of' ' Monte']" ", the author of the famous novels "" The Three Mus ket eers "" and "" The Count of Monte" False adventure stories of Alexandre Dumas and similar authors, 6 [' adventure', ' stories', ' of', ' Alexand', 're', ' Dum', 'as']
+870 197 Name of father of x -1 Name of father of Alexandre Dumas Thomas-Alexandre Dumas Alexandre Dumas "[',' ' the' ' author' ' of' ' the' ' famous' ' novels' ' ""' 'The' ' Three'
+ ' Mus' 'ket' 'eers' '""' ' and' ' ""' 'The' ' Count' ' of' ' Monte']" ", the author of the famous novels "" The Three Mus ket eers "" and "" The Count of Monte" False Honoré Daumier, Alexandre Dumas and Victor Hugo. 9 [' Honor', 'é', ' Da', 'um', 'ier', ',', ' Alexand', 're', ' Dum', 'as']
+871 198 Name of father of x -1 Name of father of Robert Louis Stevenson Thomas Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' ' man'
+ ' who' ' wrote' ' the' ' book' ' that' ' inspired' ' the' ' movie' '.'
+ '\n']" ", the author of Treasure Island , and the man who wrote the book that inspired the movie .
+" False important role of Robert Louis Stevenson in flashback scenes 5 [' important', ' role', ' of', ' Robert', ' Louis', ' Stevenson']
+872 198 Name of father of x -1 Name of father of Robert Louis Stevenson Thomas Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' ' man'
+ ' who' ' wrote' ' the' ' book' ' that' ' inspired' ' the' ' movie' '.'
+ '\n']" ", the author of Treasure Island , and the man who wrote the book that inspired the movie .
+" False 2 ['Robert', ' Louis', ' Stevenson']
+873 198 Name of father of x -1 Name of father of Robert Louis Stevenson Thomas Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' ' man'
+ ' who' ' wrote' ' the' ' book' ' that' ' inspired' ' the' ' movie' '.'
+ '\n']" ", the author of Treasure Island , and the man who wrote the book that inspired the movie .
+" False Livingstone (1864), Robert Louis Stevenson (1870) and Mendelssohn 8 [' Living', 'stone', ' (', '18', '64', '),', ' Robert', ' Louis', ' Stevenson']
+874 198 Name of father of x -1 Name of father of Robert Louis Stevenson Thomas Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' ' man'
+ ' who' ' wrote' ' the' ' book' ' that' ' inspired' ' the' ' movie' '.'
+ '\n']" ", the author of Treasure Island , and the man who wrote the book that inspired the movie .
+" False Livingstone (1864), Robert Louis Stevenson (1870) and Mendelssohn 8 [' Living', 'stone', ' (', '18', '64', '),', ' Robert', ' Louis', ' Stevenson']
+875 198 Name of father of x -1 Name of father of Robert Louis Stevenson Thomas Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' ' man'
+ ' who' ' wrote' ' the' ' book' ' that' ' inspired' ' the' ' movie' '.'
+ '\n']" ", the author of Treasure Island , and the man who wrote the book that inspired the movie .
+" False Livingstone (1864), Robert Louis Stevenson (1870) and Mendelssohn 8 [' Living', 'stone', ' (', '18', '64', '),', ' Robert', ' Louis', ' Stevenson']
+876 199 Name of father of x -1 Name of father of Abraham Lincoln Thomas Lincoln Abraham Lincoln "[',' ' the' ' great' ' American' ' president' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' nation' '.' '\n' '\n' 'The' ' first'
+ ' president' ' of']" ", the great American president , and the father of the American nation .
+
+ The first president of" False Douglas who debated Abraham Lincoln in 1858, and lost to 4 [' Douglas', ' who', ' debated', ' Abraham', ' Lincoln']
+877 199 Name of father of x -1 Name of father of Abraham Lincoln Thomas Lincoln Abraham Lincoln "[',' ' the' ' great' ' American' ' president' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' nation' '.' '\n' '\n' 'The' ' first'
+ ' president' ' of']" ", the great American president , and the father of the American nation .
+
+ The first president of" False 2 ['Ab', 'raham', ' Lincoln']
+878 199 Name of father of x -1 Name of father of Abraham Lincoln Thomas Lincoln Abraham Lincoln "[',' ' the' ' great' ' American' ' president' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' nation' '.' '\n' '\n' 'The' ' first'
+ ' president' ' of']" ", the great American president , and the father of the American nation .
+
+ The first president of" False assassinated president Abraham Lincoln would occur in 3 [' assassinated', ' president', ' Abraham', ' Lincoln']
+879 199 Name of father of x -1 Name of father of Abraham Lincoln Thomas Lincoln Abraham Lincoln "[',' ' the' ' great' ' American' ' president' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' nation' '.' '\n' '\n' 'The' ' first'
+ ' president' ' of']" ", the great American president , and the father of the American nation .
+
+ The first president of" False the Lincoln cent. Abraham Lincoln remained a 5 [' the', ' Lincoln', ' cent', '.', ' Abraham', ' Lincoln']
+880 199 Name of father of x -1 Name of father of Abraham Lincoln Thomas Lincoln Abraham Lincoln "[',' ' the' ' great' ' American' ' president' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' nation' '.' '\n' '\n' 'The' ' first'
+ ' president' ' of']" ", the great American president , and the father of the American nation .
+
+ The first president of" False 15, 1861, President Abraham Lincoln called for 50,000 6 [' 15', ',', ' 1861', ',', ' President', ' Abraham', ' Lincoln']
+881 201 Name of father of x -1 Name of father of Jean-Paul Sartre Alex Pazos Bellon Jean-Paul Sartre "[',' ' the' ' French' ' philosopher' ' and' ' writer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' French' ' philosopher' ' and' ' writer' ' Jean'
+ '-' 'Paul']" ", the French philosopher and writer .
+
+ The name of the French philosopher and writer Jean - Paul" False figures such as Jean-Paul Sartre and John Lennon, who 8 [' figures', ' such', ' as', ' Jean', '-', 'Paul', ' S', 'art', 're']
+882 201 Name of father of x -1 Name of father of Jean-Paul Sartre Alex Pazos Bellon Jean-Paul Sartre "[',' ' the' ' French' ' philosopher' ' and' ' writer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' French' ' philosopher' ' and' ' writer' ' Jean'
+ '-' 'Paul']" ", the French philosopher and writer .
+
+ The name of the French philosopher and writer Jean - Paul" False " freedom"", while Jean-Paul Sartre described him" 8 "[' freedom', '"",', ' while', ' Jean', '-', 'Paul', ' S', 'art', 're']"
+883 201 Name of father of x -1 Name of father of Jean-Paul Sartre Alex Pazos Bellon Jean-Paul Sartre "[',' ' the' ' French' ' philosopher' ' and' ' writer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' French' ' philosopher' ' and' ' writer' ' Jean'
+ '-' 'Paul']" ", the French philosopher and writer .
+
+ The name of the French philosopher and writer Jean - Paul" False existentialist Jean-Paul Sartre and novelist 7 [' existential', 'ist', ' Jean', '-', 'Paul', ' S', 'art', 're']
+884 201 Name of father of x -1 Name of father of Jean-Paul Sartre Alex Pazos Bellon Jean-Paul Sartre "[',' ' the' ' French' ' philosopher' ' and' ' writer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' French' ' philosopher' ' and' ' writer' ' Jean'
+ '-' 'Paul']" ", the French philosopher and writer .
+
+ The name of the French philosopher and writer Jean - Paul" False " Friedrich Hegel, Jean-Paul Sartre and others.
+" 8 [' Friedrich', ' Hegel', ',', ' Jean', '-', 'Paul', ' S', 'art', 're']
+885 201 Name of father of x -1 Name of father of Jean-Paul Sartre Alex Pazos Bellon Jean-Paul Sartre "[',' ' the' ' French' ' philosopher' ' and' ' writer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' French' ' philosopher' ' and' ' writer' ' Jean'
+ '-' 'Paul']" ", the French philosopher and writer .
+
+ The name of the French philosopher and writer Jean - Paul" False we had a clue who Jean-Paul Sartre was. We got inspired 10 [' we', ' had', ' a', ' clue', ' who', ' Jean', '-', 'Paul', ' S', 'art', 're']
+886 202 Name of father of x -1 Name of father of Arnold Schwarzenegger Gustav Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False 3 ['Ar', 'n', 'old', ' Schwarzenegger']
+887 202 Name of father of x -1 Name of father of Arnold Schwarzenegger Gustav Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False counterpart of Arnold Schwarzenegger in the film Conan 3 [' counterpart', ' of', ' Arnold', ' Schwarzenegger']
+888 202 Name of father of x -1 Name of father of Arnold Schwarzenegger Gustav Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False doctor alongside Arnold Schwarzenegger and Danny DeVito 3 [' doctor', ' alongside', ' Arnold', ' Schwarzenegger']
+889 202 Name of father of x -1 Name of father of Arnold Schwarzenegger Gustav Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False Eddie, and Darwin. Actor Arnold Schwarzenegger and his wife 7 [' Eddie', ',', ' and', ' Darwin', '.', ' Actor', ' Arnold', ' Schwarzenegger']
+890 202 Name of father of x -1 Name of father of Arnold Schwarzenegger Gustav Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False James Cameron and Arnold Schwarzenegger into the stratosphere. 4 [' James', ' Cameron', ' and', ' Arnold', ' Schwarzenegger']
+891 203 Name of father of x -1 Name of father of Penélope Cruz Eduardo Cruz Penélope Cruz "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Pen' 'é' 'l' 'ope' ' Cruz'
+ '.' ' I' ' have' ' seen' ' her' ' in' ' many']" "
+
+ I am a big fan of Pen é l ope Cruz . I have seen her in many" False " Penélope Cruz =
+" 4 [' Pen', 'é', 'l', 'ope', ' Cruz']
+892 203 Name of father of x -1 Name of father of Penélope Cruz Eduardo Cruz Penélope Cruz "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Pen' 'é' 'l' 'ope' ' Cruz'
+ '.' ' I' ' have' ' seen' ' her' ' in' ' many']" "
+
+ I am a big fan of Pen é l ope Cruz . I have seen her in many" False " Penélope Cruz =
+" 4 [' Pen', 'é', 'l', 'ope', ' Cruz']
+893 203 Name of father of x -1 Name of father of Penélope Cruz Eduardo Cruz Penélope Cruz "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Pen' 'é' 'l' 'ope' ' Cruz'
+ '.' ' I' ' have' ' seen' ' her' ' in' ' many']" "
+
+ I am a big fan of Pen é l ope Cruz . I have seen her in many" False Blackbeard and Penélope Cruz as Angelica, 7 [' Black', 'beard', ' and', ' Pen', 'é', 'l', 'ope', ' Cruz']
+894 203 Name of father of x -1 Name of father of Penélope Cruz Eduardo Cruz Penélope Cruz "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Pen' 'é' 'l' 'ope' ' Cruz'
+ '.' ' I' ' have' ' seen' ' her' ' in' ' many']" "
+
+ I am a big fan of Pen é l ope Cruz . I have seen her in many" False helper Carmen, Penélope Cruz doesn 't hold a candle 7 [' helper', ' Carmen', ',', ' Pen', 'é', 'l', 'ope', ' Cruz']
+895 203 Name of father of x -1 Name of father of Penélope Cruz Eduardo Cruz Penélope Cruz "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Pen' 'é' 'l' 'ope' ' Cruz'
+ '.' ' I' ' have' ' seen' ' her' ' in' ' many']" "
+
+ I am a big fan of Pen é l ope Cruz . I have seen her in many" False 4 ['Pen', 'é', 'l', 'ope', ' Cruz']
+896 204 Name of father of x -1 Name of father of Elton John Stanley Dwight Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' born' ' in' ' the' ' UK' '.' '\n' '\n']" ", the singer , and his wife , the actress , who was born in the UK .
+
+" False along with Elton John on piano and vocals. 4 [' along', ' with', ' El', 'ton', ' John']
+897 204 Name of father of x -1 Name of father of Elton John Stanley Dwight Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' born' ' in' ' the' ' UK' '.' '\n' '\n']" ", the singer , and his wife , the actress , who was born in the UK .
+
+" False songwriters Elton John and Tim Rice, while 4 [' song', 'writers', ' El', 'ton', ' John']
+898 204 Name of father of x -1 Name of father of Elton John Stanley Dwight Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' born' ' in' ' the' ' UK' '.' '\n' '\n']" ", the singer , and his wife , the actress , who was born in the UK .
+
+" False written by Elton John and Bernie Taupin, 4 [' written', ' by', ' El', 'ton', ' John']
+899 204 Name of father of x -1 Name of father of Elton John Stanley Dwight Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' born' ' in' ' the' ' UK' '.' '\n' '\n']" ", the singer , and his wife , the actress , who was born in the UK .
+
+" False " ""Tiny Dancer"" by Elton John (""The Dundies"") and" 9 "[' ""', 'T', 'iny', ' D', 'ancer', '""', ' by', ' El', 'ton', ' John']"
+900 204 Name of father of x -1 Name of father of Elton John Stanley Dwight Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' born' ' in' ' the' ' UK' '.' '\n' '\n']" ", the singer , and his wife , the actress , who was born in the UK .
+
+" False by composer Elton John and lyricist Tim 4 [' by', ' composer', ' El', 'ton', ' John']
+901 205 Name of father of x -1 Name of father of Rudyard Kipling John Lockwood Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' Jungle' ' Book' '.' '\n' '\n' 'The'
+ ' Jungle' ' Book' ' is' ' a' ' collection' ' of' ' stories' ' about' ' M']" ", the author of the Jungle Book .
+
+ The Jungle Book is a collection of stories about M" False " Kipling (ship) =
+" 10 [' Ki', 'pling', ' (', 'ship', ')', ' =', 'R', 'ud', 'yard', ' Ki', 'pling']
+902 205 Name of father of x -1 Name of father of Rudyard Kipling John Lockwood Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' Jungle' ' Book' '.' '\n' '\n' 'The'
+ ' Jungle' ' Book' ' is' ' a' ' collection' ' of' ' stories' ' about' ' M']" ", the author of the Jungle Book .
+
+ The Jungle Book is a collection of stories about M" False Arthur Conan Doyle, Rudyard Kipling as well as newspaper 7 [' Arthur', ' Conan', ' Doyle', ',', ' Rud', 'yard', ' Ki', 'pling']
+903 205 Name of father of x -1 Name of father of Rudyard Kipling John Lockwood Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' Jungle' ' Book' '.' '\n' '\n' 'The'
+ ' Jungle' ' Book' ' is' ' a' ' collection' ' of' ' stories' ' about' ' M']" ", the author of the Jungle Book .
+
+ The Jungle Book is a collection of stories about M" False " Friend"", which Rudyard Kipling thought ""one of the" 6 "[' Friend', '"",', ' which', ' Rud', 'yard', ' Ki', 'pling']"
+904 205 Name of father of x -1 Name of father of Rudyard Kipling John Lockwood Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' Jungle' ' Book' '.' '\n' '\n' 'The'
+ ' Jungle' ' Book' ' is' ' a' ' collection' ' of' ' stories' ' about' ' M']" ", the author of the Jungle Book .
+
+ The Jungle Book is a collection of stories about M" False glowing reference by Rudyard Kipling following his 1889 6 [' glowing', ' reference', ' by', ' Rud', 'yard', ' Ki', 'pling']
+905 205 Name of father of x -1 Name of father of Rudyard Kipling John Lockwood Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' Jungle' ' Book' '.' '\n' '\n' 'The'
+ ' Jungle' ' Book' ' is' ' a' ' collection' ' of' ' stories' ' about' ' M']" ", the author of the Jungle Book .
+
+ The Jungle Book is a collection of stories about M" False Conan Doyle and Rudyard Kipling were three examples. 6 [' Conan', ' Doyle', ' and', ' Rud', 'yard', ' Ki', 'pling']
+906 206 Name of father of x -1 Name of father of Elvis Presley Vernon Elvis Presley Elvis Presley "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Elvis' '.' ' I' ' have'
+ ' been' ' since' ' I' ' was' ' a' ' little' ' girl' '.']" "
+
+ I am a big fan of Elvis . I have been since I was a little girl ." False albums (behind Elvis Presley with 12 albums in 5 [' albums', ' (', 'behind', ' Elvis', ' Pres', 'ley']
+907 206 Name of father of x -1 Name of father of Elvis Presley Vernon Elvis Presley Elvis Presley "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Elvis' '.' ' I' ' have'
+ ' been' ' since' ' I' ' was' ' a' ' little' ' girl' '.']" "
+
+ I am a big fan of Elvis . I have been since I was a little girl ." False focusing mainly on Elvis Presley in an exhibit in 5 [' focusing', ' mainly', ' on', ' Elvis', ' Pres', 'ley']
+908 206 Name of father of x -1 Name of father of Elvis Presley Vernon Elvis Presley Elvis Presley "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Elvis' '.' ' I' ' have'
+ ' been' ' since' ' I' ' was' ' a' ' little' ' girl' '.']" "
+
+ I am a big fan of Elvis . I have been since I was a little girl ." False 3 ['El', 'vis', ' Pres', 'ley']
+909 206 Name of father of x -1 Name of father of Elvis Presley Vernon Elvis Presley Elvis Presley "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Elvis' '.' ' I' ' have'
+ ' been' ' since' ' I' ' was' ' a' ' little' ' girl' '.']" "
+
+ I am a big fan of Elvis . I have been since I was a little girl ." False white performer Elvis Presley who first 4 [' white', ' performer', ' Elvis', ' Pres', 'ley']
+910 206 Name of father of x -1 Name of father of Elvis Presley Vernon Elvis Presley Elvis Presley "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Elvis' '.' ' I' ' have'
+ ' been' ' since' ' I' ' was' ' a' ' little' ' girl' '.']" "
+
+ I am a big fan of Elvis . I have been since I was a little girl ." False legends such as Elvis Presley and Carl Perkins. 5 [' legends', ' such', ' as', ' Elvis', ' Pres', 'ley']
+911 207 Name of father of x -1 Name of father of Stephen King Donald Edwin King Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' movie' ','
+ ' The' ' Shining' '.' '\n' '\n' 'The' ' Shining' ' is' ' a']" ", the author of the book , and the movie , The Shining .
+
+ The Shining is a" False three adaptations of Stephen King stories: The Langoliers, 4 [' three', ' adaptations', ' of', ' Stephen', ' King']
+912 207 Name of father of x -1 Name of father of Stephen King Donald Edwin King Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' movie' ','
+ ' The' ' Shining' '.' '\n' '\n' 'The' ' Shining' ' is' ' a']" ", the author of the book , and the movie , The Shining .
+
+ The Shining is a" False Parsons, and Stephen King (particularly 4 [' Parsons', ',', ' and', ' Stephen', ' King']
+913 207 Name of father of x -1 Name of father of Stephen King Donald Edwin King Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' movie' ','
+ ' The' ' Shining' '.' '\n' '\n' 'The' ' Shining' ' is' ' a']" ", the author of the book , and the movie , The Shining .
+
+ The Shining is a" False infamous 1986 Stephen King movie Maximum Overdrive, 3 [' infamous', ' 1986', ' Stephen', ' King']
+914 207 Name of father of x -1 Name of father of Stephen King Donald Edwin King Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' movie' ','
+ ' The' ' Shining' '.' '\n' '\n' 'The' ' Shining' ' is' ' a']" ", the author of the book , and the movie , The Shining .
+
+ The Shining is a" False considering that Stephen King would appear to be 3 [' considering', ' that', ' Stephen', ' King']
+915 207 Name of father of x -1 Name of father of Stephen King Donald Edwin King Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' movie' ','
+ ' The' ' Shining' '.' '\n' '\n' 'The' ' Shining' ' is' ' a']" ", the author of the book , and the movie , The Shining .
+
+ The Shining is a" False " than cinema."" Stephen King described the series" 4 "[' than', ' cinema', '.""', ' Stephen', ' King']"
+916 208 Name of father of x -1 Name of father of Leonard Cohen Nathan Bernard Cohen Leonard Cohen "[',' ' the' ' poet' ',' ' the' ' singer' ',' ' the' ' song' 'writer' ','
+ ' the' ' man' ' who' ' wrote' ' �' '�' 'H' 'alle' 'lu']" , the poet , the singer , the song writer , the man who wrote � � H alle lu False to Australia for a Leonard Cohen tribute concert, 5 [' to', ' Australia', ' for', ' a', ' Leonard', ' Cohen']
+917 208 Name of father of x -1 Name of father of Leonard Cohen Nathan Bernard Cohen Leonard Cohen "[',' ' the' ' poet' ',' ' the' ' singer' ',' ' the' ' song' 'writer' ','
+ ' the' ' man' ' who' ' wrote' ' �' '�' 'H' 'alle' 'lu']" , the poet , the singer , the song writer , the man who wrote � � H alle lu False songwriter Leonard Cohen and American 3 [' song', 'writer', ' Leonard', ' Cohen']
+918 208 Name of father of x -1 Name of father of Leonard Cohen Nathan Bernard Cohen Leonard Cohen "[',' ' the' ' poet' ',' ' the' ' singer' ',' ' the' ' song' 'writer' ','
+ ' the' ' man' ' who' ' wrote' ' �' '�' 'H' 'alle' 'lu']" , the poet , the singer , the song writer , the man who wrote � � H alle lu False " Secret Life"" by Leonard Cohen and Sharon Robinson." 5 "[' Secret', ' Life', '""', ' by', ' Leonard', ' Cohen']"
+919 208 Name of father of x -1 Name of father of Leonard Cohen Nathan Bernard Cohen Leonard Cohen "[',' ' the' ' poet' ',' ' the' ' singer' ',' ' the' ' song' 'writer' ','
+ ' the' ' man' ' who' ' wrote' ' �' '�' 'H' 'alle' 'lu']" , the poet , the singer , the song writer , the man who wrote � � H alle lu False " Evancho recorded the Leonard Cohen song ""Hallelujah""" 6 [' Ev', 'anch', 'o', ' recorded', ' the', ' Leonard', ' Cohen']
+920 208 Name of father of x -1 Name of father of Leonard Cohen Nathan Bernard Cohen Leonard Cohen "[',' ' the' ' poet' ',' ' the' ' singer' ',' ' the' ' song' 'writer' ','
+ ' the' ' man' ' who' ' wrote' ' �' '�' 'H' 'alle' 'lu']" , the poet , the singer , the song writer , the man who wrote � � H alle lu False work of authors like Leonard Cohen and Walt Whitman 5 [' work', ' of', ' authors', ' like', ' Leonard', ' Cohen']
+921 210 Name of father of x -1 Name of father of Henry VIII of England Henry VII Henry VIII of England "[',' ' and' ' the' ' first' ' of' ' the' ' Tud' 'or' ' dynasty' '.' '\n'
+ '\n' 'The' ' Tud' 'or' ' dynasty' ' was' ' a' ' series' ' of']" ", and the first of the Tud or dynasty .
+
+ The Tud or dynasty was a series of" False film he played Henry VIII of England opposite Scarlett 6 [' film', ' he', ' played', ' Henry', ' VIII', ' of', ' England']
+922 210 Name of father of x -1 Name of father of Henry VIII of England Henry VII Henry VIII of England "[',' ' and' ' the' ' first' ' of' ' the' ' Tud' 'or' ' dynasty' '.' '\n'
+ '\n' 'The' ' Tud' 'or' ' dynasty' ' was' ' a' ' series' ' of']" ", and the first of the Tud or dynasty .
+
+ The Tud or dynasty was a series of" False seen in portraits of Henry VIII of England and Alessandro 7 [' seen', ' in', ' portraits', ' of', ' Henry', ' VIII', ' of', ' England']
+923 210 Name of father of x -1 Name of father of Henry VIII of England Henry VII Henry VIII of England "[',' ' and' ' the' ' first' ' of' ' the' ' Tud' 'or' ' dynasty' '.' '\n'
+ '\n' 'The' ' Tud' 'or' ' dynasty' ' was' ' a' ' series' ' of']" ", and the first of the Tud or dynasty .
+
+ The Tud or dynasty was a series of" False " VIII of England =
+" 7 [' VIII', ' of', ' England', ' =', 'Henry', ' VIII', ' of', ' England']
+924 210 Name of father of x -1 Name of father of Henry VIII of England Henry VII Henry VIII of England "[',' ' and' ' the' ' first' ' of' ' the' ' Tud' 'or' ' dynasty' '.' '\n'
+ '\n' 'The' ' Tud' 'or' ' dynasty' ' was' ' a' ' series' ' of']" ", and the first of the Tud or dynasty .
+
+ The Tud or dynasty was a series of" False A 1543 statute of Henry VIII of England permits treason 8 [' A', ' 15', '43', ' statute', ' of', ' Henry', ' VIII', ' of', ' England']
+925 210 Name of father of x -1 Name of father of Henry VIII of England Henry VII Henry VIII of England "[',' ' and' ' the' ' first' ' of' ' the' ' Tud' 'or' ' dynasty' '.' '\n'
+ '\n' 'The' ' Tud' 'or' ' dynasty' ' was' ' a' ' series' ' of']" ", and the first of the Tud or dynasty .
+
+ The Tud or dynasty was a series of" False of Cleves, married Henry VIII of England in 1540, and 8 [' of', ' Cle', 'ves', ',', ' married', ' Henry', ' VIII', ' of', ' England']
+926 211 Name of father of x -1 Name of father of Denis Diderot Didier Diderot Denis Diderot "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' author' ' of' ' the'
+ ' _' 'En' 'cyclop' 'é' 'die' '_' ',' ' and' ' of' ' the']" , the philosopher , and of the author of the _ En cyclop é die _ , and of the False largely begun by Denis Diderot and Jean le Rond 6 [' largely', ' begun', ' by', ' Denis', ' D', 'ider', 'ot']
+927 211 Name of father of x -1 Name of father of Denis Diderot Didier Diderot Denis Diderot "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' author' ' of' ' the'
+ ' _' 'En' 'cyclop' 'é' 'die' '_' ',' ' and' ' of' ' the']" , the philosopher , and of the author of the _ En cyclop é die _ , and of the False drew praise from Denis Diderot in the Encyclopédie 6 [' drew', ' praise', ' from', ' Denis', ' D', 'ider', 'ot']
+928 211 Name of father of x -1 Name of father of Denis Diderot Didier Diderot Denis Diderot "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' author' ' of' ' the'
+ ' _' 'En' 'cyclop' 'é' 'die' '_' ',' ' and' ' of' ' the']" , the philosopher , and of the author of the _ En cyclop é die _ , and of the False praise from Denis Diderot in the Encyclopédie 5 [' praise', ' from', ' Denis', ' D', 'ider', 'ot']
+929 211 Name of father of x -1 Name of father of Denis Diderot Didier Diderot Denis Diderot "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' author' ' of' ' the'
+ ' _' 'En' 'cyclop' 'é' 'die' '_' ',' ' and' ' of' ' the']" , the philosopher , and of the author of the _ En cyclop é die _ , and of the False praise from Denis Diderot in the Encyclopédie 5 [' praise', ' from', ' Denis', ' D', 'ider', 'ot']
+930 211 Name of father of x -1 Name of father of Denis Diderot Didier Diderot Denis Diderot "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' author' ' of' ' the'
+ ' _' 'En' 'cyclop' 'é' 'die' '_' ',' ' and' ' of' ' the']" , the philosopher , and of the author of the _ En cyclop é die _ , and of the False French philosopher Denis Diderot was visiting Russia 5 [' French', ' philosopher', ' Denis', ' D', 'ider', 'ot']
+931 212 Name of father of x -1 Name of father of James Joyce John Stanislaus Joyce James Joyce "[',' ' the' ' author' ' of' ' U' 'lyss' 'es' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ' that' ' made' ' him' ' famous' '.' '\n']" ", the author of U lyss es , and the author of the book that made him famous .
+" False Prize winner, James Joyce is widely 4 [' Prize', ' winner', ',', ' James', ' Joyce']
+932 212 Name of father of x -1 Name of father of James Joyce John Stanislaus Joyce James Joyce "[',' ' the' ' author' ' of' ' U' 'lyss' 'es' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ' that' ' made' ' him' ' famous' '.' '\n']" ", the author of U lyss es , and the author of the book that made him famous .
+" False particular attention to James Joyce and D.H. Lawrence. 4 [' particular', ' attention', ' to', ' James', ' Joyce']
+933 212 Name of father of x -1 Name of father of James Joyce John Stanislaus Joyce James Joyce "[',' ' the' ' author' ' of' ' U' 'lyss' 'es' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ' that' ' made' ' him' ' famous' '.' '\n']" ", the author of U lyss es , and the author of the book that made him famous .
+" False Stephan Zweig and James Joyce – the former noting 6 [' Stephan', ' Z', 'we', 'ig', ' and', ' James', ' Joyce']
+934 212 Name of father of x -1 Name of father of James Joyce John Stanislaus Joyce James Joyce "[',' ' the' ' author' ' of' ' U' 'lyss' 'es' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ' that' ' made' ' him' ' famous' '.' '\n']" ", the author of U lyss es , and the author of the book that made him famous .
+" False novel Ulysses, James Joyce mentions the 6 [' novel', ' U', 'lyss', 'es', ',', ' James', ' Joyce']
+935 212 Name of father of x -1 Name of father of James Joyce John Stanislaus Joyce James Joyce "[',' ' the' ' author' ' of' ' U' 'lyss' 'es' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ' that' ' made' ' him' ' famous' '.' '\n']" ", the author of U lyss es , and the author of the book that made him famous .
+" False 1 ['James', ' Joyce']
+936 213 Name of father of x -1 Name of father of Ivan Turgenev Sergey Turgenev Ivan Turgenev "['\n' '\n' 'I' 'van' ' T' 'urg' 'ene' 'v' ' (' '18' '18' '–' '18' '83' ')'
+ ' was' ' a' ' Russian' ' writer' ',']" "
+
+ I van T urg ene v ( 18 18 – 18 83 ) was a Russian writer ," False outlet for both Ivan Turgenev and Leo Tolstoy. 7 [' outlet', ' for', ' both', ' Ivan', ' T', 'urg', 'ene', 'v']
+937 213 Name of father of x -1 Name of father of Ivan Turgenev Sergey Turgenev Ivan Turgenev "['\n' '\n' 'I' 'van' ' T' 'urg' 'ene' 'v' ' (' '18' '18' '–' '18' '83' ')'
+ ' was' ' a' ' Russian' ' writer' ',']" "
+
+ I van T urg ene v ( 18 18 – 18 83 ) was a Russian writer ," False the outlet for both Ivan Turgenev and Leo Tolstoy. 8 [' the', ' outlet', ' for', ' both', ' Ivan', ' T', 'urg', 'ene', 'v']
+938 213 Name of father of x -1 Name of father of Ivan Turgenev Sergey Turgenev Ivan Turgenev "['\n' '\n' 'I' 'van' ' T' 'urg' 'ene' 'v' ' (' '18' '18' '–' '18' '83' ')'
+ ' was' ' a' ' Russian' ' writer' ',']" "
+
+ I van T urg ene v ( 18 18 – 18 83 ) was a Russian writer ," False Nekrasov and Ivan Turgenev worsened Dostoyevsky's 8 [' Nek', 'ras', 'ov', ' and', ' Ivan', ' T', 'urg', 'ene', 'v']
+939 213 Name of father of x -1 Name of father of Ivan Turgenev Sergey Turgenev Ivan Turgenev "['\n' '\n' 'I' 'van' ' T' 'urg' 'ene' 'v' ' (' '18' '18' '–' '18' '83' ')'
+ ' was' ' a' ' Russian' ' writer' ',']" "
+
+ I van T urg ene v ( 18 18 – 18 83 ) was a Russian writer ," False outlet for both Ivan Turgenev and Leo Tolstoy. 7 [' outlet', ' for', ' both', ' Ivan', ' T', 'urg', 'ene', 'v']
+940 213 Name of father of x -1 Name of father of Ivan Turgenev Sergey Turgenev Ivan Turgenev "['\n' '\n' 'I' 'van' ' T' 'urg' 'ene' 'v' ' (' '18' '18' '–' '18' '83' ')'
+ ' was' ' a' ' Russian' ' writer' ',']" "
+
+ I van T urg ene v ( 18 18 – 18 83 ) was a Russian writer ," False Nikolay Nekrasov and Ivan Turgenev worsened Dostoyevsky's 10 [' Nikol', 'ay', ' Nek', 'ras', 'ov', ' and', ' Ivan', ' T', 'urg', 'ene', 'v']
+941 214 Name of father of x -1 Name of father of Frank Sinatra Anthony Martin Sinatra Frank Sinatra "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Sin'
+ 'atra' ' family' '.' ' I' ' have' ' been' ' a']" ", Jr .
+
+ I am a big fan of the Sin atra family . I have been a" False special, guest starring Frank Sinatra and Dean Martin, was 6 [' special', ',', ' guest', ' starring', ' Frank', ' Sin', 'atra']
+942 214 Name of father of x -1 Name of father of Frank Sinatra Anthony Martin Sinatra Frank Sinatra "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Sin'
+ 'atra' ' family' '.' ' I' ' have' ' been' ' a']" ", Jr .
+
+ I am a big fan of the Sin atra family . I have been a" False " a Cold =
+" 5 [' a', ' Cold', ' =', 'Frank', ' Sin', 'atra']
+943 214 Name of father of x -1 Name of father of Frank Sinatra Anthony Martin Sinatra Frank Sinatra "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Sin'
+ 'atra' ' family' '.' ' I' ' have' ' been' ' a']" ", Jr .
+
+ I am a big fan of the Sin atra family . I have been a" False featured on Frank Sinatra Sr. ' s 1958 album, 4 [' featured', ' on', ' Frank', ' Sin', 'atra']
+944 214 Name of father of x -1 Name of father of Frank Sinatra Anthony Martin Sinatra Frank Sinatra "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Sin'
+ 'atra' ' family' '.' ' I' ' have' ' been' ' a']" ", Jr .
+
+ I am a big fan of the Sin atra family . I have been a" False " Manchurian Candidate"", with Frank Sinatra (as Major Marco)" 9 "[' Man', 'ch', 'ur', 'ian', ' Candidate', '"",', ' with', ' Frank', ' Sin', 'atra']"
+945 214 Name of father of x -1 Name of father of Frank Sinatra Anthony Martin Sinatra Frank Sinatra "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Sin'
+ 'atra' ' family' '.' ' I' ' have' ' been' ' a']" ", Jr .
+
+ I am a big fan of the Sin atra family . I have been a" False Bobbie Burns. Frank Sinatra frequented Sir 6 [' Bob', 'bie', ' Burns', '.', ' Frank', ' Sin', 'atra']
+946 215 Name of father of x -1 Name of father of Milla Jovovich Bogdan Jovovich Milla Jovovich "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False James Purefoy and Milla Jovovich were the first of the 9 [' James', ' Pure', 'f', 'oy', ' and', ' M', 'illa', ' J', 'ov', 'ovich']
+947 215 Name of father of x -1 Name of father of Milla Jovovich Bogdan Jovovich Milla Jovovich "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False " Jovovich as Joan of Arc
+" 11 [' J', 'ov', 'ovich', ' as', ' Joan', ' of', ' Arc', 'M', 'illa', ' J', 'ov', 'ovich']
+948 215 Name of father of x -1 Name of father of Milla Jovovich Bogdan Jovovich Milla Jovovich "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False " Joan of Arc
+" 7 [' Joan', ' of', ' Arc', 'M', 'illa', ' J', 'ov', 'ovich']
+949 215 Name of father of x -1 Name of father of Milla Jovovich Bogdan Jovovich Milla Jovovich "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False The film stars Milla Jovovich and Michelle Rodriguez. 7 [' The', ' film', ' stars', ' M', 'illa', ' J', 'ov', 'ovich']
+950 215 Name of father of x -1 Name of father of Milla Jovovich Bogdan Jovovich Milla Jovovich "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False James Purefoy and Milla Jovovich were the first 9 [' James', ' Pure', 'f', 'oy', ' and', ' M', 'illa', ' J', 'ov', 'ovich']
+951 216 Name of father of x -1 Name of father of Otto von Bismarck Ferdinand von Bismarck Otto von Bismarck "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Otto' ' von' ' B'
+ 'ism' 'ar' 'ck' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of Otto von B ism ar ck is not known .
+
+" False debate whether Otto von Bismarck — Minister President 7 [' debate', ' whether', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+952 216 Name of father of x -1 Name of father of Otto von Bismarck Ferdinand von Bismarck Otto von Bismarck "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Otto' ' von' ' B'
+ 'ism' 'ar' 'ck' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of Otto von B ism ar ck is not known .
+
+" False and foreigners: Otto von Bismarck (1911), Ivar Aasen 8 [' and', ' foreigners', ':', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+953 216 Name of father of x -1 Name of father of Otto von Bismarck Ferdinand von Bismarck Otto von Bismarck "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Otto' ' von' ' B'
+ 'ism' 'ar' 'ck' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of Otto von B ism ar ck is not known .
+
+" False named Dash. In 1896, Otto von Bismarck purchased a King 11 [' named', ' Dash', '.', ' In', ' 1896', ',', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+954 216 Name of father of x -1 Name of father of Otto von Bismarck Ferdinand von Bismarck Otto von Bismarck "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Otto' ' von' ' B'
+ 'ism' 'ar' 'ck' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of Otto von B ism ar ck is not known .
+
+" False Germany. Chancellor Otto von Bismarck ordered a court-martial 8 [' Germany', '.', ' Chancellor', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+955 216 Name of father of x -1 Name of father of Otto von Bismarck Ferdinand von Bismarck Otto von Bismarck "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Otto' ' von' ' B'
+ 'ism' 'ar' 'ck' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of Otto von B ism ar ck is not known .
+
+" False Minister President Otto von Bismarck by the mid-1860s. 7 [' Minister', ' President', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+956 217 Name of father of x -1 Name of father of Sean Connery Joseph Connery Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the' ' 00'
+ '7' ' films' '.' '\n' '\n' 'The' ' actor' ',' ' who' ' was']" ", the actor who played James Bond in the 00 7 films .
+
+ The actor , who was" False " film to star Sean Connery as Bond.
+" 5 [' film', ' to', ' star', ' Sean', ' Con', 'nery']
+957 217 Name of father of x -1 Name of father of Sean Connery Joseph Connery Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the' ' 00'
+ '7' ' films' '.' '\n' '\n' 'The' ' actor' ',' ' who' ' was']" ", the actor who played James Bond in the 00 7 films .
+
+ The actor , who was" False 2 ['Sean', ' Con', 'nery']
+958 217 Name of father of x -1 Name of father of Sean Connery Joseph Connery Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the' ' 00'
+ '7' ' films' '.' '\n' '\n' 'The' ' actor' ',' ' who' ' was']" ", the actor who played James Bond in the 00 7 films .
+
+ The actor , who was" False Hopkins and a bit Sean Connery and that went 6 [' Hopkins', ' and', ' a', ' bit', ' Sean', ' Con', 'nery']
+959 217 Name of father of x -1 Name of father of Sean Connery Joseph Connery Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the' ' 00'
+ '7' ' films' '.' '\n' '\n' 'The' ' actor' ',' ' who' ' was']" ", the actor who played James Bond in the 00 7 films .
+
+ The actor , who was" False " ""It's not just that Sean Connery looks a lot more haggard" 8 "[' ""', 'It', ""'s"", ' not', ' just', ' that', ' Sean', ' Con', 'nery']"
+960 217 Name of father of x -1 Name of father of Sean Connery Joseph Connery Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the' ' 00'
+ '7' ' films' '.' '\n' '\n' 'The' ' actor' ',' ' who' ' was']" ", the actor who played James Bond in the 00 7 films .
+
+ The actor , who was" False Productions series; Sean Connery portrayed Bond, 5 [' Productions', ' series', ';', ' Sean', ' Con', 'nery']
+961 218 Name of father of x -1 Name of father of Benjamin Franklin Josiah Franklin Benjamin Franklin "[',' ' the' ' inventor' ' of' ' electricity' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' republic' '.' '\n' '\n' 'The' ' first'
+ ' American' ' to']" ", the inventor of electricity , and the father of the American republic .
+
+ The first American to" False provisional governor Benjamin Franklin Perry objected 3 [' provisional', ' governor', ' Benjamin', ' Franklin']
+962 218 Name of father of x -1 Name of father of Benjamin Franklin Josiah Franklin Benjamin Franklin "[',' ' the' ' inventor' ' of' ' electricity' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' republic' '.' '\n' '\n' 'The' ' first'
+ ' American' ' to']" ", the inventor of electricity , and the father of the American republic .
+
+ The first American to" False intersections with USS Benjamin Franklin Drive opposite St. 4 [' intersections', ' with', ' USS', ' Benjamin', ' Franklin']
+963 218 Name of father of x -1 Name of father of Benjamin Franklin Josiah Franklin Benjamin Franklin "[',' ' the' ' inventor' ' of' ' electricity' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' republic' '.' '\n' '\n' 'The' ' first'
+ ' American' ' to']" ", the inventor of electricity , and the father of the American republic .
+
+ The first American to" False communications with Benjamin Franklin who at times lobbied 3 [' communications', ' with', ' Benjamin', ' Franklin']
+964 218 Name of father of x -1 Name of father of Benjamin Franklin Josiah Franklin Benjamin Franklin "[',' ' the' ' inventor' ' of' ' electricity' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' republic' '.' '\n' '\n' 'The' ' first'
+ ' American' ' to']" ", the inventor of electricity , and the father of the American republic .
+
+ The first American to" False 2 ['Ben', 'jamin', ' Franklin']
+965 218 Name of father of x -1 Name of father of Benjamin Franklin Josiah Franklin Benjamin Franklin "[',' ' the' ' inventor' ' of' ' electricity' ',' ' and' ' the' ' father'
+ ' of' ' the' ' American' ' republic' '.' '\n' '\n' 'The' ' first'
+ ' American' ' to']" ", the inventor of electricity , and the father of the American republic .
+
+ The first American to" False Keller, the widow of Benjamin Franklin Keller, a judge for 6 [' Keller', ',', ' the', ' widow', ' of', ' Benjamin', ' Franklin']
+966 219 Name of father of x -1 Name of father of Ernest Hemingway Clarence Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' ""' 'The' ' Sun' ' Also' ' R' 'ises' '""'
+ ' and' ' ""' 'A' ' Fare' 'well' ' to' ' Arms' '.""' '\n']" ", the author of "" The Sun Also R ises "" and "" A Fare well to Arms .""
+" False friendship with Ernest Hemingway began at Sun Valley 5 [' friendship', ' with', ' Ernest', ' Hem', 'ing', 'way']
+967 219 Name of father of x -1 Name of father of Ernest Hemingway Clarence Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' ""' 'The' ' Sun' ' Also' ' R' 'ises' '""'
+ ' and' ' ""' 'A' ' Fare' 'well' ' to' ' Arms' '.""' '\n']" ", the author of "" The Sun Also R ises "" and "" A Fare well to Arms .""
+" False military officers and Ernest Hemingway after he pranked 6 [' military', ' officers', ' and', ' Ernest', ' Hem', 'ing', 'way']
+968 219 Name of father of x -1 Name of father of Ernest Hemingway Clarence Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' ""' 'The' ' Sun' ' Also' ' R' 'ises' '""'
+ ' and' ' ""' 'A' ' Fare' 'well' ' to' ' Arms' '.""' '\n']" ", the author of "" The Sun Also R ises "" and "" A Fare well to Arms .""
+" False " blinding signs."" Ernest Hemingway wrote in A Moveable" 6 "[' blinding', ' signs', '.""', ' Ernest', ' Hem', 'ing', 'way']"
+969 219 Name of father of x -1 Name of father of Ernest Hemingway Clarence Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' ""' 'The' ' Sun' ' Also' ' R' 'ises' '""'
+ ' and' ' ""' 'A' ' Fare' 'well' ' to' ' Arms' '.""' '\n']" ", the author of "" The Sun Also R ises "" and "" A Fare well to Arms .""
+" False including authors Ernest Hemingway and James A. Michener; 5 [' including', ' authors', ' Ernest', ' Hem', 'ing', 'way']
+970 219 Name of father of x -1 Name of father of Ernest Hemingway Clarence Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' ""' 'The' ' Sun' ' Also' ' R' 'ises' '""'
+ ' and' ' ""' 'A' ' Fare' 'well' ' to' ' Arms' '.""' '\n']" ", the author of "" The Sun Also R ises "" and "" A Fare well to Arms .""
+" False Coco Chanel and Ernest Hemingway who lived at the 7 [' Coco', ' Chan', 'el', ' and', ' Ernest', ' Hem', 'ing', 'way']
+971 220 Name of father of x -1 Name of father of John Singer Sargent FitzWilliam Sargent John Singer Sargent "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' S' 'arg' 'ent' ' family'
+ ' was' ' a' ' prominent' ' one' ' in' ' Boston' '.' ' John' ' Singer']" ", the painter .
+
+ The S arg ent family was a prominent one in Boston . John Singer" False Hart Benton, John Singer Sargent and one of the original 8 [' Hart', ' Bent', 'on', ',', ' John', ' Singer', ' S', 'arg', 'ent']
+972 220 Name of father of x -1 Name of father of John Singer Sargent FitzWilliam Sargent John Singer Sargent "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' S' 'arg' 'ent' ' family'
+ ' was' ' a' ' prominent' ' one' ' in' ' Boston' '.' ' John' ' Singer']" ", the painter .
+
+ The S arg ent family was a prominent one in Boston . John Singer" False " Letter at Newport. John Singer Sargent considered her ""one" 8 [' Letter', ' at', ' Newport', '.', ' John', ' Singer', ' S', 'arg', 'ent']
+973 220 Name of father of x -1 Name of father of John Singer Sargent FitzWilliam Sargent John Singer Sargent "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' S' 'arg' 'ent' ' family'
+ ' was' ' a' ' prominent' ' one' ' in' ' Boston' '.' ' John' ' Singer']" ", the painter .
+
+ The S arg ent family was a prominent one in Boston . John Singer" False public. In 1879 John Singer Sargent painted a small-scale 9 [' public', '.', ' In', ' 18', '79', ' John', ' Singer', ' S', 'arg', 'ent']
+974 220 Name of father of x -1 Name of father of John Singer Sargent FitzWilliam Sargent John Singer Sargent "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' S' 'arg' 'ent' ' family'
+ ' was' ' a' ' prominent' ' one' ' in' ' Boston' '.' ' John' ' Singer']" ", the painter .
+
+ The S arg ent family was a prominent one in Boston . John Singer" False while Renoir and John Singer Sargent visited in the 8 [' while', ' Ren', 'oir', ' and', ' John', ' Singer', ' S', 'arg', 'ent']
+975 220 Name of father of x -1 Name of father of John Singer Sargent FitzWilliam Sargent John Singer Sargent "[',' ' the' ' painter' '.' '\n' '\n' 'The' ' S' 'arg' 'ent' ' family'
+ ' was' ' a' ' prominent' ' one' ' in' ' Boston' '.' ' John' ' Singer']" ", the painter .
+
+ The S arg ent family was a prominent one in Boston . John Singer" False portrait, painted by John Singer Sargent in 1907, and his own 8 [' portrait', ',', ' painted', ' by', ' John', ' Singer', ' S', 'arg', 'ent']
+976 221 Name of father of x -1 Name of father of Émile Zola Francesco Zola Émile Zola "[',' ' the' ' novelist' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 1840' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' writer']" ", the novelist , who was born in Paris in 1840 .
+
+ The name of the French writer" False those by Leo Tolstoy, Émile Zola and Victor Hugo, all 10 [' those', ' by', ' Leo', ' Tol', 'st', 'oy', ',', ' É', 'mile', ' Z', 'ola']
+977 221 Name of father of x -1 Name of father of Émile Zola Francesco Zola Émile Zola "[',' ' the' ' novelist' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 1840' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' writer']" ", the novelist , who was born in Paris in 1840 .
+
+ The name of the French writer" False and other novels, Émile Zola also reported the briquet: 7 [' and', ' other', ' novels', ',', ' É', 'mile', ' Z', 'ola']
+978 221 Name of father of x -1 Name of father of Émile Zola Francesco Zola Émile Zola "[',' ' the' ' novelist' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 1840' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' writer']" ", the novelist , who was born in Paris in 1840 .
+
+ The name of the French writer" False The French novelist Émile Zola lived in what is now 6 [' The', ' French', ' novelist', ' É', 'mile', ' Z', 'ola']
+979 221 Name of father of x -1 Name of father of Émile Zola Francesco Zola Émile Zola "[',' ' the' ' novelist' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 1840' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' writer']" ", the novelist , who was born in Paris in 1840 .
+
+ The name of the French writer" False literature. Novelist Émile Zola called it an important 7 [' literature', '.', ' Novel', 'ist', ' É', 'mile', ' Z', 'ola']
+980 221 Name of father of x -1 Name of father of Émile Zola Francesco Zola Émile Zola "[',' ' the' ' novelist' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 1840' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' writer']" ", the novelist , who was born in Paris in 1840 .
+
+ The name of the French writer" False working-class girl. Émile Zola also approved, comparing 8 [' working', '-', 'class', ' girl', '.', ' É', 'mile', ' Z', 'ola']
+981 222 Name of father of x -1 Name of father of Giorgio de Chirico Evaristo de Chirico Giorgio de Chirico "['\n' '\n' 'Name' ' of' ' mother' ' of' ' G' 'ior' 'gio' ' de' ' Ch' 'ir'
+ 'ico' '\n' '\n' 'Name' ' of' ' mother' ' of' ' G']" "
+
+ Name of mother of G ior gio de Ch ir ico
+
+ Name of mother of G" False visual artist Giorgio de Chirico and with the 8 [' visual', ' artist', ' G', 'ior', 'gio', ' de', ' Ch', 'ir', 'ico']
+982 222 Name of father of x -1 Name of father of Giorgio de Chirico Evaristo de Chirico Giorgio de Chirico "['\n' '\n' 'Name' ' of' ' mother' ' of' ' G' 'ior' 'gio' ' de' ' Ch' 'ir'
+ 'ico' '\n' '\n' 'Name' ' of' ' mother' ' of' ' G']" "
+
+ Name of mother of G ior gio de Ch ir ico
+
+ Name of mother of G" False surrealist artist Giorgio de Chirico and his work, The 9 [' surreal', 'ist', ' artist', ' G', 'ior', 'gio', ' de', ' Ch', 'ir', 'ico']
+983 222 Name of father of x -1 Name of father of Giorgio de Chirico Evaristo de Chirico Giorgio de Chirico "['\n' '\n' 'Name' ' of' ' mother' ' of' ' G' 'ior' 'gio' ' de' ' Ch' 'ir'
+ 'ico' '\n' '\n' 'Name' ' of' ' mother' ' of' ' G']" "
+
+ Name of mother of G ior gio de Ch ir ico
+
+ Name of mother of G" False Salvador Dalí and Giorgio de Chirico were used as 10 [' Salvador', ' Dal', 'í', ' and', ' G', 'ior', 'gio', ' de', ' Ch', 'ir', 'ico']
+984 222 Name of father of x -1 Name of father of Giorgio de Chirico Evaristo de Chirico Giorgio de Chirico "['\n' '\n' 'Name' ' of' ' mother' ' of' ' G' 'ior' 'gio' ' de' ' Ch' 'ir'
+ 'ico' '\n' '\n' 'Name' ' of' ' mother' ' of' ' G']" "
+
+ Name of mother of G ior gio de Ch ir ico
+
+ Name of mother of G" False Salvador Dalí and Giorgio de Chirico were used as visual 10 [' Salvador', ' Dal', 'í', ' and', ' G', 'ior', 'gio', ' de', ' Ch', 'ir', 'ico']
+985 222 Name of father of x -1 Name of father of Giorgio de Chirico Evaristo de Chirico Giorgio de Chirico "['\n' '\n' 'Name' ' of' ' mother' ' of' ' G' 'ior' 'gio' ' de' ' Ch' 'ir'
+ 'ico' '\n' '\n' 'Name' ' of' ' mother' ' of' ' G']" "
+
+ Name of mother of G ior gio de Ch ir ico
+
+ Name of mother of G" False Italian visual artist Giorgio de Chirico and with the 9 [' Italian', ' visual', ' artist', ' G', 'ior', 'gio', ' de', ' Ch', 'ir', 'ico']
+986 223 Name of father of x -1 Name of father of Jonathan Swift Jonathan Swift Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of']" ", the author of G ull iver 's Travels , and the author of the
+
+ The author of" False Alexander Pope and Jonathan Swift flourished. Henry 4 [' Alexander', ' Pope', ' and', ' Jonathan', ' Swift']
+987 223 Name of father of x -1 Name of father of Jonathan Swift Jonathan Swift Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of']" ", the author of G ull iver 's Travels , and the author of the
+
+ The author of" False wrote to a friend of Jonathan Swift to persuade 6 [' wrote', ' to', ' a', ' friend', ' of', ' Jonathan', ' Swift']
+988 223 Name of father of x -1 Name of father of Jonathan Swift Jonathan Swift Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of']" ", the author of G ull iver 's Travels , and the author of the
+
+ The author of" False – 1800. In 1803, Jonathan Swift — also an Alexandria 8 [' –', ' 1800', '.', ' In', ' 18', '03', ',', ' Jonathan', ' Swift']
+989 223 Name of father of x -1 Name of father of Jonathan Swift Jonathan Swift Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of']" ", the author of G ull iver 's Travels , and the author of the
+
+ The author of" False " = Sermons of Jonathan Swift =
+" 5 [' =', ' Ser', 'mons', ' of', ' Jonathan', ' Swift']
+990 223 Name of father of x -1 Name of father of Jonathan Swift Jonathan Swift Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of']" ", the author of G ull iver 's Travels , and the author of the
+
+ The author of" False 1800. In 1803, Jonathan Swift — also an Alexandria 7 [' 1800', '.', ' In', ' 18', '03', ',', ' Jonathan', ' Swift']
+991 224 Name of father of x -1 Name of father of Petrarch Petracco Petrarch "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' son' ',' ' the' ' poet' ','
+ ' and' ' of' ' his' ' son' ',' ' the' ' poet' ',']" , the poet , and of his son , the poet , and of his son , the poet , False relies on Petrarch's sonnet structure 3 [' relies', ' on', ' Petr', 'arch']
+992 224 Name of father of x -1 Name of father of Petrarch Petracco Petrarch "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' son' ',' ' the' ' poet' ','
+ ' and' ' of' ' his' ' son' ',' ' the' ' poet' ',']" , the poet , and of his son , the poet , and of his son , the poet , False as developed by Petrarch and Dante, 4 [' as', ' developed', ' by', ' Petr', 'arch']
+993 224 Name of father of x -1 Name of father of Petrarch Petracco Petrarch "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' son' ',' ' the' ' poet' ','
+ ' and' ' of' ' his' ' son' ',' ' the' ' poet' ',']" , the poet , and of his son , the poet , and of his son , the poet , False Dante (d. 1321), Petrarch (d. 1374) and 8 [' Dante', ' (', 'd', '.', ' 13', '21', '),', ' Petr', 'arch']
+994 224 Name of father of x -1 Name of father of Petrarch Petracco Petrarch "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' son' ',' ' the' ' poet' ','
+ ' and' ' of' ' his' ' son' ',' ' the' ' poet' ',']" , the poet , and of his son , the poet , and of his son , the poet , False Godscroft, adapted Petrarch as Triumphs of Love: 7 [' God', 'sc', 'ro', 'ft', ',', ' adapted', ' Petr', 'arch']
+995 224 Name of father of x -1 Name of father of Petrarch Petracco Petrarch "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' son' ',' ' the' ' poet' ','
+ ' and' ' of' ' his' ' son' ',' ' the' ' poet' ',']" , the poet , and of his son , the poet , and of his son , the poet , False particular, Keats relies on Petrarch's sonnet structure 7 [' particular', ',', ' Ke', 'ats', ' relies', ' on', ' Petr', 'arch']
+996 225 Name of father of x -1 Name of father of Virginia Woolf Leslie Stephen Virginia Woolf "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Waves'
+ '_' ',' ' and' ' the' ' mother' ' of' ' Quentin' ' Bell' ',' ' the']" , the author of the famous novel _ The Waves _ , and the mother of Quentin Bell , the False from each other as Virginia Woolf and Emma Goldman embraced 6 [' from', ' each', ' other', ' as', ' Virginia', ' Wool', 'f']
+997 225 Name of father of x -1 Name of father of Virginia Woolf Leslie Stephen Virginia Woolf "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Waves'
+ '_' ',' ' and' ' the' ' mother' ' of' ' Quentin' ' Bell' ',' ' the']" , the author of the famous novel _ The Waves _ , and the mother of Quentin Bell , the False Evelyn Waugh, Virginia Woolf and Noël Coward. 7 [' Eve', 'lyn', ' W', 'augh', ',', ' Virginia', ' Wool', 'f']
+998 225 Name of father of x -1 Name of father of Virginia Woolf Leslie Stephen Virginia Woolf "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Waves'
+ '_' ',' ' and' ' the' ' mother' ' of' ' Quentin' ' Bell' ',' ' the']" , the author of the famous novel _ The Waves _ , and the mother of Quentin Bell , the False between 1851 and 1855. Virginia Woolf regularly travelled 9 [' between', ' 18', '51', ' and', ' 18', '55', '.', ' Virginia', ' Wool', 'f']
+999 225 Name of father of x -1 Name of father of Virginia Woolf Leslie Stephen Virginia Woolf "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Waves'
+ '_' ',' ' and' ' the' ' mother' ' of' ' Quentin' ' Bell' ',' ' the']" , the author of the famous novel _ The Waves _ , and the mother of Quentin Bell , the False 2 ['Virginia', ' Wool', 'f']
+1000 225 Name of father of x -1 Name of father of Virginia Woolf Leslie Stephen Virginia Woolf "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Waves'
+ '_' ',' ' and' ' the' ' mother' ' of' ' Quentin' ' Bell' ',' ' the']" , the author of the famous novel _ The Waves _ , and the mother of Quentin Bell , the False of writers such as Virginia Woolf and William Faulkner. 6 [' of', ' writers', ' such', ' as', ' Virginia', ' Wool', 'f']
+1001 226 Name of father of x -1 Name of father of Anton Chekhov Pavel Chekhov Anton Chekhov "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Che' 'kh' 'ov' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' plays' ' and']" "
+
+ I am a big fan of Che kh ov . I have read all of his plays and" False 4 ['Ant', 'on', ' Che', 'kh', 'ov']
+1002 226 Name of father of x -1 Name of father of Anton Chekhov Pavel Chekhov Anton Chekhov "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Che' 'kh' 'ov' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' plays' ' and']" "
+
+ I am a big fan of Che kh ov . I have read all of his plays and" False 4 ['Ant', 'on', ' Che', 'kh', 'ov']
+1003 226 Name of father of x -1 Name of father of Anton Chekhov Pavel Chekhov Anton Chekhov "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Che' 'kh' 'ov' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' plays' ' and']" "
+
+ I am a big fan of Che kh ov . I have read all of his plays and" False and short stories, Anton Chekhov portrayed a 7 [' and', ' short', ' stories', ',', ' Anton', ' Che', 'kh', 'ov']
+1004 226 Name of father of x -1 Name of father of Anton Chekhov Pavel Chekhov Anton Chekhov "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Che' 'kh' 'ov' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' plays' ' and']" "
+
+ I am a big fan of Che kh ov . I have read all of his plays and" False short stories, Anton Chekhov portrayed a variety 6 [' short', ' stories', ',', ' Anton', ' Che', 'kh', 'ov']
+1005 226 Name of father of x -1 Name of father of Anton Chekhov Pavel Chekhov Anton Chekhov "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Che' 'kh' 'ov' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' plays' ' and']" "
+
+ I am a big fan of Che kh ov . I have read all of his plays and" False Baudelaire, Leo Tolstoy, Anton Chekhov and Sigmund Freud, 13 [' B', 'aud', 'el', 'aire', ',', ' Leo', ' Tol', 'st', 'oy', ',', ' Anton', ' Che', 'kh', 'ov']
+1006 227 Name of father of x -1 Name of father of C. S. Lewis Albert James Lewis C. S. Lewis "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' wife' ',' ' a' ' mother'
+ ',' ' a' ' grandmother' ',' ' a' ' sister' ',' ' a']" "
+
+ I am a Christian , a wife , a mother , a grandmother , a sister , a" False to the works of C. S. Lewis and J. R. R. 8 [' to', ' the', ' works', ' of', ' C', '.', ' S', '.', ' Lewis']
+1007 227 Name of father of x -1 Name of father of C. S. Lewis Albert James Lewis C. S. Lewis "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' wife' ',' ' a' ' mother'
+ ',' ' a' ' grandmother' ',' ' a' ' sister' ',' ' a']" "
+
+ I am a Christian , a wife , a mother , a grandmother , a sister , a" False scepticism or antirealism. C. S. Lewis argued that, if 12 [' scept', 'icism', ' or', ' ant', 'ire', 'al', 'ism', '.', ' C', '.', ' S', '.', ' Lewis']
+1008 227 Name of father of x -1 Name of father of C. S. Lewis Albert James Lewis C. S. Lewis "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' wife' ',' ' a' ' mother'
+ ',' ' a' ' grandmother' ',' ' a' ' sister' ',' ' a']" "
+
+ I am a Christian , a wife , a mother , a grandmother , a sister , a" False medieval scholar C. S. Lewis said the character 6 [' medieval', ' scholar', ' C', '.', ' S', '.', ' Lewis']
+1009 227 Name of father of x -1 Name of father of C. S. Lewis Albert James Lewis C. S. Lewis "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' wife' ',' ' a' ' mother'
+ ',' ' a' ' grandmother' ',' ' a' ' sister' ',' ' a']" "
+
+ I am a Christian , a wife , a mother , a grandmother , a sister , a" False " Christian apologist C. S. Lewis and saying, ""The first" 7 [' Christian', ' ap', 'ologist', ' C', '.', ' S', '.', ' Lewis']
+1010 227 Name of father of x -1 Name of father of C. S. Lewis Albert James Lewis C. S. Lewis "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' wife' ',' ' a' ' mother'
+ ',' ' a' ' grandmother' ',' ' a' ' sister' ',' ' a']" "
+
+ I am a Christian , a wife , a mother , a grandmother , a sister , a" False Christian apologist C. S. Lewis and saying, 7 [' Christian', ' ap', 'ologist', ' C', '.', ' S', '.', ' Lewis']
+1011 229 Name of father of x -1 Name of father of William Hogarth Richard Hogarth William Hogarth "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' '1' '\n']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ 1
+" False humour and satire; William Hogarth connects the castrato 6 [' humour', ' and', ' satire', ';', ' William', ' Hog', 'arth']
+1012 229 Name of father of x -1 Name of father of William Hogarth Richard Hogarth William Hogarth "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' '1' '\n']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ 1
+" False " England"" lauded by William Hogarth in his 1748 painting" 6 "[' England', '""', ' lauded', ' by', ' William', ' Hog', 'arth']"
+1013 229 Name of father of x -1 Name of father of William Hogarth Richard Hogarth William Hogarth "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' '1' '\n']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ 1
+" False and artists William Hogarth and Joshua 4 [' and', ' artists', ' William', ' Hog', 'arth']
+1014 229 Name of father of x -1 Name of father of William Hogarth Richard Hogarth William Hogarth "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' '1' '\n']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ 1
+" False public mockery. William Hogarth published Cunicularii, 5 [' public', ' mockery', '.', ' William', ' Hog', 'arth']
+1015 229 Name of father of x -1 Name of father of William Hogarth Richard Hogarth William Hogarth "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' '1' '\n']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ 1
+" False museum housing a 1756 William Hogarth triptych painted 7 [' museum', ' housing', ' a', ' 17', '56', ' William', ' Hog', 'arth']
+1016 230 Name of father of x -1 Name of father of Alanis Morissette Alan Richard Morissette Alanis Morissette "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False music's biggest — as Alanis Morissette would put it 9 "[' music', ""'s"", ' biggest', ' —', ' as', ' Alan', 'is', ' Mor', 'iss', 'ette']"
+1017 230 Name of father of x -1 Name of father of Alanis Morissette Alan Richard Morissette Alanis Morissette "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False massively successful Alanis Morissette fit into this sub 6 [' massively', ' successful', ' Alan', 'is', ' Mor', 'iss', 'ette']
+1018 230 Name of father of x -1 Name of father of Alanis Morissette Alan Richard Morissette Alanis Morissette "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False " for singer-songwriter Alanis Morissette's ""Thank U""" 9 [' for', ' singer', '-', 'song', 'writer', ' Alan', 'is', ' Mor', 'iss', 'ette']
+1019 230 Name of father of x -1 Name of father of Alanis Morissette Alan Richard Morissette Alanis Morissette "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False " opening credits), Alanis Morissette (""Crazy"", Central" 7 [' opening', ' credits', '),', ' Alan', 'is', ' Mor', 'iss', 'ette']
+1020 230 Name of father of x -1 Name of father of Alanis Morissette Alan Richard Morissette Alanis Morissette "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False That's what Alanis Morissette had you motherfuckers 7 "[' That', ""'s"", ' what', ' Alan', 'is', ' Mor', 'iss', 'ette']"
+1021 231 Name of father of x -1 Name of father of Canaletto Bernardo Canal Canaletto "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Venice' ' in'
+ ' 17' '09' ',' ' and' ' died' ' in' ' 17' '68' '.' '\n']" ", the painter , who was born in Venice in 17 09 , and died in 17 68 .
+" False Bellini, Rubens, Canaletto and Claude – and 8 [' Bell', 'ini', ',', ' Rub', 'ens', ',', ' Canal', 'et', 'to']
+1022 231 Name of father of x -1 Name of father of Canaletto Bernardo Canal Canaletto "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Venice' ' in'
+ ' 17' '09' ',' ' and' ' died' ' in' ' 17' '68' '.' '\n']" ", the painter , who was born in Venice in 17 09 , and died in 17 68 .
+" False 3 ['Can', 'al', 'et', 'to']
+1023 231 Name of father of x -1 Name of father of Canaletto Bernardo Canal Canaletto "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Venice' ' in'
+ ' 17' '09' ',' ' and' ' died' ' in' ' 17' '68' '.' '\n']" ", the painter , who was born in Venice in 17 09 , and died in 17 68 .
+" False Leonardo da Vinci, Canaletto and Bacciarelli. 7 [' Leonardo', ' da', ' Vin', 'ci', ',', ' Canal', 'et', 'to']
+1024 231 Name of father of x -1 Name of father of Canaletto Bernardo Canal Canaletto "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Venice' ' in'
+ ' 17' '09' ',' ' and' ' died' ' in' ' 17' '68' '.' '\n']" ", the painter , who was born in Venice in 17 09 , and died in 17 68 .
+" False painted in 1747 by Canaletto during a period of 7 [' painted', ' in', ' 17', '47', ' by', ' Canal', 'et', 'to']
+1025 231 Name of father of x -1 Name of father of Canaletto Bernardo Canal Canaletto "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' Venice' ' in'
+ ' 17' '09' ',' ' and' ' died' ' in' ' 17' '68' '.' '\n']" ", the painter , who was born in Venice in 17 09 , and died in 17 68 .
+" False in paintings by Canaletto and Francesco Guardi, 5 [' in', ' paintings', ' by', ' Canal', 'et', 'to']
+1026 232 Name of father of x -1 Name of father of Alec Baldwin Alexander Baldwin Alec Baldwin "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack' ' Don'
+ 'agh' 'y' ' on' ' the' ' NBC' ' sitcom' ' 30' ' Rock' '.' '\n']" ", the actor who played the role of Jack Don agh y on the NBC sitcom 30 Rock .
+" False as Cable and Alec Baldwin as Billis. Kendrick 4 [' as', ' Cable', ' and', ' Alec', ' Baldwin']
+1027 232 Name of father of x -1 Name of father of Alec Baldwin Alexander Baldwin Alec Baldwin "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack' ' Don'
+ 'agh' 'y' ' on' ' the' ' NBC' ' sitcom' ' 30' ' Rock' '.' '\n']" ", the actor who played the role of Jack Don agh y on the NBC sitcom 30 Rock .
+" False Saturday Night Live. Alec Baldwin has also hosted 5 [' Saturday', ' Night', ' Live', '.', ' Alec', ' Baldwin']
+1028 232 Name of father of x -1 Name of father of Alec Baldwin Alexander Baldwin Alec Baldwin "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack' ' Don'
+ 'agh' 'y' ' on' ' the' ' NBC' ' sitcom' ' 30' ' Rock' '.' '\n']" ", the actor who played the role of Jack Don agh y on the NBC sitcom 30 Rock .
+" False praised by critics, with Alec Baldwin winning the Primetime 6 [' praised', ' by', ' critics', ',', ' with', ' Alec', ' Baldwin']
+1029 232 Name of father of x -1 Name of father of Alec Baldwin Alexander Baldwin Alec Baldwin "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack' ' Don'
+ 'agh' 'y' ' on' ' the' ' NBC' ' sitcom' ' 30' ' Rock' '.' '\n']" ", the actor who played the role of Jack Don agh y on the NBC sitcom 30 Rock .
+" False Nature, which starred Alec Baldwin in the leading 5 [' Nature', ',', ' which', ' starred', ' Alec', ' Baldwin']
+1030 232 Name of father of x -1 Name of father of Alec Baldwin Alexander Baldwin Alec Baldwin "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack' ' Don'
+ 'agh' 'y' ' on' ' the' ' NBC' ' sitcom' ' 30' ' Rock' '.' '\n']" ", the actor who played the role of Jack Don agh y on the NBC sitcom 30 Rock .
+" False Bennett impressionist Alec Baldwin on Saturday Night 4 [' Bennett', ' impression', 'ist', ' Alec', ' Baldwin']
+1031 233 Name of father of x -1 Name of father of Antonín Dvořák František Dvořák Antonín Dvořák "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Anton' 'ín' ' D' 'vo' '�'
+ '�' 'á' 'k' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Anton ín D vo � � á k
+
+ ! Name of" False composed by Antonín Dvořák between August 4 9 [' composed', ' by', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']
+1032 233 Name of father of x -1 Name of father of Antonín Dvořák František Dvořák Antonín Dvořák "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Anton' 'ín' ' D' 'vo' '�'
+ '�' 'á' 'k' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Anton ín D vo � � á k
+
+ ! Name of" False the operas of Antonín Dvořák and other Czech composers, 11 [' the', ' oper', 'as', ' of', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']
+1033 233 Name of father of x -1 Name of father of Antonín Dvořák František Dvořák Antonín Dvořák "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Anton' 'ín' ' D' 'vo' '�'
+ '�' 'á' 'k' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Anton ín D vo � � á k
+
+ ! Name of" False contemporaries such as Antonín Dvořák. His later, mature 10 [' contemporaries', ' such', ' as', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']
+1034 233 Name of father of x -1 Name of father of Antonín Dvořák František Dvořák Antonín Dvořák "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Anton' 'ín' ' D' 'vo' '�'
+ '�' 'á' 'k' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Anton ín D vo � � á k
+
+ ! Name of" False Czech music, but Antonín Dvořák ... was the one who 11 [' Czech', ' music', ',', ' but', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']
+1035 233 Name of father of x -1 Name of father of Antonín Dvořák František Dvořák Antonín Dvořák "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Anton' 'ín' ' D' 'vo' '�'
+ '�' 'á' 'k' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Anton ín D vo � � á k
+
+ ! Name of" False in the operas of Antonín Dvořák and other 12 [' in', ' the', ' oper', 'as', ' of', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']
+1036 235 Name of father of x -1 Name of father of Felix Mendelssohn Abraham Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' composer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Felix' ' Mend' 'els']" "- Bar th old y , the composer .
+
+ The name of the father of Felix Mend els" False When in 1822 young Felix Mendelssohn composed a Magnificat 9 [' When', ' in', ' 18', '22', ' young', ' Felix', ' Mend', 'els', 'so', 'hn']
+1037 235 Name of father of x -1 Name of father of Felix Mendelssohn Abraham Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' composer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Felix' ' Mend' 'els']" "- Bar th old y , the composer .
+
+ The name of the father of Felix Mend els" False Mendelssohn, Felix Mendelssohn was born into a 9 [' Mend', 'els', 'so', 'hn', ',', ' Felix', ' Mend', 'els', 'so', 'hn']
+1038 235 Name of father of x -1 Name of father of Felix Mendelssohn Abraham Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' composer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Felix' ' Mend' 'els']" "- Bar th old y , the composer .
+
+ The name of the father of Felix Mend els" False negotiations with Felix Mendelssohn for a new opera. 6 [' negotiations', ' with', ' Felix', ' Mend', 'els', 'so', 'hn']
+1039 235 Name of father of x -1 Name of father of Felix Mendelssohn Abraham Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' composer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Felix' ' Mend' 'els']" "- Bar th old y , the composer .
+
+ The name of the father of Felix Mend els" False 5 ['Fel', 'ix', ' Mend', 'els', 'so', 'hn']
+1040 235 Name of father of x -1 Name of father of Felix Mendelssohn Abraham Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' composer' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Felix' ' Mend' 'els']" "- Bar th old y , the composer .
+
+ The name of the father of Felix Mend els" False scores by Felix Mendelssohn at the International 6 [' scores', ' by', ' Felix', ' Mend', 'els', 'so', 'hn']
+1041 236 Name of father of x -1 Name of father of Jacob Grimm Philipp Grimm Jacob Grimm "[',' ' the' ' famous' ' German' ' author' ' of' ' the' ' Brothers'
+ ' Grimm' ' fairy' ' tales' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first' ' of']" ", the famous German author of the Brothers Grimm fairy tales .
+
+ The name of the first of" False 1 ['Jacob', ' Grimm']
+1042 236 Name of father of x -1 Name of father of Jacob Grimm Philipp Grimm Jacob Grimm "[',' ' the' ' famous' ' German' ' author' ' of' ' the' ' Brothers'
+ ' Grimm' ' fairy' ' tales' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first' ' of']" ", the famous German author of the Brothers Grimm fairy tales .
+
+ The name of the first of" False century scholar Jacob Grimm observed a connection 3 [' century', ' scholar', ' Jacob', ' Grimm']
+1043 236 Name of father of x -1 Name of father of Jacob Grimm Philipp Grimm Jacob Grimm "[',' ' the' ' famous' ' German' ' author' ' of' ' the' ' Brothers'
+ ' Grimm' ' fairy' ' tales' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first' ' of']" ", the famous German author of the Brothers Grimm fairy tales .
+
+ The name of the first of" False 19th century, Jacob Grimm proposed a cognate 5 [' 19', 'th', ' century', ',', ' Jacob', ' Grimm']
+1044 236 Name of father of x -1 Name of father of Jacob Grimm Philipp Grimm Jacob Grimm "[',' ' the' ' famous' ' German' ' author' ' of' ' the' ' Brothers'
+ ' Grimm' ' fairy' ' tales' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first' ' of']" ", the famous German author of the Brothers Grimm fairy tales .
+
+ The name of the first of" False and he met Jacob Grimm in Karlsruhe 4 [' and', ' he', ' met', ' Jacob', ' Grimm']
+1045 236 Name of father of x -1 Name of father of Jacob Grimm Philipp Grimm Jacob Grimm "[',' ' the' ' famous' ' German' ' author' ' of' ' the' ' Brothers'
+ ' Grimm' ' fairy' ' tales' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first' ' of']" ", the famous German author of the Brothers Grimm fairy tales .
+
+ The name of the first of" False 19th-century scholar Jacob Grimm proposes a reconstruction 6 [' 19', 'th', '-', 'century', ' scholar', ' Jacob', ' Grimm']
+1046 237 Name of father of x -1 Name of father of Hillary Clinton Hugh E. Rodham Hillary Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' the'
+ ' former' ' first']" , the former president of the United States , and the former secretary of state , and the former first False March 21, 1996, Hillary Clinton submitted a deposition 6 [' March', ' 21', ',', ' 1996', ',', ' Hillary', ' Clinton']
+1047 237 Name of father of x -1 Name of father of Hillary Clinton Hugh E. Rodham Hillary Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' the'
+ ' former' ' first']" , the former president of the United States , and the former secretary of state , and the former first False York's City Hall. When Hillary Clinton visited New 7 "[' York', ""'s"", ' City', ' Hall', '.', ' When', ' Hillary', ' Clinton']"
+1048 237 Name of father of x -1 Name of father of Hillary Clinton Hugh E. Rodham Hillary Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' the'
+ ' former' ' first']" , the former president of the United States , and the former secretary of state , and the former first False Secretary of State Hillary Clinton was one of those 4 [' Secretary', ' of', ' State', ' Hillary', ' Clinton']
+1049 237 Name of father of x -1 Name of father of Hillary Clinton Hugh E. Rodham Hillary Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' the'
+ ' former' ' first']" , the former president of the United States , and the former secretary of state , and the former first False Obama and Senator Hillary Clinton after early contests, 4 [' Obama', ' and', ' Senator', ' Hillary', ' Clinton']
+1050 237 Name of father of x -1 Name of father of Hillary Clinton Hugh E. Rodham Hillary Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' the'
+ ' former' ' first']" , the former president of the United States , and the former secretary of state , and the former first False " stories"" — such as Hillary Clinton becoming the" 6 "[' stories', '""', ' —', ' such', ' as', ' Hillary', ' Clinton']"
+1051 238 Name of father of x -1 Name of father of Rita Hayworth Eduardo Cansino, Sr. Rita Hayworth "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie' ' ""' 'G'
+ 'ilda' '""' ' and' ' I' ' love' ' the' ' song' ' ""']" "
+
+ I am a big fan of the movie "" G ilda "" and I love the song """ False Jean Arthur and Rita Hayworth in Hawks's Only 5 [' Jean', ' Arthur', ' and', ' Rita', ' Hay', 'worth']
+1052 238 Name of father of x -1 Name of father of Rita Hayworth Eduardo Cansino, Sr. Rita Hayworth "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie' ' ""' 'G'
+ 'ilda' '""' ' and' ' I' ' love' ' the' ' song' ' ""']" "
+
+ I am a big fan of the movie "" G ilda "" and I love the song """ False Africa. Film stars Rita Hayworth and Tallulah Bankhead 6 [' Africa', '.', ' Film', ' stars', ' Rita', ' Hay', 'worth']
+1053 238 Name of father of x -1 Name of father of Rita Hayworth Eduardo Cansino, Sr. Rita Hayworth "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie' ' ""' 'G'
+ 'ilda' '""' ' and' ' I' ' love' ' the' ' song' ' ""']" "
+
+ I am a big fan of the movie "" G ilda "" and I love the song """ False American actress Rita Hayworth as the Marvel 4 [' American', ' actress', ' Rita', ' Hay', 'worth']
+1054 238 Name of father of x -1 Name of father of Rita Hayworth Eduardo Cansino, Sr. Rita Hayworth "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie' ' ""' 'G'
+ 'ilda' '""' ' and' ' I' ' love' ' the' ' song' ' ""']" "
+
+ I am a big fan of the movie "" G ilda "" and I love the song """ False compared her to Rita Hayworth in Gilda and Marilyn 5 [' compared', ' her', ' to', ' Rita', ' Hay', 'worth']
+1055 238 Name of father of x -1 Name of father of Rita Hayworth Eduardo Cansino, Sr. Rita Hayworth "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie' ' ""' 'G'
+ 'ilda' '""' ' and' ' I' ' love' ' the' ' song' ' ""']" "
+
+ I am a big fan of the movie "" G ilda "" and I love the song """ False gloves like Rita Hayworth in Gilda and points 4 [' gloves', ' like', ' Rita', ' Hay', 'worth']
+1056 239 Name of father of x -1 Name of father of Dmitri Shostakovich Dmitry Shostakovich Dmitri Shostakovich "['\n' '\n' 'I' 'van' ' Sh' 'ost' 'ak' 'ovich' '\n' '\n' 'I' 'van' ' Sh'
+ 'ost' 'ak' 'ovich' ' (' 'Russian' ':' ' �']" "
+
+ I van Sh ost ak ovich
+
+ I van Sh ost ak ovich ( Russian : �" False Kozintsev, the composer Dmitri Shostakovich and Yevgeny 12 [' Koz', 'int', 'se', 'v', ',', ' the', ' composer', ' Dmit', 'ri', ' Sh', 'ost', 'ak', 'ovich']
+1057 239 Name of father of x -1 Name of father of Dmitri Shostakovich Dmitry Shostakovich Dmitri Shostakovich "['\n' '\n' 'I' 'van' ' Sh' 'ost' 'ak' 'ovich' '\n' '\n' 'I' 'van' ' Sh'
+ 'ost' 'ak' 'ovich' ' (' 'Russian' ':' ' �']" "
+
+ I van Sh ost ak ovich
+
+ I van Sh ost ak ovich ( Russian : �" False Soviet composer Dmitri Shostakovich (1906 – 75) completed 7 [' Soviet', ' composer', ' Dmit', 'ri', ' Sh', 'ost', 'ak', 'ovich']
+1058 239 Name of father of x -1 Name of father of Dmitri Shostakovich Dmitry Shostakovich Dmitri Shostakovich "['\n' '\n' 'I' 'van' ' Sh' 'ost' 'ak' 'ovich' '\n' '\n' 'I' 'van' ' Sh'
+ 'ost' 'ak' 'ovich' ' (' 'Russian' ':' ' �']" "
+
+ I van Sh ost ak ovich
+
+ I van Sh ost ak ovich ( Russian : �" False " Dmitri Shostakovich =
+" 5 [' Dmit', 'ri', ' Sh', 'ost', 'ak', 'ovich']
+1059 239 Name of father of x -1 Name of father of Dmitri Shostakovich Dmitry Shostakovich Dmitri Shostakovich "['\n' '\n' 'I' 'van' ' Sh' 'ost' 'ak' 'ovich' '\n' '\n' 'I' 'van' ' Sh'
+ 'ost' 'ak' 'ovich' ' (' 'Russian' ':' ' �']" "
+
+ I van Sh ost ak ovich
+
+ I van Sh ost ak ovich ( Russian : �" False 6 ['D', 'mit', 'ri', ' Sh', 'ost', 'ak', 'ovich']
+1060 239 Name of father of x -1 Name of father of Dmitri Shostakovich Dmitry Shostakovich Dmitri Shostakovich "['\n' '\n' 'I' 'van' ' Sh' 'ost' 'ak' 'ovich' '\n' '\n' 'I' 'van' ' Sh'
+ 'ost' 'ak' 'ovich' ' (' 'Russian' ':' ' �']" "
+
+ I van Sh ost ak ovich
+
+ I van Sh ost ak ovich ( Russian : �" False " Shostakovich =
+" 11 [' Sh', 'ost', 'ak', 'ovich', ' =', 'D', 'mit', 'ri', ' Sh', 'ost', 'ak', 'ovich']
+1061 240 Name of father of x -1 Name of father of Marie Curie Władysław Skłodowski Marie Curie "[',' ' the' ' famous' ' French' ' chemist' ' and' ' physicist' ',' ' who'
+ ' discovered' ' rad' 'ium' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first']" ", the famous French chemist and physicist , who discovered rad ium .
+
+ The name of the first" False " Marie Curie =
+" 2 [' Marie', ' Cur', 'ie']
+1062 240 Name of father of x -1 Name of father of Marie Curie Władysław Skłodowski Marie Curie "[',' ' the' ' famous' ' French' ' chemist' ' and' ' physicist' ',' ' who'
+ ' discovered' ' rad' 'ium' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first']" ", the famous French chemist and physicist , who discovered rad ium .
+
+ The name of the first" False a statue of Marie Curie before Warsaw's 5 [' a', ' statue', ' of', ' Marie', ' Cur', 'ie']
+1063 240 Name of father of x -1 Name of father of Marie Curie Władysław Skłodowski Marie Curie "[',' ' the' ' famous' ' French' ' chemist' ' and' ' physicist' ',' ' who'
+ ' discovered' ' rad' 'ium' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first']" ", the famous French chemist and physicist , who discovered rad ium .
+
+ The name of the first" False of Pierre and Marie Curie University 5 [' of', ' Pierre', ' and', ' Marie', ' Cur', 'ie']
+1064 240 Name of father of x -1 Name of father of Marie Curie Władysław Skłodowski Marie Curie "[',' ' the' ' famous' ' French' ' chemist' ' and' ' physicist' ',' ' who'
+ ' discovered' ' rad' 'ium' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first']" ", the famous French chemist and physicist , who discovered rad ium .
+
+ The name of the first" False was named after Marie Curie and her husband 5 [' was', ' named', ' after', ' Marie', ' Cur', 'ie']
+1065 240 Name of father of x -1 Name of father of Marie Curie Władysław Skłodowski Marie Curie "[',' ' the' ' famous' ' French' ' chemist' ' and' ' physicist' ',' ' who'
+ ' discovered' ' rad' 'ium' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' first']" ", the famous French chemist and physicist , who discovered rad ium .
+
+ The name of the first" False 1944; and of Pierre and Marie Curie University (Paris 8 [' 1944', ';', ' and', ' of', ' Pierre', ' and', ' Marie', ' Cur', 'ie']
+1066 241 Name of father of x -1 Name of father of Molière Jean Poquelin Molière "[',' ' the' ' great' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great French dram at ist , was born at
+ " False and Wilde, Molière and Shaw, Congreve 5 [' and', ' Wilde', ',', ' M', 'oli', 'ère']
+1067 241 Name of father of x -1 Name of father of Molière Jean Poquelin Molière "[',' ' the' ' great' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great French dram at ist , was born at
+ " False 2 ['M', 'oli', 'ère']
+1068 241 Name of father of x -1 Name of father of Molière Jean Poquelin Molière "[',' ' the' ' great' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great French dram at ist , was born at
+ " False " of ""crucified Molière [and] hapless" 7 "[' of', ' ""', 'cru', 'c', 'ified', ' M', 'oli', 'ère']"
+1069 241 Name of father of x -1 Name of father of Molière Jean Poquelin Molière "[',' ' the' ' great' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great French dram at ist , was born at
+ " False krypton is a smaller Molière radius of 4.7 8 [' k', 'rypt', 'on', ' is', ' a', ' smaller', ' M', 'oli', 'ère']
+1070 241 Name of father of x -1 Name of father of Molière Jean Poquelin Molière "[',' ' the' ' great' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great French dram at ist , was born at
+ " False " theater of Racine, Molière and Shakespeare.
+" 7 [' theater', ' of', ' Rac', 'ine', ',', ' M', 'oli', 'ère']
+1071 242 Name of father of x -1 Name of father of Charlotte Brontë Patrick Brontë Charlotte Brontë "[',' ' the' ' author' ' of' ' Jane' ' Ey' 're' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' works']" ", the author of Jane Ey re .
+
+ The following is a list of the most important works" False " Brontë ====
+" 9 [' Br', 'ont', 'ë', ' =', '===', 'Charl', 'otte', ' Br', 'ont', 'ë']
+1072 242 Name of father of x -1 Name of father of Charlotte Brontë Patrick Brontë Charlotte Brontë "[',' ' the' ' author' ' of' ' Jane' ' Ey' 're' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' works']" ", the author of Jane Ey re .
+
+ The following is a list of the most important works" False 4 ['Charl', 'otte', ' Br', 'ont', 'ë']
+1073 242 Name of father of x -1 Name of father of Charlotte Brontë Patrick Brontë Charlotte Brontë "[',' ' the' ' author' ' of' ' Jane' ' Ey' 're' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' works']" ", the author of Jane Ey re .
+
+ The following is a list of the most important works" False 4 ['Charl', 'otte', ' Br', 'ont', 'ë']
+1074 242 Name of father of x -1 Name of father of Charlotte Brontë Patrick Brontë Charlotte Brontë "[',' ' the' ' author' ' of' ' Jane' ' Ey' 're' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' works']" ", the author of Jane Ey re .
+
+ The following is a list of the most important works" False 4 ['Charl', 'otte', ' Br', 'ont', 'ë']
+1075 242 Name of father of x -1 Name of father of Charlotte Brontë Patrick Brontë Charlotte Brontë "[',' ' the' ' author' ' of' ' Jane' ' Ey' 're' '.' '\n' '\n' 'The'
+ ' following' ' is' ' a' ' list' ' of' ' the' ' most' ' important'
+ ' works']" ", the author of Jane Ey re .
+
+ The following is a list of the most important works" False 4 ['Charl', 'otte', ' Br', 'ont', 'ë']
+1076 243 Name of father of x -1 Name of father of Catherine Deneuve Maurice Dorléac Catherine Deneuve "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Catherine' ' D' 'ene' 'uve' '.' '\n' '\n' 'The' ' name']" ", the actress , and the mother of the actress , Catherine D ene uve .
+
+ The name" False (1983), with Catherine Deneuve and Susan Sarandon. 7 [' (', '1983', '),', ' with', ' Catherine', ' D', 'ene', 'uve']
+1077 243 Name of father of x -1 Name of father of Catherine Deneuve Maurice Dorléac Catherine Deneuve "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Catherine' ' D' 'ene' 'uve' '.' '\n' '\n' 'The' ' name']" ", the actress , and the mother of the actress , Catherine D ene uve .
+
+ The name" False claims that Catherine Deneuve wanted to play 5 [' claims', ' that', ' Catherine', ' D', 'ene', 'uve']
+1078 244 Name of father of x -1 Name of father of Piet Mondrian Pieter Cornelis Sr. Piet Mondrian "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ',' ' a' ' husband' ' of'
+ ' one' ',' ' and' ' a' ' grandfather' ' of' ' four' '.']" "
+
+ I am a father of two , a husband of one , and a grandfather of four ." False painters such as Piet Mondrian (Stockhausen 1996a, 6 [' pain', 'ters', ' such', ' as', ' Piet', ' Mond', 'rian']
+1079 244 Name of father of x -1 Name of father of Piet Mondrian Pieter Cornelis Sr. Piet Mondrian "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ',' ' a' ' husband' ' of'
+ ' one' ',' ' and' ' a' ' grandfather' ' of' ' four' '.']" "
+
+ I am a father of two , a husband of one , and a grandfather of four ." False of painters Piet Mondrian and Marc Chagall, 5 [' of', ' pain', 'ters', ' Piet', ' Mond', 'rian']
+1080 244 Name of father of x -1 Name of father of Piet Mondrian Pieter Cornelis Sr. Piet Mondrian "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ',' ' a' ' husband' ' of'
+ ' one' ',' ' and' ' a' ' grandfather' ' of' ' four' '.']" "
+
+ I am a father of two , a husband of one , and a grandfather of four ." False black-and-white Piet Mondrian works. It is 7 [' black', '-', 'and', '-', 'white', ' Piet', ' Mond', 'rian']
+1081 244 Name of father of x -1 Name of father of Piet Mondrian Pieter Cornelis Sr. Piet Mondrian "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ',' ' a' ' husband' ' of'
+ ' one' ',' ' and' ' a' ' grandfather' ' of' ' four' '.']" "
+
+ I am a father of two , a husband of one , and a grandfather of four ." False Doesberg and Piet Mondrian explicitly embraced 5 [' Does', 'berg', ' and', ' Piet', ' Mond', 'rian']
+1082 244 Name of father of x -1 Name of father of Piet Mondrian Pieter Cornelis Sr. Piet Mondrian "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ',' ' a' ' husband' ' of'
+ ' one' ',' ' and' ' a' ' grandfather' ' of' ' four' '.']" "
+
+ I am a father of two , a husband of one , and a grandfather of four ." False portraits of painters Piet Mondrian and Marc Chagall, 6 [' portraits', ' of', ' pain', 'ters', ' Piet', ' Mond', 'rian']
+1083 245 Name of father of x -1 Name of father of Dwight D. Eisenhower David Jacob Eisenhower Dwight D. Eisenhower "['\n' '\n' 'I' 'ke' ' was' ' a' ' great' ' president' '.' ' He' ' was'
+ ' a' ' great' ' president' '.' ' He' ' was' ' a' ' great' ' president']" "
+
+ I ke was a great president . He was a great president . He was a great president" False to President Dwight D. Eisenhower and Prime Minister 5 [' to', ' President', ' Dwight', ' D', '.', ' Eisenhower']
+1084 245 Name of father of x -1 Name of father of Dwight D. Eisenhower David Jacob Eisenhower Dwight D. Eisenhower "['\n' '\n' 'I' 'ke' ' was' ' a' ' great' ' president' '.' ' He' ' was'
+ ' a' ' great' ' president' '.' ' He' ' was' ' a' ' great' ' president']" "
+
+ I ke was a great president . He was a great president . He was a great president" False found that President Dwight D. Eisenhower would use the satellite 6 [' found', ' that', ' President', ' Dwight', ' D', '.', ' Eisenhower']
+1085 245 Name of father of x -1 Name of father of Dwight D. Eisenhower David Jacob Eisenhower Dwight D. Eisenhower "['\n' '\n' 'I' 'ke' ' was' ' a' ' great' ' president' '.' ' He' ' was'
+ ' a' ' great' ' president' '.' ' He' ' was' ' a' ' great' ' president']" "
+
+ I ke was a great president . He was a great president . He was a great president" False 1959, President Dwight D. Eisenhower signed into law the 6 [' 1959', ',', ' President', ' Dwight', ' D', '.', ' Eisenhower']
+1086 245 Name of father of x -1 Name of father of Dwight D. Eisenhower David Jacob Eisenhower Dwight D. Eisenhower "['\n' '\n' 'I' 'ke' ' was' ' a' ' great' ' president' '.' ' He' ' was'
+ ' a' ' great' ' president' '.' ' He' ' was' ' a' ' great' ' president']" "
+
+ I ke was a great president . He was a great president . He was a great president" False 1958, President Dwight D. Eisenhower approved the construction 6 [' 1958', ',', ' President', ' Dwight', ' D', '.', ' Eisenhower']
+1087 245 Name of father of x -1 Name of father of Dwight D. Eisenhower David Jacob Eisenhower Dwight D. Eisenhower "['\n' '\n' 'I' 'ke' ' was' ' a' ' great' ' president' '.' ' He' ' was'
+ ' a' ' great' ' president' '.' ' He' ' was' ' a' ' great' ' president']" "
+
+ I ke was a great president . He was a great president . He was a great president" False to former President Dwight D. Eisenhower — reportedly 6 [' to', ' former', ' President', ' Dwight', ' D', '.', ' Eisenhower']
+1088 246 Name of father of x -1 Name of father of Igor Stravinsky Fyodor Stravinsky Igor Stravinsky "['\n' '\n' 'I' 'gor' ' St' 'rav' 'insky' ' (' 'Russian' ':' ' �' '�' '�'
+ '�' 'о' 'р' 'ь' ' �' '�' 'т']" "
+
+ I gor St rav insky ( Russian : � � � � о р ь � � т" False Hayasaka, was heard by Igor Stravinsky in 1958 during 10 [' Hay', 'as', 'aka', ',', ' was', ' heard', ' by', ' Igor', ' St', 'rav', 'insky']
+1089 246 Name of father of x -1 Name of father of Igor Stravinsky Fyodor Stravinsky Igor Stravinsky "['\n' '\n' 'I' 'gor' ' St' 'rav' 'insky' ' (' 'Russian' ':' ' �' '�' '�'
+ '�' 'о' 'р' 'ь' ' �' '�' 'т']" "
+
+ I gor St rav insky ( Russian : � � � � о р ь � � т" False " Darius Milhaud, Igor Stravinsky and Heitor Villa-Lobos.
+" 8 [' Darius', ' Mil', 'ha', 'ud', ',', ' Igor', ' St', 'rav', 'insky']
+1090 246 Name of father of x -1 Name of father of Igor Stravinsky Fyodor Stravinsky Igor Stravinsky "['\n' '\n' 'I' 'gor' ' St' 'rav' 'insky' ' (' 'Russian' ':' ' �' '�' '�'
+ '�' 'о' 'р' 'ь' ' �' '�' 'т']" "
+
+ I gor St rav insky ( Russian : � � � � о р ь � � т" False " as Aaron Copland, Igor Stravinsky and Vaughan Williams.
+" 8 [' as', ' Aaron', ' Cop', 'land', ',', ' Igor', ' St', 'rav', 'insky']
+1091 246 Name of father of x -1 Name of father of Igor Stravinsky Fyodor Stravinsky Igor Stravinsky "['\n' '\n' 'I' 'gor' ' St' 'rav' 'insky' ' (' 'Russian' ':' ' �' '�' '�'
+ '�' 'о' 'р' 'ь' ' �' '�' 'т']" "
+
+ I gor St rav insky ( Russian : � � � � о р ь � � т" False Russian composer Igor Stravinsky and rock musician 5 [' Russian', ' composer', ' Igor', ' St', 'rav', 'insky']
+1092 246 Name of father of x -1 Name of father of Igor Stravinsky Fyodor Stravinsky Igor Stravinsky "['\n' '\n' 'I' 'gor' ' St' 'rav' 'insky' ' (' 'Russian' ':' ' �' '�' '�'
+ '�' 'о' 'р' 'ь' ' �' '�' 'т']" "
+
+ I gor St rav insky ( Russian : � � � � о р ь � � т" False currently enjoyed by Igor Stravinsky and Arnold Schoenberg, 6 [' currently', ' enjoyed', ' by', ' Igor', ' St', 'rav', 'insky']
+1093 247 Name of father of x -1 Name of father of Søren Kierkegaard Michael Pedersen Kierkegaard Søren Kierkegaard "[',' ' Danish' ' philosopher' ',' ' theolog' 'ian' ',' ' and' ' author'
+ ' of' ' the' ' _' 'Con' 'cluding' ' Un' 'scientific' ' Post' 'script' '_'
+ ' (']" , Danish philosopher , theolog ian , and author of the _ Con cluding Un scientific Post script _ ( False biography about Søren Kierkegaard in 1954 and in 7 [' biography', ' about', ' S', 'ø', 'ren', ' Kier', 'ke', 'gaard']
+1094 247 Name of father of x -1 Name of father of Søren Kierkegaard Michael Pedersen Kierkegaard Søren Kierkegaard "[',' ' Danish' ' philosopher' ',' ' theolog' 'ian' ',' ' and' ' author'
+ ' of' ' the' ' _' 'Con' 'cluding' ' Un' 'scientific' ' Post' 'script' '_'
+ ' (']" , Danish philosopher , theolog ian , and author of the _ Con cluding Un scientific Post script _ ( False partly by reading Søren Kierkegaard and Reinhold Niebuhr; 8 [' partly', ' by', ' reading', ' S', 'ø', 'ren', ' Kier', 'ke', 'gaard']
+1095 247 Name of father of x -1 Name of father of Søren Kierkegaard Michael Pedersen Kierkegaard Søren Kierkegaard "[',' ' Danish' ' philosopher' ',' ' theolog' 'ian' ',' ' and' ' author'
+ ' of' ' the' ' _' 'Con' 'cluding' ' Un' 'scientific' ' Post' 'script' '_'
+ ' (']" , Danish philosopher , theolog ian , and author of the _ Con cluding Un scientific Post script _ ( False which describes Søren Kierkegaard as a Franciscan 7 [' which', ' describes', ' S', 'ø', 'ren', ' Kier', 'ke', 'gaard']
+1096 247 Name of father of x -1 Name of father of Søren Kierkegaard Michael Pedersen Kierkegaard Søren Kierkegaard "[',' ' Danish' ' philosopher' ',' ' theolog' 'ian' ',' ' and' ' author'
+ ' of' ' the' ' _' 'Con' 'cluding' ' Un' 'scientific' ' Post' 'script' '_'
+ ' (']" , Danish philosopher , theolog ian , and author of the _ Con cluding Un scientific Post script _ ( False philosophy of Søren Kierkegaard convinced many 7 [' philosophy', ' of', ' S', 'ø', 'ren', ' Kier', 'ke', 'gaard']
+1097 247 Name of father of x -1 Name of father of Søren Kierkegaard Michael Pedersen Kierkegaard Søren Kierkegaard "[',' ' Danish' ' philosopher' ',' ' theolog' 'ian' ',' ' and' ' author'
+ ' of' ' the' ' _' 'Con' 'cluding' ' Un' 'scientific' ' Post' 'script' '_'
+ ' (']" , Danish philosopher , theolog ian , and author of the _ Con cluding Un scientific Post script _ ( False partly by reading Søren Kierkegaard and Reinhold Niebuhr; 8 [' partly', ' by', ' reading', ' S', 'ø', 'ren', ' Kier', 'ke', 'gaard']
+1098 248 Name of father of x -1 Name of father of Immanuel Kant Johann Georg Kant Immanuel Kant "[',' ' the' ' philosopher' ',' ' was' ' born' ' in' ' 17' '24' ',' ' and'
+ ' died' ' in' ' 17' '81' '.' '\n' '\n' 'The' ' following']" ", the philosopher , was born in 17 24 , and died in 17 81 .
+
+ The following" False by Thomas Wright, Immanuel Kant and others of 6 [' by', ' Thomas', ' Wright', ',', ' Im', 'manuel', ' Kant']
+1099 248 Name of father of x -1 Name of father of Immanuel Kant Johann Georg Kant Immanuel Kant "[',' ' the' ' philosopher' ',' ' was' ' born' ' in' ' 17' '24' ',' ' and'
+ ' died' ' in' ' 17' '81' '.' '\n' '\n' 'The' ' following']" ", the philosopher , was born in 17 24 , and died in 17 81 .
+
+ The following" False Ways. In 1755, Immanuel Kant used the term 8 [' Ways', '.', ' In', ' 17', '55', ',', ' Im', 'manuel', ' Kant']
+1100 248 Name of father of x -1 Name of father of Immanuel Kant Johann Georg Kant Immanuel Kant "[',' ' the' ' philosopher' ',' ' was' ' born' ' in' ' 17' '24' ',' ' and'
+ ' died' ' in' ' 17' '81' '.' '\n' '\n' 'The' ' following']" ", the philosopher , was born in 17 24 , and died in 17 81 .
+
+ The following" False 2 ['Im', 'manuel', ' Kant']
+1101 248 Name of father of x -1 Name of father of Immanuel Kant Johann Georg Kant Immanuel Kant "[',' ' the' ' philosopher' ',' ' was' ' born' ' in' ' 17' '24' ',' ' and'
+ ' died' ' in' ' 17' '81' '.' '\n' '\n' 'The' ' following']" ", the philosopher , was born in 17 24 , and died in 17 81 .
+
+ The following" False " Kant ===
+" 4 [' Kant', ' ===', 'Im', 'manuel', ' Kant']
+1102 248 Name of father of x -1 Name of father of Immanuel Kant Johann Georg Kant Immanuel Kant "[',' ' the' ' philosopher' ',' ' was' ' born' ' in' ' 17' '24' ',' ' and'
+ ' died' ' in' ' 17' '81' '.' '\n' '\n' 'The' ' following']" ", the philosopher , was born in 17 24 , and died in 17 81 .
+
+ The following" False treatise in 1755, Immanuel Kant elaborated on 8 [' treat', 'ise', ' in', ' 17', '55', ',', ' Im', 'manuel', ' Kant']
+1103 249 Name of father of x -1 Name of father of Arnold Schoenberg Samuel Schönberg Arnold Schoenberg "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Arnold' ' Sch' 'oen' 'berg' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Arnold Sch oen berg , the composer ." False Concentus Musicus Wien, Arnold Schoenberg Chor (2 CDs) — 1995 10 [' Concent', 'us', ' Music', 'us', ' W', 'ien', ',', ' Arnold', ' Sch', 'oen', 'berg']
+1104 249 Name of father of x -1 Name of father of Arnold Schoenberg Samuel Schönberg Arnold Schoenberg "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Arnold' ' Sch' 'oen' 'berg' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Arnold Sch oen berg , the composer ." False earliest contacts with Arnold Schoenberg in 1909, he made 6 [' earliest', ' contacts', ' with', ' Arnold', ' Sch', 'oen', 'berg']
+1105 249 Name of father of x -1 Name of father of Arnold Schoenberg Samuel Schönberg Arnold Schoenberg "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Arnold' ' Sch' 'oen' 'berg' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Arnold Sch oen berg , the composer ." False 1927, along with Arnold Schoenberg and Paul Hindemith. 7 [' 1927', ',', ' along', ' with', ' Arnold', ' Sch', 'oen', 'berg']
+1106 249 Name of father of x -1 Name of father of Arnold Schoenberg Samuel Schönberg Arnold Schoenberg "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Arnold' ' Sch' 'oen' 'berg' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Arnold Sch oen berg , the composer ." False contacts with Arnold Schoenberg in 1909, he made 5 [' contacts', ' with', ' Arnold', ' Sch', 'oen', 'berg']
+1107 249 Name of father of x -1 Name of father of Arnold Schoenberg Samuel Schönberg Arnold Schoenberg "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Arnold' ' Sch' 'oen' 'berg' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Arnold Sch oen berg , the composer ." False in 1927, along with Arnold Schoenberg and Paul Hindemith. 8 [' in', ' 1927', ',', ' along', ' with', ' Arnold', ' Sch', 'oen', 'berg']
+1108 250 Name of father of x -1 Name of father of Mario Vargas Llosa Ernesto Vargas Maldonado Mario Vargas Llosa "[',' ' the' ' Per' 'uvian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Lima' ',' ' Peru' ',' ' in' ' 1936' '.' '\n' '\n' 'The']" ", the Per uvian writer , who was born in Lima , Peru , in 1936 .
+
+ The" False in the face by Mario Vargas Llosa in what became 8 [' in', ' the', ' face', ' by', ' Mario', ' Varg', 'as', ' Ll', 'osa']
+1109 250 Name of father of x -1 Name of father of Mario Vargas Llosa Ernesto Vargas Maldonado Mario Vargas Llosa "[',' ' the' ' Per' 'uvian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Lima' ',' ' Peru' ',' ' in' ' 1936' '.' '\n' '\n' 'The']" ", the Per uvian writer , who was born in Lima , Peru , in 1936 .
+
+ The" False the face by Mario Vargas Llosa in what became 7 [' the', ' face', ' by', ' Mario', ' Varg', 'as', ' Ll', 'osa']
+1110 250 Name of father of x -1 Name of father of Mario Vargas Llosa Ernesto Vargas Maldonado Mario Vargas Llosa "[',' ' the' ' Per' 'uvian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Lima' ',' ' Peru' ',' ' in' ' 1936' '.' '\n' '\n' 'The']" ", the Per uvian writer , who was born in Lima , Peru , in 1936 .
+
+ The" False Carlos Fuentes and Mario Vargas Llosa are widely 9 [' Carlos', ' Fu', 'ent', 'es', ' and', ' Mario', ' Varg', 'as', ' Ll', 'osa']
+1111 250 Name of father of x -1 Name of father of Mario Vargas Llosa Ernesto Vargas Maldonado Mario Vargas Llosa "[',' ' the' ' Per' 'uvian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Lima' ',' ' Peru' ',' ' in' ' 1936' '.' '\n' '\n' 'The']" ", the Per uvian writer , who was born in Lima , Peru , in 1936 .
+
+ The" False structure of DNA), Mario Vargas Llosa (writer), 8 [' structure', ' of', ' DNA', '),', ' Mario', ' Varg', 'as', ' Ll', 'osa']
+1112 250 Name of father of x -1 Name of father of Mario Vargas Llosa Ernesto Vargas Maldonado Mario Vargas Llosa "[',' ' the' ' Per' 'uvian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Lima' ',' ' Peru' ',' ' in' ' 1936' '.' '\n' '\n' 'The']" ", the Per uvian writer , who was born in Lima , Peru , in 1936 .
+
+ The" False commonly known as Mario Vargas Llosa (Spanish: [ˈmaɾjo 7 [' commonly', ' known', ' as', ' Mario', ' Varg', 'as', ' Ll', 'osa']
+1113 251 Name of father of x -1 Name of father of Helen Mirren Basil Mirren Helen Mirren "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' Queen'
+ ' Elizabeth' ' II' ' in' ' the' ' BBC' ""'s"" ' The' ' Queen' '.' '\n' '\n']" ", the actress who played the role of Queen Elizabeth II in the BBC 's The Queen .
+
+" False Oscar-winning actress Helen Mirren guest-stars in 6 [' Oscar', '-', 'winning', ' actress', ' Helen', ' Mir', 'ren']
+1114 251 Name of father of x -1 Name of father of Helen Mirren Basil Mirren Helen Mirren "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' Queen'
+ ' Elizabeth' ' II' ' in' ' the' ' BBC' ""'s"" ' The' ' Queen' '.' '\n' '\n']" ", the actress who played the role of Queen Elizabeth II in the BBC 's The Queen .
+
+" False Fu Manchu, alongside Helen Mirren and David Tomlinson. 7 [' Fu', ' Man', 'chu', ',', ' alongside', ' Helen', ' Mir', 'ren']
+1115 251 Name of father of x -1 Name of father of Helen Mirren Basil Mirren Helen Mirren "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' Queen'
+ ' Elizabeth' ' II' ' in' ' the' ' BBC' ""'s"" ' The' ' Queen' '.' '\n' '\n']" ", the actress who played the role of Queen Elizabeth II in the BBC 's The Queen .
+
+" False after seeing Helen Mirren onstage she was 4 [' after', ' seeing', ' Helen', ' Mir', 'ren']
+1116 251 Name of father of x -1 Name of father of Helen Mirren Basil Mirren Helen Mirren "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' Queen'
+ ' Elizabeth' ' II' ' in' ' the' ' BBC' ""'s"" ' The' ' Queen' '.' '\n' '\n']" ", the actress who played the role of Queen Elizabeth II in the BBC 's The Queen .
+
+" False The use of Helen Mirren to voice Becky's 5 [' The', ' use', ' of', ' Helen', ' Mir', 'ren']
+1117 251 Name of father of x -1 Name of father of Helen Mirren Basil Mirren Helen Mirren "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' Queen'
+ ' Elizabeth' ' II' ' in' ' the' ' BBC' ""'s"" ' The' ' Queen' '.' '\n' '\n']" ", the actress who played the role of Queen Elizabeth II in the BBC 's The Queen .
+
+" False 1995, when star Helen Mirren quit. Hooper initially 6 [' 1995', ',', ' when', ' star', ' Helen', ' Mir', 'ren']
+1118 252 Name of father of x -1 Name of father of Alfred Hitchcock William Hitchcock Alfred Hitchcock "[""'s"" ' daughter' ',' ' and' ' the' ' father' ' of' ' the' ' man' ' who'
+ ' was' ' to' ' become' ' the' ' father' ' of' ' the' ' man' ' who' ' was']" 's daughter , and the father of the man who was to become the father of the man who was False " created ""a dumpy, Alfred Hitchcock version of Edna,""" 7 "[' created', ' ""', 'a', ' dump', 'y', ',', ' Alfred', ' Hitchcock']"
+1119 252 Name of father of x -1 Name of father of Alfred Hitchcock William Hitchcock Alfred Hitchcock "[""'s"" ' daughter' ',' ' and' ' the' ' father' ' of' ' the' ' man' ' who'
+ ' was' ' to' ' become' ' the' ' father' ' of' ' the' ' man' ' who' ' was']" 's daughter , and the father of the man who was to become the father of the man who was False Hedren and Toby Jones as Alfred Hitchcock. It is based on 7 [' Hed', 'ren', ' and', ' Toby', ' Jones', ' as', ' Alfred', ' Hitchcock']
+1120 252 Name of father of x -1 Name of father of Alfred Hitchcock William Hitchcock Alfred Hitchcock "[""'s"" ' daughter' ',' ' and' ' the' ' father' ' of' ' the' ' man' ' who'
+ ' was' ' to' ' become' ' the' ' father' ' of' ' the' ' man' ' who' ' was']" 's daughter , and the father of the man who was to become the father of the man who was False the soundtrack of an Alfred Hitchcock film, with 5 [' the', ' soundtrack', ' of', ' an', ' Alfred', ' Hitchcock']
+1121 252 Name of father of x -1 Name of father of Alfred Hitchcock William Hitchcock Alfred Hitchcock "[""'s"" ' daughter' ',' ' and' ' the' ' father' ' of' ' the' ' man' ' who'
+ ' was' ' to' ' become' ' the' ' father' ' of' ' the' ' man' ' who' ' was']" 's daughter , and the father of the man who was to become the father of the man who was False reference to the 1954 Alfred Hitchcock thriller Dial 5 [' reference', ' to', ' the', ' 1954', ' Alfred', ' Hitchcock']
+1122 252 Name of father of x -1 Name of father of Alfred Hitchcock William Hitchcock Alfred Hitchcock "[""'s"" ' daughter' ',' ' and' ' the' ' father' ' of' ' the' ' man' ' who'
+ ' was' ' to' ' become' ' the' ' father' ' of' ' the' ' man' ' who' ' was']" 's daughter , and the father of the man who was to become the father of the man who was False Twilight Zone and Alfred Hitchcock Presents; while these 4 [' Twilight', ' Zone', ' and', ' Alfred', ' Hitchcock']
+1123 253 Name of father of x -1 Name of father of Charlotte Gainsbourg Serge Gainsbourg Charlotte Gainsbourg "[',' ' the' ' French' ' actress' ' and' ' singer' '.' '\n' '\n' 'The'
+ ' name' ' Charlotte' ' is' ' a' ' variant' ' of' ' the' ' name'
+ ' Charlotte' ',']" ", the French actress and singer .
+
+ The name Charlotte is a variant of the name Charlotte ," False collaborated with Charlotte Gainsbourg on her album 5 [' collaborated', ' with', ' Charlotte', ' G', 'ains', 'bourg']
+1124 253 Name of father of x -1 Name of father of Charlotte Gainsbourg Serge Gainsbourg Charlotte Gainsbourg "[',' ' the' ' French' ' actress' ' and' ' singer' '.' '\n' '\n' 'The'
+ ' name' ' Charlotte' ' is' ' a' ' variant' ' of' ' the' ' name'
+ ' Charlotte' ',']" ", the French actress and singer .
+
+ The name Charlotte is a variant of the name Charlotte ," False Anglo-French actress Charlotte Gainsbourg was subsequently 7 [' Anglo', '-', 'French', ' actress', ' Charlotte', ' G', 'ains', 'bourg']
+1125 253 Name of father of x -1 Name of father of Charlotte Gainsbourg Serge Gainsbourg Charlotte Gainsbourg "[',' ' the' ' French' ' actress' ' and' ' singer' '.' '\n' '\n' 'The'
+ ' name' ' Charlotte' ' is' ' a' ' variant' ' of' ' the' ' name'
+ ' Charlotte' ',']" ", the French actress and singer .
+
+ The name Charlotte is a variant of the name Charlotte ," False Beck collaborated with Charlotte Gainsbourg on her album 6 [' Beck', ' collaborated', ' with', ' Charlotte', ' G', 'ains', 'bourg']
+1126 253 Name of father of x -1 Name of father of Charlotte Gainsbourg Serge Gainsbourg Charlotte Gainsbourg "[',' ' the' ' French' ' actress' ' and' ' singer' '.' '\n' '\n' 'The'
+ ' name' ' Charlotte' ' is' ' a' ' variant' ' of' ' the' ' name'
+ ' Charlotte' ',']" ", the French actress and singer .
+
+ The name Charlotte is a variant of the name Charlotte ," False Connor, John's wife. Charlotte Gainsbourg was originally set 9 "[' Connor', ',', ' John', ""'s"", ' wife', '.', ' Charlotte', ' G', 'ains', 'bourg']"
+1127 253 Name of father of x -1 Name of father of Charlotte Gainsbourg Serge Gainsbourg Charlotte Gainsbourg "[',' ' the' ' French' ' actress' ' and' ' singer' '.' '\n' '\n' 'The'
+ ' name' ' Charlotte' ' is' ' a' ' variant' ' of' ' the' ' name'
+ ' Charlotte' ',']" ", the French actress and singer .
+
+ The name Charlotte is a variant of the name Charlotte ," False collaborated with Charlotte Gainsbourg on her album IRM, 5 [' collaborated', ' with', ' Charlotte', ' G', 'ains', 'bourg']
+1128 254 Name of father of x -1 Name of father of Kobe Bryant Joe Bryant Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' and'
+ ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ' in']" , the NBA 's all - time leading scorer , and the NBA 's all - time leading scorer in False only 35 games, Kobe Bryant tore a rotator cuff 5 [' only', ' 35', ' games', ',', ' Kobe', ' Bryant']
+1129 254 Name of father of x -1 Name of father of Kobe Bryant Joe Bryant Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' and'
+ ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ' in']" , the NBA 's all - time leading scorer , and the NBA 's all - time leading scorer in False vote-getter was Kobe Bryant with 2,805,397 votes. 6 [' vote', '-', 'get', 'ter', ' was', ' Kobe', ' Bryant']
+1130 254 Name of father of x -1 Name of father of Kobe Bryant Joe Bryant Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' and'
+ ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ' in']" , the NBA 's all - time leading scorer , and the NBA 's all - time leading scorer in False Association player Kobe Bryant (who was accused 3 [' Association', ' player', ' Kobe', ' Bryant']
+1131 254 Name of father of x -1 Name of father of Kobe Bryant Joe Bryant Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' and'
+ ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ' in']" , the NBA 's all - time leading scorer , and the NBA 's all - time leading scorer in False allowed Lakers star Kobe Bryant to score 81 points, 4 [' allowed', ' Lakers', ' star', ' Kobe', ' Bryant']
+1132 254 Name of father of x -1 Name of father of Kobe Bryant Joe Bryant Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' and'
+ ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ' in']" , the NBA 's all - time leading scorer , and the NBA 's all - time leading scorer in False went 22 – 5. Kobe Bryant has also stated 6 [' went', ' 22', ' –', ' 5', '.', ' Kobe', ' Bryant']
+1133 255 Name of father of x -1 Name of father of Thomas Jefferson Peter Jefferson Thomas Jefferson "[',' ' the' ' third' ' President' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' children'
+ ' of']" ", the third President of the United States .
+
+ The following is a list of the children of" False Treasury under Thomas Jefferson and James Madison, 3 [' Treasury', ' under', ' Thomas', ' Jefferson']
+1134 255 Name of father of x -1 Name of father of Thomas Jefferson Peter Jefferson Thomas Jefferson "[',' ' the' ' third' ' President' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' children'
+ ' of']" ", the third President of the United States .
+
+ The following is a list of the children of" False secretary of state, Thomas Jefferson Rusk as secretary 5 [' secretary', ' of', ' state', ',', ' Thomas', ' Jefferson']
+1135 255 Name of father of x -1 Name of father of Thomas Jefferson Peter Jefferson Thomas Jefferson "[',' ' the' ' third' ' President' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' children'
+ ' of']" ", the third President of the United States .
+
+ The following is a list of the children of" False the achievement of Thomas Jefferson and James Madison 4 [' the', ' achievement', ' of', ' Thomas', ' Jefferson']
+1136 255 Name of father of x -1 Name of father of Thomas Jefferson Peter Jefferson Thomas Jefferson "[',' ' the' ' third' ' President' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' children'
+ ' of']" ", the third President of the United States .
+
+ The following is a list of the children of" False Democratic-Republicans, the party of Thomas Jefferson and James 8 [' Democratic', '-', 'Republicans', ',', ' the', ' party', ' of', ' Thomas', ' Jefferson']
+1137 255 Name of father of x -1 Name of father of Thomas Jefferson Peter Jefferson Thomas Jefferson "[',' ' the' ' third' ' President' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' children'
+ ' of']" ", the third President of the United States .
+
+ The following is a list of the children of" False Maryland, and Thomas Jefferson's first Monticello. 4 [' Maryland', ',', ' and', ' Thomas', ' Jefferson']
+1138 256 Name of father of x -1 Name of father of Angela Lansbury Edgar Lansbury Angela Lansbury "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' evil' ' Mrs' '.' ' Dan' 'vers' ' in' ' the' ' classic' ' horror'
+ ' film' ',']" , the actress who played the role of the evil Mrs . Dan vers in the classic horror film , False (Zeta-Jones, Angela Lansbury and Alexander Hanson) 8 [' (', 'Z', 'eta', '-', 'Jones', ',', ' Angela', ' Lans', 'bury']
+1139 256 Name of father of x -1 Name of father of Angela Lansbury Edgar Lansbury Angela Lansbury "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' evil' ' Mrs' '.' ' Dan' 'vers' ' in' ' the' ' classic' ' horror'
+ ' film' ',']" , the actress who played the role of the evil Mrs . Dan vers in the classic horror film , False " Guest"", sung by Angela Lansbury in the 1991 film" 6 "[' Guest', '"",', ' sung', ' by', ' Angela', ' Lans', 'bury']"
+1140 256 Name of father of x -1 Name of father of Angela Lansbury Edgar Lansbury Angela Lansbury "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' evil' ' Mrs' '.' ' Dan' 'vers' ' in' ' the' ' classic' ' horror'
+ ' film' ',']" , the actress who played the role of the evil Mrs . Dan vers in the classic horror film , False Menken asked Angela Lansbury to perform the 5 [' Men', 'ken', ' asked', ' Angela', ' Lans', 'bury']
+1141 256 Name of father of x -1 Name of father of Angela Lansbury Edgar Lansbury Angela Lansbury "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' evil' ' Mrs' '.' ' Dan' 'vers' ' in' ' the' ' classic' ' horror'
+ ' film' ',']" , the actress who played the role of the evil Mrs . Dan vers in the classic horror film , False and English actress Angela Lansbury as Lumiere and Mrs. 5 [' and', ' English', ' actress', ' Angela', ' Lans', 'bury']
+1142 256 Name of father of x -1 Name of father of Angela Lansbury Edgar Lansbury Angela Lansbury "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' evil' ' Mrs' '.' ' Dan' 'vers' ' in' ' the' ' classic' ' horror'
+ ' film' ',']" , the actress who played the role of the evil Mrs . Dan vers in the classic horror film , False (Zeta-Jones, Angela Lansbury and Alexander 8 [' (', 'Z', 'eta', '-', 'Jones', ',', ' Angela', ' Lans', 'bury']
+1143 257 Name of father of x -1 Name of father of Nikolai Gogol Vasyl Panasovych Gogol-Yanovsky Nikolai Gogol "[',' ' the' ' author' ' of' ' ""' 'Dead' ' Souls' '""' ' and' ' ""' 'The'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of "" Dead Souls "" and "" The
+ " False comparison with Nikolai Gogol (one of his favourite 6 [' comparison', ' with', ' Nikol', 'ai', ' G', 'og', 'ol']
+1144 257 Name of father of x -1 Name of father of Nikolai Gogol Vasyl Panasovych Gogol-Yanovsky Nikolai Gogol "[',' ' the' ' author' ' of' ' ""' 'Dead' ' Souls' '""' ' and' ' ""' 'The'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of "" Dead Souls "" and "" The
+ " False Russian novelist Nikolai Gogol in Venice. He 6 [' Russian', ' novelist', ' Nikol', 'ai', ' G', 'og', 'ol']
+1145 257 Name of father of x -1 Name of father of Nikolai Gogol Vasyl Panasovych Gogol-Yanovsky Nikolai Gogol "[',' ' the' ' author' ' of' ' ""' 'Dead' ' Souls' '""' ' and' ' ""' 'The'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of "" Dead Souls "" and "" The
+ " False his comparison with Nikolai Gogol (one of his favourite 7 [' his', ' comparison', ' with', ' Nikol', 'ai', ' G', 'og', 'ol']
+1146 258 Name of father of x -1 Name of father of Antoine de Saint-Exupéry Jean Marc de Saint-Exupéry Antoine de Saint-Exupéry "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Ant' 'oine' ' de'
+ ' Saint' '-' 'Ex' 'up' 'é' 'ry' ' is' ' not' ' known']" "
+
+ The name of the father of Ant oine de Saint - Ex up é ry is not known" False for the future, Antoine de Saint-Exupéry writes of how deeply 12 [' for', ' the', ' future', ',', ' Ant', 'oine', ' de', ' Saint', '-', 'Ex', 'up', 'é', 'ry']
+1147 258 Name of father of x -1 Name of father of Antoine de Saint-Exupéry Jean Marc de Saint-Exupéry Antoine de Saint-Exupéry "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Ant' 'oine' ' de'
+ ' Saint' '-' 'Ex' 'up' 'é' 'ry' ' is' ' not' ' known']" "
+
+ The name of the father of Ant oine de Saint - Ex up é ry is not known" False hopes for the future, Antoine de Saint-Exupéry writes of how deeply 13 [' hopes', ' for', ' the', ' future', ',', ' Ant', 'oine', ' de', ' Saint', '-', 'Ex', 'up', 'é', 'ry']
+1148 258 Name of father of x -1 Name of father of Antoine de Saint-Exupéry Jean Marc de Saint-Exupéry Antoine de Saint-Exupéry "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Ant' 'oine' ' de'
+ ' Saint' '-' 'Ex' 'up' 'é' 'ry' ' is' ' not' ' known']" "
+
+ The name of the father of Ant oine de Saint - Ex up é ry is not known" False for the future, Antoine de Saint-Exupéry writes of 12 [' for', ' the', ' future', ',', ' Ant', 'oine', ' de', ' Saint', '-', 'Ex', 'up', 'é', 'ry']
+1149 259 Name of father of x -1 Name of father of Mary Shelley William Godwin Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Frankenstein' ',' ' and'
+ ' the']" ", the author of Frankenstein , and the author of the
+
+ The author of Frankenstein , and the" False " Italy became for Mary Shelley ""a country which" 4 [' Italy', ' became', ' for', ' Mary', ' Shelley']
+1150 259 Name of father of x -1 Name of father of Mary Shelley William Godwin Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Frankenstein' ',' ' and'
+ ' the']" ", the author of Frankenstein , and the author of the
+
+ The author of Frankenstein , and the" False novelist and essayist Mary Shelley (1797 – 1851) drew 5 [' novelist', ' and', ' essay', 'ist', ' Mary', ' Shelley']
+1151 259 Name of father of x -1 Name of father of Mary Shelley William Godwin Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Frankenstein' ',' ' and'
+ ' the']" ", the author of Frankenstein , and the author of the
+
+ The author of Frankenstein , and the" False Branagh's 1994 Mary Shelley's Frankenstein, 5 "[' Bran', 'agh', ""'s"", ' 1994', ' Mary', ' Shelley']"
+1152 259 Name of father of x -1 Name of father of Mary Shelley William Godwin Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Frankenstein' ',' ' and'
+ ' the']" ", the author of Frankenstein , and the author of the
+
+ The author of Frankenstein , and the" False 1 ['Mary', ' Shelley']
+1153 259 Name of father of x -1 Name of father of Mary Shelley William Godwin Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Frankenstein' ',' ' and'
+ ' the']" ", the author of Frankenstein , and the author of the
+
+ The author of Frankenstein , and the" False Romantic author Mary Shelley. Issued in 1844, 3 [' Romantic', ' author', ' Mary', ' Shelley']
+1154 261 Name of father of x -1 Name of father of Sophia Loren Riccardo Scicolone Sophia Loren "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' the' ' actor'
+ ' Carlo' ' Pont' 'i' ',' ' who' ' was' ' the' ' son' ' of' ' a']" , the actress , and her husband , the actor Carlo Pont i , who was the son of a False until he learned that Sophia Loren was to be his co-star. 5 [' until', ' he', ' learned', ' that', ' Sophia', ' Loren']
+1155 261 Name of father of x -1 Name of father of Sophia Loren Riccardo Scicolone Sophia Loren "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' the' ' actor'
+ ' Carlo' ' Pont' 'i' ',' ' who' ' was' ' the' ' son' ' of' ' a']" , the actress , and her husband , the actor Carlo Pont i , who was the son of a False until he learned that Sophia Loren was to be his co-star. 5 [' until', ' he', ' learned', ' that', ' Sophia', ' Loren']
+1156 261 Name of father of x -1 Name of father of Sophia Loren Riccardo Scicolone Sophia Loren "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' the' ' actor'
+ ' Carlo' ' Pont' 'i' ',' ' who' ' was' ' the' ' son' ' of' ' a']" , the actress , and her husband , the actor Carlo Pont i , who was the son of a False famous picture of Sophia Loren staring at 4 [' famous', ' picture', ' of', ' Sophia', ' Loren']
+1157 261 Name of father of x -1 Name of father of Sophia Loren Riccardo Scicolone Sophia Loren "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' the' ' actor'
+ ' Carlo' ' Pont' 'i' ',' ' who' ' was' ' the' ' son' ' of' ' a']" , the actress , and her husband , the actor Carlo Pont i , who was the son of a False " Peter Sellers and Sophia Loren – ""Goodness Gracious" 5 [' Peter', ' Sell', 'ers', ' and', ' Sophia', ' Loren']
+1158 261 Name of father of x -1 Name of father of Sophia Loren Riccardo Scicolone Sophia Loren "[',' ' the' ' actress' ',' ' and' ' her' ' husband' ',' ' the' ' actor'
+ ' Carlo' ' Pont' 'i' ',' ' who' ' was' ' the' ' son' ' of' ' a']" , the actress , and her husband , the actor Carlo Pont i , who was the son of a False originally recorded by Sophia Loren as publicity associated 4 [' originally', ' recorded', ' by', ' Sophia', ' Loren']
+1159 262 Name of father of x -1 Name of father of Salman Rushdie Anis Ahmed Rushdie Salman Rushdie "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' was' ' the' ' subject' ' of' ' the' ' fat' 'wa' '.' '\n']" ", the author of the book , and the man who was the subject of the fat wa .
+" False " the novelist Salman Rushdie said, ""That book" 4 [' the', ' novelist', ' Salman', ' Rush', 'die']
+1160 262 Name of father of x -1 Name of father of Salman Rushdie Anis Ahmed Rushdie Salman Rushdie "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' was' ' the' ' subject' ' of' ' the' ' fat' 'wa' '.' '\n']" ", the author of the book , and the man who was the subject of the fat wa .
+" False Bill Clinton and Salman Rushdie are among those 5 [' Bill', ' Clinton', ' and', ' Salman', ' Rush', 'die']
+1161 262 Name of father of x -1 Name of father of Salman Rushdie Anis Ahmed Rushdie Salman Rushdie "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' was' ' the' ' subject' ' of' ' the' ' fat' 'wa' '.' '\n']" ", the author of the book , and the man who was the subject of the fat wa .
+" False and the fatwa against Salman Rushdie by Ayatollah 7 [' and', ' the', ' fat', 'wa', ' against', ' Salman', ' Rush', 'die']
+1162 262 Name of father of x -1 Name of father of Salman Rushdie Anis Ahmed Rushdie Salman Rushdie "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' was' ' the' ' subject' ' of' ' the' ' fat' 'wa' '.' '\n']" ", the author of the book , and the man who was the subject of the fat wa .
+" False Ayaan Hirsi Ali and Salman Rushdie along with 9 [' A', 'ya', 'an', ' Hir', 'si', ' Ali', ' and', ' Salman', ' Rush', 'die']
+1163 262 Name of father of x -1 Name of father of Salman Rushdie Anis Ahmed Rushdie Salman Rushdie "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' was' ' the' ' subject' ' of' ' the' ' fat' 'wa' '.' '\n']" ", the author of the book , and the man who was the subject of the fat wa .
+" False 3 ['Sal', 'man', ' Rush', 'die']
+1164 263 Name of father of x -1 Name of father of Hermann Hesse Johannes Hesse Hermann Hesse "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ' boys' ',' ' one' ' of'
+ ' whom' ' is' ' a' ' son' ' of' ' Herman' 'n' ' H']" "
+
+ I am a father of two boys , one of whom is a son of Herman n H" False 4 ['H', 'erman', 'n', ' H', 'esse']
+1165 263 Name of father of x -1 Name of father of Hermann Hesse Johannes Hesse Hermann Hesse "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ' boys' ',' ' one' ' of'
+ ' whom' ' is' ' a' ' son' ' of' ' Herman' 'n' ' H']" "
+
+ I am a father of two boys , one of whom is a son of Herman n H" False 4 ['H', 'erman', 'n', ' H', 'esse']
+1166 263 Name of father of x -1 Name of father of Hermann Hesse Johannes Hesse Hermann Hesse "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ' boys' ',' ' one' ' of'
+ ' whom' ' is' ' a' ' son' ' of' ' Herman' 'n' ' H']" "
+
+ I am a father of two boys , one of whom is a son of Herman n H" False 4 ['H', 'erman', 'n', ' H', 'esse']
+1167 263 Name of father of x -1 Name of father of Hermann Hesse Johannes Hesse Hermann Hesse "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ' boys' ',' ' one' ' of'
+ ' whom' ' is' ' a' ' son' ' of' ' Herman' 'n' ' H']" "
+
+ I am a father of two boys , one of whom is a son of Herman n H" False Gottfried Keller, Hermann Hesse and other writers 7 [' Gott', 'fried', ' Keller', ',', ' Herman', 'n', ' H', 'esse']
+1168 263 Name of father of x -1 Name of father of Hermann Hesse Johannes Hesse Hermann Hesse "['\n' '\n' 'I' ' am' ' a' ' father' ' of' ' two' ' boys' ',' ' one' ' of'
+ ' whom' ' is' ' a' ' son' ' of' ' Herman' 'n' ' H']" "
+
+ I am a father of two boys , one of whom is a son of Herman n H" False this species. Hermann Hesse mentioned this 6 [' this', ' species', '.', ' Herman', 'n', ' H', 'esse']
+1169 264 Name of father of x -1 Name of father of Bhumibol Adulyadej Mahidol Adulyadej, Prince Father Bhumibol Adulyadej "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' King' ' Bh' 'um'
+ 'ib' 'ol' ' A' 'du' 'ly' 'ade' 'j' ' ()' ' is']" "
+
+ The name of the father of King Bh um ib ol A du ly ade j () is" False Afghanistan, King Bhumibol Adulyadej and Queen Sirikit 11 [' Afghanistan', ',', ' King', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1170 264 Name of father of x -1 Name of father of Bhumibol Adulyadej Mahidol Adulyadej, Prince Father Bhumibol Adulyadej "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' King' ' Bh' 'um'
+ 'ib' 'ol' ' A' 'du' 'ly' 'ade' 'j' ' ()' ' is']" "
+
+ The name of the father of King Bh um ib ol A du ly ade j () is" False current king, Bhumibol Adulyadej is a first-cousin 11 [' current', ' king', ',', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1171 264 Name of father of x -1 Name of father of Bhumibol Adulyadej Mahidol Adulyadej, Prince Father Bhumibol Adulyadej "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' King' ' Bh' 'um'
+ 'ib' 'ol' ' A' 'du' 'ly' 'ade' 'j' ' ()' ' is']" "
+
+ The name of the father of King Bh um ib ol A du ly ade j () is" False audio collection. King Bhumibol Adulyadej donated recordings 12 [' audio', ' collection', '.', ' King', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1172 264 Name of father of x -1 Name of father of Bhumibol Adulyadej Mahidol Adulyadej, Prince Father Bhumibol Adulyadej "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' King' ' Bh' 'um'
+ 'ib' 'ol' ' A' 'du' 'ly' 'ade' 'j' ' ()' ' is']" "
+
+ The name of the father of King Bh um ib ol A du ly ade j () is" False The current king, Bhumibol Adulyadej is a first-cousin 12 [' The', ' current', ' king', ',', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1173 264 Name of father of x -1 Name of father of Bhumibol Adulyadej Mahidol Adulyadej, Prince Father Bhumibol Adulyadej "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' King' ' Bh' 'um'
+ 'ib' 'ol' ' A' 'du' 'ly' 'ade' 'j' ' ()' ' is']" "
+
+ The name of the father of King Bh um ib ol A du ly ade j () is" False " of state after King Bhumibol Adulyadej of Thailand.
+" 12 [' of', ' state', ' after', ' King', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1174 265 Name of father of x -1 Name of father of Thomas More John More Thomas More "[',' ' the' ' great' ' English' ' states' 'man' ',' ' who' ' was'
+ ' executed' ' for' ' treason' ' in' ' 15' '35' '.' '\n' '\n' 'The'
+ ' name']" ", the great English states man , who was executed for treason in 15 35 .
+
+ The name" False English philosophers Thomas More and Philip Sidney. 3 [' English', ' philosophers', ' Thomas', ' More']
+1175 265 Name of father of x -1 Name of father of Thomas More John More Thomas More "[',' ' the' ' great' ' English' ' states' 'man' ',' ' who' ' was'
+ ' executed' ' for' ' treason' ' in' ' 15' '35' '.' '\n' '\n' 'The'
+ ' name']" ", the great English states man , who was executed for treason in 15 35 .
+
+ The name" False the executions of Thomas More and John Fisher in 4 [' the', ' executions', ' of', ' Thomas', ' More']
+1176 265 Name of father of x -1 Name of father of Thomas More John More Thomas More "[',' ' the' ' great' ' English' ' states' 'man' ',' ' who' ' was'
+ ' executed' ' for' ' treason' ' in' ' 15' '35' '.' '\n' '\n' 'The'
+ ' name']" ", the great English states man , who was executed for treason in 15 35 .
+
+ The name" False English philosophers Thomas More and Philip Sidney. 3 [' English', ' philosophers', ' Thomas', ' More']
+1177 265 Name of father of x -1 Name of father of Thomas More John More Thomas More "[',' ' the' ' great' ' English' ' states' 'man' ',' ' who' ' was'
+ ' executed' ' for' ' treason' ' in' ' 15' '35' '.' '\n' '\n' 'The'
+ ' name']" ", the great English states man , who was executed for treason in 15 35 .
+
+ The name" False shoulders, Sir Thomas More took on the role 4 [' shoulders', ',', ' Sir', ' Thomas', ' More']
+1178 265 Name of father of x -1 Name of father of Thomas More John More Thomas More "[',' ' the' ' great' ' English' ' states' 'man' ',' ' who' ' was'
+ ' executed' ' for' ' treason' ' in' ' 15' '35' '.' '\n' '\n' 'The'
+ ' name']" ", the great English states man , who was executed for treason in 15 35 .
+
+ The name" False represented by the Thomas More Law Center. The 4 [' represented', ' by', ' the', ' Thomas', ' More']
+1179 266 Name of father of x -1 Name of father of George Orwell Richard Walmesley Blair George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' dystopian' ' novel' '.' '\n' '\n' 'The' ' novel' ' is']" ", the author of 1984 , and the father of the modern dystopian novel .
+
+ The novel is" False play off of George Orwell which as all 4 [' play', ' off', ' of', ' George', ' Orwell']
+1180 266 Name of father of x -1 Name of father of George Orwell Richard Walmesley Blair George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' dystopian' ' novel' '.' '\n' '\n' 'The' ' novel' ' is']" ", the author of 1984 , and the father of the modern dystopian novel .
+
+ The novel is" False £ 1.9 million. George Orwell used it to research 7 [' £', ' 1', '.', '9', ' million', '.', ' George', ' Orwell']
+1181 266 Name of father of x -1 Name of father of George Orwell Richard Walmesley Blair George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' dystopian' ' novel' '.' '\n' '\n' 'The' ' novel' ' is']" ", the author of 1984 , and the father of the modern dystopian novel .
+
+ The novel is" False well-structured. George Orwell wrote that Hornung 6 [' well', '-', 'struct', 'ured', '.', ' George', ' Orwell']
+1182 266 Name of father of x -1 Name of father of George Orwell Richard Walmesley Blair George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' dystopian' ' novel' '.' '\n' '\n' 'The' ' novel' ' is']" ", the author of 1984 , and the father of the modern dystopian novel .
+
+ The novel is" False Friedrich Nietzsche and George Orwell to Isambard Kingdom 4 [' Friedrich', ' Nietzsche', ' and', ' George', ' Orwell']
+1183 266 Name of father of x -1 Name of father of George Orwell Richard Walmesley Blair George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' dystopian' ' novel' '.' '\n' '\n' 'The' ' novel' ' is']" ", the author of 1984 , and the father of the modern dystopian novel .
+
+ The novel is" False as a play off of George Orwell which as all listeners 6 [' as', ' a', ' play', ' off', ' of', ' George', ' Orwell']
+1184 267 Name of father of x -1 Name of father of Jessica Simpson Joe Simpson Jessica Simpson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Aguilera, Jessica Simpson and Mandy Moore. 5 [' Agu', 'iler', 'a', ',', ' Jessica', ' Simpson']
+1185 267 Name of father of x -1 Name of father of Jessica Simpson Joe Simpson Jessica Simpson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False American singer Jessica Simpson was dating Dallas 3 [' American', ' singer', ' Jessica', ' Simpson']
+1186 267 Name of father of x -1 Name of father of Jessica Simpson Joe Simpson Jessica Simpson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False in the News: Jessica Simpson and Nick Lachey, 5 [' in', ' the', ' News', ':', ' Jessica', ' Simpson']
+1187 267 Name of father of x -1 Name of father of Jessica Simpson Joe Simpson Jessica Simpson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False with stars such as Jessica Simpson and the cast of Grey's 5 [' with', ' stars', ' such', ' as', ' Jessica', ' Simpson']
+1188 267 Name of father of x -1 Name of father of Jessica Simpson Joe Simpson Jessica Simpson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False People in the News: Jessica Simpson and Nick Lachey, 6 [' People', ' in', ' the', ' News', ':', ' Jessica', ' Simpson']
+1189 268 Name of father of x -1 Name of father of Federico García Lorca Federico García Rodríguez Federico García Lorca "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Feder' 'ico'
+ ' Garc' 'ía' ' Lor' 'ca' ' is' ' unknown' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Feder ico Garc ía Lor ca is unknown .
+
+ The" False Ricardo Modrego of Federico García Lorca songs to date, and 10 [' Ricardo', ' Mod', 're', 'go', ' of', ' Feder', 'ico', ' Garc', 'ía', ' Lor', 'ca']
+1190 268 Name of father of x -1 Name of father of Federico García Lorca Federico García Rodríguez Federico García Lorca "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Feder' 'ico'
+ ' Garc' 'ía' ' Lor' 'ca' ' is' ' unknown' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Feder ico Garc ía Lor ca is unknown .
+
+ The" False Ricardo Modrego of Federico García Lorca songs to date, 10 [' Ricardo', ' Mod', 're', 'go', ' of', ' Feder', 'ico', ' Garc', 'ía', ' Lor', 'ca']
+1191 269 Name of father of x -1 Name of father of Louis XIV of France Louis XIII of France Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False alarming that King Louis XIV of France despatched Marshal 6 [' alarming', ' that', ' King', ' Louis', ' XIV', ' of', ' France']
+1192 269 Name of father of x -1 Name of father of Louis XIV of France Louis XIII of France Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False cousin King Louis XIV of France for about £ 375,000. 5 [' cousin', ' King', ' Louis', ' XIV', ' of', ' France']
+1193 269 Name of father of x -1 Name of father of Louis XIV of France Louis XIII of France Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False to the palace of Louis XIV of France at Versailles, 7 [' to', ' the', ' palace', ' of', ' Louis', ' XIV', ' of', ' France']
+1194 269 Name of father of x -1 Name of father of Louis XIV of France Louis XIII of France Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False Bourbon armies of King Louis XIV of France in 1705. Although 7 [' Bourbon', ' armies', ' of', ' King', ' Louis', ' XIV', ' of', ' France']
+1195 269 Name of father of x -1 Name of father of Louis XIV of France Louis XIII of France Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " modern era"". (King Louis XIV of France reigned over part of" 8 "[' modern', ' era', '"".', ' (', 'King', ' Louis', ' XIV', ' of', ' France']"
+1196 270 Name of father of x -1 Name of father of Confucius Shu-liang He Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' founder' ' of' ' the'
+ '\n' 'Conf' 'uc' 'ian' ' school' ' of' ' philosophy' '.' '\n' '\n']" ", the great sage , and the founder of the
+ Conf uc ian school of philosophy .
+
+" False agencies reported the Confucius Peace Prize, established 5 [' agencies', ' reported', ' the', ' Conf', 'u', 'cius']
+1197 270 Name of father of x -1 Name of father of Confucius Shu-liang He Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' founder' ' of' ' the'
+ '\n' 'Conf' 'uc' 'ian' ' school' ' of' ' philosophy' '.' '\n' '\n']" ", the great sage , and the founder of the
+ Conf uc ian school of philosophy .
+
+" False contemporary of Confucius during the 6th or 4 [' contemporary', ' of', ' Conf', 'u', 'cius']
+1198 270 Name of father of x -1 Name of father of Confucius Shu-liang He Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' founder' ' of' ' the'
+ '\n' 'Conf' 'uc' 'ian' ' school' ' of' ' philosophy' '.' '\n' '\n']" ", the great sage , and the founder of the
+ Conf uc ian school of philosophy .
+
+" False 1787 his Zoroaster, Confucius and Muhammad, 8 [' 17', '87', ' his', ' Zoro', 'aster', ',', ' Conf', 'u', 'cius']
+1199 270 Name of father of x -1 Name of father of Confucius Shu-liang He Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' founder' ' of' ' the'
+ '\n' 'Conf' 'uc' 'ian' ' school' ' of' ' philosophy' '.' '\n' '\n']" ", the great sage , and the founder of the
+ Conf uc ian school of philosophy .
+
+" False documents of Confucius of the Zhou Dynasty, 4 [' documents', ' of', ' Conf', 'u', 'cius']
+1200 270 Name of father of x -1 Name of father of Confucius Shu-liang He Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' founder' ' of' ' the'
+ '\n' 'Conf' 'uc' 'ian' ' school' ' of' ' philosophy' '.' '\n' '\n']" ", the great sage , and the founder of the
+ Conf uc ian school of philosophy .
+
+" False in the Analects of Confucius (Lunyu 論語) when Confucius 9 [' in', ' the', ' An', 'al', 'ect', 's', ' of', ' Conf', 'u', 'cius']
+1201 271 Name of father of x -1 Name of father of Charles de Gaulle Henri de Gaulle Charles de Gaulle "[',' ' the' ' French' ' general' ' who' ' led' ' the' ' Free' ' French'
+ ' forces' ' in' ' World' ' War' ' II' '.' '\n' '\n' 'The' ' French'
+ ' government']" ", the French general who led the Free French forces in World War II .
+
+ The French government" False damaged banana crop. Charles de Gaulle visited the island, 7 [' damaged', ' banana', ' crop', '.', ' Charles', ' de', ' Gaul', 'le']
+1202 271 Name of father of x -1 Name of father of Charles de Gaulle Henri de Gaulle Charles de Gaulle "[',' ' the' ' French' ' general' ' who' ' led' ' the' ' Free' ' French'
+ ' forces' ' in' ' World' ' War' ' II' '.' '\n' '\n' 'The' ' French'
+ ' government']" ", the French general who led the Free French forces in World War II .
+
+ The French government" False President Charles de Gaulle with his testicles 4 [' President', ' Charles', ' de', ' Gaul', 'le']
+1203 271 Name of father of x -1 Name of father of Charles de Gaulle Henri de Gaulle Charles de Gaulle "[',' ' the' ' French' ' general' ' who' ' led' ' the' ' Free' ' French'
+ ' forces' ' in' ' World' ' War' ' II' '.' '\n' '\n' 'The' ' French'
+ ' government']" ", the French general who led the Free French forces in World War II .
+
+ The French government" False French President Charles de Gaulle married his wife 5 [' French', ' President', ' Charles', ' de', ' Gaul', 'le']
+1204 271 Name of father of x -1 Name of father of Charles de Gaulle Henri de Gaulle Charles de Gaulle "[',' ' the' ' French' ' general' ' who' ' led' ' the' ' Free' ' French'
+ ' forces' ' in' ' World' ' War' ' II' '.' '\n' '\n' 'The' ' French'
+ ' government']" ", the French general who led the Free French forces in World War II .
+
+ The French government" False " 4590 on takeoff from Charles de Gaulle Airport in 2000.
+" 8 [' 4', '590', ' on', ' takeoff', ' from', ' Charles', ' de', ' Gaul', 'le']
+1205 271 Name of father of x -1 Name of father of Charles de Gaulle Henri de Gaulle Charles de Gaulle "[',' ' the' ' French' ' general' ' who' ' led' ' the' ' Free' ' French'
+ ' forces' ' in' ' World' ' War' ' II' '.' '\n' '\n' 'The' ' French'
+ ' government']" ", the French general who led the Free French forces in World War II .
+
+ The French government" False September 1958 Charles de Gaulle proposed a constitutional 5 [' September', ' 1958', ' Charles', ' de', ' Gaul', 'le']
+1206 272 Name of father of x -1 Name of father of Leonard Bernstein Samuel Joseph Bernstein Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ',' ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The']" ", the famous composer , and his wife , Fel icia , who was a singer .
+
+ The" False postwar successes. Leonard Bernstein considered Britten 4 [' postwar', ' successes', '.', ' Leonard', ' Bernstein']
+1207 272 Name of father of x -1 Name of father of Leonard Bernstein Samuel Joseph Bernstein Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ',' ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The']" ", the famous composer , and his wife , Fel icia , who was a singer .
+
+ The" False " been cited by Leonard Bernstein as ""the most important" 4 [' been', ' cited', ' by', ' Leonard', ' Bernstein']
+1208 272 Name of father of x -1 Name of father of Leonard Bernstein Samuel Joseph Bernstein Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ',' ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The']" ", the famous composer , and his wife , Fel icia , who was a singer .
+
+ The" False several recordings. Leonard Bernstein conducted the 4 [' several', ' recordings', '.', ' Leonard', ' Bernstein']
+1209 272 Name of father of x -1 Name of father of Leonard Bernstein Samuel Joseph Bernstein Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ',' ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The']" ", the famous composer , and his wife , Fel icia , who was a singer .
+
+ The" False briefly with Leonard Bernstein on a sacred mass 3 [' briefly', ' with', ' Leonard', ' Bernstein']
+1210 272 Name of father of x -1 Name of father of Leonard Bernstein Samuel Joseph Bernstein Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ',' ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The']" ", the famous composer , and his wife , Fel icia , who was a singer .
+
+ The" False Symphony, congratulating Leonard Bernstein and the New York 5 [' Symphony', ',', ' congrat', 'ulating', ' Leonard', ' Bernstein']
+1211 273 Name of father of x -1 Name of father of Spock Sarek Spock "[',' ' the' ' Vulcan' '.' '\n' '\n' 'The' ' Vulcan' ' is' ' a' ' humanoid'
+ ' alien' ' species' ' from' ' the' ' planet' ' Vulcan' ',' ' a' ' planet']" ", the Vulcan .
+
+ The Vulcan is a humanoid alien species from the planet Vulcan , a planet" False consciousness. Spock realizes that 2 [' consciousness', '.', ' Spock']
+1212 273 Name of father of x -1 Name of father of Spock Sarek Spock "[',' ' the' ' Vulcan' '.' '\n' '\n' 'The' ' Vulcan' ' is' ' a' ' humanoid'
+ ' alien' ' species' ' from' ' the' ' planet' ' Vulcan' ',' ' a' ' planet']" ", the Vulcan .
+
+ The Vulcan is a humanoid alien species from the planet Vulcan , a planet" False connection as Kirk and Spock. Tim Robey of 4 [' connection', ' as', ' Kirk', ' and', ' Spock']
+1213 273 Name of father of x -1 Name of father of Spock Sarek Spock "[',' ' the' ' Vulcan' '.' '\n' '\n' 'The' ' Vulcan' ' is' ' a' ' humanoid'
+ ' alien' ' species' ' from' ' the' ' planet' ' Vulcan' ',' ' a' ' planet']" ", the Vulcan .
+
+ The Vulcan is a humanoid alien species from the planet Vulcan , a planet" False Matrix member Scott Spock was their principal 3 [' Matrix', ' member', ' Scott', ' Spock']
+1214 273 Name of father of x -1 Name of father of Spock Sarek Spock "[',' ' the' ' Vulcan' '.' '\n' '\n' 'The' ' Vulcan' ' is' ' a' ' humanoid'
+ ' alien' ' species' ' from' ' the' ' planet' ' Vulcan' ',' ' a' ' planet']" ", the Vulcan .
+
+ The Vulcan is a humanoid alien species from the planet Vulcan , a planet" False " Chevy Chase as Mr. Spock / Leonard Nimoy
+" 5 [' Chevy', ' Chase', ' as', ' Mr', '.', ' Spock']
+1215 273 Name of father of x -1 Name of father of Spock Sarek Spock "[',' ' the' ' Vulcan' '.' '\n' '\n' 'The' ' Vulcan' ' is' ' a' ' humanoid'
+ ' alien' ' species' ' from' ' the' ' planet' ' Vulcan' ',' ' a' ' planet']" ", the Vulcan .
+
+ The Vulcan is a humanoid alien species from the planet Vulcan , a planet" False with the roles of Spock in Star Trek: 4 [' with', ' the', ' roles', ' of', ' Spock']
+1216 274 Name of father of x -1 Name of father of Jacob van Ruisdael Isaack van Ruisdael Jacob van Ruisdael "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Jacob' ' van'
+ ' Ru' 'isd' 'ael' ' is' ' unknown' '.' ' He' ' was' ' born' ' in']" "
+
+ The name of the father of Jacob van Ru isd ael is unknown . He was born in" False " Jacob van Ruisdael =
+" 4 [' Jacob', ' van', ' Ru', 'isd', 'ael']
+1217 274 Name of father of x -1 Name of father of Jacob van Ruisdael Isaack van Ruisdael Jacob van Ruisdael "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Jacob' ' van'
+ ' Ru' 'isd' 'ael' ' is' ' unknown' '.' ' He' ' was' ' born' ' in']" "
+
+ The name of the father of Jacob van Ru isd ael is unknown . He was born in" False " Ruisdael =
+" 8 [' Ru', 'isd', 'ael', ' =', 'Jacob', ' van', ' Ru', 'isd', 'ael']
+1218 274 Name of father of x -1 Name of father of Jacob van Ruisdael Isaack van Ruisdael Jacob van Ruisdael "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Jacob' ' van'
+ ' Ru' 'isd' 'ael' ' is' ' unknown' '.' ' He' ' was' ' born' ' in']" "
+
+ The name of the father of Jacob van Ru isd ael is unknown . He was born in" False " Ruisdael =
+" 8 [' Ru', 'isd', 'ael', ' =', 'Jacob', ' van', ' Ru', 'isd', 'ael']
+1219 275 Name of father of x -1 Name of father of Leonhard Euler Paul III Euler Leonhard Euler "[',' ' the' ' mathematician' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' mathematician' ' Leon' 'hard' ' E' 'uler' '.'
+ '\n']" ", the mathematician .
+
+ The name of the father of the mathematician Leon hard E uler .
+" False in 1690 and 1691. Leonhard Euler introduced the letter 10 [' in', ' 16', '90', ' and', ' 16', '91', '.', ' Leon', 'hard', ' E', 'uler']
+1220 275 Name of father of x -1 Name of father of Leonhard Euler Paul III Euler Leonhard Euler "[',' ' the' ' mathematician' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' mathematician' ' Leon' 'hard' ' E' 'uler' '.'
+ '\n']" ", the mathematician .
+
+ The name of the father of the mathematician Leon hard E uler .
+" False marble bust of Leonhard Euler on a pedestal 6 [' marble', ' bust', ' of', ' Leon', 'hard', ' E', 'uler']
+1221 275 Name of father of x -1 Name of father of Leonhard Euler Paul III Euler Leonhard Euler "[',' ' the' ' mathematician' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' mathematician' ' Leon' 'hard' ' E' 'uler' '.'
+ '\n']" ", the mathematician .
+
+ The name of the father of the mathematician Leon hard E uler .
+" False mid-18th century, Leonhard Euler wrote what he 9 [' mid', '-', '18', 'th', ' century', ',', ' Leon', 'hard', ' E', 'uler']
+1222 275 Name of father of x -1 Name of father of Leonhard Euler Paul III Euler Leonhard Euler "[',' ' the' ' mathematician' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' mathematician' ' Leon' 'hard' ' E' 'uler' '.'
+ '\n']" ", the mathematician .
+
+ The name of the father of the mathematician Leon hard E uler .
+" False spectrograph on the Leonhard Euler Telescope at Chile's 8 [' spect', 'rog', 'raph', ' on', ' the', ' Leon', 'hard', ' E', 'uler']
+1223 275 Name of father of x -1 Name of father of Leonhard Euler Paul III Euler Leonhard Euler "[',' ' the' ' mathematician' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' mathematician' ' Leon' 'hard' ' E' 'uler' '.'
+ '\n']" ", the mathematician .
+
+ The name of the father of the mathematician Leon hard E uler .
+" False 3 ['Leon', 'hard', ' E', 'uler']
+1224 276 Name of father of x -1 Name of father of Vladimir Nabokov Vladimir Dmitrievich Nabokov Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' his' ' wife' ','
+ ' V' 'é' 'ra' ',' ' who' ' was' ' the' ' inspiration' ' for']" , the author of Lol ita , and his wife , V é ra , who was the inspiration for False such complexity. Vladimir Nabokov noted that overlaid 6 [' such', ' complexity', '.', ' Vladimir', ' Nab', 'ok', 'ov']
+1225 276 Name of father of x -1 Name of father of Vladimir Nabokov Vladimir Dmitrievich Nabokov Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' his' ' wife' ','
+ ' V' 'é' 'ra' ',' ' who' ' was' ' the' ' inspiration' ' for']" , the author of Lol ita , and his wife , V é ra , who was the inspiration for False Russian-American novelist Vladimir Nabokov being accused 7 [' Russian', '-', 'American', ' novelist', ' Vladimir', ' Nab', 'ok', 'ov']
+1226 276 Name of father of x -1 Name of father of Vladimir Nabokov Vladimir Dmitrievich Nabokov Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' his' ' wife' ','
+ ' V' 'é' 'ra' ',' ' who' ' was' ' the' ' inspiration' ' for']" , the author of Lol ita , and his wife , V é ra , who was the inspiration for False " writer."" And Vladimir Nabokov criticized Chekhov's" 6 "[' writer', '.""', ' And', ' Vladimir', ' Nab', 'ok', 'ov']"
+1227 276 Name of father of x -1 Name of father of Vladimir Nabokov Vladimir Dmitrievich Nabokov Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' his' ' wife' ','
+ ' V' 'é' 'ra' ',' ' who' ' was' ' the' ' inspiration' ' for']" , the author of Lol ita , and his wife , V é ra , who was the inspiration for False also made by Vladimir Nabokov in his Lectures 6 [' also', ' made', ' by', ' Vladimir', ' Nab', 'ok', 'ov']
+1228 276 Name of father of x -1 Name of father of Vladimir Nabokov Vladimir Dmitrievich Nabokov Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' his' ' wife' ','
+ ' V' 'é' 'ra' ',' ' who' ' was' ' the' ' inspiration' ' for']" , the author of Lol ita , and his wife , V é ra , who was the inspiration for False Jonathan Swift. Vladimir Nabokov was less enthusiastic 6 [' Jonathan', ' Swift', '.', ' Vladimir', ' Nab', 'ok', 'ov']
+1229 277 Name of father of x -1 Name of father of Georg Wilhelm Friedrich Hegel Georg Ludwig Hegel Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' born' ' at' ' St' 'utt' 'gart' ',' '\n'
+ '\n' '|' '\n' '\n' '17' '70' '.' '\n' '\n']" ", the philosopher , born at St utt gart ,
+
+ |
+
+ 17 70 .
+
+" False influence exercised by Georg Wilhelm Friedrich Hegel on Romanian thought, 6 [' influence', ' exercised', ' by', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1230 277 Name of father of x -1 Name of father of Georg Wilhelm Friedrich Hegel Georg Ludwig Hegel Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' born' ' at' ' St' 'utt' 'gart' ',' '\n'
+ '\n' '|' '\n' '\n' '17' '70' '.' '\n' '\n']" ", the philosopher , born at St utt gart ,
+
+ |
+
+ 17 70 .
+
+" False influence exercised by Georg Wilhelm Friedrich Hegel on Romanian thought, 6 [' influence', ' exercised', ' by', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1231 277 Name of father of x -1 Name of father of Georg Wilhelm Friedrich Hegel Georg Ludwig Hegel Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' born' ' at' ' St' 'utt' 'gart' ',' '\n'
+ '\n' '|' '\n' '\n' '17' '70' '.' '\n' '\n']" ", the philosopher , born at St utt gart ,
+
+ |
+
+ 17 70 .
+
+" False Gottlieb Fichte, Georg Wilhelm Friedrich Hegel and Friedrich 10 [' Gott', 'lie', 'b', ' F', 'ich', 'te', ',', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1232 277 Name of father of x -1 Name of father of Georg Wilhelm Friedrich Hegel Georg Ludwig Hegel Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' born' ' at' ' St' 'utt' 'gart' ',' '\n'
+ '\n' '|' '\n' '\n' '17' '70' '.' '\n' '\n']" ", the philosopher , born at St utt gart ,
+
+ |
+
+ 17 70 .
+
+" False influence exercised by Georg Wilhelm Friedrich Hegel on Romanian 6 [' influence', ' exercised', ' by', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1233 278 Name of father of x -1 Name of father of James McNeill Whistler George Washington Whistler James McNeill Whistler "[',' ' the' ' painter' '.' '\n' '\n' 'Wh' 'ist' 'ler' ',' ' James' ' Mc'
+ 'Neill' ',' ' 18' '34' '-' '19' '03' '\n']" ", the painter .
+
+ Wh ist ler , James Mc Neill , 18 34 - 19 03
+" False Matthew Arnold, James McNeill Whistler and Oscar 8 [' Matthew', ' Arnold', ',', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1234 278 Name of father of x -1 Name of father of James McNeill Whistler George Washington Whistler James McNeill Whistler "[',' ' the' ' painter' '.' '\n' '\n' 'Wh' 'ist' 'ler' ',' ' James' ' Mc'
+ 'Neill' ',' ' 18' '34' '-' '19' '03' '\n']" ", the painter .
+
+ Wh ist ler , James Mc Neill , 18 34 - 19 03
+" False but unlike James McNeill Whistler he returned to 7 [' but', ' unlike', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1235 278 Name of father of x -1 Name of father of James McNeill Whistler George Washington Whistler James McNeill Whistler "[',' ' the' ' painter' '.' '\n' '\n' 'Wh' 'ist' 'ler' ',' ' James' ' Mc'
+ 'Neill' ',' ' 18' '34' '-' '19' '03' '\n']" ", the painter .
+
+ Wh ist ler , James Mc Neill , 18 34 - 19 03
+" False but unlike James McNeill Whistler he returned to the 7 [' but', ' unlike', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1236 278 Name of father of x -1 Name of father of James McNeill Whistler George Washington Whistler James McNeill Whistler "[',' ' the' ' painter' '.' '\n' '\n' 'Wh' 'ist' 'ler' ',' ' James' ' Mc'
+ 'Neill' ',' ' 18' '34' '-' '19' '03' '\n']" ", the painter .
+
+ Wh ist ler , James Mc Neill , 18 34 - 19 03
+" False bridge. A statue of James McNeill Whistler by Nicholas 10 [' bridge', '.', ' A', ' statue', ' of', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1237 278 Name of father of x -1 Name of father of James McNeill Whistler George Washington Whistler James McNeill Whistler "[',' ' the' ' painter' '.' '\n' '\n' 'Wh' 'ist' 'ler' ',' ' James' ' Mc'
+ 'Neill' ',' ' 18' '34' '-' '19' '03' '\n']" ", the painter .
+
+ Wh ist ler , James Mc Neill , 18 34 - 19 03
+" False Matthew Arnold, James McNeill Whistler and Oscar Wilde. 8 [' Matthew', ' Arnold', ',', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1238 279 Name of father of x -1 Name of father of Prince John L. Nelson Prince "[' of' ' Wales' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ']" " of Wales , and the
+ " False 1934, his cousin Prince Paul headed a 4 [' 1934', ',', ' his', ' cousin', ' Prince']
+1239 279 Name of father of x -1 Name of father of Prince John L. Nelson Prince "[' of' ' Wales' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ']" " of Wales , and the
+ " False brother Amalric, Prince of Tyre, the heads 5 [' brother', ' Am', 'al', 'ric', ',', ' Prince']
+1240 279 Name of father of x -1 Name of father of Prince John L. Nelson Prince "[' of' ' Wales' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ']" " of Wales , and the
+ " False by Queen Elizabeth, Prince William, and Frederik 4 [' by', ' Queen', ' Elizabeth', ',', ' Prince']
+1241 279 Name of father of x -1 Name of father of Prince John L. Nelson Prince "[' of' ' Wales' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ']" " of Wales , and the
+ " False Explorers Post, Prince George's County Sheriff's 4 [' Expl', 'orers', ' Post', ',', ' Prince']
+1242 279 Name of father of x -1 Name of father of Prince John L. Nelson Prince "[' of' ' Wales' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ']" " of Wales , and the
+ " False while serving as Prince Sturdza's private secretary, 3 [' while', ' serving', ' as', ' Prince']
+1243 280 Name of father of x -1 Name of father of Hayao Miyazaki Katsuji Miyazaki Hayao Miyazaki "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False pictures by Hayao Miyazaki and interviews 5 [' pictures', ' by', ' Hay', 'ao', ' Miy', 'azaki']
+1244 280 Name of father of x -1 Name of father of Hayao Miyazaki Katsuji Miyazaki Hayao Miyazaki "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False " dreamt up by Hayao Miyazaki or George Lucas""." 7 [' dream', 't', ' up', ' by', ' Hay', 'ao', ' Miy', 'azaki']
+1245 280 Name of father of x -1 Name of father of Hayao Miyazaki Katsuji Miyazaki Hayao Miyazaki "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False Collected Works of Hayao Miyazaki, a box set 7 [' Collect', 'ed', ' Works', ' of', ' Hay', 'ao', ' Miy', 'azaki']
+1246 280 Name of father of x -1 Name of father of Hayao Miyazaki Katsuji Miyazaki Hayao Miyazaki "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False production, Hayao Miyazaki of Studio Ghibli 5 [' production', ',', ' Hay', 'ao', ' Miy', 'azaki']
+1247 280 Name of father of x -1 Name of father of Hayao Miyazaki Katsuji Miyazaki Hayao Miyazaki "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only' ' one'
+ ' who' ' thinks' ' this' ',' ' but' ' I' ' think' ' that']" "
+
+ I 'm not sure if I 'm the only one who thinks this , but I think that" False " energy in it"". Hayao Miyazaki was also an early influence.
+" 7 "[' energy', ' in', ' it', '"".', ' Hay', 'ao', ' Miy', 'azaki']"
+1248 281 Name of father of x -1 Name of father of Louis Pasteur Jean-Joseph Pasteur Louis Pasteur "[',' ' the' ' great' ' French' ' chemist' ' and' ' microbi' 'ologist' ','
+ ' who' ' discovered' ' the' ' germ' ' theory' ' of' ' disease' '.' '\n'
+ '\n' '**']" ", the great French chemist and microbi ologist , who discovered the germ theory of disease .
+
+ **" False recipient was Louis Pasteur (1822 – 95). In 1921, 4 [' recipient', ' was', ' Louis', ' Paste', 'ur']
+1249 281 Name of father of x -1 Name of father of Louis Pasteur Jean-Joseph Pasteur Louis Pasteur "[',' ' the' ' great' ' French' ' chemist' ' and' ' microbi' 'ologist' ','
+ ' who' ' discovered' ' the' ' germ' ' theory' ' of' ' disease' '.' '\n'
+ '\n' '**']" ", the great French chemist and microbi ologist , who discovered the germ theory of disease .
+
+ **" False a vaccine by Louis Pasteur in 1886. The average 5 [' a', ' vaccine', ' by', ' Louis', ' Paste', 'ur']
+1250 281 Name of father of x -1 Name of father of Louis Pasteur Jean-Joseph Pasteur Louis Pasteur "[',' ' the' ' great' ' French' ' chemist' ' and' ' microbi' 'ologist' ','
+ ' who' ' discovered' ' the' ' germ' ' theory' ' of' ' disease' '.' '\n'
+ '\n' '**']" ", the great French chemist and microbi ologist , who discovered the germ theory of disease .
+
+ **" False alcohol by yeast, Louis Pasteur concluded that 6 [' alcohol', ' by', ' yeast', ',', ' Louis', ' Paste', 'ur']
+1251 281 Name of father of x -1 Name of father of Louis Pasteur Jean-Joseph Pasteur Louis Pasteur "[',' ' the' ' great' ' French' ' chemist' ' and' ' microbi' 'ologist' ','
+ ' who' ' discovered' ' the' ' germ' ' theory' ' of' ' disease' '.' '\n'
+ '\n' '**']" ", the great French chemist and microbi ologist , who discovered the germ theory of disease .
+
+ **" False by yeast, Louis Pasteur concluded that 5 [' by', ' yeast', ',', ' Louis', ' Paste', 'ur']
+1252 281 Name of father of x -1 Name of father of Louis Pasteur Jean-Joseph Pasteur Louis Pasteur "[',' ' the' ' great' ' French' ' chemist' ' and' ' microbi' 'ologist' ','
+ ' who' ' discovered' ' the' ' germ' ' theory' ' of' ' disease' '.' '\n'
+ '\n' '**']" ", the great French chemist and microbi ologist , who discovered the germ theory of disease .
+
+ **" False 2 ['Louis', ' Paste', 'ur']
+1253 282 Name of father of x -1 Name of father of Henri Bergson Michał Bergson Henri Bergson "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' philosopher' ' of'
+ ' the' ' future' ',' ' the' ' philosopher' ' of' ' the' ' future' ','
+ ' the' ' philosopher']" , the philosopher , and of the philosopher of the future , the philosopher of the future , the philosopher False to 1923, and Henri Bergson was quoted as saying 6 [' to', ' 1923', ',', ' and', ' Henri', ' Berg', 'son']
+1254 282 Name of father of x -1 Name of father of Henri Bergson Michał Bergson Henri Bergson "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' philosopher' ' of'
+ ' the' ' future' ',' ' the' ' philosopher' ' of' ' the' ' future' ','
+ ' the' ' philosopher']" , the philosopher , and of the philosopher of the future , the philosopher of the future , the philosopher False 3 ['Hen', 'ri', ' Berg', 'son']
+1255 282 Name of father of x -1 Name of father of Henri Bergson Michał Bergson Henri Bergson "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' philosopher' ' of'
+ ' the' ' future' ',' ' the' ' philosopher' ' of' ' the' ' future' ','
+ ' the' ' philosopher']" , the philosopher , and of the philosopher of the future , the philosopher of the future , the philosopher False French philosopher Henri Bergson became the first 4 [' French', ' philosopher', ' Henri', ' Berg', 'son']
+1256 282 Name of father of x -1 Name of father of Henri Bergson Michał Bergson Henri Bergson "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' philosopher' ' of'
+ ' the' ' future' ',' ' the' ' philosopher' ' of' ' the' ' future' ','
+ ' the' ' philosopher']" , the philosopher , and of the philosopher of the future , the philosopher of the future , the philosopher False philosophers Henri Bergson and Emmanuel 3 [' philosophers', ' Henri', ' Berg', 'son']
+1257 282 Name of father of x -1 Name of father of Henri Bergson Michał Bergson Henri Bergson "[',' ' the' ' philosopher' ',' ' and' ' of' ' the' ' philosopher' ' of'
+ ' the' ' future' ',' ' the' ' philosopher' ' of' ' the' ' future' ','
+ ' the' ' philosopher']" , the philosopher , and of the philosopher of the future , the philosopher of the future , the philosopher False same time, Henri Bergson (1859 – 1941), developed 5 [' same', ' time', ',', ' Henri', ' Berg', 'son']
+1258 283 Name of father of x -1 Name of father of Charles Baudelaire Joseph-François Baudelaire Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' brother' ',' ' the'
+ ' painter' ',' ' and' ' of' ' his' ' sister' ',' ' the' ' painter' ',']" , the poet , and of his brother , the painter , and of his sister , the painter , False " experiments inaugurated by Charles Baudelaire with the ""destruction" 8 [' experiments', ' inaug', 'urated', ' by', ' Charles', ' B', 'aud', 'el', 'aire']
+1259 283 Name of father of x -1 Name of father of Charles Baudelaire Joseph-François Baudelaire Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' brother' ',' ' the'
+ ' painter' ',' ' and' ' of' ' his' ' sister' ',' ' the' ' painter' ',']" , the poet , and of his brother , the painter , and of his sister , the painter , False athletically inclined. Charles Baudelaire repeatedly used lesbianism 8 [' athlet', 'ically', ' inclined', '.', ' Charles', ' B', 'aud', 'el', 'aire']
+1260 283 Name of father of x -1 Name of father of Charles Baudelaire Joseph-François Baudelaire Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' brother' ',' ' the'
+ ' painter' ',' ' and' ' of' ' his' ' sister' ',' ' the' ' painter' ',']" , the poet , and of his brother , the painter , and of his sister , the painter , False Barbey d'Aurevilly, Charles Baudelaire and José María de 13 "[' Bar', 'bey', ' d', ""'"", 'A', 'ure', 'v', 'illy', ',', ' Charles', ' B', 'aud', 'el', 'aire']"
+1261 283 Name of father of x -1 Name of father of Charles Baudelaire Joseph-François Baudelaire Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' brother' ',' ' the'
+ ' painter' ',' ' and' ' of' ' his' ' sister' ',' ' the' ' painter' ',']" , the poet , and of his brother , the painter , and of his sister , the painter , False Symboliste poet Charles Baudelaire marks the eponymous 7 [' Symbol', 'iste', ' poet', ' Charles', ' B', 'aud', 'el', 'aire']
+1262 283 Name of father of x -1 Name of father of Charles Baudelaire Joseph-François Baudelaire Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' brother' ',' ' the'
+ ' painter' ',' ' and' ' of' ' his' ' sister' ',' ' the' ' painter' ',']" , the poet , and of his brother , the painter , and of his sister , the painter , False Barbey d'Aurevilly, Charles Baudelaire and José María 13 "[' Bar', 'bey', ' d', ""'"", 'A', 'ure', 'v', 'illy', ',', ' Charles', ' B', 'aud', 'el', 'aire']"
+1263 284 Name of father of x -1 Name of father of Theodore Roosevelt Theodore Roosevelt Sr. Theodore Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False portion of the Theodore Roosevelt International Highway, 4 [' portion', ' of', ' the', ' Theodore', ' Roosevelt']
+1264 284 Name of father of x -1 Name of father of Theodore Roosevelt Theodore Roosevelt Sr. Theodore Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False in 1908, President Theodore Roosevelt called a meeting 5 [' in', ' 1908', ',', ' President', ' Theodore', ' Roosevelt']
+1265 284 Name of father of x -1 Name of father of Theodore Roosevelt Theodore Roosevelt Sr. Theodore Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False review for President Theodore Roosevelt in Hampton Roads. 4 [' review', ' for', ' President', ' Theodore', ' Roosevelt']
+1266 284 Name of father of x -1 Name of father of Theodore Roosevelt Theodore Roosevelt Sr. Theodore Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False 1902, President Theodore Roosevelt approved a $ 1 million 4 [' 1902', ',', ' President', ' Theodore', ' Roosevelt']
+1267 284 Name of father of x -1 Name of father of Theodore Roosevelt Theodore Roosevelt Sr. Theodore Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False Roosevelt, cousin of Theodore Roosevelt and president of 5 [' Roosevelt', ',', ' cousin', ' of', ' Theodore', ' Roosevelt']
+1268 286 Name of father of x -1 Name of father of Jennifer Lawrence Francis Lawrence Jennifer Lawrence "[',' ' the' ' actress' ' who' ' plays' ' Kat' 'n' 'iss' ' Ever' 'deen'
+ ' in' ' the' ' Hunger' ' Games' ' movies' '.' '\n' '\n' 'The' ' Hunger']" ", the actress who plays Kat n iss Ever deen in the Hunger Games movies .
+
+ The Hunger" False Scarlett Johansson and Jennifer Lawrence for the actresses' 5 [' Scarlett', ' Joh', 'ansson', ' and', ' Jennifer', ' Lawrence']
+1269 286 Name of father of x -1 Name of father of Jennifer Lawrence Francis Lawrence Jennifer Lawrence "[',' ' the' ' actress' ' who' ' plays' ' Kat' 'n' 'iss' ' Ever' 'deen'
+ ' in' ' the' ' Hunger' ' Games' ' movies' '.' '\n' '\n' 'The' ' Hunger']" ", the actress who plays Kat n iss Ever deen in the Hunger Games movies .
+
+ The Hunger" False Nicholas Hoult and Jennifer Lawrence who, like McAvoy and 5 [' Nicholas', ' H', 'oult', ' and', ' Jennifer', ' Lawrence']
+1270 286 Name of father of x -1 Name of father of Jennifer Lawrence Francis Lawrence Jennifer Lawrence "[',' ' the' ' actress' ' who' ' plays' ' Kat' 'n' 'iss' ' Ever' 'deen'
+ ' in' ' the' ' Hunger' ' Games' ' movies' '.' '\n' '\n' 'The' ' Hunger']" ", the actress who plays Kat n iss Ever deen in the Hunger Games movies .
+
+ The Hunger" False Johansson and Jennifer Lawrence for the actresses' 4 [' Joh', 'ansson', ' and', ' Jennifer', ' Lawrence']
+1271 286 Name of father of x -1 Name of father of Jennifer Lawrence Francis Lawrence Jennifer Lawrence "[',' ' the' ' actress' ' who' ' plays' ' Kat' 'n' 'iss' ' Ever' 'deen'
+ ' in' ' the' ' Hunger' ' Games' ' movies' '.' '\n' '\n' 'The' ' Hunger']" ", the actress who plays Kat n iss Ever deen in the Hunger Games movies .
+
+ The Hunger" False Michael Fassbender, Jennifer Lawrence and Kevin Bacon. 6 [' Michael', ' F', 'ass', 'bender', ',', ' Jennifer', ' Lawrence']
+1272 286 Name of father of x -1 Name of father of Jennifer Lawrence Francis Lawrence Jennifer Lawrence "[',' ' the' ' actress' ' who' ' plays' ' Kat' 'n' 'iss' ' Ever' 'deen'
+ ' in' ' the' ' Hunger' ' Games' ' movies' '.' '\n' '\n' 'The' ' Hunger']" ", the actress who plays Kat n iss Ever deen in the Hunger Games movies .
+
+ The Hunger" False Games series, opposite Jennifer Lawrence as Katniss Everdeen. 5 [' Games', ' series', ',', ' opposite', ' Jennifer', ' Lawrence']
+1273 287 Name of father of x -1 Name of father of Shirley MacLaine Ira Owens Beaty Shirley MacLaine "[',' ' the' ' actress' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Shirley' ' Mac' 'L' 'aine' ',' ' the' ' actress' ',' ' is']" ", the actress
+
+ The name of the father of Shirley Mac L aine , the actress , is" False Despite speculation, Shirley MacLaine did not attend 6 [' Despite', ' speculation', ',', ' Shirley', ' Mac', 'L', 'aine']
+1274 287 Name of father of x -1 Name of father of Shirley MacLaine Ira Owens Beaty Shirley MacLaine "[',' ' the' ' actress' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Shirley' ' Mac' 'L' 'aine' ',' ' the' ' actress' ',' ' is']" ", the actress
+
+ The name of the father of Shirley Mac L aine , the actress , is" False considerable. Shirley MacLaine said that 5 [' considerable', '.', ' Shirley', ' Mac', 'L', 'aine']
+1275 287 Name of father of x -1 Name of father of Shirley MacLaine Ira Owens Beaty Shirley MacLaine "[',' ' the' ' actress' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Shirley' ' Mac' 'L' 'aine' ',' ' the' ' actress' ',' ' is']" ", the actress
+
+ The name of the father of Shirley Mac L aine , the actress , is" False spoke against, and Shirley MacLaine also was an opposition 7 [' spoke', ' against', ',', ' and', ' Shirley', ' Mac', 'L', 'aine']
+1276 287 Name of father of x -1 Name of father of Shirley MacLaine Ira Owens Beaty Shirley MacLaine "[',' ' the' ' actress' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Shirley' ' Mac' 'L' 'aine' ',' ' the' ' actress' ',' ' is']" ", the actress
+
+ The name of the father of Shirley Mac L aine , the actress , is" False considerable. Shirley MacLaine said that McCracken 5 [' considerable', '.', ' Shirley', ' Mac', 'L', 'aine']
+1277 287 Name of father of x -1 Name of father of Shirley MacLaine Ira Owens Beaty Shirley MacLaine "[',' ' the' ' actress' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Shirley' ' Mac' 'L' 'aine' ',' ' the' ' actress' ',' ' is']" ", the actress
+
+ The name of the father of Shirley Mac L aine , the actress , is" False starred with Shirley MacLaine in the western Two 5 [' starred', ' with', ' Shirley', ' Mac', 'L', 'aine']
+1278 288 Name of father of x -1 Name of father of Leonid Brezhnev Ilya Yakovlevich Brezhnev Leonid Brezhnev "[',' ' Soviet' ' leader' '\n' '\n' 'Category' ':' '19' '10' ' births' '\n'
+ 'Category' ':' '1982' ' deaths' '\n' 'Category' ':' 'People' ' from']" ", Soviet leader
+
+ Category : 19 10 births
+ Category : 1982 deaths
+ Category : People from" False the Ukrainian-born Leonid Brezhnev was especially 9 [' the', ' Ukrainian', '-', 'born', ' Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1279 288 Name of father of x -1 Name of father of Leonid Brezhnev Ilya Yakovlevich Brezhnev Leonid Brezhnev "[',' ' Soviet' ' leader' '\n' '\n' 'Category' ':' '19' '10' ' births' '\n'
+ 'Category' ':' '1982' ' deaths' '\n' 'Category' ':' 'People' ' from']" ", Soviet leader
+
+ Category : 19 10 births
+ Category : 1982 deaths
+ Category : People from" False 5 ['Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1280 288 Name of father of x -1 Name of father of Leonid Brezhnev Ilya Yakovlevich Brezhnev Leonid Brezhnev "[',' ' Soviet' ' leader' '\n' '\n' 'Category' ':' '19' '10' ' births' '\n'
+ 'Category' ':' '1982' ' deaths' '\n' 'Category' ':' 'People' ' from']" ", Soviet leader
+
+ Category : 19 10 births
+ Category : 1982 deaths
+ Category : People from" False worldwide, Soviet Premier Leonid Brezhnev ordered the destruction 9 [' worldwide', ',', ' Soviet', ' Premier', ' Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1281 288 Name of father of x -1 Name of father of Leonid Brezhnev Ilya Yakovlevich Brezhnev Leonid Brezhnev "[',' ' Soviet' ' leader' '\n' '\n' 'Category' ':' '19' '10' ' births' '\n'
+ 'Category' ':' '1982' ' deaths' '\n' 'Category' ':' 'People' ' from']" ", Soviet leader
+
+ Category : 19 10 births
+ Category : 1982 deaths
+ Category : People from" False proved useful. With Leonid Brezhnev in power, Soviet foreign 9 [' proved', ' useful', '.', ' With', ' Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1282 288 Name of father of x -1 Name of father of Leonid Brezhnev Ilya Yakovlevich Brezhnev Leonid Brezhnev "[',' ' Soviet' ' leader' '\n' '\n' 'Category' ':' '19' '10' ' births' '\n'
+ 'Category' ':' '1982' ' deaths' '\n' 'Category' ':' 'People' ' from']" ", Soviet leader
+
+ Category : 19 10 births
+ Category : 1982 deaths
+ Category : People from" False 1964, Kosygin and Leonid Brezhnev became Premier 11 [' 1964', ',', ' Kos', 'y', 'gin', ' and', ' Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1283 289 Name of father of x -1 Name of father of Bill Clinton William Jefferson Blythe Jr. Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False under President Bill Clinton in January 1993, 3 [' under', ' President', ' Bill', ' Clinton']
+1284 289 Name of father of x -1 Name of father of Bill Clinton William Jefferson Blythe Jr. Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False " the period of the Bill Clinton administration.
+" 5 [' the', ' period', ' of', ' the', ' Bill', ' Clinton']
+1285 289 Name of father of x -1 Name of father of Bill Clinton William Jefferson Blythe Jr. Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False U.S. President Bill Clinton proposed the division 6 [' U', '.', 'S', '.', ' President', ' Bill', ' Clinton']
+1286 289 Name of father of x -1 Name of father of Bill Clinton William Jefferson Blythe Jr. Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False United States Bill Clinton and his wife 3 [' United', ' States', ' Bill', ' Clinton']
+1287 289 Name of father of x -1 Name of father of Bill Clinton William Jefferson Blythe Jr. Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False 26, President Bill Clinton added the counties 4 [' 26', ',', ' President', ' Bill', ' Clinton']
+1288 290 Name of father of x -1 Name of father of H. G. Wells Joseph Wells H. G. Wells "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' H' '.' ' G' '.'
+ ' Wells' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of H . G . Wells is not known .
+
+ The" False Joseph Conrad, H. G. Wells and Ford Madox Ford, 7 [' Joseph', ' Conrad', ',', ' H', '.', ' G', '.', ' Wells']
+1289 290 Name of father of x -1 Name of father of H. G. Wells Joseph Wells H. G. Wells "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' H' '.' ' G' '.'
+ ' Wells' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of H . G . Wells is not known .
+
+ The" False writer and futurist H. G. Wells promoted cosmopolitanism 9 [' writer', ' and', ' fut', 'ur', 'ist', ' H', '.', ' G', '.', ' Wells']
+1290 290 Name of father of x -1 Name of father of H. G. Wells Joseph Wells H. G. Wells "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' H' '.' ' G' '.'
+ ' Wells' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of H . G . Wells is not known .
+
+ The" False During the war H. G. Wells wrote about Britain's 7 [' During', ' the', ' war', ' H', '.', ' G', '.', ' Wells']
+1291 290 Name of father of x -1 Name of father of H. G. Wells Joseph Wells H. G. Wells "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' H' '.' ' G' '.'
+ ' Wells' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of H . G . Wells is not known .
+
+ The" False the writings of H. G. Wells in 1901. The shorter 7 [' the', ' writings', ' of', ' H', '.', ' G', '.', ' Wells']
+1292 290 Name of father of x -1 Name of father of H. G. Wells Joseph Wells H. G. Wells "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' H' '.' ' G' '.'
+ ' Wells' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of H . G . Wells is not known .
+
+ The" False guided the writer H. G. Wells around the trenches 7 [' guided', ' the', ' writer', ' H', '.', ' G', '.', ' Wells']
+1293 291 Name of father of x -1 Name of father of André Gide Paul Gide André Gide "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '69'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' writer']" ", the French writer , who was born in 18 69 .
+
+ The name of the French writer" False by the likes of André Gide and Albert Londres 7 [' by', ' the', ' likes', ' of', ' And', 'ré', ' G', 'ide']
+1294 292 Name of father of x -1 Name of father of Giorgio Vasari Antonio Vasari Giorgio Vasari "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '11' ','
+ ' and' ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 11 , and died in 15 74 .
+
+ The" False " Francesco's family, Giorgio Vasari wrote, ""Leonardo" 9 "[' Frances', 'co', ""'s"", ' family', ',', ' G', 'ior', 'gio', ' Vas', 'ari']"
+1295 292 Name of father of x -1 Name of father of Giorgio Vasari Antonio Vasari Giorgio Vasari "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '11' ','
+ ' and' ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 11 , and died in 15 74 .
+
+ The" False 16th-century art historian Giorgio Vasari proposed that 10 [' 16', 'th', '-', 'century', ' art', ' historian', ' G', 'ior', 'gio', ' Vas', 'ari']
+1296 292 Name of father of x -1 Name of father of Giorgio Vasari Antonio Vasari Giorgio Vasari "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '11' ','
+ ' and' ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 11 , and died in 15 74 .
+
+ The" False the language used by Giorgio Vasari in his work Lives 8 [' the', ' language', ' used', ' by', ' G', 'ior', 'gio', ' Vas', 'ari']
+1297 292 Name of father of x -1 Name of father of Giorgio Vasari Antonio Vasari Giorgio Vasari "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '11' ','
+ ' and' ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 11 , and died in 15 74 .
+
+ The" False resonate with those of Giorgio Vasari who, in the 16th 8 [' resonate', ' with', ' those', ' of', ' G', 'ior', 'gio', ' Vas', 'ari']
+1298 292 Name of father of x -1 Name of father of Giorgio Vasari Antonio Vasari Giorgio Vasari "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '11' ','
+ ' and' ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 11 , and died in 15 74 .
+
+ The" False " Palomino (""the Giorgio Vasari of the Spanish" 9 "[' Pal', 'om', 'ino', ' (""', 'the', ' G', 'ior', 'gio', ' Vas', 'ari']"
+1299 293 Name of father of x -1 Name of father of Henri Poincaré Émile-Léon Poincaré Henri Poincaré "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' Nancy' ','
+ ' France' ',' ' in' ' 18' '54' '.' '\n' '\n' 'The' ' name']" ", the mathematician , who was born in Nancy , France , in 18 54 .
+
+ The name" False the Institut Henri Poincaré in Paris, including 7 [' the', ' Instit', 'ut', ' Henri', ' Po', 'inc', 'ar', 'é']
+1300 293 Name of father of x -1 Name of father of Henri Poincaré Émile-Léon Poincaré Henri Poincaré "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' Nancy' ','
+ ' France' ',' ' in' ' 18' '54' '.' '\n' '\n' 'The' ' name']" ", the mathematician , who was born in Nancy , France , in 18 54 .
+
+ The name" False was adopted by Henri Poincaré in his 1895 7 [' was', ' adopted', ' by', ' Henri', ' Po', 'inc', 'ar', 'é']
+1301 293 Name of father of x -1 Name of father of Henri Poincaré Émile-Léon Poincaré Henri Poincaré "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' Nancy' ','
+ ' France' ',' ' in' ' 18' '54' '.' '\n' '\n' 'The' ' name']" ", the mathematician , who was born in Nancy , France , in 18 54 .
+
+ The name" False 5 ['Hen', 'ri', ' Po', 'inc', 'ar', 'é']
+1302 293 Name of father of x -1 Name of father of Henri Poincaré Émile-Léon Poincaré Henri Poincaré "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' Nancy' ','
+ ' France' ',' ' in' ' 18' '54' '.' '\n' '\n' 'The' ' name']" ", the mathematician , who was born in Nancy , France , in 18 54 .
+
+ The name" False Kronecker and Henri Poincaré and later from 8 [' Kr', 'one', 'cker', ' and', ' Henri', ' Po', 'inc', 'ar', 'é']
+1303 293 Name of father of x -1 Name of father of Henri Poincaré Émile-Léon Poincaré Henri Poincaré "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' Nancy' ','
+ ' France' ',' ' in' ' 18' '54' '.' '\n' '\n' 'The' ' name']" ", the mathematician , who was born in Nancy , France , in 18 54 .
+
+ The name" False For example, Henri Poincaré founded what is 7 [' For', ' example', ',', ' Henri', ' Po', 'inc', 'ar', 'é']
+1304 295 Name of father of x -1 Name of father of Charlemagne Pepin the Short Charlemagne "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' great' ' and' ' good'
+ ' King' ' of' ' France' '.' '\n' '\n' '|' '\n']" ",
+
+ |
+
+ And of the great and good King of France .
+
+ |
+" False 2 ['Char', 'lem', 'agne']
+1305 295 Name of father of x -1 Name of father of Charlemagne Pepin the Short Charlemagne "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' great' ' and' ' good'
+ ' King' ' of' ' France' '.' '\n' '\n' '|' '\n']" ",
+
+ |
+
+ And of the great and good King of France .
+
+ |
+" False used at the court of Charlemagne to suggest 7 [' used', ' at', ' the', ' court', ' of', ' Char', 'lem', 'agne']
+1306 295 Name of father of x -1 Name of father of Charlemagne Pepin the Short Charlemagne "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' great' ' and' ' good'
+ ' King' ' of' ' France' '.' '\n' '\n' '|' '\n']" ",
+
+ |
+
+ And of the great and good King of France .
+
+ |
+" False khagan asked Charlemagne to let his people 5 [' kh', 'agan', ' asked', ' Char', 'lem', 'agne']
+1307 295 Name of father of x -1 Name of father of Charlemagne Pepin the Short Charlemagne "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' great' ' and' ' good'
+ ' King' ' of' ' France' '.' '\n' '\n' '|' '\n']" ",
+
+ |
+
+ And of the great and good King of France .
+
+ |
+" False system instituted by Charlemagne a thousand years 5 [' system', ' instituted', ' by', ' Char', 'lem', 'agne']
+1308 295 Name of father of x -1 Name of father of Charlemagne Pepin the Short Charlemagne "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' great' ' and' ' good'
+ ' King' ' of' ' France' '.' '\n' '\n' '|' '\n']" ",
+
+ |
+
+ And of the great and good King of France .
+
+ |
+" False the southeast, and Charlemagne settled Avar groups 6 [' the', ' southeast', ',', ' and', ' Char', 'lem', 'agne']
+1309 296 Name of father of x -1 Name of father of William Butler Yeats John Butler Yeats William Butler Yeats "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' William' ' Butler'
+ ' Ye' 'ats' ' is' ' not' ' known' '.' '\n' '\n' 'The' ' name']" "
+
+ The name of the father of William Butler Ye ats is not known .
+
+ The name" False 3 ['William', ' Butler', ' Ye', 'ats']
+1310 296 Name of father of x -1 Name of father of William Butler Yeats John Butler Yeats William Butler Yeats "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' William' ' Butler'
+ ' Ye' 'ats' ' is' ' not' ' known' '.' '\n' '\n' 'The' ' name']" "
+
+ The name of the father of William Butler Ye ats is not known .
+
+ The name" False " recitation of the William Butler Yeats poem ""Mother" 7 [' rec', 'itation', ' of', ' the', ' William', ' Butler', ' Ye', 'ats']
+1311 296 Name of father of x -1 Name of father of William Butler Yeats John Butler Yeats William Butler Yeats "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' William' ' Butler'
+ ' Ye' 'ats' ' is' ' not' ' known' '.' '\n' '\n' 'The' ' name']" "
+
+ The name of the father of William Butler Ye ats is not known .
+
+ The name" False Lady Gregory and William Butler Yeats to establish the 6 [' Lady', ' Gregory', ' and', ' William', ' Butler', ' Ye', 'ats']
+1312 296 Name of father of x -1 Name of father of William Butler Yeats John Butler Yeats William Butler Yeats "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' William' ' Butler'
+ ' Ye' 'ats' ' is' ' not' ' known' '.' '\n' '\n' 'The' ' name']" "
+
+ The name of the father of William Butler Ye ats is not known .
+
+ The name" False Lady Gregory and William Butler Yeats to establish the Irish 6 [' Lady', ' Gregory', ' and', ' William', ' Butler', ' Ye', 'ats']
+1313 296 Name of father of x -1 Name of father of William Butler Yeats John Butler Yeats William Butler Yeats "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' William' ' Butler'
+ ' Ye' 'ats' ' is' ' not' ' known' '.' '\n' '\n' 'The' ' name']" "
+
+ The name of the father of William Butler Ye ats is not known .
+
+ The name" False " recitation of the William Butler Yeats poem ""Mother" 7 [' rec', 'itation', ' of', ' the', ' William', ' Butler', ' Ye', 'ats']
+1314 297 Name of father of x -1 Name of father of Claude Debussy Manuel Debussy Claude Debussy "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' son' ',' ' the' ' painter' ',' ' and' ' his']" , the composer , and his wife , the painter , and his son , the painter , and his False " composers like Claude Debussy are common.
+" 5 [' compos', 'ers', ' like', ' Claude', ' Deb', 'ussy']
+1315 297 Name of father of x -1 Name of father of Claude Debussy Manuel Debussy Claude Debussy "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' son' ',' ' the' ' painter' ',' ' and' ' his']" , the composer , and his wife , the painter , and his son , the painter , and his False works composed by Claude Debussy and recorded by 5 [' works', ' composed', ' by', ' Claude', ' Deb', 'ussy']
+1316 297 Name of father of x -1 Name of father of Claude Debussy Manuel Debussy Claude Debussy "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' son' ',' ' the' ' painter' ',' ' and' ' his']" , the composer , and his wife , the painter , and his son , the painter , and his False A meeting with Claude Debussy produced further 5 [' A', ' meeting', ' with', ' Claude', ' Deb', 'ussy']
+1317 297 Name of father of x -1 Name of father of Claude Debussy Manuel Debussy Claude Debussy "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' son' ',' ' the' ' painter' ',' ' and' ' his']" , the composer , and his wife , the painter , and his son , the painter , and his False echoes of La mer by Claude Debussy as well of Igor Stravinsky's 7 [' echoes', ' of', ' La', ' mer', ' by', ' Claude', ' Deb', 'ussy']
+1318 297 Name of father of x -1 Name of father of Claude Debussy Manuel Debussy Claude Debussy "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' son' ',' ' the' ' painter' ',' ' and' ' his']" , the composer , and his wife , the painter , and his son , the painter , and his False Head at his request. Claude Debussy finished composing 7 [' Head', ' at', ' his', ' request', '.', ' Claude', ' Deb', 'ussy']
+1319 298 Name of father of x -1 Name of father of Franklin Delano Roosevelt James Roosevelt Franklin Delano Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False July 1932, Franklin Delano Roosevelt made a campaign 6 [' July', ' 1932', ',', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1320 298 Name of father of x -1 Name of father of Franklin Delano Roosevelt James Roosevelt Franklin Delano Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False in 1937, President Franklin Delano Roosevelt issued Executive 7 [' in', ' 1937', ',', ' President', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1321 298 Name of father of x -1 Name of father of Franklin Delano Roosevelt James Roosevelt Franklin Delano Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False year, President Franklin Delano Roosevelt ordered the construction 6 [' year', ',', ' President', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1322 298 Name of father of x -1 Name of father of Franklin Delano Roosevelt James Roosevelt Franklin Delano Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False President Franklin Delano Roosevelt reviewed the battleship 4 [' President', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1323 298 Name of father of x -1 Name of father of Franklin Delano Roosevelt James Roosevelt Franklin Delano Roosevelt "[',' ' Jr' '.' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' ('
+ '18' '84' '-' '19' '44' ')' ' was' ' the']" ", Jr .
+
+ The odore Roosevelt , Jr . ( 18 84 - 19 44 ) was the" False year, President Franklin Delano Roosevelt ordered the 6 [' year', ',', ' President', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1324 299 Name of father of x -1 Name of father of Poseidon Kronos Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' of' ' the' ' earth'
+ ',' ' and' ' of' ' the' ' underworld' ',' ' and' ' of' ' the']" , the god of the sea , and of the earth , and of the underworld , and of the False glorious kings ... and Poseidon [lay with] ... Ares 5 [' glorious', ' kings', '...', ' and', ' Pose', 'idon']
+1325 299 Name of father of x -1 Name of father of Poseidon Kronos Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' of' ' the' ' earth'
+ ',' ' and' ' of' ' the' ' underworld' ',' ' and' ' of' ' the']" , the god of the sea , and of the earth , and of the underworld , and of the False primary antagonist; Poseidon (Gideon Emery), 4 [' primary', ' antagonist', ';', ' Pose', 'idon']
+1326 299 Name of father of x -1 Name of father of Poseidon Kronos Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' of' ' the' ' earth'
+ ',' ' and' ' of' ' the' ' underworld' ',' ' and' ' of' ' the']" , the god of the sea , and of the earth , and of the underworld , and of the False guidance systems for the Poseidon missile. While 5 [' guidance', ' systems', ' for', ' the', ' Pose', 'idon']
+1327 299 Name of father of x -1 Name of father of Poseidon Kronos Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' of' ' the' ' earth'
+ ',' ' and' ' of' ' the' ' underworld' ',' ' and' ' of' ' the']" , the god of the sea , and of the earth , and of the underworld , and of the False (Τρίτων), the son of Poseidon (the Greek god 14 [' (', '�', '�', 'ρ', '�', '�', 'τ', 'ω', 'ν', '),', ' the', ' son', ' of', ' Pose', 'idon']
+1328 299 Name of father of x -1 Name of father of Poseidon Kronos Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' of' ' the' ' earth'
+ ',' ' and' ' of' ' the' ' underworld' ',' ' and' ' of' ' the']" , the god of the sea , and of the earth , and of the underworld , and of the False kings ... and Poseidon [lay with] ... Ares 4 [' kings', '...', ' and', ' Pose', 'idon']
+1329 300 Name of father of x -1 Name of father of Niccolò Machiavelli Bernardo di Niccolò Machiavelli Niccolò Machiavelli "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Nic' 'col' '�' '�' ' Mach'
+ 'ia' 've' 'lli' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Nic col � � Mach ia ve lli
+
+ ! Name of" False Italian philosopher Niccolò Machiavelli remarked upon the 9 [' Italian', ' philosopher', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1330 300 Name of father of x -1 Name of father of Niccolò Machiavelli Bernardo di Niccolò Machiavelli Niccolò Machiavelli "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Nic' 'col' '�' '�' ' Mach'
+ 'ia' 've' 'lli' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Nic col � � Mach ia ve lli
+
+ ! Name of" False project on which Niccolò Machiavelli also worked. Leonardo's 10 [' project', ' on', ' which', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1331 300 Name of father of x -1 Name of father of Niccolò Machiavelli Bernardo di Niccolò Machiavelli Niccolò Machiavelli "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Nic' 'col' '�' '�' ' Mach'
+ 'ia' 've' 'lli' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Nic col � � Mach ia ve lli
+
+ ! Name of" False project on which Niccolò Machiavelli also worked. Leonardo's 10 [' project', ' on', ' which', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1332 300 Name of father of x -1 Name of father of Niccolò Machiavelli Bernardo di Niccolò Machiavelli Niccolò Machiavelli "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Nic' 'col' '�' '�' ' Mach'
+ 'ia' 've' 'lli' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Nic col � � Mach ia ve lli
+
+ ! Name of" False " the Republic, Niccolò Machiavelli remarked
+" 10 [' the', ' Republic', ',', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1333 300 Name of father of x -1 Name of father of Niccolò Machiavelli Bernardo di Niccolò Machiavelli Niccolò Machiavelli "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' Nic' 'col' '�' '�' ' Mach'
+ 'ia' 've' 'lli' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of Nic col � � Mach ia ve lli
+
+ ! Name of" False project on which Niccolò Machiavelli also worked. Leonardo's 10 [' project', ' on', ' which', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1334 301 Name of father of x -1 Name of father of Emily Brontë Patrick Brontë Emily Brontë "[',' ' the' ' author' ' of' ' W' 'uther' 'ing' ' Heights' '.' '\n' '\n'
+ 'The' ' Br' 'ont' 'ë' 's' ' were' ' a' ' family' ' of']" ", the author of W uther ing Heights .
+
+ The Br ont ë s were a family of" False the juvenilia of Emily Brontë at eight o 7 [' the', ' juven', 'ilia', ' of', ' Emily', ' Br', 'ont', 'ë']
+1335 301 Name of father of x -1 Name of father of Emily Brontë Patrick Brontë Emily Brontë "[',' ' the' ' author' ' of' ' W' 'uther' 'ing' ' Heights' '.' '\n' '\n'
+ 'The' ' Br' 'ont' 'ë' 's' ' were' ' a' ' family' ' of']" ", the author of W uther ing Heights .
+
+ The Br ont ë s were a family of" False about the juvenilia of Emily Brontë at eight o 'clock 8 [' about', ' the', ' juven', 'ilia', ' of', ' Emily', ' Br', 'ont', 'ë']
+1336 301 Name of father of x -1 Name of father of Emily Brontë Patrick Brontë Emily Brontë "[',' ' the' ' author' ' of' ' W' 'uther' 'ing' ' Heights' '.' '\n' '\n'
+ 'The' ' Br' 'ont' 'ë' 's' ' were' ' a' ' family' ' of']" ", the author of W uther ing Heights .
+
+ The Br ont ë s were a family of" False " Mine"", a poem by Emily Brontë that had been" 8 "[' Mine', '"",', ' a', ' poem', ' by', ' Emily', ' Br', 'ont', 'ë']"
+1337 302 Name of father of x -1 Name of father of John Ruskin John James Ruskin John Ruskin "[',' ' the' ' famous' ' English' ' critic' ' and' ' art' ' historian' ','
+ ' who' ' was' ' born' ' in' ' 18' '19' '.' '\n' '\n' 'The' ' Rus']" ", the famous English critic and art historian , who was born in 18 19 .
+
+ The Rus" False Temeraire's wood. John Ruskin foreshadowed 8 "[' Tem', 'er', 'aire', ""'s"", ' wood', '.', ' John', ' Rus', 'kin']"
+1338 302 Name of father of x -1 Name of father of John Ruskin John James Ruskin John Ruskin "[',' ' the' ' famous' ' English' ' critic' ' and' ' art' ' historian' ','
+ ' who' ' was' ' born' ' in' ' 18' '19' '.' '\n' '\n' 'The' ' Rus']" ", the famous English critic and art historian , who was born in 18 19 .
+
+ The Rus" False social critic John Ruskin lived at 163 4 [' social', ' critic', ' John', ' Rus', 'kin']
+1339 302 Name of father of x -1 Name of father of John Ruskin John James Ruskin John Ruskin "[',' ' the' ' famous' ' English' ' critic' ' and' ' art' ' historian' ','
+ ' who' ' was' ' born' ' in' ' 18' '19' '.' '\n' '\n' 'The' ' Rus']" ", the famous English critic and art historian , who was born in 18 19 .
+
+ The Rus" False " with a lucid mind"". John Ruskin however, in 1860," 7 "[' with', ' a', ' lucid', ' mind', '"".', ' John', ' Rus', 'kin']"
+1340 302 Name of father of x -1 Name of father of John Ruskin John James Ruskin John Ruskin "[',' ' the' ' famous' ' English' ' critic' ' and' ' art' ' historian' ','
+ ' who' ' was' ' born' ' in' ' 18' '19' '.' '\n' '\n' 'The' ' Rus']" ", the famous English critic and art historian , who was born in 18 19 .
+
+ The Rus" False claimed by John Ruskin that, architecturally, 4 [' claimed', ' by', ' John', ' Rus', 'kin']
+1341 302 Name of father of x -1 Name of father of John Ruskin John James Ruskin John Ruskin "[',' ' the' ' famous' ' English' ' critic' ' and' ' art' ' historian' ','
+ ' who' ' was' ' born' ' in' ' 18' '19' '.' '\n' '\n' 'The' ' Rus']" ", the famous English critic and art historian , who was born in 18 19 .
+
+ The Rus" False acclaim. The art critic John Ruskin said of the model, 7 [' acclaim', '.', ' The', ' art', ' critic', ' John', ' Rus', 'kin']
+1342 303 Name of father of x -1 Name of father of Audrey Hepburn Joseph Victor Anthony Hepburn-Ruston Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' author' ' of' ' the' ' book'
+ ',' ' _' 'The' ' Children' ' of' ' the' ' World' '_' ',' ' which']" , the actress , and the author of the book , _ The Children of the World _ , which False the role went to Audrey Hepburn as its producers 6 [' the', ' role', ' went', ' to', ' Audrey', ' Hep', 'burn']
+1343 303 Name of father of x -1 Name of father of Audrey Hepburn Joseph Victor Anthony Hepburn-Ruston Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' author' ' of' ' the' ' book'
+ ',' ' _' 'The' ' Children' ' of' ' the' ' World' '_' ',' ' which']" , the actress , and the author of the book , _ The Children of the World _ , which False Children's Hour with Audrey Hepburn and Shirley MacLaine. 6 "[' Children', ""'s"", ' Hour', ' with', ' Audrey', ' Hep', 'burn']"
+1344 303 Name of father of x -1 Name of father of Audrey Hepburn Joseph Victor Anthony Hepburn-Ruston Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' author' ' of' ' the' ' book'
+ ',' ' _' 'The' ' Children' ' of' ' the' ' World' '_' ',' ' which']" , the actress , and the author of the book , _ The Children of the World _ , which False roles, including Audrey Hepburn for Empress 5 [' roles', ',', ' including', ' Audrey', ' Hep', 'burn']
+1345 303 Name of father of x -1 Name of father of Audrey Hepburn Joseph Victor Anthony Hepburn-Ruston Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' author' ' of' ' the' ' book'
+ ',' ' _' 'The' ' Children' ' of' ' the' ' World' '_' ',' ' which']" , the actress , and the author of the book , _ The Children of the World _ , which False the Afternoon with Audrey Hepburn and Maurice Chevalier. 6 [' the', ' After', 'noon', ' with', ' Audrey', ' Hep', 'burn']
+1346 303 Name of father of x -1 Name of father of Audrey Hepburn Joseph Victor Anthony Hepburn-Ruston Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' author' ' of' ' the' ' book'
+ ',' ' _' 'The' ' Children' ' of' ' the' ' World' '_' ',' ' which']" , the actress , and the author of the book , _ The Children of the World _ , which False " character as ""an Audrey Hepburn type"" and was initially" 6 "[' character', ' as', ' ""', 'an', ' Audrey', ' Hep', 'burn']"
+1347 305 Name of father of x -1 Name of father of Charles V Philip I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' ' King' ' of'
+ ' Spain' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and' ' the']" , the Emperor of Germany , and the King of Spain , and the King of France , and the False the Indians with Charles V who was by now 4 [' the', ' Indians', ' with', ' Charles', ' V']
+1348 305 Name of father of x -1 Name of father of Charles V Philip I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' ' King' ' of'
+ ' Spain' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and' ' the']" , the Emperor of Germany , and the King of Spain , and the King of France , and the False closer due to Charles V ’ s victory over 4 [' closer', ' due', ' to', ' Charles', ' V']
+1349 305 Name of father of x -1 Name of father of Charles V Philip I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' ' King' ' of'
+ ' Spain' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and' ' the']" , the Emperor of Germany , and the King of Spain , and the King of France , and the False during the reigns of Charles V (1500 – 1558) and Phillip 6 [' during', ' the', ' reign', 's', ' of', ' Charles', ' V']
+1350 305 Name of father of x -1 Name of father of Charles V Philip I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' ' King' ' of'
+ ' Spain' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and' ' the']" , the Emperor of Germany , and the King of Spain , and the King of France , and the False Holy Roman Emperor Charles V and Pope Leo X 4 [' Holy', ' Roman', ' Emperor', ' Charles', ' V']
+1351 305 Name of father of x -1 Name of father of Charles V Philip I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' ' King' ' of'
+ ' Spain' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and' ' the']" , the Emperor of Germany , and the King of Spain , and the King of France , and the False it passes through Charles V Wall, just before 4 [' it', ' passes', ' through', ' Charles', ' V']
+1352 306 Name of father of x -1 Name of father of Johannes Brahms Johann Jakob Brahms Johannes Brahms "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Hamburg' ',' ' Germany'
+ ',' ' in' ' 18' '33' '.' '\n' '\n' 'The' ' son' ' of']" ", the composer , was born in Hamburg , Germany , in 18 33 .
+
+ The son of" False Schütz and Johannes Brahms set to music, 6 [' Sch', 'ü', 'tz', ' and', ' Johannes', ' Brah', 'ms']
+1353 306 Name of father of x -1 Name of father of Johannes Brahms Johann Jakob Brahms Johannes Brahms "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Hamburg' ',' ' Germany'
+ ',' ' in' ' 18' '33' '.' '\n' '\n' 'The' ' son' ' of']" ", the composer , was born in Hamburg , Germany , in 18 33 .
+
+ The son of" False particular the models of Johannes Brahms and the orchestral 6 [' particular', ' the', ' models', ' of', ' Johannes', ' Brah', 'ms']
+1354 306 Name of father of x -1 Name of father of Johannes Brahms Johann Jakob Brahms Johannes Brahms "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Hamburg' ',' ' Germany'
+ ',' ' in' ' 18' '33' '.' '\n' '\n' 'The' ' son' ' of']" ", the composer , was born in Hamburg , Germany , in 18 33 .
+
+ The son of" False " like that."" Johannes Brahms also admired the concerto," 5 "[' like', ' that', '.""', ' Johannes', ' Brah', 'ms']"
+1355 306 Name of father of x -1 Name of father of Johannes Brahms Johann Jakob Brahms Johannes Brahms "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Hamburg' ',' ' Germany'
+ ',' ' in' ' 18' '33' '.' '\n' '\n' 'The' ' son' ' of']" ", the composer , was born in Hamburg , Germany , in 18 33 .
+
+ The son of" False Heinrich Schütz and Johannes Brahms set to music, 8 [' Hein', 'rich', ' Sch', 'ü', 'tz', ' and', ' Johannes', ' Brah', 'ms']
+1356 306 Name of father of x -1 Name of father of Johannes Brahms Johann Jakob Brahms Johannes Brahms "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Hamburg' ',' ' Germany'
+ ',' ' in' ' 18' '33' '.' '\n' '\n' 'The' ' son' ' of']" ", the composer , was born in Hamburg , Germany , in 18 33 .
+
+ The son of" False " anything like that."" Johannes Brahms also admired" 6 "[' anything', ' like', ' that', '.""', ' Johannes', ' Brah', 'ms']"
+1357 307 Name of father of x -1 Name of father of Cameron Diaz Emilio Diaz Cameron Diaz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False TheWrap reported that Cameron Diaz was in talks 6 [' The', 'W', 'rap', ' reported', ' that', ' Cameron', ' Diaz']
+1358 307 Name of father of x -1 Name of father of Cameron Diaz Emilio Diaz Cameron Diaz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Tom Cruise and Cameron Diaz in James Mangold's 4 [' Tom', ' Cruise', ' and', ' Cameron', ' Diaz']
+1359 307 Name of father of x -1 Name of father of Cameron Diaz Emilio Diaz Cameron Diaz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Grammy-winning album Unchained. Cameron Diaz made the gesture 8 [' Grammy', '-', 'winning', ' album', ' Unch', 'ained', '.', ' Cameron', ' Diaz']
+1360 307 Name of father of x -1 Name of father of Cameron Diaz Emilio Diaz Cameron Diaz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False TheWrap reported that Cameron Diaz was in talks for 6 [' The', 'W', 'rap', ' reported', ' that', ' Cameron', ' Diaz']
+1361 307 Name of father of x -1 Name of father of Cameron Diaz Emilio Diaz Cameron Diaz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False opposite Tom Cruise and Cameron Diaz in James Mangold's 5 [' opposite', ' Tom', ' Cruise', ' and', ' Cameron', ' Diaz']
+1362 308 Name of father of x -1 Name of father of Max Planck Wilhelm von Planck Max Planck "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' Physics' ' in' ' 1918' '.' '\n' '\n' 'The' ' Nobel' ' Prize' ' in'
+ ' Physics']" ", the physicist who won the Nobel Prize for Physics in 1918 .
+
+ The Nobel Prize in Physics" False " Studies, and the Max Planck Society.
+" 6 [' Studies', ',', ' and', ' the', ' Max', ' Plan', 'ck']
+1363 308 Name of father of x -1 Name of father of Max Planck Wilhelm von Planck Max Planck "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' Physics' ' in' ' 1918' '.' '\n' '\n' 'The' ' Nobel' ' Prize' ' in'
+ ' Physics']" ", the physicist who won the Nobel Prize for Physics in 1918 .
+
+ The Nobel Prize in Physics" False Germany include the Max Planck Society, the Helmholtz 5 [' Germany', ' include', ' the', ' Max', ' Plan', 'ck']
+1364 308 Name of father of x -1 Name of father of Max Planck Wilhelm von Planck Max Planck "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' Physics' ' in' ' 1918' '.' '\n' '\n' 'The' ' Nobel' ' Prize' ' in'
+ ' Physics']" ", the physicist who won the Nobel Prize for Physics in 1918 .
+
+ The Nobel Prize in Physics" False physicists, including Max Planck and Niels Bohr. This 5 [' physicists', ',', ' including', ' Max', ' Plan', 'ck']
+1365 308 Name of father of x -1 Name of father of Max Planck Wilhelm von Planck Max Planck "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' Physics' ' in' ' 1918' '.' '\n' '\n' 'The' ' Nobel' ' Prize' ' in'
+ ' Physics']" ", the physicist who won the Nobel Prize for Physics in 1918 .
+
+ The Nobel Prize in Physics" False to join the Max Planck Institute for 5 [' to', ' join', ' the', ' Max', ' Plan', 'ck']
+1366 308 Name of father of x -1 Name of father of Max Planck Wilhelm von Planck Max Planck "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' Physics' ' in' ' 1918' '.' '\n' '\n' 'The' ' Nobel' ' Prize' ' in'
+ ' Physics']" ", the physicist who won the Nobel Prize for Physics in 1918 .
+
+ The Nobel Prize in Physics" False joining the Max Planck Institute for 4 [' joining', ' the', ' Max', ' Plan', 'ck']
+1367 309 Name of father of x -1 Name of father of Neil Young Scott Young Neil Young "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Peg' 'i' ' Young' ',' ' the' ' singer' '-' 'song' 'writer']" , the singer - song writer , and his wife , Peg i Young , the singer - song writer False Along with Neil Young and John Mellencamp, 3 [' Along', ' with', ' Neil', ' Young']
+1368 309 Name of father of x -1 Name of father of Neil Young Scott Young Neil Young "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Peg' 'i' ' Young' ',' ' the' ' singer' '-' 'song' 'writer']" , the singer - song writer , and his wife , Peg i Young , the singer - song writer False Jansch opened for Neil Young on his Twisted 6 [' J', 'ans', 'ch', ' opened', ' for', ' Neil', ' Young']
+1369 309 Name of father of x -1 Name of father of Neil Young Scott Young Neil Young "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Peg' 'i' ' Young' ',' ' the' ' singer' '-' 'song' 'writer']" , the singer - song writer , and his wife , Peg i Young , the singer - song writer False of Gram Parsons and Neil Young on the music. However, 5 [' of', ' Gram', ' Parsons', ' and', ' Neil', ' Young']
+1370 309 Name of father of x -1 Name of father of Neil Young Scott Young Neil Young "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Peg' 'i' ' Young' ',' ' the' ' singer' '-' 'song' 'writer']" , the singer - song writer , and his wife , Peg i Young , the singer - song writer False " ""shed some of the Neil Young obsession."" Tom" 6 "[' ""', 'shed', ' some', ' of', ' the', ' Neil', ' Young']"
+1371 309 Name of father of x -1 Name of father of Neil Young Scott Young Neil Young "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' Peg' 'i' ' Young' ',' ' the' ' singer' '-' 'song' 'writer']" , the singer - song writer , and his wife , Peg i Young , the singer - song writer False the feet of Neil Young for a simple finish. 4 [' the', ' feet', ' of', ' Neil', ' Young']
+1372 310 Name of father of x -1 Name of father of Cristiano Ronaldo José Dinis Aveiro Cristiano Ronaldo "[',' ' the' ' world' ""'s"" ' most' ' expensive' ' footballer' ',' ' and'
+ ' the' ' most' ' expensive' ' player' ' in' ' the' ' world' '.' '\n' '\n'
+ 'The']" ", the world 's most expensive footballer , and the most expensive player in the world .
+
+ The" False Ryan Giggs and Cristiano Ronaldo both pleaded 6 [' Ryan', ' G', 'iggs', ' and', ' Crist', 'iano', ' Ronaldo']
+1373 310 Name of father of x -1 Name of father of Cristiano Ronaldo José Dinis Aveiro Cristiano Ronaldo "[',' ' the' ' world' ""'s"" ' most' ' expensive' ' footballer' ',' ' and'
+ ' the' ' most' ' expensive' ' player' ' in' ' the' ' world' '.' '\n' '\n'
+ 'The']" ", the world 's most expensive footballer , and the most expensive player in the world .
+
+ The" False assisted on a Cristiano Ronaldo goal in the 5 [' assisted', ' on', ' a', ' Crist', 'iano', ' Ronaldo']
+1374 310 Name of father of x -1 Name of father of Cristiano Ronaldo José Dinis Aveiro Cristiano Ronaldo "[',' ' the' ' world' ""'s"" ' most' ' expensive' ' footballer' ',' ' and'
+ ' the' ' most' ' expensive' ' player' ' in' ' the' ' world' '.' '\n' '\n'
+ 'The']" ", the world 's most expensive footballer , and the most expensive player in the world .
+
+ The" False competition, with Cristiano Ronaldo scoring the 5 [' competition', ',', ' with', ' Crist', 'iano', ' Ronaldo']
+1375 310 Name of father of x -1 Name of father of Cristiano Ronaldo José Dinis Aveiro Cristiano Ronaldo "[',' ' the' ' world' ""'s"" ' most' ' expensive' ' footballer' ',' ' and'
+ ' the' ' most' ' expensive' ' player' ' in' ' the' ' world' '.' '\n' '\n'
+ 'The']" ", the world 's most expensive footballer , and the most expensive player in the world .
+
+ The" False glamorous players like Cristiano Ronaldo and David Beckham. 5 [' glamorous', ' players', ' like', ' Crist', 'iano', ' Ronaldo']
+1376 310 Name of father of x -1 Name of father of Cristiano Ronaldo José Dinis Aveiro Cristiano Ronaldo "[',' ' the' ' world' ""'s"" ' most' ' expensive' ' footballer' ',' ' and'
+ ' the' ' most' ' expensive' ' player' ' in' ' the' ' world' '.' '\n' '\n'
+ 'The']" ", the world 's most expensive footballer , and the most expensive player in the world .
+
+ The" False competition, with Cristiano Ronaldo scoring the decisive 5 [' competition', ',', ' with', ' Crist', 'iano', ' Ronaldo']
+1377 311 Name of father of x -1 Name of father of Niels Bohr Christian Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The name of the" False itself was entrusted to Niels Bohr for safekeeping. 7 [' itself', ' was', ' entrusted', ' to', ' Ni', 'els', ' Boh', 'r']
+1378 311 Name of father of x -1 Name of father of Niels Bohr Christian Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The name of the" False 3 ['Ni', 'els', ' Boh', 'r']
+1379 311 Name of father of x -1 Name of father of Niels Bohr Christian Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The name of the" False In that period, Niels Bohr was on a lecture 7 [' In', ' that', ' period', ',', ' Ni', 'els', ' Boh', 'r']
+1380 311 Name of father of x -1 Name of father of Niels Bohr Christian Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The name of the" False spectrometer. Niels Bohr won the Physics 7 [' spect', 'rom', 'eter', '.', ' Ni', 'els', ' Boh', 'r']
+1381 311 Name of father of x -1 Name of father of Niels Bohr Christian Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The name of the" False was eager to invite Niels Bohr to the Tube Alloys 7 [' was', ' eager', ' to', ' invite', ' Ni', 'els', ' Boh', 'r']
+1382 312 Name of father of x -1 Name of father of Martin Luther King Jr. Martin Luther King Sr. Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' Catholic' ',' ' a'
+ ' conservative' ',' ' a' ' Republican' ',' ' and' ' a' ' writer' '.']" "
+
+ I am a Christian , a Catholic , a conservative , a Republican , and a writer ." False Xavier / Magneto and Martin Luther King Jr. / Malcolm X stating 9 [' Xavier', ' /', ' Magnet', 'o', ' and', ' Martin', ' Luther', ' King', ' Jr', '.']
+1383 312 Name of father of x -1 Name of father of Martin Luther King Jr. Martin Luther King Sr. Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' Catholic' ',' ' a'
+ ' conservative' ',' ' a' ' Republican' ',' ' and' ' a' ' writer' '.']" "
+
+ I am a Christian , a Catholic , a conservative , a Republican , and a writer ." False Maryland. Known as Martin Luther King Jr. Highway, the highway 8 [' Maryland', '.', ' Known', ' as', ' Martin', ' Luther', ' King', ' Jr', '.']
+1384 312 Name of father of x -1 Name of father of Martin Luther King Jr. Martin Luther King Sr. Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' Catholic' ',' ' a'
+ ' conservative' ',' ' a' ' Republican' ',' ' and' ' a' ' writer' '.']" "
+
+ I am a Christian , a Catholic , a conservative , a Republican , and a writer ." False have marched with Martin Luther King Jr. to repeal Jim 7 [' have', ' marched', ' with', ' Martin', ' Luther', ' King', ' Jr', '.']
+1385 312 Name of father of x -1 Name of father of Martin Luther King Jr. Martin Luther King Sr. Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' Catholic' ',' ' a'
+ ' conservative' ',' ' a' ' Republican' ',' ' and' ' a' ' writer' '.']" "
+
+ I am a Christian , a Catholic , a conservative , a Republican , and a writer ." False " samples, including Martin Luther King Jr. ' s ""I Have a" 7 [' samples', ',', ' including', ' Martin', ' Luther', ' King', ' Jr', '.']
+1386 312 Name of father of x -1 Name of father of Martin Luther King Jr. Martin Luther King Sr. Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' Christian' ',' ' a' ' Catholic' ',' ' a'
+ ' conservative' ',' ' a' ' Republican' ',' ' and' ' a' ' writer' '.']" "
+
+ I am a Christian , a Catholic , a conservative , a Republican , and a writer ." False Des Moines. Martin Luther King Jr. Parkway and 19th 7 [' Des', ' Moines', '.', ' Martin', ' Luther', ' King', ' Jr', '.']
+1387 313 Name of father of x -1 Name of father of Henrik Ibsen Knud Ibsen Henrik Ibsen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hen' 'rik' ' I'
+ 'bs' 'en' ' is' ' not' ' known' '.' '\n' '\n' 'Hen']" "
+
+ The name of the father of Hen rik I bs en is not known .
+
+ Hen" False Norwegian playwright Henrik Ibsen in that work, 7 [' Norwegian', ' play', 'wright', ' Hen', 'rik', ' I', 'bs', 'en']
+1388 313 Name of father of x -1 Name of father of Henrik Ibsen Knud Ibsen Henrik Ibsen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hen' 'rik' ' I'
+ 'bs' 'en' ' is' ' not' ' known' '.' '\n' '\n' 'Hen']" "
+
+ The name of the father of Hen rik I bs en is not known .
+
+ Hen" False Sverdrup and Henrik Ibsen spanned several 9 [' S', 'ver', 'd', 'rup', ' and', ' Hen', 'rik', ' I', 'bs', 'en']
+1389 313 Name of father of x -1 Name of father of Henrik Ibsen Knud Ibsen Henrik Ibsen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hen' 'rik' ' I'
+ 'bs' 'en' ' is' ' not' ' known' '.' '\n' '\n' 'Hen']" "
+
+ The name of the father of Hen rik I bs en is not known .
+
+ Hen" False Along with Henrik Ibsen and August Strindberg, 6 [' Along', ' with', ' Hen', 'rik', ' I', 'bs', 'en']
+1390 313 Name of father of x -1 Name of father of Henrik Ibsen Knud Ibsen Henrik Ibsen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hen' 'rik' ' I'
+ 'bs' 'en' ' is' ' not' ' known' '.' '\n' '\n' 'Hen']" "
+
+ The name of the father of Hen rik I bs en is not known .
+
+ Hen" False Johan Sverdrup and Henrik Ibsen spanned several 11 [' Joh', 'an', ' S', 'ver', 'd', 'rup', ' and', ' Hen', 'rik', ' I', 'bs', 'en']
+1391 313 Name of father of x -1 Name of father of Henrik Ibsen Knud Ibsen Henrik Ibsen "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Hen' 'rik' ' I'
+ 'bs' 'en' ' is' ' not' ' known' '.' '\n' '\n' 'Hen']" "
+
+ The name of the father of Hen rik I bs en is not known .
+
+ Hen" False in turn influenced Henrik Ibsen in his 1863 play 7 [' in', ' turn', ' influenced', ' Hen', 'rik', ' I', 'bs', 'en']
+1392 314 Name of father of x -1 Name of father of George W. Bush George H. W. Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Laura' ',' ' who' ' is' ' a' ' former'
+ ' first']" , the former president of the United States , and his wife , Laura , who is a former first False then-President George W. Bush because, according 6 [' then', '-', 'President', ' George', ' W', '.', ' Bush']
+1393 314 Name of father of x -1 Name of father of George W. Bush George H. W. Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Laura' ',' ' who' ' is' ' a' ' former'
+ ' first']" , the former president of the United States , and his wife , Laura , who is a former first False States President George W. Bush ordered federal 5 [' States', ' President', ' George', ' W', '.', ' Bush']
+1394 314 Name of father of x -1 Name of father of George W. Bush George H. W. Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Laura' ',' ' who' ' is' ' a' ' former'
+ ' first']" , the former president of the United States , and his wife , Laura , who is a former first False summit, President George W. Bush and First Lady 6 [' summit', ',', ' President', ' George', ' W', '.', ' Bush']
+1395 314 Name of father of x -1 Name of father of George W. Bush George H. W. Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Laura' ',' ' who' ' is' ' a' ' former'
+ ' first']" , the former president of the United States , and his wife , Laura , who is a former first False President George W. Bush in the 2004 Republican 4 [' President', ' George', ' W', '.', ' Bush']
+1396 314 Name of father of x -1 Name of father of George W. Bush George H. W. Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Laura' ',' ' who' ' is' ' a' ' former'
+ ' first']" , the former president of the United States , and his wife , Laura , who is a former first False Kaniewski, a former George W. Bush administration homeland 9 [' Kan', 'iew', 'ski', ',', ' a', ' former', ' George', ' W', '.', ' Bush']
+1397 315 Name of father of x -1 Name of father of Aeschylus Euphorion of Eleusis Aeschylus "[',' ' the' ' poet' ',' ' who' ' was' ' the' ' son' ' of' ' P' 'here' 'cy'
+ 'des' ',' ' and' ' of' '\n' ' ' ' ' ' ']" ", the poet , who was the son of P here cy des , and of
+ " False Ancient Greek playwright Aeschylus has a gadfly pursue 8 [' Ancient', ' Greek', ' play', 'wright', ' A', 'es', 'ch', 'yl', 'us']
+1398 315 Name of father of x -1 Name of father of Aeschylus Euphorion of Eleusis Aeschylus "[',' ' the' ' poet' ',' ' who' ' was' ' the' ' son' ' of' ' P' 'here' 'cy'
+ 'des' ',' ' and' ' of' '\n' ' ' ' ' ' ']" ", the poet , who was the son of P here cy des , and of
+ " False same time; thus, as Aeschylus states in The Persians, 10 [' same', ' time', ';', ' thus', ',', ' as', ' A', 'es', 'ch', 'yl', 'us']
+1399 315 Name of father of x -1 Name of father of Aeschylus Euphorion of Eleusis Aeschylus "[',' ' the' ' poet' ',' ' who' ' was' ' the' ' son' ' of' ' P' 'here' 'cy'
+ 'des' ',' ' and' ' of' '\n' ' ' ' ' ' ']" ", the poet , who was the son of P here cy des , and of
+ " False presented The Persians of Aeschylus at the Greater 9 [' presented', ' The', ' Pers', 'ians', ' of', ' A', 'es', 'ch', 'yl', 'us']
+1400 315 Name of father of x -1 Name of father of Aeschylus Euphorion of Eleusis Aeschylus "[',' ' the' ' poet' ',' ' who' ' was' ' the' ' son' ' of' ' P' 'here' 'cy'
+ 'des' ',' ' and' ' of' '\n' ' ' ' ' ' ']" ", the poet , who was the son of P here cy des , and of
+ " False Persians of Aeschylus at the Greater 7 [' Pers', 'ians', ' of', ' A', 'es', 'ch', 'yl', 'us']
+1401 315 Name of father of x -1 Name of father of Aeschylus Euphorion of Eleusis Aeschylus "[',' ' the' ' poet' ',' ' who' ' was' ' the' ' son' ' of' ' P' 'here' 'cy'
+ 'des' ',' ' and' ' of' '\n' ' ' ' ' ' ']" ", the poet , who was the son of P here cy des , and of
+ " False Bacon was introduced to Aeschylus through T. S. 8 [' Bacon', ' was', ' introduced', ' to', ' A', 'es', 'ch', 'yl', 'us']
+1402 316 Name of father of x -1 Name of father of Osamu Tezuka Yutaka Tezuka Osamu Tezuka "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False 5 ['Os', 'am', 'u', ' Te', 'z', 'uka']
+1403 316 Name of father of x -1 Name of father of Osamu Tezuka Yutaka Tezuka Osamu Tezuka "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False 5 ['Os', 'am', 'u', ' Te', 'z', 'uka']
+1404 316 Name of father of x -1 Name of father of Osamu Tezuka Yutaka Tezuka Osamu Tezuka "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False to flourish when Osamu Tezuka produced a prolific 8 [' to', ' flourish', ' when', ' Os', 'am', 'u', ' Te', 'z', 'uka']
+1405 316 Name of father of x -1 Name of father of Osamu Tezuka Yutaka Tezuka Osamu Tezuka "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False the prolific Osamu Tezuka and the comic 7 [' the', ' prolific', ' Os', 'am', 'u', ' Te', 'z', 'uka']
+1406 316 Name of father of x -1 Name of father of Osamu Tezuka Yutaka Tezuka Osamu Tezuka "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' looking']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm looking" False ideas of manga author Osamu Tezuka and his Star System. 9 [' ideas', ' of', ' manga', ' author', ' Os', 'am', 'u', ' Te', 'z', 'uka']
+1407 317 Name of father of x -1 Name of father of Peter Gabriel Ralph Parton Gabriel Peter Gabriel "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' and' ' singer' '-' 'song' 'writer' ' Kate' ' Bush']" , the singer - song writer , and his wife , the actress and singer - song writer Kate Bush False with Paul Simon, Peter Gabriel and Mickey Hart 5 [' with', ' Paul', ' Simon', ',', ' Peter', ' Gabriel']
+1408 317 Name of father of x -1 Name of father of Peter Gabriel Ralph Parton Gabriel Peter Gabriel "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' and' ' singer' '-' 'song' 'writer' ' Kate' ' Bush']" , the singer - song writer , and his wife , the actress and singer - song writer Kate Bush False " ""Down to Earth"" by Peter Gabriel and ""O ..." 7 "[' ""', 'Down', ' to', ' Earth', '""', ' by', ' Peter', ' Gabriel']"
+1409 317 Name of father of x -1 Name of father of Peter Gabriel Ralph Parton Gabriel Peter Gabriel "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' and' ' singer' '-' 'song' 'writer' ' Kate' ' Bush']" , the singer - song writer , and his wife , the actress and singer - song writer Kate Bush False " appearance to duet with Peter Gabriel on ""Don 't" 6 [' appearance', ' to', ' du', 'et', ' with', ' Peter', ' Gabriel']
+1410 317 Name of father of x -1 Name of father of Peter Gabriel Ralph Parton Gabriel Peter Gabriel "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' and' ' singer' '-' 'song' 'writer' ' Kate' ' Bush']" , the singer - song writer , and his wife , the actress and singer - song writer Kate Bush False Roger Waters and Peter Gabriel as influences 4 [' Roger', ' Waters', ' and', ' Peter', ' Gabriel']
+1411 317 Name of father of x -1 Name of father of Peter Gabriel Ralph Parton Gabriel Peter Gabriel "[',' ' the' ' singer' '-' 'song' 'writer' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' and' ' singer' '-' 'song' 'writer' ' Kate' ' Bush']" , the singer - song writer , and his wife , the actress and singer - song writer Kate Bush False original lead singer Peter Gabriel and guitarist 4 [' original', ' lead', ' singer', ' Peter', ' Gabriel']
+1412 318 Name of father of x -1 Name of father of Nicolaus Copernicus Niklas Koppernigk the Elder Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 14' '73' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Nicola']" ", the astronomer , who was born in 14 73 .
+
+ The name of the father of Nicola" False 5 ['Nic', 'ola', 'us', ' Cop', 'ern', 'icus']
+1413 318 Name of father of x -1 Name of father of Nicolaus Copernicus Niklas Koppernigk the Elder Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 14' '73' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Nicola']" ", the astronomer , who was born in 14 73 .
+
+ The name of the father of Nicola" False 5 ['Nic', 'ola', 'us', ' Cop', 'ern', 'icus']
+1414 318 Name of father of x -1 Name of father of Nicolaus Copernicus Niklas Koppernigk the Elder Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 14' '73' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Nicola']" ", the astronomer , who was born in 14 73 .
+
+ The name of the father of Nicola" False of the cosmos, Nicolaus Copernicus was the first to 8 [' of', ' the', ' cosmos', ',', ' Nicola', 'us', ' Cop', 'ern', 'icus']
+1415 318 Name of father of x -1 Name of father of Nicolaus Copernicus Niklas Koppernigk the Elder Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 14' '73' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Nicola']" ", the astronomer , who was born in 14 73 .
+
+ The name of the father of Nicola" False Chemistry after Nicolaus Copernicus on February 19, 2010, 6 [' Chemistry', ' after', ' Nicola', 'us', ' Cop', 'ern', 'icus']
+1416 318 Name of father of x -1 Name of father of Nicolaus Copernicus Niklas Koppernigk the Elder Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 14' '73' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Nicola']" ", the astronomer , who was born in 14 73 .
+
+ The name of the father of Nicola" False the 16th century, Nicolaus Copernicus proposed a heliocentric 9 [' the', ' 16', 'th', ' century', ',', ' Nicola', 'us', ' Cop', 'ern', 'icus']
+1417 320 Name of father of x -1 Name of father of Douglas Adams Christopher Douglas Adams Douglas Adams "[',' ' the' ' author' ' of' ' the' ' Hitch' 'h' 'iker' ""'s"" ' Guide' ' to'
+ ' the' ' Galaxy' '.' '\n' '\n' 'The' ' Hitch' 'h' 'iker']" ", the author of the Hitch h iker 's Guide to the Galaxy .
+
+ The Hitch h iker" False endangered animals Douglas Adams and Mark Carwardine 3 [' endangered', ' animals', ' Douglas', ' Adams']
+1418 320 Name of father of x -1 Name of father of Douglas Adams Christopher Douglas Adams Douglas Adams "[',' ' the' ' author' ' of' ' the' ' Hitch' 'h' 'iker' ""'s"" ' Guide' ' to'
+ ' the' ' Galaxy' '.' '\n' '\n' 'The' ' Hitch' 'h' 'iker']" ", the author of the Hitch h iker 's Guide to the Galaxy .
+
+ The Hitch h iker" False Chapman and Douglas Adams wrote a pilot 3 [' Chapman', ' and', ' Douglas', ' Adams']
+1419 320 Name of father of x -1 Name of father of Douglas Adams Christopher Douglas Adams Douglas Adams "[',' ' the' ' author' ' of' ' the' ' Hitch' 'h' 'iker' ""'s"" ' Guide' ' to'
+ ' the' ' Galaxy' '.' '\n' '\n' 'The' ' Hitch' 'h' 'iker']" ", the author of the Hitch h iker 's Guide to the Galaxy .
+
+ The Hitch h iker" False chain. When Douglas Adams attempted to attend 4 [' chain', '.', ' When', ' Douglas', ' Adams']
+1420 320 Name of father of x -1 Name of father of Douglas Adams Christopher Douglas Adams Douglas Adams "[',' ' the' ' author' ' of' ' the' ' Hitch' 'h' 'iker' ""'s"" ' Guide' ' to'
+ ' the' ' Galaxy' '.' '\n' '\n' 'The' ' Hitch' 'h' 'iker']" ", the author of the Hitch h iker 's Guide to the Galaxy .
+
+ The Hitch h iker" False collaborated with author Douglas Adams in a week-long 4 [' collaborated', ' with', ' author', ' Douglas', ' Adams']
+1421 320 Name of father of x -1 Name of father of Douglas Adams Christopher Douglas Adams Douglas Adams "[',' ' the' ' author' ' of' ' the' ' Hitch' 'h' 'iker' ""'s"" ' Guide' ' to'
+ ' the' ' Galaxy' '.' '\n' '\n' 'The' ' Hitch' 'h' 'iker']" ", the author of the Hitch h iker 's Guide to the Galaxy .
+
+ The Hitch h iker" False Target approached Douglas Adams on a number of 3 [' Target', ' approached', ' Douglas', ' Adams']
+1422 321 Name of father of x -1 Name of father of Karl Popper Simon Siegmund Carl Popper Karl Popper "[',' ' the' ' philosopher' ',' ' who' ' was' ' a' ' friend' ' of' ' mine'
+ '.' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' book' ' of']" ", the philosopher , who was a friend of mine .
+
+ I have been reading the book of" False efficacious. Karl Popper develops in 5 [' effic', 'acious', '.', ' Karl', ' Po', 'pper']
+1423 321 Name of father of x -1 Name of father of Karl Popper Simon Siegmund Carl Popper Karl Popper "[',' ' the' ' philosopher' ',' ' who' ' was' ' a' ' friend' ' of' ' mine'
+ '.' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' book' ' of']" ", the philosopher , who was a friend of mine .
+
+ I have been reading the book of" False hoc hypotheses. Karl Popper accepted this thesis, 5 [' hoc', ' hypotheses', '.', ' Karl', ' Po', 'pper']
+1424 321 Name of father of x -1 Name of father of Karl Popper Simon Siegmund Carl Popper Karl Popper "[',' ' the' ' philosopher' ',' ' who' ' was' ' a' ' friend' ' of' ' mine'
+ '.' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' book' ' of']" ", the philosopher , who was a friend of mine .
+
+ I have been reading the book of" False defenders have been Karl Popper and John Carew 5 [' defenders', ' have', ' been', ' Karl', ' Po', 'pper']
+1425 321 Name of father of x -1 Name of father of Karl Popper Simon Siegmund Carl Popper Karl Popper "[',' ' the' ' philosopher' ',' ' who' ' was' ' a' ' friend' ' of' ' mine'
+ '.' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' book' ' of']" ", the philosopher , who was a friend of mine .
+
+ I have been reading the book of" False efficacious. Karl Popper develops in 5 [' effic', 'acious', '.', ' Karl', ' Po', 'pper']
+1426 321 Name of father of x -1 Name of father of Karl Popper Simon Siegmund Carl Popper Karl Popper "[',' ' the' ' philosopher' ',' ' who' ' was' ' a' ' friend' ' of' ' mine'
+ '.' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' book' ' of']" ", the philosopher , who was a friend of mine .
+
+ I have been reading the book of" False claim at hand. Karl Popper would also describe 6 [' claim', ' at', ' hand', '.', ' Karl', ' Po', 'pper']
+1427 322 Name of father of x -1 Name of father of Seneca Seneca the Elder Seneca "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' Empire' '.' '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' the'
+ ' first']" ", the philosopher , and the father of the Roman Empire .
+
+ The Roman Empire was the first" False provides that the Seneca Nation of Indians 4 [' provides', ' that', ' the', ' Sen', 'eca']
+1428 322 Name of father of x -1 Name of father of Seneca Seneca the Elder Seneca "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' Empire' '.' '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' the'
+ ' first']" ", the philosopher , and the father of the Roman Empire .
+
+ The Roman Empire was the first" False southward toward the Seneca County line. At the 5 [' south', 'ward', ' toward', ' the', ' Sen', 'eca']
+1429 322 Name of father of x -1 Name of father of Seneca Seneca the Elder Seneca "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' Empire' '.' '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' the'
+ ' first']" ", the philosopher , and the father of the Roman Empire .
+
+ The Roman Empire was the first" False ambiguity appears in Seneca and in Ausonius'19th 4 [' ambiguity', ' appears', ' in', ' Sen', 'eca']
+1430 322 Name of father of x -1 Name of father of Seneca Seneca the Elder Seneca "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' Empire' '.' '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' the'
+ ' first']" ", the philosopher , and the father of the Roman Empire .
+
+ The Roman Empire was the first" False " oonah!"", the Seneca signal to retire, and" 6 "[' o', 'on', 'ah', '!"",', ' the', ' Sen', 'eca']"
+1431 322 Name of father of x -1 Name of father of Seneca Seneca the Elder Seneca "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Roman' ' Empire' '.' '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' the'
+ ' first']" ", the philosopher , and the father of the Roman Empire .
+
+ The Roman Empire was the first" False ancient historian. Seneca the elder and Philo, 4 [' ancient', ' historian', '.', ' Sen', 'eca']
+1432 323 Name of father of x -1 Name of father of George H. W. Bush Prescott Bush George H. W. Bush "['\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' George' ' H' '.' ' W'
+ '.' ' Bush' '.' ' I' ' am' ' also' ' a' ' proud']" "
+
+ I am a proud father of George H . W . Bush . I am also a proud" False Administration and George H. W. Bush Administration, 7 [' Administration', ' and', ' George', ' H', '.', ' W', '.', ' Bush']
+1433 323 Name of father of x -1 Name of father of George H. W. Bush Prescott Bush George H. W. Bush "['\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' George' ' H' '.' ' W'
+ '.' ' Bush' '.' ' I' ' am' ' also' ' a' ' proud']" "
+
+ I am a proud father of George H . W . Bush . I am also a proud" False Republican incumbent George H. W. Bush (37.4 percent of 7 [' Republican', ' incumbent', ' George', ' H', '.', ' W', '.', ' Bush']
+1434 323 Name of father of x -1 Name of father of George H. W. Bush Prescott Bush George H. W. Bush "['\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' George' ' H' '.' ' W'
+ '.' ' Bush' '.' ' I' ' am' ' also' ' a' ' proud']" "
+
+ I am a proud father of George H . W . Bush . I am also a proud" False that when the George H. W. Bush administrative staff 8 [' that', ' when', ' the', ' George', ' H', '.', ' W', '.', ' Bush']
+1435 323 Name of father of x -1 Name of father of George H. W. Bush Prescott Bush George H. W. Bush "['\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' George' ' H' '.' ' W'
+ '.' ' Bush' '.' ' I' ' am' ' also' ' a' ' proud']" "
+
+ I am a proud father of George H . W . Bush . I am also a proud" False pardon from George H. W. Bush on December 24, 7 [' pardon', ' from', ' George', ' H', '.', ' W', '.', ' Bush']
+1436 323 Name of father of x -1 Name of father of George H. W. Bush Prescott Bush George H. W. Bush "['\n' '\n' 'I' ' am' ' a' ' proud' ' father' ' of' ' George' ' H' '.' ' W'
+ '.' ' Bush' '.' ' I' ' am' ' also' ' a' ' proud']" "
+
+ I am a proud father of George H . W . Bush . I am also a proud" False 1991, by President George H. W. Bush to a seat on the 9 [' 1991', ',', ' by', ' President', ' George', ' H', '.', ' W', '.', ' Bush']
+1437 324 Name of father of x -1 Name of father of Joseph Haydn Mathias Haydn Joseph Haydn "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '32' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' composer']" ", the composer , was born in 17 32 .
+
+ The name of the father of the composer" False " Favorite Overture – Joseph Haydn (1786)
+" 6 [' Favorite', ' Overt', 'ure', ' –', ' Joseph', ' Hay', 'dn']
+1438 324 Name of father of x -1 Name of father of Joseph Haydn Mathias Haydn Joseph Haydn "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '32' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' composer']" ", the composer , was born in 17 32 .
+
+ The name of the father of the composer" False who included Joseph Haydn and Ludwig van 4 [' who', ' included', ' Joseph', ' Hay', 'dn']
+1439 324 Name of father of x -1 Name of father of Joseph Haydn Mathias Haydn Joseph Haydn "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '32' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' composer']" ", the composer , was born in 17 32 .
+
+ The name of the father of the composer" False this room that Joseph Haydn gave a series of 5 [' this', ' room', ' that', ' Joseph', ' Hay', 'dn']
+1440 324 Name of father of x -1 Name of father of Joseph Haydn Mathias Haydn Joseph Haydn "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '32' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' composer']" ", the composer , was born in 17 32 .
+
+ The name of the father of the composer" False in B minor. Joseph Haydn later set the same 6 [' in', ' B', ' minor', '.', ' Joseph', ' Hay', 'dn']
+1441 324 Name of father of x -1 Name of father of Joseph Haydn Mathias Haydn Joseph Haydn "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '32' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' father' ' of' ' the' ' composer']" ", the composer , was born in 17 32 .
+
+ The name of the father of the composer" False 36-41), were composed by Joseph Haydn in 1787. The 9 [' 36', '-', '41', '),', ' were', ' composed', ' by', ' Joseph', ' Hay', 'dn']
+1442 326 Name of father of x -1 Name of father of Julia Roberts Walter Grady Roberts Julia Roberts "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' three'
+ ' children' '.' '\n' '\n' 'The' ' actress' ',' ' who' ' is' ' married'
+ ' to']" ", the actress , and the mother of three children .
+
+ The actress , who is married to" False for The Proposal. Julia Roberts was originally approached 6 [' for', ' The', ' Pro', 'posal', '.', ' Julia', ' Roberts']
+1443 326 Name of father of x -1 Name of father of Julia Roberts Walter Grady Roberts Julia Roberts "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' three'
+ ' children' '.' '\n' '\n' 'The' ' actress' ',' ' who' ' is' ' married'
+ ' to']" ", the actress , and the mother of three children .
+
+ The actress , who is married to" False finds the outfit Julia Roberts wears in the 1990 4 [' finds', ' the', ' outfit', ' Julia', ' Roberts']
+1444 326 Name of father of x -1 Name of father of Julia Roberts Walter Grady Roberts Julia Roberts "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' three'
+ ' children' '.' '\n' '\n' 'The' ' actress' ',' ' who' ' is' ' married'
+ ' to']" ", the actress , and the mother of three children .
+
+ The actress , who is married to" False (1990), which earned Julia Roberts a nomination for the 6 [' (', '1990', '),', ' which', ' earned', ' Julia', ' Roberts']
+1445 326 Name of father of x -1 Name of father of Julia Roberts Walter Grady Roberts Julia Roberts "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' three'
+ ' children' '.' '\n' '\n' 'The' ' actress' ',' ' who' ' is' ' married'
+ ' to']" ", the actress , and the mother of three children .
+
+ The actress , who is married to" False stars like Julia Roberts and Jean-Claude 3 [' stars', ' like', ' Julia', ' Roberts']
+1446 326 Name of father of x -1 Name of father of Julia Roberts Walter Grady Roberts Julia Roberts "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' three'
+ ' children' '.' '\n' '\n' 'The' ' actress' ',' ' who' ' is' ' married'
+ ' to']" ", the actress , and the mother of three children .
+
+ The actress , who is married to" False seller, alongside Julia Roberts in Eat Pray Love. Arriving 4 [' seller', ',', ' alongside', ' Julia', ' Roberts']
+1447 327 Name of father of x -1 Name of father of Thomas Gainsborough John Gainsborough Thomas Gainsborough "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 17' '27' '.' '\n'
+ '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the']" ", the painter , who was born in 17 27 .
+
+ The following is a list of the" False Richard Cosway and Thomas Gainsborough lived at Schomberg 7 [' Richard', ' Cos', 'way', ' and', ' Thomas', ' G', 'ains', 'borough']
+1448 327 Name of father of x -1 Name of father of Thomas Gainsborough John Gainsborough Thomas Gainsborough "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 17' '27' '.' '\n'
+ '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the']" ", the painter , who was born in 17 27 .
+
+ The following is a list of the" False types. His portrait by Thomas Gainsborough shows him 8 [' types', '.', ' His', ' portrait', ' by', ' Thomas', ' G', 'ains', 'borough']
+1449 327 Name of father of x -1 Name of father of Thomas Gainsborough John Gainsborough Thomas Gainsborough "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 17' '27' '.' '\n'
+ '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the']" ", the painter , who was born in 17 27 .
+
+ The following is a list of the" False famous painting by Thomas Gainsborough in 1782, which 6 [' famous', ' painting', ' by', ' Thomas', ' G', 'ains', 'borough']
+1450 327 Name of father of x -1 Name of father of Thomas Gainsborough John Gainsborough Thomas Gainsborough "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 17' '27' '.' '\n'
+ '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the']" ", the painter , who was born in 17 27 .
+
+ The following is a list of the" False Hogarth and Thomas Gainsborough as well as the by-then 6 [' Hog', 'arth', ' and', ' Thomas', ' G', 'ains', 'borough']
+1451 327 Name of father of x -1 Name of father of Thomas Gainsborough John Gainsborough Thomas Gainsborough "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 17' '27' '.' '\n'
+ '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the']" ", the painter , who was born in 17 27 .
+
+ The following is a list of the" False 18th century Thomas Gainsborough and Sir Thomas 6 [' 18', 'th', ' century', ' Thomas', ' G', 'ains', 'borough']
+1452 328 Name of father of x -1 Name of father of Ingmar Bergman Erik Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' The' ' Seventh' ' Seal' ','
+ ' The' ' Silence' ',' ' Wild' ' Straw' 'berries' ',' ' and' ' F' 'anny'
+ ' and']" , the Swedish director of The Seventh Seal , The Silence , Wild Straw berries , and F anny and False 1962. In 1965 he and Ingmar Bergman were joint 9 [' 1962', '.', ' In', ' 1965', ' he', ' and', ' Ing', 'mar', ' Berg', 'man']
+1453 328 Name of father of x -1 Name of father of Ingmar Bergman Erik Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' The' ' Seventh' ' Seal' ','
+ ' The' ' Silence' ',' ' Wild' ' Straw' 'berries' ',' ' and' ' F' 'anny'
+ ' and']" , the Swedish director of The Seventh Seal , The Silence , Wild Straw berries , and F anny and False betrayal worthy of an Ingmar Bergman film, and 7 [' betrayal', ' worthy', ' of', ' an', ' Ing', 'mar', ' Berg', 'man']
+1454 328 Name of father of x -1 Name of father of Ingmar Bergman Erik Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' The' ' Seventh' ' Seal' ','
+ ' The' ' Silence' ',' ' Wild' ' Straw' 'berries' ',' ' and' ' F' 'anny'
+ ' and']" , the Swedish director of The Seventh Seal , The Silence , Wild Straw berries , and F anny and False " Kieślowski, Ingmar Bergman and J. Mahendran.
+" 10 [' K', 'ie', '�', '�', 'l', 'owski', ',', ' Ing', 'mar', ' Berg', 'man']
+1455 328 Name of father of x -1 Name of father of Ingmar Bergman Erik Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' The' ' Seventh' ' Seal' ','
+ ' The' ' Silence' ',' ' Wild' ' Straw' 'berries' ',' ' and' ' F' 'anny'
+ ' and']" , the Swedish director of The Seventh Seal , The Silence , Wild Straw berries , and F anny and False years later, Ingmar Bergman referred to 6 [' years', ' later', ',', ' Ing', 'mar', ' Berg', 'man']
+1456 328 Name of father of x -1 Name of father of Ingmar Bergman Erik Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' The' ' Seventh' ' Seal' ','
+ ' The' ' Silence' ',' ' Wild' ' Straw' 'berries' ',' ' and' ' F' 'anny'
+ ' and']" , the Swedish director of The Seventh Seal , The Silence , Wild Straw berries , and F anny and False Aghed, the director Ingmar Bergman referenced Magnolia 8 [' Ag', 'hed', ',', ' the', ' director', ' Ing', 'mar', ' Berg', 'man']
+1457 329 Name of father of x -1 Name of father of Hector Berlioz Louis Berlioz Hector Berlioz "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Hector' ' Ber' 'lio' 'z' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Hector Ber lio z , the composer ." False symphony genre. Hector Berlioz was the only composer 7 [' sym', 'phony', ' genre', '.', ' Hector', ' Ber', 'lio', 'z']
+1458 329 Name of father of x -1 Name of father of Hector Berlioz Louis Berlioz Hector Berlioz "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Hector' ' Ber' 'lio' 'z' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Hector Ber lio z , the composer ." False in 1868 for Hector Berlioz as a sequel to that 7 [' in', ' 18', '68', ' for', ' Hector', ' Ber', 'lio', 'z']
+1459 329 Name of father of x -1 Name of father of Hector Berlioz Louis Berlioz Hector Berlioz "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Hector' ' Ber' 'lio' 'z' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Hector Ber lio z , the composer ." False Act 3; and so on. Hector Berlioz was particularly impressed; 10 [' Act', ' 3', ';', ' and', ' so', ' on', '.', ' Hector', ' Ber', 'lio', 'z']
+1460 329 Name of father of x -1 Name of father of Hector Berlioz Louis Berlioz Hector Berlioz "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Hector' ' Ber' 'lio' 'z' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Hector Ber lio z , the composer ." False context was coined by Hector Berlioz when he described 7 [' context', ' was', ' coined', ' by', ' Hector', ' Ber', 'lio', 'z']
+1461 329 Name of father of x -1 Name of father of Hector Berlioz Louis Berlioz Hector Berlioz "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Hector' ' Ber' 'lio' 'z' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Hector Ber lio z , the composer ." False concert organized by Hector Berlioz for his bankrupt 6 [' concert', ' organized', ' by', ' Hector', ' Ber', 'lio', 'z']
+1462 330 Name of father of x -1 Name of father of Leon Trotsky David Bronstein Leon Trotsky "[',' ' the' ' great' ' revolutionary' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution']" , the great revolutionary , the great leader of the Russian Revolution , the great leader of the Russian Revolution False Vladimir Lenin and Leon Trotsky and many others, 4 [' Vladimir', ' Lenin', ' and', ' Leon', ' Trotsky']
+1463 330 Name of father of x -1 Name of father of Leon Trotsky David Bronstein Leon Trotsky "[',' ' the' ' great' ' revolutionary' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution']" , the great revolutionary , the great leader of the Russian Revolution , the great leader of the Russian Revolution False numerous songs. In 1913, Leon Trotsky described Andranik 7 [' numerous', ' songs', '.', ' In', ' 1913', ',', ' Leon', ' Trotsky']
+1464 330 Name of father of x -1 Name of father of Leon Trotsky David Bronstein Leon Trotsky "[',' ' the' ' great' ' revolutionary' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution']" , the great revolutionary , the great leader of the Russian Revolution , the great leader of the Russian Revolution False 17 June when Leon Trotsky promised them their 4 [' 17', ' June', ' when', ' Leon', ' Trotsky']
+1465 330 Name of father of x -1 Name of father of Leon Trotsky David Bronstein Leon Trotsky "[',' ' the' ' great' ' revolutionary' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution']" , the great revolutionary , the great leader of the Russian Revolution , the great leader of the Russian Revolution False 1969 play about Leon Trotsky (Trotzki im Exil), 4 [' 1969', ' play', ' about', ' Leon', ' Trotsky']
+1466 330 Name of father of x -1 Name of father of Leon Trotsky David Bronstein Leon Trotsky "[',' ' the' ' great' ' revolutionary' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution' ',' ' the' ' great' ' leader' ' of'
+ ' the' ' Russian' ' Revolution']" , the great revolutionary , the great leader of the Russian Revolution , the great leader of the Russian Revolution False joined supporters of Leon Trotsky against those 4 [' joined', ' supporters', ' of', ' Leon', ' Trotsky']
+1467 331 Name of father of x -1 Name of father of Wilhelm II Friedrich III of Germany Wilhelm II "[',' ' the' ' Kaiser' ""'s"" ' son' ',' ' who' ' was' ' born' ' in' ' 18'
+ '82' '.' '\n' '\n' 'The' ' Kaiser' ""'s"" ' son' ',']" ", the Kaiser 's son , who was born in 18 82 .
+
+ The Kaiser 's son ," False 1889, where Wilhelm II took part in the 4 [' 1889', ',', ' where', ' Wilhelm', ' II']
+1468 331 Name of father of x -1 Name of father of Wilhelm II Friedrich III of Germany Wilhelm II "[',' ' the' ' Kaiser' ""'s"" ' son' ',' ' who' ' was' ' born' ' in' ' 18'
+ '82' '.' '\n' '\n' 'The' ' Kaiser' ""'s"" ' son' ',']" ", the Kaiser 's son , who was born in 18 82 .
+
+ The Kaiser 's son ," False hosted Kaiser Wilhelm II during the celebration 3 [' hosted', ' Kaiser', ' Wilhelm', ' II']
+1469 331 Name of father of x -1 Name of father of Wilhelm II Friedrich III of Germany Wilhelm II "[',' ' the' ' Kaiser' ""'s"" ' son' ',' ' who' ' was' ' born' ' in' ' 18'
+ '82' '.' '\n' '\n' 'The' ' Kaiser' ""'s"" ' son' ',']" ", the Kaiser 's son , who was born in 18 82 .
+
+ The Kaiser 's son ," False " II =
+" 4 [' II', ' =', 'Wil', 'helm', ' II']
+1470 331 Name of father of x -1 Name of father of Wilhelm II Friedrich III of Germany Wilhelm II "[',' ' the' ' Kaiser' ""'s"" ' son' ',' ' who' ' was' ' born' ' in' ' 18'
+ '82' '.' '\n' '\n' 'The' ' Kaiser' ""'s"" ' son' ',']" ", the Kaiser 's son , who was born in 18 82 .
+
+ The Kaiser 's son ," False roadstead for Wilhelm II and Archduke Franz 4 [' road', 'stead', ' for', ' Wilhelm', ' II']
+1471 331 Name of father of x -1 Name of father of Wilhelm II Friedrich III of Germany Wilhelm II "[',' ' the' ' Kaiser' ""'s"" ' son' ',' ' who' ' was' ' born' ' in' ' 18'
+ '82' '.' '\n' '\n' 'The' ' Kaiser' ""'s"" ' son' ',']" ", the Kaiser 's son , who was born in 18 82 .
+
+ The Kaiser 's son ," False sister of Kaiser Wilhelm II of Germany. His policy 4 [' sister', ' of', ' Kaiser', ' Wilhelm', ' II']
+1472 332 Name of father of x -1 Name of father of Robert Schumann August Schumann Robert Schumann "[',' ' the' ' composer' ',' ' who' ' was' ' born' ' in' ' 18' '10' '.'
+ '\n' '\n' 'The' ' Sch' 'umann' 's' ' were' ' a' ' musical']" ", the composer , who was born in 18 10 .
+
+ The Sch umann s were a musical" False studied at the Robert Schumann Hochschule in Düsseldorf, 5 [' studied', ' at', ' the', ' Robert', ' Sch', 'umann']
+1473 332 Name of father of x -1 Name of father of Robert Schumann August Schumann Robert Schumann "[',' ' the' ' composer' ',' ' who' ' was' ' born' ' in' ' 18' '10' '.'
+ '\n' '\n' 'The' ' Sch' 'umann' 's' ' were' ' a' ' musical']" ", the composer , who was born in 18 10 .
+
+ The Sch umann s were a musical" False Western contemporaries, Robert Schumann stands out as an 5 [' Western', ' contemporaries', ',', ' Robert', ' Sch', 'umann']
+1474 332 Name of father of x -1 Name of father of Robert Schumann August Schumann Robert Schumann "[',' ' the' ' composer' ',' ' who' ' was' ' born' ' in' ' 18' '10' '.'
+ '\n' '\n' 'The' ' Sch' 'umann' 's' ' were' ' a' ' musical']" ", the composer , who was born in 18 10 .
+
+ The Sch umann s were a musical" False example, on hearing Robert Schumann conduct in 6 [' example', ',', ' on', ' hearing', ' Robert', ' Sch', 'umann']
+1475 332 Name of father of x -1 Name of father of Robert Schumann August Schumann Robert Schumann "[',' ' the' ' composer' ',' ' who' ' was' ' born' ' in' ' 18' '10' '.'
+ '\n' '\n' 'The' ' Sch' 'umann' 's' ' were' ' a' ' musical']" ", the composer , who was born in 18 10 .
+
+ The Sch umann s were a musical" False contemporaries, Robert Schumann stands out 4 [' contemporaries', ',', ' Robert', ' Sch', 'umann']
+1476 332 Name of father of x -1 Name of father of Robert Schumann August Schumann Robert Schumann "[',' ' the' ' composer' ',' ' who' ' was' ' born' ' in' ' 18' '10' '.'
+ '\n' '\n' 'The' ' Sch' 'umann' 's' ' were' ' a' ' musical']" ", the composer , who was born in 18 10 .
+
+ The Sch umann s were a musical" False for a singer, Robert Schumann ignored the 6 [' for', ' a', ' singer', ',', ' Robert', ' Sch', 'umann']
+1477 333 Name of father of x -1 Name of father of Michael Jordan James R. Jordan, Sr. Michael Jordan "[',' ' the' ' basketball' ' player' '.' '\n' '\n' 'I' ' am' ' a' ' big'
+ ' fan' ' of' ' the' ' NBA' '.' ' I' ' am' ' a' ' big']" ", the basketball player .
+
+ I am a big fan of the NBA . I am a big" False Bobcats owner Michael Jordan had been fined 4 [' Bob', 'cats', ' owner', ' Michael', ' Jordan']
+1478 333 Name of father of x -1 Name of father of Michael Jordan James R. Jordan, Sr. Michael Jordan "[',' ' the' ' basketball' ' player' '.' '\n' '\n' 'I' ' am' ' a' ' big'
+ ' fan' ' of' ' the' ' NBA' '.' ' I' ' am' ' a' ' big']" ", the basketball player .
+
+ I am a big fan of the NBA . I am a big" False Australia's answer to Michael Jordan or Shaquille 5 "[' Australia', ""'s"", ' answer', ' to', ' Michael', ' Jordan']"
+1479 333 Name of father of x -1 Name of father of Michael Jordan James R. Jordan, Sr. Michael Jordan "[',' ' the' ' basketball' ' player' '.' '\n' '\n' 'I' ' am' ' a' ' big'
+ ' fan' ' of' ' the' ' NBA' '.' ' I' ' am' ' a' ' big']" ", the basketball player .
+
+ I am a big fan of the NBA . I am a big" False Bulls assigned Michael Jordan to defend Recasner 3 [' Bulls', ' assigned', ' Michael', ' Jordan']
+1480 333 Name of father of x -1 Name of father of Michael Jordan James R. Jordan, Sr. Michael Jordan "[',' ' the' ' basketball' ' player' '.' '\n' '\n' 'I' ' am' ' a' ' big'
+ ' fan' ' of' ' the' ' NBA' '.' ' I' ' am' ' a' ' big']" ", the basketball player .
+
+ I am a big fan of the NBA . I am a big" False 1 ['Michael', ' Jordan']
+1481 333 Name of father of x -1 Name of father of Michael Jordan James R. Jordan, Sr. Michael Jordan "[',' ' the' ' basketball' ' player' '.' '\n' '\n' 'I' ' am' ' a' ' big'
+ ' fan' ' of' ' the' ' NBA' '.' ' I' ' am' ' a' ' big']" ", the basketball player .
+
+ I am a big fan of the NBA . I am a big" False acclamation, Michael Jordan is the greatest 5 [' acc', 'lam', 'ation', ',', ' Michael', ' Jordan']
+1482 334 Name of father of x -1 Name of father of Orson Welles Richard Head Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ '.' '\n' '\n' 'The' ' couple' ' had' ' been' ' married' ' for']" ", the actor , and his wife , Rita Hay worth .
+
+ The couple had been married for" False Bergman, we have a love of Orson Welles and his use 11 [' Berg', 'man', ',', ' we', ' have', ' a', ' love', ' of', ' Or', 'son', ' Well', 'es']
+1483 334 Name of father of x -1 Name of father of Orson Welles Richard Head Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ '.' '\n' '\n' 'The' ' couple' ' had' ' been' ' married' ' for']" ", the actor , and his wife , Rita Hay worth .
+
+ The couple had been married for" False of Othello starring Orson Welles as the pick of Olivier's 8 [' of', ' O', 'the', 'llo', ' starring', ' Or', 'son', ' Well', 'es']
+1484 334 Name of father of x -1 Name of father of Orson Welles Richard Head Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ '.' '\n' '\n' 'The' ' couple' ' had' ' been' ' married' ' for']" ", the actor , and his wife , Rita Hay worth .
+
+ The couple had been married for" False of the 1941 Orson Welles film Citizen Kane. 6 [' of', ' the', ' 1941', ' Or', 'son', ' Well', 'es']
+1485 334 Name of father of x -1 Name of father of Orson Welles Richard Head Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ '.' '\n' '\n' 'The' ' couple' ' had' ' been' ' married' ' for']" ", the actor , and his wife , Rita Hay worth .
+
+ The couple had been married for" False Classic Hollywood Cinema. Orson Welles said of Cagney, 7 [' Classic', ' Hollywood', ' Cinema', '.', ' Or', 'son', ' Well', 'es']
+1486 334 Name of father of x -1 Name of father of Orson Welles Richard Head Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ '.' '\n' '\n' 'The' ' couple' ' had' ' been' ' married' ' for']" ", the actor , and his wife , Rita Hay worth .
+
+ The couple had been married for" False importance to Orson Welles … In the rhetoric of 5 [' importance', ' to', ' Or', 'son', ' Well', 'es']
+1487 335 Name of father of x -1 Name of father of Ella Fitzgerald William Fitzgerald Ella Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False American songbook Ella Fitzgerald owned by then, 5 [' American', ' song', 'book', ' Ell', 'a', ' Fitzgerald']
+1488 335 Name of father of x -1 Name of father of Ella Fitzgerald William Fitzgerald Ella Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Wilson in 1968 and Ella Fitzgerald in 1969 (on her album 6 [' Wilson', ' in', ' 1968', ' and', ' Ell', 'a', ' Fitzgerald']
+1489 335 Name of father of x -1 Name of father of Ella Fitzgerald William Fitzgerald Ella Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False including Duke Ellington, Ella Fitzgerald and Earl Hines, 8 [' including', ' Duke', ' E', 'lling', 'ton', ',', ' Ell', 'a', ' Fitzgerald']
+1490 335 Name of father of x -1 Name of father of Ella Fitzgerald William Fitzgerald Ella Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False American songbook Ella Fitzgerald owned by then, 5 [' American', ' song', 'book', ' Ell', 'a', ' Fitzgerald']
+1491 335 Name of father of x -1 Name of father of Ella Fitzgerald William Fitzgerald Ella Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 2 ['E', 'lla', ' Fitzgerald']
+1492 336 Name of father of x -1 Name of father of Gian Lorenzo Bernini Pietro Bernini Gian Lorenzo Bernini "[',' ' the' ' sculpt' 'or' ',' ' who' ' was' ' born' ' in' ' 15' '98' ','
+ ' and' ' died' ' in' ' 16' '80' '.' '\n' '\n']" ", the sculpt or , who was born in 15 98 , and died in 16 80 .
+
+" False was lost when Gian Lorenzo Bernini was commissioned 6 [' was', ' lost', ' when', ' Gian', ' Lorenzo', ' Bern', 'ini']
+1493 336 Name of father of x -1 Name of father of Gian Lorenzo Bernini Pietro Bernini Gian Lorenzo Bernini "[',' ' the' ' sculpt' 'or' ',' ' who' ' was' ' born' ' in' ' 15' '98' ','
+ ' and' ' died' ' in' ' 16' '80' '.' '\n' '\n']" ", the sculpt or , who was born in 15 98 , and died in 16 80 .
+
+" False was lost when Gian Lorenzo Bernini was commissioned 6 [' was', ' lost', ' when', ' Gian', ' Lorenzo', ' Bern', 'ini']
+1494 336 Name of father of x -1 Name of father of Gian Lorenzo Bernini Pietro Bernini Gian Lorenzo Bernini "[',' ' the' ' sculpt' 'or' ',' ' who' ' was' ' born' ' in' ' 15' '98' ','
+ ' and' ' died' ' in' ' 16' '80' '.' '\n' '\n']" ", the sculpt or , who was born in 15 98 , and died in 16 80 .
+
+" False Peter's square by Gian Lorenzo Bernini made it necessary 7 "[' Peter', ""'s"", ' square', ' by', ' Gian', ' Lorenzo', ' Bern', 'ini']"
+1495 336 Name of father of x -1 Name of father of Gian Lorenzo Bernini Pietro Bernini Gian Lorenzo Bernini "[',' ' the' ' sculpt' 'or' ',' ' who' ' was' ' born' ' in' ' 15' '98' ','
+ ' and' ' died' ' in' ' 16' '80' '.' '\n' '\n']" ", the sculpt or , who was born in 15 98 , and died in 16 80 .
+
+" False momentum was lost when Gian Lorenzo Bernini was commissioned 7 [' momentum', ' was', ' lost', ' when', ' Gian', ' Lorenzo', ' Bern', 'ini']
+1496 336 Name of father of x -1 Name of father of Gian Lorenzo Bernini Pietro Bernini Gian Lorenzo Bernini "[',' ' the' ' sculpt' 'or' ',' ' who' ' was' ' born' ' in' ' 15' '98' ','
+ ' and' ' died' ' in' ' 16' '80' '.' '\n' '\n']" ", the sculpt or , who was born in 15 98 , and died in 16 80 .
+
+" False square by Gian Lorenzo Bernini made it necessary 5 [' square', ' by', ' Gian', ' Lorenzo', ' Bern', 'ini']
+1497 337 Name of father of x -1 Name of father of Marcus Aurelius Marcus Annius Verus Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' of' ' his' ' mother' ',' ' Faust'
+ 'ina' ',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Anton' 'inus' ' P']" , the Emperor , and of his mother , Faust ina , the daughter of the Emperor Anton inus P False envoys sent by Emperor Marcus Aurelius (r. 161 – 180 8 [' env', 'oys', ' sent', ' by', ' Emperor', ' Marcus', ' Aure', 'l', 'ius']
+1498 337 Name of father of x -1 Name of father of Marcus Aurelius Marcus Annius Verus Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' of' ' his' ' mother' ',' ' Faust'
+ 'ina' ',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Anton' 'inus' ' P']" , the Emperor , and of his mother , Faust ina , the daughter of the Emperor Anton inus P False landscape. Lt. Col. Marcus Aurelius Belt (Ed Lauter), 9 [' landscape', '.', ' Lt', '.', ' Col', '.', ' Marcus', ' Aure', 'l', 'ius']
+1499 337 Name of father of x -1 Name of father of Marcus Aurelius Marcus Annius Verus Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' of' ' his' ' mother' ',' ' Faust'
+ 'ina' ',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Anton' 'inus' ' P']" , the Emperor , and of his mother , Faust ina , the daughter of the Emperor Anton inus P False " co-emperor had existed from Marcus Aurelius (r. 161 – 180) on.
+" 10 [' co', '-', 'em', 'peror', ' had', ' existed', ' from', ' Marcus', ' Aure', 'l', 'ius']
+1500 337 Name of father of x -1 Name of father of Marcus Aurelius Marcus Annius Verus Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' of' ' his' ' mother' ',' ' Faust'
+ 'ina' ',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Anton' 'inus' ' P']" , the Emperor , and of his mother , Faust ina , the daughter of the Emperor Anton inus P False embassy of Emperor Marcus Aurelius (r. 161 – 180 AD) 6 [' embassy', ' of', ' Emperor', ' Marcus', ' Aure', 'l', 'ius']
+1501 337 Name of father of x -1 Name of father of Marcus Aurelius Marcus Annius Verus Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' of' ' his' ' mother' ',' ' Faust'
+ 'ina' ',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Anton' 'inus' ' P']" , the Emperor , and of his mother , Faust ina , the daughter of the Emperor Anton inus P False represented Roman emperor Marcus Aurelius Antoninus (Andun 6 [' represented', ' Roman', ' emperor', ' Marcus', ' Aure', 'l', 'ius']
+1502 338 Name of father of x -1 Name of father of William Ewart Gladstone Sir John Gladstone, 1st Baronet William Ewart Gladstone "[',' ' 1' 'st' ' Earl' ' of' ' Mid' 'leton' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Mid' 'leton' ',' ' 1' 'st' ' Baron' ' Mid']" , 1 st Earl of Mid leton , 1 st V isc ount Mid leton , 1 st Baron Mid False Prime Minister William Ewart Gladstone sent him papers 6 [' Prime', ' Minister', ' William', ' E', 'wart', ' Glad', 'stone']
+1503 338 Name of father of x -1 Name of father of William Ewart Gladstone Sir John Gladstone, 1st Baronet William Ewart Gladstone "[',' ' 1' 'st' ' Earl' ' of' ' Mid' 'leton' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Mid' 'leton' ',' ' 1' 'st' ' Baron' ' Mid']" , 1 st Earl of Mid leton , 1 st V isc ount Mid leton , 1 st Baron Mid False Robert Peel and William Ewart Gladstone (who grew up 7 [' Robert', ' Peel', ' and', ' William', ' E', 'wart', ' Glad', 'stone']
+1504 338 Name of father of x -1 Name of father of William Ewart Gladstone Sir John Gladstone, 1st Baronet William Ewart Gladstone "[',' ' 1' 'st' ' Earl' ' of' ' Mid' 'leton' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Mid' 'leton' ',' ' 1' 'st' ' Baron' ' Mid']" , 1 st Earl of Mid leton , 1 st V isc ount Mid leton , 1 st Baron Mid False Liberal Party's William Ewart Gladstone to the Premiership 7 "[' Liberal', ' Party', ""'s"", ' William', ' E', 'wart', ' Glad', 'stone']"
+1505 338 Name of father of x -1 Name of father of William Ewart Gladstone Sir John Gladstone, 1st Baronet William Ewart Gladstone "[',' ' 1' 'st' ' Earl' ' of' ' Mid' 'leton' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Mid' 'leton' ',' ' 1' 'st' ' Baron' ' Mid']" , 1 st Earl of Mid leton , 1 st V isc ount Mid leton , 1 st Baron Mid False Prime Minister William Ewart Gladstone and poet Robert 6 [' Prime', ' Minister', ' William', ' E', 'wart', ' Glad', 'stone']
+1506 338 Name of father of x -1 Name of father of William Ewart Gladstone Sir John Gladstone, 1st Baronet William Ewart Gladstone "[',' ' 1' 'st' ' Earl' ' of' ' Mid' 'leton' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Mid' 'leton' ',' ' 1' 'st' ' Baron' ' Mid']" , 1 st Earl of Mid leton , 1 st V isc ount Mid leton , 1 st Baron Mid False on Robert Peel and William Ewart Gladstone (who grew up 8 [' on', ' Robert', ' Peel', ' and', ' William', ' E', 'wart', ' Glad', 'stone']
+1507 339 Name of father of x -1 Name of father of Glenn Close William Close Glenn Close "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False their graves. Glenn Close is up to the 4 [' their', ' graves', '.', ' Glenn', ' Close']
+1508 339 Name of father of x -1 Name of father of Glenn Close William Close Glenn Close "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False in 2001, starring Glenn Close as Nellie, Harry 5 [' in', ' 2001', ',', ' starring', ' Glenn', ' Close']
+1509 339 Name of father of x -1 Name of father of Glenn Close William Close Glenn Close "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False Bérénice Marlohe. Glenn Close was cast through Curtis 9 [' B', 'é', 'ré', 'nice', ' Mar', 'lo', 'he', '.', ' Glenn', ' Close']
+1510 339 Name of father of x -1 Name of father of Glenn Close William Close Glenn Close "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False " but perfectly cast Glenn Close and Frank Langella"".
+" 4 [' but', ' perfectly', ' cast', ' Glenn', ' Close']
+1511 339 Name of father of x -1 Name of father of Glenn Close William Close Glenn Close "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False Arlene 4 [' Ar', 'len', 'Gl', 'enn', ' Close']
+1512 340 Name of father of x -1 Name of father of Céline Dion Adhémar Dion Céline Dion "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' C' 'é' 'line'
+ ' Dion' ' is' ' not' ' known' '.' '\n' '\n' 'C' 'é']" "
+
+ The name of the father of C é line Dion is not known .
+
+ C é" False Decade of Song by Céline Dion from hitting number 8 [' Dec', 'ade', ' of', ' Song', ' by', ' C', 'é', 'line', ' Dion']
+1513 340 Name of father of x -1 Name of father of Céline Dion Adhémar Dion Céline Dion "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' C' 'é' 'line'
+ ' Dion' ' is' ' not' ' known' '.' '\n' '\n' 'C' 'é']" "
+
+ The name of the father of C é line Dion is not known .
+
+ C é" False early career: Céline Dion chante Noël (1981) 6 [' early', ' career', ':', ' C', 'é', 'line', ' Dion']
+1514 340 Name of father of x -1 Name of father of Céline Dion Adhémar Dion Céline Dion "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' C' 'é' 'line'
+ ' Dion' ' is' ' not' ' known' '.' '\n' '\n' 'C' 'é']" "
+
+ The name of the father of C é line Dion is not known .
+
+ C é" False Decade of Song by Céline Dion from hitting number 8 [' Dec', 'ade', ' of', ' Song', ' by', ' C', 'é', 'line', ' Dion']
+1515 340 Name of father of x -1 Name of father of Céline Dion Adhémar Dion Céline Dion "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' C' 'é' 'line'
+ ' Dion' ' is' ' not' ' known' '.' '\n' '\n' 'C' 'é']" "
+
+ The name of the father of C é line Dion is not known .
+
+ C é" False taped La spéciale Céline Dion in Paris, France 9 [' taped', ' La', ' sp', 'é', 'cial', 'e', ' C', 'é', 'line', ' Dion']
+1516 340 Name of father of x -1 Name of father of Céline Dion Adhémar Dion Céline Dion "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' C' 'é' 'line'
+ ' Dion' ' is' ' not' ' known' '.' '\n' '\n' 'C' 'é']" "
+
+ The name of the father of C é line Dion is not known .
+
+ C é" False her early career: Céline Dion chante Noël 7 [' her', ' early', ' career', ':', ' C', 'é', 'line', ' Dion']
+1517 341 Name of father of x -1 Name of father of Ray Bradbury Leonard Spaulding Bradbury, Sr. Ray Bradbury "[',' ' the' ' author' ' of' ' Fahrenheit' ' 451' '.' '\n' '\n' 'The'
+ ' book' ' is' ' a' ' dystopian' ' novel' ' about' ' a' ' future' ' where'
+ ' books']" ", the author of Fahrenheit 451 .
+
+ The book is a dystopian novel about a future where books" False stories by Ray Bradbury and Edgar Rice 4 [' stories', ' by', ' Ray', ' Brad', 'bury']
+1518 341 Name of father of x -1 Name of father of Ray Bradbury Leonard Spaulding Bradbury, Sr. Ray Bradbury "[',' ' the' ' author' ' of' ' Fahrenheit' ' 451' '.' '\n' '\n' 'The'
+ ' book' ' is' ' a' ' dystopian' ' novel' ' about' ' a' ' future' ' where'
+ ' books']" ", the author of Fahrenheit 451 .
+
+ The book is a dystopian novel about a future where books" False to the film. Ray Bradbury praised the film's 6 [' to', ' the', ' film', '.', ' Ray', ' Brad', 'bury']
+1519 341 Name of father of x -1 Name of father of Ray Bradbury Leonard Spaulding Bradbury, Sr. Ray Bradbury "[',' ' the' ' author' ' of' ' Fahrenheit' ' 451' '.' '\n' '\n' 'The'
+ ' book' ' is' ' a' ' dystopian' ' novel' ' about' ' a' ' future' ' where'
+ ' books']" ", the author of Fahrenheit 451 .
+
+ The book is a dystopian novel about a future where books" False 2006, and the Ray Bradbury Award for Outstanding 6 [' 2006', ',', ' and', ' the', ' Ray', ' Brad', 'bury']
+1520 341 Name of father of x -1 Name of father of Ray Bradbury Leonard Spaulding Bradbury, Sr. Ray Bradbury "[',' ' the' ' author' ' of' ' Fahrenheit' ' 451' '.' '\n' '\n' 'The'
+ ' book' ' is' ' a' ' dystopian' ' novel' ' about' ' a' ' future' ' where'
+ ' books']" ", the author of Fahrenheit 451 .
+
+ The book is a dystopian novel about a future where books" False Hemingway); Ray Bradbury wrote The 6 [' Hem', 'ing', 'way', ');', ' Ray', ' Brad', 'bury']
+1521 341 Name of father of x -1 Name of father of Ray Bradbury Leonard Spaulding Bradbury, Sr. Ray Bradbury "[',' ' the' ' author' ' of' ' Fahrenheit' ' 451' '.' '\n' '\n' 'The'
+ ' book' ' is' ' a' ' dystopian' ' novel' ' about' ' a' ' future' ' where'
+ ' books']" ", the author of Fahrenheit 451 .
+
+ The book is a dystopian novel about a future where books" False " Year Picnic"", by Ray Bradbury — the first of" 7 "[' Year', ' Pic', 'nic', '"",', ' by', ' Ray', ' Brad', 'bury']"
+1522 342 Name of father of x -1 Name of father of Giovanni Bellini Jacopo Bellini Giovanni Bellini "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 14' '30' ','
+ ' and' ' died' ' in' '\n' '14' '94' '.' '\n' '\n']" ", the painter , who was born in 14 30 , and died in
+ 14 94 .
+
+" False and artists like Giovanni Bellini came under 5 [' and', ' artists', ' like', ' Giovanni', ' Bell', 'ini']
+1523 342 Name of father of x -1 Name of father of Giovanni Bellini Jacopo Bellini Giovanni Bellini "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 14' '30' ','
+ ' and' ' died' ' in' '\n' '14' '94' '.' '\n' '\n']" ", the painter , who was born in 14 30 , and died in
+ 14 94 .
+
+" False leading painter Giovanni Bellini adopted the technique 4 [' leading', ' painter', ' Giovanni', ' Bell', 'ini']
+1524 342 Name of father of x -1 Name of father of Giovanni Bellini Jacopo Bellini Giovanni Bellini "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 14' '30' ','
+ ' and' ' died' ' in' '\n' '14' '94' '.' '\n' '\n']" ", the painter , who was born in 14 30 , and died in
+ 14 94 .
+
+" False and artists like Giovanni Bellini came under 5 [' and', ' artists', ' like', ' Giovanni', ' Bell', 'ini']
+1525 342 Name of father of x -1 Name of father of Giovanni Bellini Jacopo Bellini Giovanni Bellini "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 14' '30' ','
+ ' and' ' died' ' in' '\n' '14' '94' '.' '\n' '\n']" ", the painter , who was born in 14 30 , and died in
+ 14 94 .
+
+" False leading painter Giovanni Bellini adopted the 4 [' leading', ' painter', ' Giovanni', ' Bell', 'ini']
+1526 342 Name of father of x -1 Name of father of Giovanni Bellini Jacopo Bellini Giovanni Bellini "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 14' '30' ','
+ ' and' ' died' ' in' '\n' '14' '94' '.' '\n' '\n']" ", the painter , who was born in 14 30 , and died in
+ 14 94 .
+
+" False artists like Giovanni Bellini came under the 4 [' artists', ' like', ' Giovanni', ' Bell', 'ini']
+1527 343 Name of father of x -1 Name of father of Lily Allen Keith Allen Lily Allen "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Simon' ' Cow'
+ 'ell' ',' ' the' ' music' ' producer' '.' '\n' '\n' 'The' ' couple']" ", the singer , and her husband , Simon Cow ell , the music producer .
+
+ The couple" False singer and songwriter Lily Allen was quoted to have 5 [' singer', ' and', ' song', 'writer', ' Lily', ' Allen']
+1528 343 Name of father of x -1 Name of father of Lily Allen Keith Allen Lily Allen "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Simon' ' Cow'
+ 'ell' ',' ' the' ' music' ' producer' '.' '\n' '\n' 'The' ' couple']" ", the singer , and her husband , Simon Cow ell , the music producer .
+
+ The couple" False " Written by — Lily Allen and Greg Kurstin
+" 4 [' Written', ' by', ' —', ' Lily', ' Allen']
+1529 343 Name of father of x -1 Name of father of Lily Allen Keith Allen Lily Allen "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Simon' ' Cow'
+ 'ell' ',' ' the' ' music' ' producer' '.' '\n' '\n' 'The' ' couple']" ", the singer , and her husband , Simon Cow ell , the music producer .
+
+ The couple" False as British pop star Lily Allen on her second album, 5 [' as', ' British', ' pop', ' star', ' Lily', ' Allen']
+1530 343 Name of father of x -1 Name of father of Lily Allen Keith Allen Lily Allen "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Simon' ' Cow'
+ 'ell' ',' ' the' ' music' ' producer' '.' '\n' '\n' 'The' ' couple']" ", the singer , and her husband , Simon Cow ell , the music producer .
+
+ The couple" False 2014, English singer Lily Allen announced that 5 [' 2014', ',', ' English', ' singer', ' Lily', ' Allen']
+1531 343 Name of father of x -1 Name of father of Lily Allen Keith Allen Lily Allen "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Simon' ' Cow'
+ 'ell' ',' ' the' ' music' ' producer' '.' '\n' '\n' 'The' ' couple']" ", the singer , and her husband , Simon Cow ell , the music producer .
+
+ The couple" False songwriter Lily Allen was quoted to have 3 [' song', 'writer', ' Lily', ' Allen']
+1532 344 Name of father of x -1 Name of father of Uma Thurman Robert Thurman Uma Thurman "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' wife' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' ""'
+ 'Kill' ' Bill']" ", the actress who played the role of the wife of the main character in the movie "" Kill Bill" False pay homage to Uma Thurman as The Bride and Lucy 6 [' pay', ' homage', ' to', ' U', 'ma', ' Thur', 'man']
+1533 344 Name of father of x -1 Name of father of Uma Thurman Robert Thurman Uma Thurman "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' wife' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' ""'
+ 'Kill' ' Bill']" ", the actress who played the role of the wife of the main character in the movie "" Kill Bill" False Supporting Actor), Uma Thurman (Worst Supporting Actress), 6 [' Supporting', ' Actor', '),', ' U', 'ma', ' Thur', 'man']
+1534 344 Name of father of x -1 Name of father of Uma Thurman Robert Thurman Uma Thurman "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' wife' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' ""'
+ 'Kill' ' Bill']" ", the actress who played the role of the wife of the main character in the movie "" Kill Bill" False Patrick Stewart and Uma Thurman had been cast. Natalie 6 [' Patrick', ' Stewart', ' and', ' U', 'ma', ' Thur', 'man']
+1535 344 Name of father of x -1 Name of father of Uma Thurman Robert Thurman Uma Thurman "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' wife' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' ""'
+ 'Kill' ' Bill']" ", the actress who played the role of the wife of the main character in the movie "" Kill Bill" False 3 ['U', 'ma', ' Thur', 'man']
+1536 344 Name of father of x -1 Name of father of Uma Thurman Robert Thurman Uma Thurman "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' wife' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' ""'
+ 'Kill' ' Bill']" ", the actress who played the role of the wife of the main character in the movie "" Kill Bill" False Mr. Freeze, while Uma Thurman starred as Poison 8 [' Mr', '.', ' Freeze', ',', ' while', ' U', 'ma', ' Thur', 'man']
+1537 345 Name of father of x -1 Name of father of Joseph Vernet Antoine Vernet Joseph Vernet "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of his wife , the
+ " False by sea. Claude Joseph Vernet (1714 – 1789) 6 [' by', ' sea', '.', ' Claude', ' Joseph', ' Vern', 'et']
+1538 345 Name of father of x -1 Name of father of Joseph Vernet Antoine Vernet Joseph Vernet "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of his wife , the
+ " False made by sea. Claude Joseph Vernet (1714 – 1789) 7 [' made', ' by', ' sea', '.', ' Claude', ' Joseph', ' Vern', 'et']
+1539 345 Name of father of x -1 Name of father of Joseph Vernet Antoine Vernet Joseph Vernet "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and of his wife , the
+ " False by sea. Claude Joseph Vernet (1714 – 1789) 6 [' by', ' sea', '.', ' Claude', ' Joseph', ' Vern', 'et']
+1540 346 Name of father of x -1 Name of father of Alexander the Great Philip II of Macedon Alexander the Great "[',' ' and' ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and'
+ ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and' ' the'
+ ' father']" , and the father of Alexander the Great , and the father of Alexander the Great , and the father False Persian Empire. Alexander the Great besieged Gaza, 5 [' Persian', ' Empire', '.', ' Alexander', ' the', ' Great']
+1541 346 Name of father of x -1 Name of father of Alexander the Great Philip II of Macedon Alexander the Great "[',' ' and' ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and'
+ ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and' ' the'
+ ' father']" , and the father of Alexander the Great , and the father of Alexander the Great , and the father False 2 ['Alexander', ' the', ' Great']
+1542 346 Name of father of x -1 Name of father of Alexander the Great Philip II of Macedon Alexander the Great "[',' ' and' ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and'
+ ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and' ' the'
+ ' father']" , and the father of Alexander the Great , and the father of Alexander the Great , and the father False 2 ['Alexander', ' the', ' Great']
+1543 346 Name of father of x -1 Name of father of Alexander the Great Philip II of Macedon Alexander the Great "[',' ' and' ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and'
+ ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and' ' the'
+ ' father']" , and the father of Alexander the Great , and the father of Alexander the Great , and the father False luminaries as Alexander the Great and Hannibal, this 5 [' lumin', 'aries', ' as', ' Alexander', ' the', ' Great']
+1544 346 Name of father of x -1 Name of father of Alexander the Great Philip II of Macedon Alexander the Great "[',' ' and' ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and'
+ ' the' ' father' ' of' ' Alexander' ' the' ' Great' ',' ' and' ' the'
+ ' father']" , and the father of Alexander the Great , and the father of Alexander the Great , and the father False title role of Alexander the Great in Oliver Stone's 5 [' title', ' role', ' of', ' Alexander', ' the', ' Great']
+1545 347 Name of father of x -1 Name of father of Dante Gabriel Rossetti Gabriele Rossetti Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ' William' ' Morris' '.']" , the poet , and his wife , Christina , who was the daughter of the poet William Morris . False There he met Dante Gabriel Rossetti and other members 6 [' There', ' he', ' met', ' Dante', ' Gabriel', ' Ross', 'etti']
+1546 347 Name of father of x -1 Name of father of Dante Gabriel Rossetti Gabriele Rossetti Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ' William' ' Morris' '.']" , the poet , and his wife , Christina , who was the daughter of the poet William Morris . False style — Holman Hunt, Dante Gabriel Rossetti and John Everett 9 [' style', ' —', ' Hol', 'man', ' Hunt', ',', ' Dante', ' Gabriel', ' Ross', 'etti']
+1547 347 Name of father of x -1 Name of father of Dante Gabriel Rossetti Gabriele Rossetti Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ' William' ' Morris' '.']" , the poet , and his wife , Christina , who was the daughter of the poet William Morris . False triptych by Dante Gabriel Rossetti was designed for use 7 [' tri', 'pty', 'ch', ' by', ' Dante', ' Gabriel', ' Ross', 'etti']
+1548 347 Name of father of x -1 Name of father of Dante Gabriel Rossetti Gabriele Rossetti Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ' William' ' Morris' '.']" , the poet , and his wife , Christina , who was the daughter of the poet William Morris . False There he met Dante Gabriel Rossetti and other members 6 [' There', ' he', ' met', ' Dante', ' Gabriel', ' Ross', 'etti']
+1549 347 Name of father of x -1 Name of father of Dante Gabriel Rossetti Gabriele Rossetti Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ' William' ' Morris' '.']" , the poet , and his wife , Christina , who was the daughter of the poet William Morris . False there including Dante Gabriel Rossetti and William Michael 5 [' there', ' including', ' Dante', ' Gabriel', ' Ross', 'etti']
+1550 348 Name of father of x -1 Name of father of Jack London William Henry Chaney Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' was' ' a' ' man' ' of' ' the']" , the author of the book , and the man who wrote the book , was a man of the False rush. The writer Jack London incorporated 5 [' rush', '.', ' The', ' writer', ' Jack', ' London']
+1551 348 Name of father of x -1 Name of father of Jack London William Henry Chaney Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' was' ' a' ' man' ' of' ' the']" , the author of the book , and the man who wrote the book , was a man of the False patriotism, Jack London on society, and Nietzsche 3 [' patriotism', ',', ' Jack', ' London']
+1552 348 Name of father of x -1 Name of father of Jack London William Henry Chaney Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' was' ' a' ' man' ' of' ' the']" , the author of the book , and the man who wrote the book , was a man of the False patriotism, Jack London on society, 3 [' patriotism', ',', ' Jack', ' London']
+1553 348 Name of father of x -1 Name of father of Jack London William Henry Chaney Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' was' ' a' ' man' ' of' ' the']" , the author of the book , and the man who wrote the book , was a man of the False readers including Jack London and Isadora 3 [' readers', ' including', ' Jack', ' London']
+1554 348 Name of father of x -1 Name of father of Jack London William Henry Chaney Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' wrote' ' the' ' book' ',' ' was' ' a' ' man' ' of' ' the']" , the author of the book , and the man who wrote the book , was a man of the False rebuttals. As Jack London would later 5 [' rebutt', 'als', '.', ' As', ' Jack', ' London']
+1555 349 Name of father of x -1 Name of father of Stendhal Chérubin Beyle Stendhal "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 17'
+ '83' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 17 83 .
+
+ The first of the great French" False " Voltaire, Goethe, Stendhal and Victor Hugo.
+" 8 [' Volt', 'aire', ',', ' Go', 'ethe', ',', ' St', 'end', 'hal']
+1556 349 Name of father of x -1 Name of father of Stendhal Chérubin Beyle Stendhal "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 17'
+ '83' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 17 83 .
+
+ The first of the great French" False Unknown Letters of Stendhal '’, which featured 5 [' Unknown', ' Letters', ' of', ' St', 'end', 'hal']
+1557 349 Name of father of x -1 Name of father of Stendhal Chérubin Beyle Stendhal "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 17'
+ '83' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 17 83 .
+
+ The first of the great French" False of recreation when Stendhal visited in 1827. A 5 [' of', ' recreation', ' when', ' St', 'end', 'hal']
+1558 349 Name of father of x -1 Name of father of Stendhal Chérubin Beyle Stendhal "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 17'
+ '83' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 17 83 .
+
+ The first of the great French" False " Voltaire, Goethe, Stendhal and Victor Hugo.
+" 8 [' Volt', 'aire', ',', ' Go', 'ethe', ',', ' St', 'end', 'hal']
+1559 349 Name of father of x -1 Name of father of Stendhal Chérubin Beyle Stendhal "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 17'
+ '83' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 17 83 .
+
+ The first of the great French" False Unknown Letters of Stendhal '’, which featured 5 [' Unknown', ' Letters', ' of', ' St', 'end', 'hal']
+1560 350 Name of father of x -1 Name of father of Werner Heisenberg August Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False principle, formulated by Werner Heisenberg in 1926. In this 7 [' principle', ',', ' formulated', ' by', ' Werner', ' He', 'isen', 'berg']
+1561 350 Name of father of x -1 Name of father of Werner Heisenberg August Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False his student Werner Heisenberg to be Born's assistant. 5 [' his', ' student', ' Werner', ' He', 'isen', 'berg']
+1562 350 Name of father of x -1 Name of father of Werner Heisenberg August Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False where he heard Werner Heisenberg present a paper 6 [' where', ' he', ' heard', ' Werner', ' He', 'isen', 'berg']
+1563 350 Name of father of x -1 Name of father of Werner Heisenberg August Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False 5 ['W', 'er', 'ner', ' He', 'isen', 'berg']
+1564 350 Name of father of x -1 Name of father of Werner Heisenberg August Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False Max von Laue, Werner Heisenberg and Carl Friedrich 8 [' Max', ' von', ' La', 'ue', ',', ' Werner', ' He', 'isen', 'berg']
+1565 352 Name of father of x -1 Name of father of Ernest Rutherford James Rutherford Ernest Rutherford "[',' ' the' ' famous' ' physicist' ',' ' who' ' discovered' ' the'
+ ' atomic' ' nucleus' '.' '\n' '\n' 'The' ' first' ' atomic' ' bomb'
+ ' was' ' invented' ' by']" ", the famous physicist , who discovered the atomic nucleus .
+
+ The first atomic bomb was invented by" False atomic number. Ernest Rutherford (1919) observed 4 [' atomic', ' number', '.', ' Ernest', ' Rutherford']
+1566 352 Name of father of x -1 Name of father of Ernest Rutherford James Rutherford Ernest Rutherford "[',' ' the' ' famous' ' physicist' ',' ' who' ' discovered' ' the'
+ ' atomic' ' nucleus' '.' '\n' '\n' 'The' ' first' ' atomic' ' bomb'
+ ' was' ' invented' ' by']" ", the famous physicist , who discovered the atomic nucleus .
+
+ The first atomic bomb was invented by" False 3 ['Er', 'n', 'est', ' Rutherford']
+1567 352 Name of father of x -1 Name of father of Ernest Rutherford James Rutherford Ernest Rutherford "[',' ' the' ' famous' ' physicist' ',' ' who' ' discovered' ' the'
+ ' atomic' ' nucleus' '.' '\n' '\n' 'The' ' first' ' atomic' ' bomb'
+ ' was' ' invented' ' by']" ", the famous physicist , who discovered the atomic nucleus .
+
+ The first atomic bomb was invented by" False studied under Ernest Rutherford and earned 3 [' studied', ' under', ' Ernest', ' Rutherford']
+1568 352 Name of father of x -1 Name of father of Ernest Rutherford James Rutherford Ernest Rutherford "[',' ' the' ' famous' ' physicist' ',' ' who' ' discovered' ' the'
+ ' atomic' ' nucleus' '.' '\n' '\n' 'The' ' first' ' atomic' ' bomb'
+ ' was' ' invented' ' by']" ", the famous physicist , who discovered the atomic nucleus .
+
+ The first atomic bomb was invented by" False different systems. After Ernest Rutherford discovered the 5 [' different', ' systems', '.', ' After', ' Ernest', ' Rutherford']
+1569 352 Name of father of x -1 Name of father of Ernest Rutherford James Rutherford Ernest Rutherford "[',' ' the' ' famous' ' physicist' ',' ' who' ' discovered' ' the'
+ ' atomic' ' nucleus' '.' '\n' '\n' 'The' ' first' ' atomic' ' bomb'
+ ' was' ' invented' ' by']" ", the famous physicist , who discovered the atomic nucleus .
+
+ The first atomic bomb was invented by" False solution offered. Ernest Rutherford suggested that the 4 [' solution', ' offered', '.', ' Ernest', ' Rutherford']
+1570 353 Name of father of x -1 Name of father of Frederick II of Prussia Friedrich Wilhelm I of Prussia Frederick II of Prussia "['\n' '\n' 'Category' ':' '17' 'th' '-' 'century' ' births' '\n'
+ 'Category' ':' '17' 'th' '-' 'century' ' deaths' '\n' 'Category' ':']" "
+
+ Category : 17 th - century births
+ Category : 17 th - century deaths
+ Category :" False hegemony. King Frederick II of Prussia had no intention 7 [' hegemony', '.', ' King', ' Frederick', ' II', ' of', ' Pr', 'ussia']
+1571 353 Name of father of x -1 Name of father of Frederick II of Prussia Friedrich Wilhelm I of Prussia Frederick II of Prussia "['\n' '\n' 'Category' ':' '17' 'th' '-' 'century' ' births' '\n'
+ 'Category' ':' '17' 'th' '-' 'century' ' deaths' '\n' 'Category' ':']" "
+
+ Category : 17 th - century births
+ Category : 17 th - century deaths
+ Category :" False December, King Frederick II of Prussia invaded the 7 [' December', ',', ' King', ' Frederick', ' II', ' of', ' Pr', 'ussia']
+1572 353 Name of father of x -1 Name of father of Frederick II of Prussia Friedrich Wilhelm I of Prussia Frederick II of Prussia "['\n' '\n' 'Category' ':' '17' 'th' '-' 'century' ' births' '\n'
+ 'Category' ':' '17' 'th' '-' 'century' ' deaths' '\n' 'Category' ':']" "
+
+ Category : 17 th - century births
+ Category : 17 th - century deaths
+ Category :" False In December, King Frederick II of Prussia invaded the Duchy 8 [' In', ' December', ',', ' King', ' Frederick', ' II', ' of', ' Pr', 'ussia']
+1573 353 Name of father of x -1 Name of father of Frederick II of Prussia Friedrich Wilhelm I of Prussia Frederick II of Prussia "['\n' '\n' 'Category' ':' '17' 'th' '-' 'century' ' births' '\n'
+ 'Category' ':' '17' 'th' '-' 'century' ' deaths' '\n' 'Category' ':']" "
+
+ Category : 17 th - century births
+ Category : 17 th - century deaths
+ Category :" False Years'War — used by Frederick II of Prussia — and in Napoleon's 10 "[' Years', ""'"", 'War', ' —', ' used', ' by', ' Frederick', ' II', ' of', ' Pr', 'ussia']"
+1574 353 Name of father of x -1 Name of father of Frederick II of Prussia Friedrich Wilhelm I of Prussia Frederick II of Prussia "['\n' '\n' 'Category' ':' '17' 'th' '-' 'century' ' births' '\n'
+ 'Category' ':' '17' 'th' '-' 'century' ' deaths' '\n' 'Category' ':']" "
+
+ Category : 17 th - century births
+ Category : 17 th - century deaths
+ Category :" False German hegemony. King Frederick II of Prussia had no intention 8 [' German', ' hegemony', '.', ' King', ' Frederick', ' II', ' of', ' Pr', 'ussia']
+1575 354 Name of father of x -1 Name of father of Joseph Conrad Apollo Korzeniowski Joseph Conrad "[',' ' the' ' author' ' of' ' ""' 'The' ' Secret' ' Share' 'r' '""' ' and'
+ ' ""' 'The' ' Shadow' '-' 'Line' '.""' '\n' '\n' 'The']" ", the author of "" The Secret Share r "" and "" The Shadow - Line .""
+
+ The" False Crane's friend Joseph Conrad wrote while 4 "[' Crane', ""'s"", ' friend', ' Joseph', ' Conrad']"
+1576 354 Name of father of x -1 Name of father of Joseph Conrad Apollo Korzeniowski Joseph Conrad "[',' ' the' ' author' ' of' ' ""' 'The' ' Secret' ' Share' 'r' '""' ' and'
+ ' ""' 'The' ' Shadow' '-' 'Line' '.""' '\n' '\n' 'The']" ", the author of "" The Secret Share r "" and "" The Shadow - Line .""
+
+ The" False " famous criticism of Joseph Conrad as ""a thoroughgoing" 4 [' famous', ' criticism', ' of', ' Joseph', ' Conrad']
+1577 354 Name of father of x -1 Name of father of Joseph Conrad Apollo Korzeniowski Joseph Conrad "[',' ' the' ' author' ' of' ' ""' 'The' ' Secret' ' Share' 'r' '""' ' and'
+ ' ""' 'The' ' Shadow' '-' 'Line' '.""' '\n' '\n' 'The']" ", the author of "" The Secret Share r "" and "" The Shadow - Line .""
+
+ The" False first published book, Joseph Conrad and the Fiction 5 [' first', ' published', ' book', ',', ' Joseph', ' Conrad']
+1578 354 Name of father of x -1 Name of father of Joseph Conrad Apollo Korzeniowski Joseph Conrad "[',' ' the' ' author' ' of' ' ""' 'The' ' Secret' ' Share' 'r' '""' ' and'
+ ' ""' 'The' ' Shadow' '-' 'Line' '.""' '\n' '\n' 'The']" ", the author of "" The Secret Share r "" and "" The Shadow - Line .""
+
+ The" False " Horseman"") and Joseph Conrad (Under Western" 5 "[' Horse', 'man', '"")', ' and', ' Joseph', ' Conrad']"
+1579 354 Name of father of x -1 Name of father of Joseph Conrad Apollo Korzeniowski Joseph Conrad "[',' ' the' ' author' ' of' ' ""' 'The' ' Secret' ' Share' 'r' '""' ' and'
+ ' ""' 'The' ' Shadow' '-' 'Line' '.""' '\n' '\n' 'The']" ", the author of "" The Secret Share r "" and "" The Shadow - Line .""
+
+ The" False " Horseman"") and Joseph Conrad (Under Western" 5 "[' Horse', 'man', '"")', ' and', ' Joseph', ' Conrad']"
+1580 355 Name of father of x -1 Name of father of Gwyneth Paltrow Bruce Paltrow Gwyneth Paltrow "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False Baker Hall, Gwyneth Paltrow and John C. 8 [' Baker', ' Hall', ',', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1581 355 Name of father of x -1 Name of father of Gwyneth Paltrow Bruce Paltrow Gwyneth Paltrow "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False the mother of Gwyneth Paltrow and former mother-in-law 8 [' the', ' mother', ' of', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1582 355 Name of father of x -1 Name of father of Gwyneth Paltrow Bruce Paltrow Gwyneth Paltrow "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False Stark, as did Gwyneth Paltrow as Pepper Potts and 9 [' Stark', ',', ' as', ' did', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1583 355 Name of father of x -1 Name of father of Gwyneth Paltrow Bruce Paltrow Gwyneth Paltrow "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False Philip Baker Hall, Gwyneth Paltrow and John C. Reilly 9 [' Philip', ' Baker', ' Hall', ',', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1584 355 Name of father of x -1 Name of father of Gwyneth Paltrow Bruce Paltrow Gwyneth Paltrow "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' ask' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False Freeman and Gwyneth Paltrow in the crime 7 [' Freeman', ' and', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1585 356 Name of father of x -1 Name of father of John Milton John Milton John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' English'
+ ' language' '.' '\n' '\n' 'The' ' first' ' edition' ' of' ' the' ' _']" ", the poet , and the father of the English language .
+
+ The first edition of the _" False propaganda. John Milton wrote a Parliamentary 3 [' propaganda', '.', ' John', ' Milton']
+1586 356 Name of father of x -1 Name of father of John Milton John Milton John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' English'
+ ' language' '.' '\n' '\n' 'The' ' first' ' edition' ' of' ' the' ' _']" ", the poet , and the father of the English language .
+
+ The first edition of the _" False in Paradise Lost, John Milton depicted Satan 5 [' in', ' Paradise', ' Lost', ',', ' John', ' Milton']
+1587 356 Name of father of x -1 Name of father of John Milton John Milton John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' English'
+ ' language' '.' '\n' '\n' 'The' ' first' ' edition' ' of' ' the' ' _']" ", the poet , and the father of the English language .
+
+ The first edition of the _" False respectively. Poet John Milton wrote anonymous 5 [' respectively', '.', ' Po', 'et', ' John', ' Milton']
+1588 356 Name of father of x -1 Name of father of John Milton John Milton John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' English'
+ ' language' '.' '\n' '\n' 'The' ' first' ' edition' ' of' ' the' ' _']" ", the poet , and the father of the English language .
+
+ The first edition of the _" False John Eachard, and John Milton had previously advocated 6 [' John', ' Each', 'ard', ',', ' and', ' John', ' Milton']
+1589 356 Name of father of x -1 Name of father of John Milton John Milton John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' English'
+ ' language' '.' '\n' '\n' 'The' ' first' ' edition' ' of' ' the' ' _']" ", the poet , and the father of the English language .
+
+ The first edition of the _" False poetry. Dante and John Milton reference the star, 5 [' poetry', '.', ' Dante', ' and', ' John', ' Milton']
+1590 357 Name of father of x -1 Name of father of Caravaggio Fermo Merixio Caravaggio "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '10' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 71 , and died in 16 10 .
+
+ The" False Michelangelo Merisi da Caravaggio located at Italy 8 [' Michel', 'angelo', ' Mer', 'isi', ' da', ' Car', 'av', 'agg', 'io']
+1591 357 Name of father of x -1 Name of father of Caravaggio Fermo Merixio Caravaggio "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '10' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 71 , and died in 16 10 .
+
+ The" False 3 ['Car', 'av', 'agg', 'io']
+1592 357 Name of father of x -1 Name of father of Caravaggio Fermo Merixio Caravaggio "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '10' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 71 , and died in 16 10 .
+
+ The" False 3 ['Car', 'av', 'agg', 'io']
+1593 357 Name of father of x -1 Name of father of Caravaggio Fermo Merixio Caravaggio "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '10' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 71 , and died in 16 10 .
+
+ The" False admirer of Caravaggio and utilised tenebrism 6 [' admire', 'r', ' of', ' Car', 'av', 'agg', 'io']
+1594 357 Name of father of x -1 Name of father of Caravaggio Fermo Merixio Caravaggio "[',' ' the' ' painter' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '10' '.' '\n' '\n' 'The']" ", the painter , who was born in 15 71 , and died in 16 10 .
+
+ The" False at the Battle of Caravaggio in 1448. It 7 [' at', ' the', ' Battle', ' of', ' Car', 'av', 'agg', 'io']
+1595 358 Name of father of x -1 Name of father of Francis Bacon Capt. Anthony Edward Mortimer Bacon Francis Bacon "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',']" ", the philosopher , and the father of the modern world .
+
+ The first of the three ," False art of the painter Francis Bacon when developing the 5 [' art', ' of', ' the', ' painter', ' Francis', ' Bacon']
+1596 358 Name of father of x -1 Name of father of Francis Bacon Capt. Anthony Edward Mortimer Bacon Francis Bacon "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',']" ", the philosopher , and the father of the modern world .
+
+ The first of the three ," False Hertfordshire. Sir Francis Bacon (1561 – 1626), 6 [' Hert', 'ford', 'shire', '.', ' Sir', ' Francis', ' Bacon']
+1597 358 Name of father of x -1 Name of father of Francis Bacon Capt. Anthony Edward Mortimer Bacon Francis Bacon "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',']" ", the philosopher , and the father of the modern world .
+
+ The first of the three ," False " biographer of Francis Bacon noted that ""[t] he" 4 [' bi', 'ographer', ' of', ' Francis', ' Bacon']
+1598 358 Name of father of x -1 Name of father of Francis Bacon Capt. Anthony Edward Mortimer Bacon Francis Bacon "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',']" ", the philosopher , and the father of the modern world .
+
+ The first of the three ," False 2 ['Franc', 'is', ' Bacon']
+1599 358 Name of father of x -1 Name of father of Francis Bacon Capt. Anthony Edward Mortimer Bacon Francis Bacon "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' the'
+ ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',']" ", the philosopher , and the father of the modern world .
+
+ The first of the three ," False philosopher and scientist Francis Bacon (1561 – 1626), in 4 [' philosopher', ' and', ' scientist', ' Francis', ' Bacon']
+1600 359 Name of father of x -1 Name of father of George Eliot Robert Evans George Eliot "[',' ' the' ' author' ' of' ' _' 'Middle' 'm' 'arch' '_' ',' ' and' ' _'
+ 'Daniel' ' Der' 'onda' '_' ',' ' and' '\n' ' ']" ", the author of _ Middle m arch _ , and _ Daniel Der onda _ , and
+ " False Review from 1851. George Eliot lived at No. 6 [' Review', ' from', ' 18', '51', '.', ' George', ' Eliot']
+1601 359 Name of father of x -1 Name of father of George Eliot Robert Evans George Eliot "[',' ' the' ' author' ' of' ' _' 'Middle' 'm' 'arch' '_' ',' ' and' ' _'
+ 'Daniel' ' Der' 'onda' '_' ',' ' and' '\n' ' ']" ", the author of _ Middle m arch _ , and _ Daniel Der onda _ , and
+ " False including George Eliot and William Michael 2 [' including', ' George', ' Eliot']
+1602 359 Name of father of x -1 Name of father of George Eliot Robert Evans George Eliot "[',' ' the' ' author' ' of' ' _' 'Middle' 'm' 'arch' '_' ',' ' and' ' _'
+ 'Daniel' ' Der' 'onda' '_' ',' ' and' '\n' ' ']" ", the author of _ Middle m arch _ , and _ Daniel Der onda _ , and
+ " False 1 ['George', ' Eliot']
+1603 359 Name of father of x -1 Name of father of George Eliot Robert Evans George Eliot "[',' ' the' ' author' ' of' ' _' 'Middle' 'm' 'arch' '_' ',' ' and' ' _'
+ 'Daniel' ' Der' 'onda' '_' ',' ' and' '\n' ' ']" ", the author of _ Middle m arch _ , and _ Daniel Der onda _ , and
+ " False writers such as George Eliot professed profound 4 [' writers', ' such', ' as', ' George', ' Eliot']
+1604 359 Name of father of x -1 Name of father of George Eliot Robert Evans George Eliot "[',' ' the' ' author' ' of' ' _' 'Middle' 'm' 'arch' '_' ',' ' and' ' _'
+ 'Daniel' ' Der' 'onda' '_' ',' ' and' '\n' ' ']" ", the author of _ Middle m arch _ , and _ Daniel Der onda _ , and
+ " False while writers such as George Eliot professed profound 5 [' while', ' writers', ' such', ' as', ' George', ' Eliot']
+1605 360 Name of father of x -1 Name of father of George VI George V George VI "['.' '\n' '\n' 'The' ' King' ' was' ' born' ' at' ' Sand' 'ring' 'ham'
+ ' on' ' the' ' 20' 'th' ' of' ' May' ',' ' 18' '92']" ".
+
+ The King was born at Sand ring ham on the 20 th of May , 18 92" False to convey King George VI and his family on 4 [' to', ' convey', ' King', ' George', ' VI']
+1606 360 Name of father of x -1 Name of father of George VI George V George VI "['.' '\n' '\n' 'The' ' King' ' was' ' born' ' at' ' Sand' 'ring' 'ham'
+ ' on' ' the' ' 20' 'th' ' of' ' May' ',' ' 18' '92']" ".
+
+ The King was born at Sand ring ham on the 20 th of May , 18 92" False United Kingdom, George VI announced he 4 [' United', ' Kingdom', ',', ' George', ' VI']
+1607 360 Name of father of x -1 Name of father of George VI George V George VI "['.' '\n' '\n' 'The' ' King' ' was' ' born' ' at' ' Sand' 'ring' 'ham'
+ ' on' ' the' ' 20' 'th' ' of' ' May' ',' ' 18' '92']" ".
+
+ The King was born at Sand ring ham on the 20 th of May , 18 92" False 1 ['George', ' VI']
+1608 360 Name of father of x -1 Name of father of George VI George V George VI "['.' '\n' '\n' 'The' ' King' ' was' ' born' ' at' ' Sand' 'ring' 'ham'
+ ' on' ' the' ' 20' 'th' ' of' ' May' ',' ' 18' '92']" ".
+
+ The King was born at Sand ring ham on the 20 th of May , 18 92" False audience with King George VI at Buckingham Palace. 4 [' audience', ' with', ' King', ' George', ' VI']
+1609 360 Name of father of x -1 Name of father of George VI George V George VI "['.' '\n' '\n' 'The' ' King' ' was' ' born' ' at' ' Sand' 'ring' 'ham'
+ ' on' ' the' ' 20' 'th' ' of' ' May' ',' ' 18' '92']" ".
+
+ The King was born at Sand ring ham on the 20 th of May , 18 92" False coronation of King George VI and Queen Elizabeth, 5 [' coron', 'ation', ' of', ' King', ' George', ' VI']
+1610 361 Name of father of x -1 Name of father of Christopher Columbus Domenico Colombo Christopher Columbus "[',' ' the' ' disc' 'ove' 'rer' ' of' ' America' ',' ' and' ' the'
+ ' first' ' European' ' to' ' reach' ' the' ' New' ' World' '.' '\n' '\n']" ", the disc ove rer of America , and the first European to reach the New World .
+
+" False as pilot under Christopher Columbus on his final voyage. 4 [' as', ' pilot', ' under', ' Christopher', ' Columbus']
+1611 361 Name of father of x -1 Name of father of Christopher Columbus Domenico Colombo Christopher Columbus "[',' ' the' ' disc' 'ove' 'rer' ' of' ' America' ',' ' and' ' the'
+ ' first' ' European' ' to' ' reach' ' the' ' New' ' World' '.' '\n' '\n']" ", the disc ove rer of America , and the first European to reach the New World .
+
+" False widely known. Christopher Columbus was inspired enough 4 [' widely', ' known', '.', ' Christopher', ' Columbus']
+1612 361 Name of father of x -1 Name of father of Christopher Columbus Domenico Colombo Christopher Columbus "[',' ' the' ' disc' 'ove' 'rer' ' of' ' America' ',' ' and' ' the'
+ ' first' ' European' ' to' ' reach' ' the' ' New' ' World' '.' '\n' '\n']" ", the disc ove rer of America , and the first European to reach the New World .
+
+" False claiming that Christopher Columbus subdued a revolt 3 [' claiming', ' that', ' Christopher', ' Columbus']
+1613 361 Name of father of x -1 Name of father of Christopher Columbus Domenico Colombo Christopher Columbus "[',' ' the' ' disc' 'ove' 'rer' ' of' ' America' ',' ' and' ' the'
+ ' first' ' European' ' to' ' reach' ' the' ' New' ' World' '.' '\n' '\n']" ", the disc ove rer of America , and the first European to reach the New World .
+
+" False gifts offered to Christopher Columbus when he reached 4 [' gifts', ' offered', ' to', ' Christopher', ' Columbus']
+1614 361 Name of father of x -1 Name of father of Christopher Columbus Domenico Colombo Christopher Columbus "[',' ' the' ' disc' 'ove' 'rer' ' of' ' America' ',' ' and' ' the'
+ ' first' ' European' ' to' ' reach' ' the' ' New' ' World' '.' '\n' '\n']" ", the disc ove rer of America , and the first European to reach the New World .
+
+" False exploration by Christopher Columbus (d. 1506) in 1492 3 [' exploration', ' by', ' Christopher', ' Columbus']
+1615 362 Name of father of x -1 Name of father of Henry Wadsworth Longfellow Stephen Longfellow Henry Wadsworth Longfellow "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' works'
+ ' of' ' Henry' ' Wad' 'sworth' ' Long' 'f' 'ellow' '.' ' I']" "
+
+ I am a great admire r of the works of Henry Wad sworth Long f ellow . I" False property. Henry Wadsworth Longfellow oversaw the creation 7 [' property', '.', ' Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1616 362 Name of father of x -1 Name of father of Henry Wadsworth Longfellow Stephen Longfellow Henry Wadsworth Longfellow "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' works'
+ ' of' ' Henry' ' Wad' 'sworth' ' Long' 'f' 'ellow' '.' ' I']" "
+
+ I am a great admire r of the works of Henry Wad sworth Long f ellow . I" False replica of the Henry Wadsworth Longfellow House in Cambridge, 8 [' replica', ' of', ' the', ' Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1617 362 Name of father of x -1 Name of father of Henry Wadsworth Longfellow Stephen Longfellow Henry Wadsworth Longfellow "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' works'
+ ' of' ' Henry' ' Wad' 'sworth' ' Long' 'f' 'ellow' '.' ' I']" "
+
+ I am a great admire r of the works of Henry Wad sworth Long f ellow . I" False a boarder that Henry Wadsworth Longfellow came into the 9 [' a', ' board', 'er', ' that', ' Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1618 362 Name of father of x -1 Name of father of Henry Wadsworth Longfellow Stephen Longfellow Henry Wadsworth Longfellow "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' works'
+ ' of' ' Henry' ' Wad' 'sworth' ' Long' 'f' 'ellow' '.' ' I']" "
+
+ I am a great admire r of the works of Henry Wad sworth Long f ellow . I" False publicly accusing Henry Wadsworth Longfellow of plagiarism, though 7 [' publicly', ' accusing', ' Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1619 362 Name of father of x -1 Name of father of Henry Wadsworth Longfellow Stephen Longfellow Henry Wadsworth Longfellow "['\n' '\n' 'I' ' am' ' a' ' great' ' admire' 'r' ' of' ' the' ' works'
+ ' of' ' Henry' ' Wad' 'sworth' ' Long' 'f' 'ellow' '.' ' I']" "
+
+ I am a great admire r of the works of Henry Wad sworth Long f ellow . I" False 5 ['Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1620 363 Name of father of x -1 Name of father of Neil Gaiman David Gaiman Neil Gaiman "[',' ' author' ' of' ' the' ' Sand' 'man' ' series' ',' ' and' ' the'
+ ' author' ' of' ' the' ' graphic' ' novel' ',' ' The' ' Graveyard'
+ ' Book' '.']" , author of the Sand man series , and the author of the graphic novel , The Graveyard Book . False (issues # 25 & 26) and Neil Gaiman (issue # 27) 10 [' (', 'issues', ' #', ' 25', ' &', ' 26', ')', ' and', ' Neil', ' G', 'aiman']
+1621 363 Name of father of x -1 Name of father of Neil Gaiman David Gaiman Neil Gaiman "[',' ' author' ' of' ' the' ' Sand' 'man' ' series' ',' ' and' ' the'
+ ' author' ' of' ' the' ' graphic' ' novel' ',' ' The' ' Graveyard'
+ ' Book' '.']" , author of the Sand man series , and the author of the graphic novel , The Graveyard Book . False series written by Neil Gaiman and published by 5 [' series', ' written', ' by', ' Neil', ' G', 'aiman']
+1622 363 Name of father of x -1 Name of father of Neil Gaiman David Gaiman Neil Gaiman "[',' ' author' ' of' ' the' ' Sand' 'man' ' series' ',' ' and' ' the'
+ ' author' ' of' ' the' ' graphic' ' novel' ',' ' The' ' Graveyard'
+ ' Book' '.']" , author of the Sand man series , and the author of the graphic novel , The Graveyard Book . False collaboration with Neil Gaiman (which was nominated 4 [' collaboration', ' with', ' Neil', ' G', 'aiman']
+1623 363 Name of father of x -1 Name of father of Neil Gaiman David Gaiman Neil Gaiman "[',' ' author' ' of' ' the' ' Sand' 'man' ' series' ',' ' and' ' the'
+ ' author' ' of' ' the' ' graphic' ' novel' ',' ' The' ' Graveyard'
+ ' Book' '.']" , author of the Sand man series , and the author of the graphic novel , The Graveyard Book . False 2 ['Neil', ' G', 'aiman']
+1624 363 Name of father of x -1 Name of father of Neil Gaiman David Gaiman Neil Gaiman "[',' ' author' ' of' ' the' ' Sand' 'man' ' series' ',' ' and' ' the'
+ ' author' ' of' ' the' ' graphic' ' novel' ',' ' The' ' Graveyard'
+ ' Book' '.']" , author of the Sand man series , and the author of the graphic novel , The Graveyard Book . False " Omens, written with Neil Gaiman (1990)
+" 7 [' Om', 'ens', ',', ' written', ' with', ' Neil', ' G', 'aiman']
+1625 364 Name of father of x -1 Name of father of Tom Hanks Amos Mefford Hanks Tom Hanks "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Forrest' ' G'
+ 'ump' ' in' ' the' ' movie' ' Forrest' ' G' 'ump' '.' '\n' '\n']" ", the actor who played the role of Forrest G ump in the movie Forrest G ump .
+
+" False 2009 [1]. Tom Hanks is the Honorary 6 [' 2009', ' [', '1', '].', ' Tom', ' H', 'anks']
+1626 364 Name of father of x -1 Name of father of Tom Hanks Amos Mefford Hanks Tom Hanks "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Forrest' ' G'
+ 'ump' ' in' ' the' ' movie' ' Forrest' ' G' 'ump' '.' '\n' '\n']" ", the actor who played the role of Forrest G ump in the movie Forrest G ump .
+
+" False Seymour Hoffman and Tom Hanks do these amazing 5 [' Seymour', ' Hoffman', ' and', ' Tom', ' H', 'anks']
+1627 364 Name of father of x -1 Name of father of Tom Hanks Amos Mefford Hanks Tom Hanks "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Forrest' ' G'
+ 'ump' ' in' ' the' ' movie' ' Forrest' ' G' 'ump' '.' '\n' '\n']" ", the actor who played the role of Forrest G ump in the movie Forrest G ump .
+
+" False night. Instead of Tom Hanks and Tim Allen, 6 [' night', '.', ' Instead', ' of', ' Tom', ' H', 'anks']
+1628 364 Name of father of x -1 Name of father of Tom Hanks Amos Mefford Hanks Tom Hanks "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Forrest' ' G'
+ 'ump' ' in' ' the' ' movie' ' Forrest' ' G' 'ump' '.' '\n' '\n']" ", the actor who played the role of Forrest G ump in the movie Forrest G ump .
+
+" False night. Instead of Tom Hanks and Tim Allen, 6 [' night', '.', ' Instead', ' of', ' Tom', ' H', 'anks']
+1629 364 Name of father of x -1 Name of father of Tom Hanks Amos Mefford Hanks Tom Hanks "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Forrest' ' G'
+ 'ump' ' in' ' the' ' movie' ' Forrest' ' G' 'ump' '.' '\n' '\n']" ", the actor who played the role of Forrest G ump in the movie Forrest G ump .
+
+" False first: whether stars Tom Hanks and Tim Allen would 6 [' first', ':', ' whether', ' stars', ' Tom', ' H', 'anks']
+1630 365 Name of father of x -1 Name of father of Sarah Bernhardt Edouard Bernhardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '44' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , who was born in Paris in 18 44 .
+
+ The name of the father" False " theatre."" He saw Sarah Bernhardt act, Adeline" 6 "[' theatre', '.""', ' He', ' saw', ' Sarah', ' Bern', 'hardt']"
+1631 365 Name of father of x -1 Name of father of Sarah Bernhardt Edouard Bernhardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '44' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , who was born in Paris in 18 44 .
+
+ The name of the father" False Lillie Langtry, Sarah Bernhardt and Lady Randolph 8 [' L', 'ill', 'ie', ' Lang', 'try', ',', ' Sarah', ' Bern', 'hardt']
+1632 365 Name of father of x -1 Name of father of Sarah Bernhardt Edouard Bernhardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '44' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , who was born in Paris in 18 44 .
+
+ The name of the father" False included Lillie Langtry, Sarah Bernhardt and Lady Randolph Churchill. 9 [' included', ' L', 'ill', 'ie', ' Lang', 'try', ',', ' Sarah', ' Bern', 'hardt']
+1633 365 Name of father of x -1 Name of father of Sarah Bernhardt Edouard Bernhardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '44' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , who was born in Paris in 18 44 .
+
+ The name of the father" False Lillie Langtry, Sarah Bernhardt and Lady Randolph 8 [' L', 'ill', 'ie', ' Lang', 'try', ',', ' Sarah', ' Bern', 'hardt']
+1634 365 Name of father of x -1 Name of father of Sarah Bernhardt Edouard Bernhardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' who' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '44' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the actress , who was born in Paris in 18 44 .
+
+ The name of the father" False Lillie Langtry, Sarah Bernhardt and Lady Randolph 8 [' L', 'ill', 'ie', ' Lang', 'try', ',', ' Sarah', ' Bern', 'hardt']
+1635 366 Name of father of x -1 Name of father of Boethius Manlius Boethius Boethius "[',' ' the' ' son' ' of' ' The' 'od' 'oric' ',' ' king' ' of' ' the'
+ ' Ost' 'rog' 'oths' ',' ' and' ' of' ' the' '\n' ' ']" ", the son of The od oric , king of the Ost rog oths , and of the
+ " False back at least to Boethius, if not to 6 [' back', ' at', ' least', ' to', ' Bo', 'eth', 'ius']
+1636 366 Name of father of x -1 Name of father of Boethius Manlius Boethius Boethius "[',' ' the' ' son' ' of' ' The' 'od' 'oric' ',' ' king' ' of' ' the'
+ ' Ost' 'rog' 'oths' ',' ' and' ' of' ' the' '\n' ' ']" ", the son of The od oric , king of the Ost rog oths , and of the
+ " False translation of Boethius' Consolation 4 [' translation', ' of', ' Bo', 'eth', 'ius']
+1637 366 Name of father of x -1 Name of father of Boethius Manlius Boethius Boethius "[',' ' the' ' son' ' of' ' The' 'od' 'oric' ',' ' king' ' of' ' the'
+ ' Ost' 'rog' 'oths' ',' ' and' ' of' ' the' '\n' ' ']" ", the son of The od oric , king of the Ost rog oths , and of the
+ " False translation of Boethius' Consolation of Philosophy, 4 [' translation', ' of', ' Bo', 'eth', 'ius']
+1638 366 Name of father of x -1 Name of father of Boethius Manlius Boethius Boethius "[',' ' the' ' son' ' of' ' The' 'od' 'oric' ',' ' king' ' of' ' the'
+ ' Ost' 'rog' 'oths' ',' ' and' ' of' ' the' '\n' ' ']" ", the son of The od oric , king of the Ost rog oths , and of the
+ " False (d. c. 585), and Boethius (d. c. 525) were 11 [' (', 'd', '.', ' c', '.', ' 5', '85', '),', ' and', ' Bo', 'eth', 'ius']
+1639 366 Name of father of x -1 Name of father of Boethius Manlius Boethius Boethius "[',' ' the' ' son' ' of' ' The' 'od' 'oric' ',' ' king' ' of' ' the'
+ ' Ost' 'rog' 'oths' ',' ' and' ' of' ' the' '\n' ' ']" ", the son of The od oric , king of the Ost rog oths , and of the
+ " False Anglo-Saxon translation of Boethius' Consolation of Philosophy, 8 [' Anglo', '-', 'Sax', 'on', ' translation', ' of', ' Bo', 'eth', 'ius']
+1640 367 Name of father of x -1 Name of father of Édith Piaf Louis Gassion Édith Piaf "['\n' '\n' 'The' ' name' ' É' 'd' 'ith' ' P' 'ia' 'f' ' is' ' a' ' French'
+ ' feminine' ' given' ' name' '.' ' It' ' is' ' a']" "
+
+ The name É d ith P ia f is a French feminine given name . It is a" False cabaret singer Édith Piaf and ingenues in 8 [' cab', 'aret', ' singer', ' É', 'd', 'ith', ' P', 'ia', 'f']
+1641 367 Name of father of x -1 Name of father of Édith Piaf Louis Gassion Édith Piaf "['\n' '\n' 'The' ' name' ' É' 'd' 'ith' ' P' 'ia' 'f' ' is' ' a' ' French'
+ ' feminine' ' given' ' name' '.' ' It' ' is' ' a']" "
+
+ The name É d ith P ia f is a French feminine given name . It is a" False singer named Édith Piaf was discovered 7 [' singer', ' named', ' É', 'd', 'ith', ' P', 'ia', 'f']
+1642 367 Name of father of x -1 Name of father of Édith Piaf Louis Gassion Édith Piaf "['\n' '\n' 'The' ' name' ' É' 'd' 'ith' ' P' 'ia' 'f' ' is' ' a' ' French'
+ ' feminine' ' given' ' name' '.' ' It' ' is' ' a']" "
+
+ The name É d ith P ia f is a French feminine given name . It is a" False French chanteuse Édith Piaf in Pam Gems' 9 [' French', ' ch', 'ante', 'use', ' É', 'd', 'ith', ' P', 'ia', 'f']
+1643 368 Name of father of x -1 Name of father of Franz Schubert Franz Theodor Schubert Franz Schubert "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Franz' ' Sch' 'u' 'bert' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Franz Sch u bert , the composer ." False and of the works of Franz Schubert and Robert 8 [' and', ' of', ' the', ' works', ' of', ' Franz', ' Sch', 'u', 'bert']
+1644 368 Name of father of x -1 Name of father of Franz Schubert Franz Theodor Schubert Franz Schubert "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Franz' ' Sch' 'u' 'bert' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Franz Sch u bert , the composer ." False Sechs Melodien von Franz Schubert (S 560); the 10 [' Se', 'ch', 's', ' Mel', 'od', 'ien', ' von', ' Franz', ' Sch', 'u', 'bert']
+1645 368 Name of father of x -1 Name of father of Franz Schubert Franz Theodor Schubert Franz Schubert "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Franz' ' Sch' 'u' 'bert' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Franz Sch u bert , the composer ." False Sechs Melodien von Franz Schubert (S 560); the second 10 [' Se', 'ch', 's', ' Mel', 'od', 'ien', ' von', ' Franz', ' Sch', 'u', 'bert']
+1646 368 Name of father of x -1 Name of father of Franz Schubert Franz Theodor Schubert Franz Schubert "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Franz' ' Sch' 'u' 'bert' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Franz Sch u bert , the composer ." False Melodien von Franz Schubert (S 560); the second 7 [' Mel', 'od', 'ien', ' von', ' Franz', ' Sch', 'u', 'bert']
+1647 368 Name of father of x -1 Name of father of Franz Schubert Franz Theodor Schubert Franz Schubert "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Franz' ' Sch' 'u' 'bert' ',' ' the' ' composer' '.']" ", the composer .
+
+ The name of the father of Franz Sch u bert , the composer ." False of the works of Franz Schubert and Robert 7 [' of', ' the', ' works', ' of', ' Franz', ' Sch', 'u', 'bert']
+1648 369 Name of father of x -1 Name of father of Alberto Giacometti Giovanni Giacometti Alberto Giacometti "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Alberto' ' Gi'
+ 'ac' 'omet' 'ti' ' is' ' Alberto' ' Gi' 'ac' 'omet' 'ti' '.']" "
+
+ The name of the father of Alberto Gi ac omet ti is Alberto Gi ac omet ti ." False 1949, Swiss artist Alberto Giacometti made Tzara the subject 8 [' 1949', ',', ' Swiss', ' artist', ' Alberto', ' Gi', 'ac', 'omet', 'ti']
+1649 369 Name of father of x -1 Name of father of Alberto Giacometti Giovanni Giacometti Alberto Giacometti "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Alberto' ' Gi'
+ 'ac' 'omet' 'ti' ' is' ' Alberto' ' Gi' 'ac' 'omet' 'ti' '.']" "
+
+ The name of the father of Alberto Gi ac omet ti is Alberto Gi ac omet ti ." False Swiss artist Alberto Giacometti made Tzara the subject 6 [' Swiss', ' artist', ' Alberto', ' Gi', 'ac', 'omet', 'ti']
+1650 369 Name of father of x -1 Name of father of Alberto Giacometti Giovanni Giacometti Alberto Giacometti "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Alberto' ' Gi'
+ 'ac' 'omet' 'ti' ' is' ' Alberto' ' Gi' 'ac' 'omet' 'ti' '.']" "
+
+ The name of the father of Alberto Gi ac omet ti is Alberto Gi ac omet ti ." False 1949, Swiss artist Alberto Giacometti made Tzara the 8 [' 1949', ',', ' Swiss', ' artist', ' Alberto', ' Gi', 'ac', 'omet', 'ti']
+1651 369 Name of father of x -1 Name of father of Alberto Giacometti Giovanni Giacometti Alberto Giacometti "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Alberto' ' Gi'
+ 'ac' 'omet' 'ti' ' is' ' Alberto' ' Gi' 'ac' 'omet' 'ti' '.']" "
+
+ The name of the father of Alberto Gi ac omet ti is Alberto Gi ac omet ti ." False 1949, Swiss artist Alberto Giacometti made Tzara the subject 8 [' 1949', ',', ' Swiss', ' artist', ' Alberto', ' Gi', 'ac', 'omet', 'ti']
+1652 369 Name of father of x -1 Name of father of Alberto Giacometti Giovanni Giacometti Alberto Giacometti "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Alberto' ' Gi'
+ 'ac' 'omet' 'ti' ' is' ' Alberto' ' Gi' 'ac' 'omet' 'ti' '.']" "
+
+ The name of the father of Alberto Gi ac omet ti is Alberto Gi ac omet ti ." False 1949, Swiss artist Alberto Giacometti made Tzara 8 [' 1949', ',', ' Swiss', ' artist', ' Alberto', ' Gi', 'ac', 'omet', 'ti']
+1653 370 Name of father of x -1 Name of father of Emma Thompson Eric Thompson Emma Thompson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False the book [and] Emma Thompson delivered a charming 6 [' the', ' book', ' [', 'and', ']', ' Emma', ' Thompson']
+1654 370 Name of father of x -1 Name of father of Emma Thompson Eric Thompson Emma Thompson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False starred alongside Emma Thompson in the film Carrington 3 [' starred', ' alongside', ' Emma', ' Thompson']
+1655 370 Name of father of x -1 Name of father of Emma Thompson Eric Thompson Emma Thompson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False name. Actress Emma Thompson wrote the script 4 [' name', '.', ' Actress', ' Emma', ' Thompson']
+1656 370 Name of father of x -1 Name of father of Emma Thompson Eric Thompson Emma Thompson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False the film stars Emma Thompson as author P. 4 [' the', ' film', ' stars', ' Emma', ' Thompson']
+1657 370 Name of father of x -1 Name of father of Emma Thompson Eric Thompson Emma Thompson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False idea that Emma Thompson would be providing 3 [' idea', ' that', ' Emma', ' Thompson']
+1658 371 Name of father of x -1 Name of father of Henry James Henry James, Sr. Henry James "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ',' ' The' ' Port'
+ 'rait' ' of' ' a' ' Lady' ',' ' and' ' the' ' son' ' of' ' a']" , the author of the famous novel , The Port rait of a Lady , and the son of a False for the war effort. Henry James was so incensed 6 [' for', ' the', ' war', ' effort', '.', ' Henry', ' James']
+1659 371 Name of father of x -1 Name of father of Henry James Henry James, Sr. Henry James "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ',' ' The' ' Port'
+ 'rait' ' of' ' a' ' Lady' ',' ' and' ' the' ' son' ' of' ' a']" , the author of the famous novel , The Port rait of a Lady , and the son of a False was described by Henry James as the most 4 [' was', ' described', ' by', ' Henry', ' James']
+1660 371 Name of father of x -1 Name of father of Henry James Henry James, Sr. Henry James "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ',' ' The' ' Port'
+ 'rait' ' of' ' a' ' Lady' ',' ' and' ' the' ' son' ' of' ' a']" , the author of the famous novel , The Port rait of a Lady , and the son of a False results written by Henry James Forman appeared 4 [' results', ' written', ' by', ' Henry', ' James']
+1661 371 Name of father of x -1 Name of father of Henry James Henry James, Sr. Henry James "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ',' ' The' ' Port'
+ 'rait' ' of' ' a' ' Lady' ',' ' and' ' the' ' son' ' of' ' a']" , the author of the famous novel , The Port rait of a Lady , and the son of a False Campion's adaptation of the Henry James novel The 7 "[' Camp', 'ion', ""'s"", ' adaptation', ' of', ' the', ' Henry', ' James']"
+1662 371 Name of father of x -1 Name of father of Henry James Henry James, Sr. Henry James "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ',' ' The' ' Port'
+ 'rait' ' of' ' a' ' Lady' ',' ' and' ' the' ' son' ' of' ' a']" , the author of the famous novel , The Port rait of a Lady , and the son of a False protagonists of the Henry James novels Washington 4 [' protagonists', ' of', ' the', ' Henry', ' James']
+1663 372 Name of father of x -1 Name of father of Jackie Chan Charles Chan Jackie Chan "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Jackie' ' Chan' '.' '\n' '\n' 'The' ' actor' ',' ' who']" ", the actor , and the father of the actor , Jackie Chan .
+
+ The actor , who" False collaboration between Jackie Chan and Yuen Biao. The 3 [' collaboration', ' between', ' Jackie', ' Chan']
+1664 372 Name of father of x -1 Name of father of Jackie Chan Charles Chan Jackie Chan "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Jackie' ' Chan' '.' '\n' '\n' 'The' ' actor' ',' ' who']" ", the actor , and the father of the actor , Jackie Chan .
+
+ The actor , who" False other commitments. Jackie Chan said he had been 4 [' other', ' commitments', '.', ' Jackie', ' Chan']
+1665 372 Name of father of x -1 Name of father of Jackie Chan Charles Chan Jackie Chan "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Jackie' ' Chan' '.' '\n' '\n' 'The' ' actor' ',' ' who']" ", the actor , and the father of the actor , Jackie Chan .
+
+ The actor , who" False " Bachchan, and Jackie Chan titled Gold Struck.
+" 5 [' Bach', 'chan', ',', ' and', ' Jackie', ' Chan']
+1666 372 Name of father of x -1 Name of father of Jackie Chan Charles Chan Jackie Chan "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Jackie' ' Chan' '.' '\n' '\n' 'The' ' actor' ',' ' who']" ", the actor , and the father of the actor , Jackie Chan .
+
+ The actor , who" False and starring Jackie Chan, Louis Koo, Yuen 3 [' and', ' starring', ' Jackie', ' Chan']
+1667 372 Name of father of x -1 Name of father of Jackie Chan Charles Chan Jackie Chan "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Jackie' ' Chan' '.' '\n' '\n' 'The' ' actor' ',' ' who']" ", the actor , and the father of the actor , Jackie Chan .
+
+ The actor , who" False respectively, while Jackie Chan provided the 4 [' respectively', ',', ' while', ' Jackie', ' Chan']
+1668 374 Name of father of x -1 Name of father of Francis of Assisi Pietro di Bernardone dei Moriconi Francis of Assisi "[',' ' the' ' founder' ' of' ' the' ' Franc' 'isc' 'an' ' Order' '.' '\n'
+ '\n' 'The' ' Franc' 'isc' 'an' ' Order' ' is' ' a' ' mon']" ", the founder of the Franc isc an Order .
+
+ The Franc isc an Order is a mon" False 4 ['Franc', 'is', ' of', ' Ass', 'isi']
+1669 374 Name of father of x -1 Name of father of Francis of Assisi Pietro di Bernardone dei Moriconi Francis of Assisi "[',' ' the' ' founder' ' of' ' the' ' Franc' 'isc' 'an' ' Order' '.' '\n'
+ '\n' 'The' ' Franc' 'isc' 'an' ' Order' ' is' ' a' ' mon']" ", the founder of the Franc isc an Order .
+
+ The Franc isc an Order is a mon" False Milford Haven, St. Francis of Assisi on Priory Road. 9 [' Mil', 'ford', ' Haven', ',', ' St', '.', ' Francis', ' of', ' Ass', 'isi']
+1670 374 Name of father of x -1 Name of father of Francis of Assisi Pietro di Bernardone dei Moriconi Francis of Assisi "[',' ' the' ' founder' ' of' ' the' ' Franc' 'isc' 'an' ' Order' '.' '\n'
+ '\n' 'The' ' Franc' 'isc' 'an' ' Order' ' is' ' a' ' mon']" ", the founder of the Franc isc an Order .
+
+ The Franc isc an Order is a mon" False were founded by Francis of Assisi and Dominic 6 [' were', ' founded', ' by', ' Francis', ' of', ' Ass', 'isi']
+1671 374 Name of father of x -1 Name of father of Francis of Assisi Pietro di Bernardone dei Moriconi Francis of Assisi "[',' ' the' ' founder' ' of' ' the' ' Franc' 'isc' 'an' ' Order' '.' '\n'
+ '\n' 'The' ' Franc' 'isc' 'an' ' Order' ' is' ' a' ' mon']" ", the founder of the Franc isc an Order .
+
+ The Franc isc an Order is a mon" False were founded by Francis of Assisi and Dominic de Guzmán. 6 [' were', ' founded', ' by', ' Francis', ' of', ' Ass', 'isi']
+1672 374 Name of father of x -1 Name of father of Francis of Assisi Pietro di Bernardone dei Moriconi Francis of Assisi "[',' ' the' ' founder' ' of' ' the' ' Franc' 'isc' 'an' ' Order' '.' '\n'
+ '\n' 'The' ' Franc' 'isc' 'an' ' Order' ' is' ' a' ' mon']" ", the founder of the Franc isc an Order .
+
+ The Franc isc an Order is a mon" False school, St Francis of Assisi School, is located 6 [' school', ',', ' St', ' Francis', ' of', ' Ass', 'isi']
+1673 375 Name of father of x -1 Name of father of Michel Foucault Paul Foucault Michel Foucault "[',' ' the' ' French' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 18'
+ '19' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' philosopher']" ", the French philosopher , who was born in 18 19 .
+
+ The name of the French philosopher" False as sociologists: Michel Foucault (1), Pierre Bourdieu 8 [' as', ' soc', 'i', 'ologists', ':', ' Michel', ' Fou', 'c', 'ault']
+1674 375 Name of father of x -1 Name of father of Michel Foucault Paul Foucault Michel Foucault "[',' ' the' ' French' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 18'
+ '19' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' philosopher']" ", the French philosopher , who was born in 18 19 .
+
+ The name of the French philosopher" False of social theory. Michel Foucault provides a potent 7 [' of', ' social', ' theory', '.', ' Michel', ' Fou', 'c', 'ault']
+1675 375 Name of father of x -1 Name of father of Michel Foucault Paul Foucault Michel Foucault "[',' ' the' ' French' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 18'
+ '19' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' philosopher']" ", the French philosopher , who was born in 18 19 .
+
+ The name of the French philosopher" False cultural critic Michel Foucault observed that 5 [' cultural', ' critic', ' Michel', ' Fou', 'c', 'ault']
+1676 375 Name of father of x -1 Name of father of Michel Foucault Paul Foucault Michel Foucault "[',' ' the' ' French' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 18'
+ '19' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' philosopher']" ", the French philosopher , who was born in 18 19 .
+
+ The name of the French philosopher" False and cultural critic Michel Foucault observed that the 6 [' and', ' cultural', ' critic', ' Michel', ' Fou', 'c', 'ault']
+1677 375 Name of father of x -1 Name of father of Michel Foucault Paul Foucault Michel Foucault "[',' ' the' ' French' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 18'
+ '19' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' philosopher']" ", the French philosopher , who was born in 18 19 .
+
+ The name of the French philosopher" False discourse, amongst them Michel Foucault (1926 – 1984). 7 [' discourse', ',', ' amongst', ' them', ' Michel', ' Fou', 'c', 'ault']
+1678 376 Name of father of x -1 Name of father of Johnny Cash Ray Cash Johnny Cash "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' June' ' Carter'
+ ' Cash' ',' ' who' ' was' ' a' ' singer' ' and' ' song' 'writer' '.']" , the singer , and his wife , June Carter Cash , who was a singer and song writer . False Kristofferson, and Johnny Cash formed The Highwaymen, 6 [' Krist', 'off', 'erson', ',', ' and', ' Johnny', ' Cash']
+1679 376 Name of father of x -1 Name of father of Johnny Cash Ray Cash Johnny Cash "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' June' ' Carter'
+ ' Cash' ',' ' who' ' was' ' a' ' singer' ' and' ' song' 'writer' '.']" , the singer , and his wife , June Carter Cash , who was a singer and song writer . False Roy Orbison, Johnny Cash and June Carter, Rare 5 [' Roy', ' Orb', 'ison', ',', ' Johnny', ' Cash']
+1680 376 Name of father of x -1 Name of father of Johnny Cash Ray Cash Johnny Cash "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' June' ' Carter'
+ ' Cash' ',' ' who' ' was' ' a' ' singer' ' and' ' song' 'writer' '.']" , the singer , and his wife , June Carter Cash , who was a singer and song writer . False theme. IGN ranked Johnny Cash's performance as 5 [' theme', '.', ' IGN', ' ranked', ' Johnny', ' Cash']
+1681 376 Name of father of x -1 Name of father of Johnny Cash Ray Cash Johnny Cash "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' June' ' Carter'
+ ' Cash' ',' ' who' ' was' ' a' ' singer' ' and' ' song' 'writer' '.']" , the singer , and his wife , June Carter Cash , who was a singer and song writer . False Country singer Johnny Cash parodied the song 3 [' Country', ' singer', ' Johnny', ' Cash']
+1682 376 Name of father of x -1 Name of father of Johnny Cash Ray Cash Johnny Cash "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' June' ' Carter'
+ ' Cash' ',' ' who' ' was' ' a' ' singer' ' and' ' song' 'writer' '.']" , the singer , and his wife , June Carter Cash , who was a singer and song writer . False Carter on a cover of Johnny Cash and June Carter 6 [' Carter', ' on', ' a', ' cover', ' of', ' Johnny', ' Cash']
+1683 377 Name of father of x -1 Name of father of Rosa Bonheur Raymond Bonheur Rosa Bonheur "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False realists such as Rosa Bonheur and Gustave Courbet. 7 [' real', 'ists', ' such', ' as', ' Rosa', ' Bon', 'he', 'ur']
+1684 377 Name of father of x -1 Name of father of Rosa Bonheur Raymond Bonheur Rosa Bonheur "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False realists such as Rosa Bonheur and Gustave Courbet. 7 [' real', 'ists', ' such', ' as', ' Rosa', ' Bon', 'he', 'ur']
+1685 377 Name of father of x -1 Name of father of Rosa Bonheur Raymond Bonheur Rosa Bonheur "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False influenced by the works of Rosa Bonheur and Gustave Courbet, 8 [' influenced', ' by', ' the', ' works', ' of', ' Rosa', ' Bon', 'he', 'ur']
+1686 377 Name of father of x -1 Name of father of Rosa Bonheur Raymond Bonheur Rosa Bonheur "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False by the works of Rosa Bonheur and Gustave Courbet, 7 [' by', ' the', ' works', ' of', ' Rosa', ' Bon', 'he', 'ur']
+1687 377 Name of father of x -1 Name of father of Rosa Bonheur Raymond Bonheur Rosa Bonheur "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False the works of Rosa Bonheur and Gustave 6 [' the', ' works', ' of', ' Rosa', ' Bon', 'he', 'ur']
+1688 378 Name of father of x -1 Name of father of Jean-Luc Godard Paul Godard Jean-Luc Godard "[',' ' the' ' French' ' director' ' of' ' the' ' film' '.' '\n' '\n' 'The'
+ ' film' ' is' ' a' ' remake' ' of' ' the' ' French' ' film' ' of']" ", the French director of the film .
+
+ The film is a remake of the French film of" False acknowledged his debts to Jean-Luc Godard and François Truffaut 8 [' acknowledged', ' his', ' debts', ' to', ' Jean', '-', 'Luc', ' God', 'ard']
+1689 378 Name of father of x -1 Name of father of Jean-Luc Godard Paul Godard Jean-Luc Godard "[',' ' the' ' French' ' director' ' of' ' the' ' film' '.' '\n' '\n' 'The'
+ ' film' ' is' ' a' ' remake' ' of' ' the' ' French' ' film' ' of']" ", the French director of the film .
+
+ The film is a remake of the French film of" False partially inspired by the Jean-Luc Godard film Breathless 8 [' partially', ' inspired', ' by', ' the', ' Jean', '-', 'Luc', ' God', 'ard']
+1690 378 Name of father of x -1 Name of father of Jean-Luc Godard Paul Godard Jean-Luc Godard "[',' ' the' ' French' ' director' ' of' ' the' ' film' '.' '\n' '\n' 'The'
+ ' film' ' is' ' a' ' remake' ' of' ' the' ' French' ' film' ' of']" ", the French director of the film .
+
+ The film is a remake of the French film of" False inspired by the Jean-Luc Godard film Breathless 7 [' inspired', ' by', ' the', ' Jean', '-', 'Luc', ' God', 'ard']
+1691 378 Name of father of x -1 Name of father of Jean-Luc Godard Paul Godard Jean-Luc Godard "[',' ' the' ' French' ' director' ' of' ' the' ' film' '.' '\n' '\n' 'The'
+ ' film' ' is' ' a' ' remake' ' of' ' the' ' French' ' film' ' of']" ", the French director of the film .
+
+ The film is a remake of the French film of" False his debts to Jean-Luc Godard and François Truffaut 7 [' his', ' debts', ' to', ' Jean', '-', 'Luc', ' God', 'ard']
+1692 378 Name of father of x -1 Name of father of Jean-Luc Godard Paul Godard Jean-Luc Godard "[',' ' the' ' French' ' director' ' of' ' the' ' film' '.' '\n' '\n' 'The'
+ ' film' ' is' ' a' ' remake' ' of' ' the' ' French' ' film' ' of']" ", the French director of the film .
+
+ The film is a remake of the French film of" False his debts to Jean-Luc Godard and François Truffaut 7 [' his', ' debts', ' to', ' Jean', '-', 'Luc', ' God', 'ard']
+1693 379 Name of father of x -1 Name of father of Aleksandr Solzhenitsyn Isaackly Semyonovich Solzhenitsyn Aleksandr Solzhenitsyn "['\n' '\n' 'The' ' Russian' ' writer' ' Ale' 'ks' 'andr' ' Sol' 'zhen'
+ 'its' 'yn' ' was' ' born' ' in' ' 1918' ' in' ' the' ' village' ' of']" "
+
+ The Russian writer Ale ks andr Sol zhen its yn was born in 1918 in the village of" False personalities such as Aleksandr Solzhenitsyn and Andrei Sakharov, 9 [' personalities', ' such', ' as', ' Ale', 'ks', 'andr', ' Sol', 'zhen', 'its', 'yn']
+1694 379 Name of father of x -1 Name of father of Aleksandr Solzhenitsyn Isaackly Semyonovich Solzhenitsyn Aleksandr Solzhenitsyn "['\n' '\n' 'The' ' Russian' ' writer' ' Ale' 'ks' 'andr' ' Sol' 'zhen'
+ 'its' 'yn' ' was' ' born' ' in' ' 1918' ' in' ' the' ' village' ' of']" "
+
+ The Russian writer Ale ks andr Sol zhen its yn was born in 1918 in the village of" False personalities such as Aleksandr Solzhenitsyn and Andrei Sakharov, 9 [' personalities', ' such', ' as', ' Ale', 'ks', 'andr', ' Sol', 'zhen', 'its', 'yn']
+1695 379 Name of father of x -1 Name of father of Aleksandr Solzhenitsyn Isaackly Semyonovich Solzhenitsyn Aleksandr Solzhenitsyn "['\n' '\n' 'The' ' Russian' ' writer' ' Ale' 'ks' 'andr' ' Sol' 'zhen'
+ 'its' 'yn' ' was' ' born' ' in' ' 1918' ' in' ' the' ' village' ' of']" "
+
+ The Russian writer Ale ks andr Sol zhen its yn was born in 1918 in the village of" False personalities such as Aleksandr Solzhenitsyn and Andrei Sakharov, 9 [' personalities', ' such', ' as', ' Ale', 'ks', 'andr', ' Sol', 'zhen', 'its', 'yn']
+1696 380 Name of father of x -1 Name of father of Béla Bartók Béla Bartók Béla Bartók "['\n' '\n' 'B' 'é' 'la' ' Bart' 'ó' 'k' ' (' '18' '81' '–' '1945' ')'
+ ' was' ' a' ' Hungarian' ' composer' ',' ' pian']" "
+
+ B é la Bart ó k ( 18 81 – 1945 ) was a Hungarian composer , pian" True to sign up composers Béla Bartók and Zoltán Kodály, 10 [' to', ' sign', ' up', ' compos', 'ers', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1697 380 Name of father of x -1 Name of father of Béla Bartók Béla Bartók Béla Bartók "['\n' '\n' 'B' 'é' 'la' ' Bart' 'ó' 'k' ' (' '18' '81' '–' '1945' ')'
+ ' was' ' a' ' Hungarian' ' composer' ',' ' pian']" "
+
+ B é la Bart ó k ( 18 81 – 1945 ) was a Hungarian composer , pian" True Composers such as Béla Bartók and, later, Lou Harrison 9 [' Compos', 'ers', ' such', ' as', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1698 380 Name of father of x -1 Name of father of Béla Bartók Béla Bartók Béla Bartók "['\n' '\n' 'B' 'é' 'la' ' Bart' 'ó' 'k' ' (' '18' '81' '–' '1945' ')'
+ ' was' ' a' ' Hungarian' ' composer' ',' ' pian']" "
+
+ B é la Bart ó k ( 18 81 – 1945 ) was a Hungarian composer , pian" True up composers Béla Bartók and Zoltán Kodály, 8 [' up', ' compos', 'ers', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1699 380 Name of father of x -1 Name of father of Béla Bartók Béla Bartók Béla Bartók "['\n' '\n' 'B' 'é' 'la' ' Bart' 'ó' 'k' ' (' '18' '81' '–' '1945' ')'
+ ' was' ' a' ' Hungarian' ' composer' ',' ' pian']" "
+
+ B é la Bart ó k ( 18 81 – 1945 ) was a Hungarian composer , pian" True considered Béla Bartók and Dmitri Shostakovich 6 [' considered', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1700 380 Name of father of x -1 Name of father of Béla Bartók Béla Bartók Béla Bartók "['\n' '\n' 'B' 'é' 'la' ' Bart' 'ó' 'k' ' (' '18' '81' '–' '1945' ')'
+ ' was' ' a' ' Hungarian' ' composer' ',' ' pian']" "
+
+ B é la Bart ó k ( 18 81 – 1945 ) was a Hungarian composer , pian" True to sign up composers Béla Bartók and Zoltán Kodály, 10 [' to', ' sign', ' up', ' compos', 'ers', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1701 381 Name of father of x -1 Name of father of Judit Polgár László Polgár Judit Polgár "['\n' '\n' 'Category' ':' '19' '30' ' births' '\n' 'Category' ':' 'Living'
+ ' people' '\n' 'Category' ':' 'Hung' 'arian' ' female' ' tennis'
+ ' players']" "
+
+ Category : 19 30 births
+ Category : Living people
+ Category : Hung arian female tennis players" False for a final against Judit Polgár (who had in 9 [' for', ' a', ' final', ' against', ' Jud', 'it', ' Pol', 'g', 'á', 'r']
+1702 382 Name of father of x -1 Name of father of Bernard of Clairvaux Tescelin de Fontaine Bernard of Clairvaux "[',' ' the' ' great' ' preacher' ' of' ' the' ' tw' 'elfth' ' century' ','
+ ' who' '\n' 'was' ' the' ' founder' ' of' ' the' ' C' 'ister' 'c']" ", the great preacher of the tw elfth century , who
+ was the founder of the C ister c" False the guidance of Bernard of Clairvaux (d. 1153). 7 [' the', ' guidance', ' of', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1703 382 Name of father of x -1 Name of father of Bernard of Clairvaux Tescelin de Fontaine Bernard of Clairvaux "[',' ' the' ' great' ' preacher' ' of' ' the' ' tw' 'elfth' ' century' ','
+ ' who' '\n' 'was' ' the' ' founder' ' of' ' the' ' C' 'ister' 'c']" ", the great preacher of the tw elfth century , who
+ was the founder of the C ister c" False was attended by Bernard of Clairvaux and Suger of 7 [' was', ' attended', ' by', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1704 382 Name of father of x -1 Name of father of Bernard of Clairvaux Tescelin de Fontaine Bernard of Clairvaux "[',' ' the' ' great' ' preacher' ' of' ' the' ' tw' 'elfth' ' century' ','
+ ' who' '\n' 'was' ' the' ' founder' ' of' ' the' ' C' 'ister' 'c']" ", the great preacher of the tw elfth century , who
+ was the founder of the C ister c" False contemporaries like Bernard of Clairvaux for being more in 6 [' contemporaries', ' like', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1705 382 Name of father of x -1 Name of father of Bernard of Clairvaux Tescelin de Fontaine Bernard of Clairvaux "[',' ' the' ' great' ' preacher' ' of' ' the' ' tw' 'elfth' ' century' ','
+ ' who' '\n' 'was' ' the' ' founder' ' of' ' the' ' C' 'ister' 'c']" ", the great preacher of the tw elfth century , who
+ was the founder of the C ister c" False 1145, the year St Bernard of Clairvaux began preaching for 10 [' 11', '45', ',', ' the', ' year', ' St', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1706 382 Name of father of x -1 Name of father of Bernard of Clairvaux Tescelin de Fontaine Bernard of Clairvaux "[',' ' the' ' great' ' preacher' ' of' ' the' ' tw' 'elfth' ' century' ','
+ ' who' '\n' 'was' ' the' ' founder' ' of' ' the' ' C' 'ister' 'c']" ", the great preacher of the tw elfth century , who
+ was the founder of the C ister c" False " passage attributed to Bernard of Clairvaux from his ""Sermones" 7 [' passage', ' attributed', ' to', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1707 383 Name of father of x -1 Name of father of Robin Williams Robert Williams Robin Williams "[',' ' the' ' actor' ' who' ' died' ' in' ' 2014' '.' '\n' '\n' 'The'
+ ' actor' ',' ' who' ' was' ' found' ' dead' ' in' ' his' ' home']" ", the actor who died in 2014 .
+
+ The actor , who was found dead in his home" False film that Robin Williams did for the Disney-MGM 3 [' film', ' that', ' Robin', ' Williams']
+1708 383 Name of father of x -1 Name of father of Robin Williams Robert Williams Robin Williams "[',' ' the' ' actor' ' who' ' died' ' in' ' 2014' '.' '\n' '\n' 'The'
+ ' actor' ',' ' who' ' was' ' found' ' dead' ' in' ' his' ' home']" ", the actor who died in 2014 .
+
+ The actor , who was found dead in his home" False lamenting the loss of Robin Williams and the originality 6 [' lament', 'ing', ' the', ' loss', ' of', ' Robin', ' Williams']
+1709 383 Name of father of x -1 Name of father of Robin Williams Robert Williams Robin Williams "[',' ' the' ' actor' ' who' ' died' ' in' ' 2014' '.' '\n' '\n' 'The'
+ ' actor' ',' ' who' ' was' ' found' ' dead' ' in' ' his' ' home']" ", the actor who died in 2014 .
+
+ The actor , who was found dead in his home" False 1 ['Robin', ' Williams']
+1710 383 Name of father of x -1 Name of father of Robin Williams Robert Williams Robin Williams "[',' ' the' ' actor' ' who' ' died' ' in' ' 2014' '.' '\n' '\n' 'The'
+ ' actor' ',' ' who' ' was' ' found' ' dead' ' in' ' his' ' home']" ", the actor who died in 2014 .
+
+ The actor , who was found dead in his home" False Ben Kingsley and Robin Williams were set to star 5 [' Ben', ' Kings', 'ley', ' and', ' Robin', ' Williams']
+1711 383 Name of father of x -1 Name of father of Robin Williams Robert Williams Robin Williams "[',' ' the' ' actor' ' who' ' died' ' in' ' 2014' '.' '\n' '\n' 'The'
+ ' actor' ',' ' who' ' was' ' found' ' dead' ' in' ' his' ' home']" ", the actor who died in 2014 .
+
+ The actor , who was found dead in his home" False " race; Cambridge coach Robin Williams agreed: ""We've got" 5 [' race', ';', ' Cambridge', ' coach', ' Robin', ' Williams']
+1712 384 Name of father of x -1 Name of father of Bhimrao Ramji Ambedkar Ramji Maloji Sakpal Bhimrao Ramji Ambedkar "['\n' '\n' 'B' 'him' 'ra' 'o' ' Ram' 'ji' ' Am' 'bed' 'kar' ' was' ' born'
+ ' on' ' 14' 'th' ' April' ' 18' '91' ' in']" "
+
+ B him ra o Ram ji Am bed kar was born on 14 th April 18 91 in" False people such as Bhimrao Ramji Ambedkar and Muhammad Ali 11 [' people', ' such', ' as', ' Bh', 'im', 'ra', 'o', ' Ram', 'ji', ' Am', 'bed', 'kar']
+1713 384 Name of father of x -1 Name of father of Bhimrao Ramji Ambedkar Ramji Maloji Sakpal Bhimrao Ramji Ambedkar "['\n' '\n' 'B' 'him' 'ra' 'o' ' Ram' 'ji' ' Am' 'bed' 'kar' ' was' ' born'
+ ' on' ' 14' 'th' ' April' ' 18' '91' ' in']" "
+
+ B him ra o Ram ji Am bed kar was born on 14 th April 18 91 in" False people such as Bhimrao Ramji Ambedkar and Muhammad Ali 11 [' people', ' such', ' as', ' Bh', 'im', 'ra', 'o', ' Ram', 'ji', ' Am', 'bed', 'kar']
+1714 384 Name of father of x -1 Name of father of Bhimrao Ramji Ambedkar Ramji Maloji Sakpal Bhimrao Ramji Ambedkar "['\n' '\n' 'B' 'him' 'ra' 'o' ' Ram' 'ji' ' Am' 'bed' 'kar' ' was' ' born'
+ ' on' ' 14' 'th' ' April' ' 18' '91' ' in']" "
+
+ B him ra o Ram ji Am bed kar was born on 14 th April 18 91 in" False people such as Bhimrao Ramji Ambedkar and Muhammad Ali Jinnah 11 [' people', ' such', ' as', ' Bh', 'im', 'ra', 'o', ' Ram', 'ji', ' Am', 'bed', 'kar']
+1715 384 Name of father of x -1 Name of father of Bhimrao Ramji Ambedkar Ramji Maloji Sakpal Bhimrao Ramji Ambedkar "['\n' '\n' 'B' 'him' 'ra' 'o' ' Ram' 'ji' ' Am' 'bed' 'kar' ' was' ' born'
+ ' on' ' 14' 'th' ' April' ' 18' '91' ' in']" "
+
+ B him ra o Ram ji Am bed kar was born on 14 th April 18 91 in" False people such as Bhimrao Ramji Ambedkar and Muhammad Ali 11 [' people', ' such', ' as', ' Bh', 'im', 'ra', 'o', ' Ram', 'ji', ' Am', 'bed', 'kar']
+1716 384 Name of father of x -1 Name of father of Bhimrao Ramji Ambedkar Ramji Maloji Sakpal Bhimrao Ramji Ambedkar "['\n' '\n' 'B' 'him' 'ra' 'o' ' Ram' 'ji' ' Am' 'bed' 'kar' ' was' ' born'
+ ' on' ' 14' 'th' ' April' ' 18' '91' ' in']" "
+
+ B him ra o Ram ji Am bed kar was born on 14 th April 18 91 in" False from people such as Bhimrao Ramji Ambedkar and Muhammad Ali 12 [' from', ' people', ' such', ' as', ' Bh', 'im', 'ra', 'o', ' Ram', 'ji', ' Am', 'bed', 'kar']
+1717 385 Name of father of x -1 Name of father of Bette Davis Harlow Morrell Davis Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False (1965) starring Bette Davis and William Castle's 6 [' (', '1965', ')', ' starring', ' Bet', 'te', ' Davis']
+1718 385 Name of father of x -1 Name of father of Bette Davis Harlow Morrell Davis Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False reportedly suggested Bette Davis for the role 4 [' reportedly', ' suggested', ' Bet', 'te', ' Davis']
+1719 385 Name of father of x -1 Name of father of Bette Davis Harlow Morrell Davis Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Elizabeth and Essex with Bette Davis and Errol Flynn. 6 [' Elizabeth', ' and', ' Essex', ' with', ' Bet', 'te', ' Davis']
+1720 385 Name of father of x -1 Name of father of Bette Davis Harlow Morrell Davis Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False reportedly suggested Bette Davis for the role of Jane. 4 [' reportedly', ' suggested', ' Bet', 'te', ' Davis']
+1721 385 Name of father of x -1 Name of father of Bette Davis Harlow Morrell Davis Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False echoes of Kim Carnes'' Bette Davis Eyes, 'Rihanna 8 "[' echoes', ' of', ' Kim', ' Carn', 'es', ""''"", ' Bet', 'te', ' Davis']"
+1722 386 Name of father of x -1 Name of father of Judy Garland Francis Avent Gumm Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the actress , and the mother of the actress , Judy Garland .
+
+ The first time I" False audience that included Judy Garland and Marlene Dietrich. 4 [' audience', ' that', ' included', ' Judy', ' Garland']
+1723 386 Name of father of x -1 Name of father of Judy Garland Francis Avent Gumm Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the actress , and the mother of the actress , Judy Garland .
+
+ The first time I" False audience that included Judy Garland and Marlene Dietrich. 4 [' audience', ' that', ' included', ' Judy', ' Garland']
+1724 386 Name of father of x -1 Name of father of Judy Garland Francis Avent Gumm Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the actress , and the mother of the actress , Judy Garland .
+
+ The first time I" False " love letter to the Judy Garland specials it honored""" 5 [' love', ' letter', ' to', ' the', ' Judy', ' Garland']
+1725 386 Name of father of x -1 Name of father of Judy Garland Francis Avent Gumm Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the actress , and the mother of the actress , Judy Garland .
+
+ The first time I" False was entitled The Judy Garland Show. This special, 4 [' was', ' entitled', ' The', ' Judy', ' Garland']
+1726 386 Name of father of x -1 Name of father of Judy Garland Francis Avent Gumm Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the actress , and the mother of the actress , Judy Garland .
+
+ The first time I" False with them singing the Judy Garland / Barbra Streisand 5 [' with', ' them', ' singing', ' the', ' Judy', ' Garland']
+1727 387 Name of father of x -1 Name of father of Jean Sibelius Christian Gustaf Sibelius Jean Sibelius "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' father' ' of' ' the' ' composer' ',' ' the'
+ ' painter']" , the composer , and his wife , the painter , and the father of the composer , the painter False 3 ['Jean', ' S', 'ibel', 'ius']
+1728 387 Name of father of x -1 Name of father of Jean Sibelius Christian Gustaf Sibelius Jean Sibelius "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' father' ' of' ' the' ' composer' ',' ' the'
+ ' painter']" , the composer , and his wife , the painter , and the father of the composer , the painter False incidental music of Jean Sibelius on October 30, 6 [' incidental', ' music', ' of', ' Jean', ' S', 'ibel', 'ius']
+1729 387 Name of father of x -1 Name of father of Jean Sibelius Christian Gustaf Sibelius Jean Sibelius "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' father' ' of' ' the' ' composer' ',' ' the'
+ ' painter']" , the composer , and his wife , the painter , and the father of the composer , the painter False " Jean Sibelius =
+" 3 [' Jean', ' S', 'ibel', 'ius']
+1730 387 Name of father of x -1 Name of father of Jean Sibelius Christian Gustaf Sibelius Jean Sibelius "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' father' ' of' ' the' ' composer' ',' ' the'
+ ' painter']" , the composer , and his wife , the painter , and the father of the composer , the painter False quinquennial International Jean Sibelius Violin Competition, 7 [' qu', 'inqu', 'ennial', ' International', ' Jean', ' S', 'ibel', 'ius']
+1731 387 Name of father of x -1 Name of father of Jean Sibelius Christian Gustaf Sibelius Jean Sibelius "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' father' ' of' ' the' ' composer' ',' ' the'
+ ' painter']" , the composer , and his wife , the painter , and the father of the composer , the painter False International Jean Sibelius Violin Competition, 4 [' International', ' Jean', ' S', 'ibel', 'ius']
+1732 389 Name of father of x -1 Name of father of Graham Greene Charles Henry Greene Graham Greene "[',' ' the' ' author' ' of' ' The' ' Power' ' and' ' the' ' Glory' ','
+ ' The' ' Heart' ' of' ' the' ' Matter' ',' ' The' ' Comed' 'ians' ',']" , the author of The Power and the Glory , The Heart of the Matter , The Comed ians , False praised by Yorke, Graham Greene and, in glowing 6 [' praised', ' by', ' Yor', 'ke', ',', ' Graham', ' Greene']
+1733 389 Name of father of x -1 Name of father of Graham Greene Charles Henry Greene Graham Greene "[',' ' the' ' author' ' of' ' The' ' Power' ' and' ' the' ' Glory' ','
+ ' The' ' Heart' ' of' ' the' ' Matter' ',' ' The' ' Comed' 'ians' ',']" , the author of The Power and the Glory , The Heart of the Matter , The Comed ians , False wrote Frankenstein. Graham Greene (1904 – 1991) was 4 [' wrote', ' Frankenstein', '.', ' Graham', ' Greene']
+1734 389 Name of father of x -1 Name of father of Graham Greene Charles Henry Greene Graham Greene "[',' ' the' ' author' ' of' ' The' ' Power' ' and' ' the' ' Glory' ','
+ ' The' ' Heart' ' of' ' the' ' Matter' ',' ' The' ' Comed' 'ians' ',']" , the author of The Power and the Glory , The Heart of the Matter , The Comed ians , False " Real Glory, Graham Greene wrote, ""Sometimes his" 4 [' Real', ' Glory', ',', ' Graham', ' Greene']
+1735 389 Name of father of x -1 Name of father of Graham Greene Charles Henry Greene Graham Greene "[',' ' the' ' author' ' of' ' The' ' Power' ' and' ' the' ' Glory' ','
+ ' The' ' Heart' ' of' ' the' ' Matter' ',' ' The' ' Comed' 'ians' ',']" , the author of The Power and the Glory , The Heart of the Matter , The Comed ians , False John le Carré and Graham Greene this way, and it's 6 [' John', ' le', ' Carr', 'é', ' and', ' Graham', ' Greene']
+1736 389 Name of father of x -1 Name of father of Graham Greene Charles Henry Greene Graham Greene "[',' ' the' ' author' ' of' ' The' ' Power' ' and' ' the' ' Glory' ','
+ ' The' ' Heart' ' of' ' the' ' Matter' ',' ' The' ' Comed' 'ians' ',']" , the author of The Power and the Glory , The Heart of the Matter , The Comed ians , False Fitzgerald, Evelyn Waugh, Graham Greene and William Golding. 8 [' Fitzgerald', ',', ' Eve', 'lyn', ' W', 'augh', ',', ' Graham', ' Greene']
+1737 390 Name of father of x -1 Name of father of Joan Crawford Thomas Le Sueur Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Today We Live with Joan Crawford and One Sunday Afternoon 5 [' Today', ' We', ' Live', ' with', ' Joan', ' Crawford']
+1738 390 Name of father of x -1 Name of father of Joan Crawford Thomas Le Sueur Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False played to excess. Joan Crawford ultimately reforms 5 [' played', ' to', ' excess', '.', ' Joan', ' Crawford']
+1739 390 Name of father of x -1 Name of father of Joan Crawford Thomas Le Sueur Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False " forcing anything."" Joan Crawford likewise expressed" 4 "[' forcing', ' anything', '.""', ' Joan', ' Crawford']"
+1740 390 Name of father of x -1 Name of father of Joan Crawford Thomas Le Sueur Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False " her series, The Joan Crawford Show.
+" 5 [' her', ' series', ',', ' The', ' Joan', ' Crawford']
+1741 390 Name of father of x -1 Name of father of Joan Crawford Thomas Le Sueur Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False veteran star Joan Crawford to describe her 3 [' veteran', ' star', ' Joan', ' Crawford']
+1742 391 Name of father of x -1 Name of father of Alessandro Manzoni Pietro Manzoni Alessandro Manzoni "[',' ' the' ' Italian' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' year' ' of' ' the' ' great' '\n' '\n' 'The' ' Italian' ' poet' ',']" ", the Italian poet , who was born in the year of the great
+
+ The Italian poet ," False Tolstoy in Russia and Alessandro Manzoni in Italy. The tradition 9 [' Tol', 'st', 'oy', ' in', ' Russia', ' and', ' Aless', 'andro', ' Manz', 'oni']
+1743 391 Name of father of x -1 Name of father of Alessandro Manzoni Pietro Manzoni Alessandro Manzoni "[',' ' the' ' Italian' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' year' ' of' ' the' ' great' '\n' '\n' 'The' ' Italian' ' poet' ',']" ", the Italian poet , who was born in the year of the great
+
+ The Italian poet ," False Tolstoy in Russia and Alessandro Manzoni in Italy. The tradition 9 [' Tol', 'st', 'oy', ' in', ' Russia', ' and', ' Aless', 'andro', ' Manz', 'oni']
+1744 391 Name of father of x -1 Name of father of Alessandro Manzoni Pietro Manzoni Alessandro Manzoni "[',' ' the' ' Italian' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' year' ' of' ' the' ' great' '\n' '\n' 'The' ' Italian' ' poet' ',']" ", the Italian poet , who was born in the year of the great
+
+ The Italian poet ," False Tolstoy in Russia and Alessandro Manzoni in Italy. The 9 [' Tol', 'st', 'oy', ' in', ' Russia', ' and', ' Aless', 'andro', ' Manz', 'oni']
+1745 391 Name of father of x -1 Name of father of Alessandro Manzoni Pietro Manzoni Alessandro Manzoni "[',' ' the' ' Italian' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' year' ' of' ' the' ' great' '\n' '\n' 'The' ' Italian' ' poet' ',']" ", the Italian poet , who was born in the year of the great
+
+ The Italian poet ," False Tolstoy in Russia and Alessandro Manzoni in Italy. 9 [' Tol', 'st', 'oy', ' in', ' Russia', ' and', ' Aless', 'andro', ' Manz', 'oni']
+1746 392 Name of father of x -1 Name of father of Arthur Schopenhauer Heinrich Floris Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '88' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur']" ", the philosopher , who was born in 17 88 .
+
+ The name of the mother of Arthur" False reading of Arthur Schopenhauer in 1854, and this 6 [' reading', ' of', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1747 392 Name of father of x -1 Name of father of Arthur Schopenhauer Heinrich Floris Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '88' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur']" ", the philosopher , who was born in 17 88 .
+
+ The name of the mother of Arthur" False German philosopher Arthur Schopenhauer argued that ethics 6 [' German', ' philosopher', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1748 392 Name of father of x -1 Name of father of Arthur Schopenhauer Heinrich Floris Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '88' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur']" ", the philosopher , who was born in 17 88 .
+
+ The name of the mother of Arthur" False " Schopenhauer ====
+" 10 [' Sch', 'open', 'h', 'auer', ' =', '===', 'Arthur', ' Sch', 'open', 'h', 'auer']
+1749 392 Name of father of x -1 Name of father of Arthur Schopenhauer Heinrich Floris Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '88' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur']" ", the philosopher , who was born in 17 88 .
+
+ The name of the mother of Arthur" False the attention of Arthur Schopenhauer and other 7 [' the', ' attention', ' of', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1750 392 Name of father of x -1 Name of father of Arthur Schopenhauer Heinrich Floris Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '88' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur']" ", the philosopher , who was born in 17 88 .
+
+ The name of the mother of Arthur" False western audience. Arthur Schopenhauer was deeply impressed 7 [' western', ' audience', '.', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1751 393 Name of father of x -1 Name of father of Daniel Defoe James Foe Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False travelogue by Daniel Defoe where he described 5 [' travel', 'ogue', ' by', ' Daniel', ' Def', 'oe']
+1752 393 Name of father of x -1 Name of father of Daniel Defoe James Foe Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False Year (1722) by Daniel Defoe is a fictionalisation 8 [' Year', ' (', '17', '22', ')', ' by', ' Daniel', ' Def', 'oe']
+1753 393 Name of father of x -1 Name of father of Daniel Defoe James Foe Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False cheese-making; Daniel Defoe thought Stroma 6 [' cheese', '-', 'making', ';', ' Daniel', ' Def', 'oe']
+1754 393 Name of father of x -1 Name of father of Daniel Defoe James Foe Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False collapse; the writer Daniel Defoe visited in 6 [' collapse', ';', ' the', ' writer', ' Daniel', ' Def', 'oe']
+1755 393 Name of father of x -1 Name of father of Daniel Defoe James Foe Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False Plague Year (1722) by Daniel Defoe is a fictionalisation 9 [' Plague', ' Year', ' (', '17', '22', ')', ' by', ' Daniel', ' Def', 'oe']
+1756 394 Name of father of x -1 Name of father of Michelle Yeoh Yeoh Kian Teik Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' and' ' her' ' husband' ',' ' actor' ' John' ' Woo' ',' ' who']" , the actress who plays the lead in the film , and her husband , actor John Woo , who False Hong Kong actress Michelle Yeoh play the role of 5 [' Hong', ' Kong', ' actress', ' Michelle', ' Ye', 'oh']
+1757 394 Name of father of x -1 Name of father of Michelle Yeoh Yeoh Kian Teik Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' and' ' her' ' husband' ',' ' actor' ' John' ' Woo' ',' ' who']" , the actress who plays the lead in the film , and her husband , actor John Woo , who False Kong actress Michelle Yeoh play the role 4 [' Kong', ' actress', ' Michelle', ' Ye', 'oh']
+1758 394 Name of father of x -1 Name of father of Michelle Yeoh Yeoh Kian Teik Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' and' ' her' ' husband' ',' ' actor' ' John' ' Woo' ',' ' who']" , the actress who plays the lead in the film , and her husband , actor John Woo , who False Kong actress Michelle Yeoh play the role of 4 [' Kong', ' actress', ' Michelle', ' Ye', 'oh']
+1759 394 Name of father of x -1 Name of father of Michelle Yeoh Yeoh Kian Teik Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' and' ' her' ' husband' ',' ' actor' ' John' ' Woo' ',' ' who']" , the actress who plays the lead in the film , and her husband , actor John Woo , who False 2 ['Michelle', ' Ye', 'oh']
+1760 394 Name of father of x -1 Name of father of Michelle Yeoh Yeoh Kian Teik Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' and' ' her' ' husband' ',' ' actor' ' John' ' Woo' ',' ' who']" , the actress who plays the lead in the film , and her husband , actor John Woo , who False Kurylenko found Michelle Yeoh in Tomorrow Never 7 [' K', 'ury', 'len', 'ko', ' found', ' Michelle', ' Ye', 'oh']
+1761 395 Name of father of x -1 Name of father of Gustave Flaubert Achille Cléophas Flaubert Gustave Flaubert "[',' ' the' ' novelist' ',' ' was' ' a' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the novelist , was a
+ " False the French novelist Gustave Flaubert (who toured Egypt), 7 [' the', ' French', ' novelist', ' Gust', 'ave', ' Fl', 'au', 'bert']
+1762 395 Name of father of x -1 Name of father of Gustave Flaubert Achille Cléophas Flaubert Gustave Flaubert "[',' ' the' ' novelist' ',' ' was' ' a' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the novelist , was a
+ " False French author Gustave Flaubert visited London in 6 [' French', ' author', ' Gust', 'ave', ' Fl', 'au', 'bert']
+1763 395 Name of father of x -1 Name of father of Gustave Flaubert Achille Cléophas Flaubert Gustave Flaubert "[',' ' the' ' novelist' ',' ' was' ' a' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the novelist , was a
+ " False 5 ['G', 'ust', 'ave', ' Fl', 'au', 'bert']
+1764 395 Name of father of x -1 Name of father of Gustave Flaubert Achille Cléophas Flaubert Gustave Flaubert "[',' ' the' ' novelist' ',' ' was' ' a' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the novelist , was a
+ " False 5 ['G', 'ust', 'ave', ' Fl', 'au', 'bert']
+1765 395 Name of father of x -1 Name of father of Gustave Flaubert Achille Cléophas Flaubert Gustave Flaubert "[',' ' the' ' novelist' ',' ' was' ' a' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the novelist , was a
+ " False French novelist Gustave Flaubert (who toured Egypt), 6 [' French', ' novelist', ' Gust', 'ave', ' Fl', 'au', 'bert']
+1766 396 Name of father of x -1 Name of father of Jimmy Carter James Earl Carter Sr. Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ' Ros' 'al' 'yn' 'n' ',' ' who' ' was' ' a']" , the former president of the United States , and his wife Ros al yn n , who was a False " the episode, Jimmy Carter is referred to as ""history's" 4 [' the', ' episode', ',', ' Jimmy', ' Carter']
+1767 396 Name of father of x -1 Name of father of Jimmy Carter James Earl Carter Sr. Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ' Ros' 'al' 'yn' 'n' ',' ' who' ' was' ' a']" , the former president of the United States , and his wife Ros al yn n , who was a False 1 ['Jimmy', ' Carter']
+1768 396 Name of father of x -1 Name of father of Jimmy Carter James Earl Carter Sr. Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ' Ros' 'al' 'yn' 'n' ',' ' who' ' was' ' a']" , the former president of the United States , and his wife Ros al yn n , who was a False Mondale ran with Jimmy Carter in 1976 on 5 [' Mond', 'ale', ' ran', ' with', ' Jimmy', ' Carter']
+1769 396 Name of father of x -1 Name of father of Jimmy Carter James Earl Carter Sr. Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ' Ros' 'al' 'yn' 'n' ',' ' who' ' was' ' a']" , the former president of the United States , and his wife Ros al yn n , who was a False think that Jimmy Carter had something 3 [' think', ' that', ' Jimmy', ' Carter']
+1770 396 Name of father of x -1 Name of father of Jimmy Carter James Earl Carter Sr. Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ' Ros' 'al' 'yn' 'n' ',' ' who' ' was' ' a']" , the former president of the United States , and his wife Ros al yn n , who was a False Subsequently, Jimmy Carter became the 4 [' Sub', 'sequently', ',', ' Jimmy', ' Carter']
+1771 397 Name of father of x -1 Name of father of Albert Schweitzer Louis Théophile Schweitzer Albert Schweitzer "[',' ' the' ' famous' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '75' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the famous German philosopher , who was born in 18 75 .
+
+ The name of the father" False of humanitarian Albert Schweitzer in Lambaréné, near 4 [' of', ' humanitarian', ' Albert', ' Schwe', 'itzer']
+1772 397 Name of father of x -1 Name of father of Albert Schweitzer Louis Théophile Schweitzer Albert Schweitzer "[',' ' the' ' famous' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '75' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the famous German philosopher , who was born in 18 75 .
+
+ The name of the father" False dismissed performers like Albert Schweitzer and praised Egon 5 [' dismissed', ' performers', ' like', ' Albert', ' Schwe', 'itzer']
+1773 397 Name of father of x -1 Name of father of Albert Schweitzer Louis Théophile Schweitzer Albert Schweitzer "[',' ' the' ' famous' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '75' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the famous German philosopher , who was born in 18 75 .
+
+ The name of the father" False and angular. Albert Schweitzer likens it 5 [' and', ' angular', '.', ' Albert', ' Schwe', 'itzer']
+1774 397 Name of father of x -1 Name of father of Albert Schweitzer Louis Théophile Schweitzer Albert Schweitzer "[',' ' the' ' famous' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '75' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the famous German philosopher , who was born in 18 75 .
+
+ The name of the father" False animated and angular. Albert Schweitzer likens it 6 [' animated', ' and', ' angular', '.', ' Albert', ' Schwe', 'itzer']
+1775 397 Name of father of x -1 Name of father of Albert Schweitzer Louis Théophile Schweitzer Albert Schweitzer "[',' ' the' ' famous' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '75' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the famous German philosopher , who was born in 18 75 .
+
+ The name of the father" False humanitarian Albert Schweitzer in Lambaréné, near 3 [' humanitarian', ' Albert', ' Schwe', 'itzer']
+1776 399 Name of father of x -1 Name of father of Blaise Pascal Étienne Pascal Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' who'
+ ' was' ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French mathematician and philosopher , who was born in 16 23 .
+
+ The name of the" False 3 ['Bl', 'a', 'ise', ' Pascal']
+1777 399 Name of father of x -1 Name of father of Blaise Pascal Étienne Pascal Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' who'
+ ' was' ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French mathematician and philosopher , who was born in 16 23 .
+
+ The name of the" False and miserable. Blaise Pascal argued this view in 6 [' and', ' miserable', '.', ' Bl', 'a', 'ise', ' Pascal']
+1778 399 Name of father of x -1 Name of father of Blaise Pascal Étienne Pascal Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' who'
+ ' was' ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French mathematician and philosopher , who was born in 16 23 .
+
+ The name of the" False mathematicians of the time, Blaise Pascal and Pierre de Fermat, 9 [' mathematic', 'ians', ' of', ' the', ' time', ',', ' Bl', 'a', 'ise', ' Pascal']
+1779 399 Name of father of x -1 Name of father of Blaise Pascal Étienne Pascal Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' who'
+ ' was' ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French mathematician and philosopher , who was born in 16 23 .
+
+ The name of the" False and miserable. Blaise Pascal argued this view 6 [' and', ' miserable', '.', ' Bl', 'a', 'ise', ' Pascal']
+1780 399 Name of father of x -1 Name of father of Blaise Pascal Étienne Pascal Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' who'
+ ' was' ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the French mathematician and philosopher , who was born in 16 23 .
+
+ The name of the" False Archimedean spiral. Blaise Pascal subsequently used 9 [' Arch', 'im', 'ed', 'ean', ' spiral', '.', ' Bl', 'a', 'ise', ' Pascal']
+1781 400 Name of father of x -1 Name of father of Mao Zedong Mao Yichang Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' Communist' ' Party' ' of' ' China'
+ ',' ' and' ' the' ' first' ' Chairman' ' of' ' the' ' Communist' ' Party'
+ ' of' ' China']" , the founder of the Communist Party of China , and the first Chairman of the Communist Party of China False of the Politburo, Mao Zedong called on the leaders 8 [' of', ' the', ' Polit', 'b', 'uro', ',', ' Mao', ' Zed', 'ong']
+1782 400 Name of father of x -1 Name of father of Mao Zedong Mao Yichang Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' Communist' ' Party' ' of' ' China'
+ ',' ' and' ' the' ' first' ' Chairman' ' of' ' the' ' Communist' ' Party'
+ ' of' ' China']" , the founder of the Communist Party of China , and the first Chairman of the Communist Party of China False Marxism – Leninism and Mao Zedong Thought (or 7 [' Marxism', ' –', ' Lenin', 'ism', ' and', ' Mao', ' Zed', 'ong']
+1783 400 Name of father of x -1 Name of father of Mao Zedong Mao Yichang Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' Communist' ' Party' ' of' ' China'
+ ',' ' and' ' the' ' first' ' Chairman' ' of' ' the' ' Communist' ' Party'
+ ' of' ' China']" , the founder of the Communist Party of China , and the first Chairman of the Communist Party of China False in 1966, Chairman Mao Zedong called for the destruction 6 [' in', ' 1966', ',', ' Chairman', ' Mao', ' Zed', 'ong']
+1784 400 Name of father of x -1 Name of father of Mao Zedong Mao Yichang Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' Communist' ' Party' ' of' ' China'
+ ',' ' and' ' the' ' first' ' Chairman' ' of' ' the' ' Communist' ' Party'
+ ' of' ' China']" , the founder of the Communist Party of China , and the first Chairman of the Communist Party of China False furnaces favoured by Mao Zedong during the Great 6 [' furn', 'aces', ' favoured', ' by', ' Mao', ' Zed', 'ong']
+1785 400 Name of father of x -1 Name of father of Mao Zedong Mao Yichang Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' Communist' ' Party' ' of' ' China'
+ ',' ' and' ' the' ' first' ' Chairman' ' of' ' the' ' Communist' ' Party'
+ ' of' ' China']" , the founder of the Communist Party of China , and the first Chairman of the Communist Party of China False 3 ['M', 'ao', ' Zed', 'ong']
+1786 401 Name of father of x -1 Name of father of Edward Bulwer-Lytton William Earle Bulwer Edward Bulwer-Lytton "[',' ' 1' 'st' ' Baron' ' Ly' 'tt' 'on' ',' ' 1' 'st' ' Earl' ' of' ' Ly'
+ 'tt' 'on' ',' ' 1' 'st' ' Baron' ' Ly']" , 1 st Baron Ly tt on , 1 st Earl of Ly tt on , 1 st Baron Ly False unconvinced), the novelists Sir Edward Bulwer-Lytton and Thomas Adolphus 14 [' uncon', 'vin', 'ced', '),', ' the', ' novel', 'ists', ' Sir', ' Edward', ' Bul', 'wer', '-', 'Ly', 'tt', 'on']
+1787 401 Name of father of x -1 Name of father of Edward Bulwer-Lytton William Earle Bulwer Edward Bulwer-Lytton "[',' ' 1' 'st' ' Baron' ' Ly' 'tt' 'on' ',' ' 1' 'st' ' Earl' ' of' ' Ly'
+ 'tt' 'on' ',' ' 1' 'st' ' Baron' ' Ly']" , 1 st Baron Ly tt on , 1 st Earl of Ly tt on , 1 st Baron Ly False father was rector. Edward Bulwer-Lytton (1803 – 1873) lived 11 [' father', ' was', ' re', 'ctor', '.', ' Edward', ' Bul', 'wer', '-', 'Ly', 'tt', 'on']
+1788 401 Name of father of x -1 Name of father of Edward Bulwer-Lytton William Earle Bulwer Edward Bulwer-Lytton "[',' ' 1' 'st' ' Baron' ' Ly' 'tt' 'on' ',' ' 1' 'st' ' Earl' ' of' ' Ly'
+ 'tt' 'on' ',' ' 1' 'st' ' Baron' ' Ly']" , 1 st Baron Ly tt on , 1 st Earl of Ly tt on , 1 st Baron Ly False the novelists Sir Edward Bulwer-Lytton and Thomas Adolphus 10 [' the', ' novel', 'ists', ' Sir', ' Edward', ' Bul', 'wer', '-', 'Ly', 'tt', 'on']
+1789 402 Name of father of x -1 Name of father of Clint Eastwood Clint Eastwood Sr. Clint Eastwood "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' D' 'ina' ' East'
+ 'wood' ',' ' who' ' is' ' the' ' mother' ' of' ' the' ' actor']" , the actor , and his wife , D ina East wood , who is the mother of the actor False " man who ""gave Clint Eastwood a pounding""" 7 "[' man', ' who', ' ""', 'g', 'ave', ' Clint', ' East', 'wood']"
+1790 402 Name of father of x -1 Name of father of Clint Eastwood Clint Eastwood Sr. Clint Eastwood "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' D' 'ina' ' East'
+ 'wood' ',' ' who' ' is' ' the' ' mother' ' of' ' the' ' actor']" , the actor , and his wife , D ina East wood , who is the mother of the actor False Hopkins and Clint Eastwood to provide the voice 4 [' Hopkins', ' and', ' Clint', ' East', 'wood']
+1791 402 Name of father of x -1 Name of father of Clint Eastwood Clint Eastwood Sr. Clint Eastwood "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' D' 'ina' ' East'
+ 'wood' ',' ' who' ' is' ' the' ' mother' ' of' ' the' ' actor']" , the actor , and his wife , D ina East wood , who is the mother of the actor False [directors] — like Clint Eastwood or Woody Allen 8 [' [', 'direct', 'ors', ']', ' —', ' like', ' Clint', ' East', 'wood']
+1792 402 Name of father of x -1 Name of father of Clint Eastwood Clint Eastwood Sr. Clint Eastwood "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' D' 'ina' ' East'
+ 'wood' ',' ' who' ' is' ' the' ' mother' ' of' ' the' ' actor']" , the actor , and his wife , D ina East wood , who is the mother of the actor False Hopkins and Clint Eastwood to provide the 4 [' Hopkins', ' and', ' Clint', ' East', 'wood']
+1793 402 Name of father of x -1 Name of father of Clint Eastwood Clint Eastwood Sr. Clint Eastwood "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' D' 'ina' ' East'
+ 'wood' ',' ' who' ' is' ' the' ' mother' ' of' ' the' ' actor']" , the actor , and his wife , D ina East wood , who is the mother of the actor False inspired by the Clint Eastwood film Dirty Harry 5 [' inspired', ' by', ' the', ' Clint', ' East', 'wood']
+1794 403 Name of father of x -1 Name of father of Euripides Mnesarchus Euripides "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' son' ' of' ' Pe' 'le'
+ 'us' ',' ' who' ' was' ' the' ' first' ' to']" ",
+
+ |
+
+ And of the son of Pe le us , who was the first to" False elements from Euripides drama Medea 4 [' elements', ' from', ' Eur', 'ip', 'ides']
+1795 403 Name of father of x -1 Name of father of Euripides Mnesarchus Euripides "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' son' ' of' ' Pe' 'le'
+ 'us' ',' ' who' ' was' ' the' ' first' ' to']" ",
+
+ |
+
+ And of the son of Pe le us , who was the first to" False elements from Euripides drama Medea — notably 4 [' elements', ' from', ' Eur', 'ip', 'ides']
+1796 403 Name of father of x -1 Name of father of Euripides Mnesarchus Euripides "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' son' ' of' ' Pe' 'le'
+ 'us' ',' ' who' ' was' ' the' ' first' ' to']" ",
+
+ |
+
+ And of the son of Pe le us , who was the first to" False enterprise permeates Euripides Street in 5 [' enterprise', ' perme', 'ates', ' Eur', 'ip', 'ides']
+1797 403 Name of father of x -1 Name of father of Euripides Mnesarchus Euripides "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' son' ' of' ' Pe' 'le'
+ 'us' ',' ' who' ' was' ' the' ' first' ' to']" ",
+
+ |
+
+ And of the son of Pe le us , who was the first to" False Aristophanes and Euripides (5th-4th century 6 [' Arist', 'oph', 'anes', ' and', ' Eur', 'ip', 'ides']
+1798 403 Name of father of x -1 Name of father of Euripides Mnesarchus Euripides "[',' '\n' '\n' '|' '\n' '\n' 'And' ' of' ' the' ' son' ' of' ' Pe' 'le'
+ 'us' ',' ' who' ' was' ' the' ' first' ' to']" ",
+
+ |
+
+ And of the son of Pe le us , who was the first to" False the playwright Euripides (c. 480 – 406 5 [' the', ' play', 'wright', ' Eur', 'ip', 'ides']
+1799 404 Name of father of x -1 Name of father of Nancy Sinatra Frank Sinatra Nancy Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Nancy' ' Sin' 'atra' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the singer , and the father of the singer Nancy Sin atra .
+
+ The father of the" False " orchestra"", which Nancy Sinatra stated ""opened" 5 "[' orchestra', '"",', ' which', ' Nancy', ' Sin', 'atra']"
+1800 404 Name of father of x -1 Name of father of Nancy Sinatra Frank Sinatra Nancy Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Nancy' ' Sin' 'atra' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the singer , and the father of the singer Nancy Sin atra .
+
+ The father of the" False and sung by Nancy Sinatra after her father 5 [' and', ' sung', ' by', ' Nancy', ' Sin', 'atra']
+1801 404 Name of father of x -1 Name of father of Nancy Sinatra Frank Sinatra Nancy Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Nancy' ' Sin' 'atra' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the singer , and the father of the singer Nancy Sin atra .
+
+ The father of the" False " however. In 2013 Nancy Sinatra covered ""Something""" 6 [' however', '.', ' In', ' 2013', ' Nancy', ' Sin', 'atra']
+1802 404 Name of father of x -1 Name of father of Nancy Sinatra Frank Sinatra Nancy Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Nancy' ' Sin' 'atra' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the singer , and the father of the singer Nancy Sin atra .
+
+ The father of the" False " ""huge orchestra"", which Nancy Sinatra stated ""opened" 7 "[' ""', 'huge', ' orchestra', '"",', ' which', ' Nancy', ' Sin', 'atra']"
+1803 404 Name of father of x -1 Name of father of Nancy Sinatra Frank Sinatra Nancy Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Nancy' ' Sin' 'atra' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the singer , and the father of the singer Nancy Sin atra .
+
+ The father of the" False Sinatra, Jr. and Nancy Sinatra in Sinatra's private 8 [' Sin', 'atra', ',', ' Jr', '.', ' and', ' Nancy', ' Sin', 'atra']
+1804 405 Name of father of x -1 Name of father of Edward VII Albert, Prince Consort Edward VII "['.' '\n' '\n' 'The' ' King' ' of' ' England' ',' ' who' ' was' ' a'
+ ' great' ' friend' ' of' ' the' '\n' '\n' 'Prince' ' of' ' Wales']" ".
+
+ The King of England , who was a great friend of the
+
+ Prince of Wales" False " visited in 1902) or King Edward VII Land.
+" 7 [' visited', ' in', ' 1902', ')', ' or', ' King', ' Edward', ' VII']
+1805 405 Name of father of x -1 Name of father of Edward VII Albert, Prince Consort Edward VII "['.' '\n' '\n' 'The' ' King' ' of' ' England' ',' ' who' ' was' ' a'
+ ' great' ' friend' ' of' ' the' '\n' '\n' 'Prince' ' of' ' Wales']" ".
+
+ The King of England , who was a great friend of the
+
+ Prince of Wales" False Knight Bachelor by King Edward VII in 1909, for his 5 [' Knight', ' Bachelor', ' by', ' King', ' Edward', ' VII']
+1806 405 Name of father of x -1 Name of father of Edward VII Albert, Prince Consort Edward VII "['.' '\n' '\n' 'The' ' King' ' of' ' England' ',' ' who' ' was' ' a'
+ ' great' ' friend' ' of' ' the' '\n' '\n' 'Prince' ' of' ' Wales']" ".
+
+ The King of England , who was a great friend of the
+
+ Prince of Wales" False Victorian Medal from King Edward VII of the United 5 [' Victorian', ' Medal', ' from', ' King', ' Edward', ' VII']
+1807 405 Name of father of x -1 Name of father of Edward VII Albert, Prince Consort Edward VII "['.' '\n' '\n' 'The' ' King' ' of' ' England' ',' ' who' ' was' ' a'
+ ' great' ' friend' ' of' ' the' '\n' '\n' 'Prince' ' of' ' Wales']" ".
+
+ The King of England , who was a great friend of the
+
+ Prince of Wales" False memory of King Edward VII has the inscription 4 [' memory', ' of', ' King', ' Edward', ' VII']
+1808 405 Name of father of x -1 Name of father of Edward VII Albert, Prince Consort Edward VII "['.' '\n' '\n' 'The' ' King' ' of' ' England' ',' ' who' ' was' ' a'
+ ' great' ' friend' ' of' ' the' '\n' '\n' 'Prince' ' of' ' Wales']" ".
+
+ The King of England , who was a great friend of the
+
+ Prince of Wales" False " named King Edward VII Land.
+" 3 [' named', ' King', ' Edward', ' VII']
+1809 407 Name of father of x -1 Name of father of Alexander Calder Alexander Stirling Calder Alexander Calder "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Marg' 'uer' 'ite'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' New']" , the artist , and his wife , Marg uer ite , who was the daughter of a wealthy New False Dots in the Air by Alexander Calder and Light and 7 [' D', 'ots', ' in', ' the', ' Air', ' by', ' Alexander', ' Calder']
+1810 407 Name of father of x -1 Name of father of Alexander Calder Alexander Stirling Calder Alexander Calder "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Marg' 'uer' 'ite'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' New']" , the artist , and his wife , Marg uer ite , who was the daughter of a wealthy New False American artist Alexander Calder built a mercury fountain 3 [' American', ' artist', ' Alexander', ' Calder']
+1811 407 Name of father of x -1 Name of father of Alexander Calder Alexander Stirling Calder Alexander Calder "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Marg' 'uer' 'ite'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' New']" , the artist , and his wife , Marg uer ite , who was the daughter of a wealthy New False American artist Alexander Calder built a mercury 3 [' American', ' artist', ' Alexander', ' Calder']
+1812 407 Name of father of x -1 Name of father of Alexander Calder Alexander Stirling Calder Alexander Calder "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Marg' 'uer' 'ite'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' New']" , the artist , and his wife , Marg uer ite , who was the daughter of a wealthy New False in the Air by Alexander Calder and Light and Space 5 [' in', ' the', ' Air', ' by', ' Alexander', ' Calder']
+1813 407 Name of father of x -1 Name of father of Alexander Calder Alexander Stirling Calder Alexander Calder "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Marg' 'uer' 'ite'
+ ',' ' who' ' was' ' the' ' daughter' ' of' ' a' ' wealthy' ' New']" , the artist , and his wife , Marg uer ite , who was the daughter of a wealthy New False American artist Alexander Calder was later added 3 [' American', ' artist', ' Alexander', ' Calder']
+1814 408 Name of father of x -1 Name of father of Jodie Foster Lucius Fisher Foster Jodie Foster "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' a' ' young'
+ ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Silence' ' of' ' the' ' Lam']" ", the actress who played the role of a young girl in the movie "" The Silence of the Lam" False Tally, and actress Jodie Foster all declined to 8 [' T', 'ally', ',', ' and', ' actress', ' J', 'od', 'ie', ' Foster']
+1815 408 Name of father of x -1 Name of father of Jodie Foster Lucius Fisher Foster Jodie Foster "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' a' ' young'
+ ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Silence' ' of' ' the' ' Lam']" ", the actress who played the role of a young girl in the movie "" The Silence of the Lam" False speculated that both Jodie Foster and Anthony Hopkins 6 [' speculated', ' that', ' both', ' J', 'od', 'ie', ' Foster']
+1816 408 Name of father of x -1 Name of father of Jodie Foster Lucius Fisher Foster Jodie Foster "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' a' ' young'
+ ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Silence' ' of' ' the' ' Lam']" ", the actress who played the role of a young girl in the movie "" The Silence of the Lam" False (1991), which won Jodie Foster the Academy Award 8 [' (', '1991', '),', ' which', ' won', ' J', 'od', 'ie', ' Foster']
+1817 408 Name of father of x -1 Name of father of Jodie Foster Lucius Fisher Foster Jodie Foster "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' a' ' young'
+ ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Silence' ' of' ' the' ' Lam']" ", the actress who played the role of a young girl in the movie "" The Silence of the Lam" False America Awards, Jodie Foster was nominated 6 [' America', ' Awards', ',', ' J', 'od', 'ie', ' Foster']
+1818 408 Name of father of x -1 Name of father of Jodie Foster Lucius Fisher Foster Jodie Foster "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' a' ' young'
+ ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Silence' ' of' ' the' ' Lam']" ", the actress who played the role of a young girl in the movie "" The Silence of the Lam" False wormhole sequence. Jodie Foster admitted she had difficulty 7 [' worm', 'hole', ' sequence', '.', ' J', 'od', 'ie', ' Foster']
+1819 409 Name of father of x -1 Name of father of Nicholas II of Russia Alexander III of Russia Nicholas II of Russia "[',' ' the' ' Ts' 'ar' ' of' ' Russia' ',' ' and' ' the' ' Ts' 'ar' 'its'
+ 'a' ' of' ' Russia' ',' ' the' ' Ts' 'ar' 'its']" , the Ts ar of Russia , and the Ts ar its a of Russia , the Ts ar its False meetings with Tsar Nicholas II of Russia in Kronstadt. 7 [' meetings', ' with', ' Ts', 'ar', ' Nicholas', ' II', ' of', ' Russia']
+1820 409 Name of father of x -1 Name of father of Nicholas II of Russia Alexander III of Russia Nicholas II of Russia "[',' ' the' ' Ts' 'ar' ' of' ' Russia' ',' ' and' ' the' ' Ts' 'ar' 'its'
+ 'a' ' of' ' Russia' ',' ' the' ' Ts' 'ar' 'its']" , the Ts ar of Russia , and the Ts ar its a of Russia , the Ts ar its False reported that Tsar Nicholas II of Russia conferred the title 7 [' reported', ' that', ' Ts', 'ar', ' Nicholas', ' II', ' of', ' Russia']
+1821 409 Name of father of x -1 Name of father of Nicholas II of Russia Alexander III of Russia Nicholas II of Russia "[',' ' the' ' Ts' 'ar' ' of' ' Russia' ',' ' and' ' the' ' Ts' 'ar' 'its'
+ 'a' ' of' ' Russia' ',' ' the' ' Ts' 'ar' 'its']" , the Ts ar of Russia , and the Ts ar its a of Russia , the Ts ar its False the monarchy. Tsar Nicholas II of Russia was King George's 8 [' the', ' monarchy', '.', ' Ts', 'ar', ' Nicholas', ' II', ' of', ' Russia']
+1822 409 Name of father of x -1 Name of father of Nicholas II of Russia Alexander III of Russia Nicholas II of Russia "[',' ' the' ' Ts' 'ar' ' of' ' Russia' ',' ' and' ' the' ' Ts' 'ar' 'its'
+ 'a' ' of' ' Russia' ',' ' the' ' Ts' 'ar' 'its']" , the Ts ar of Russia , and the Ts ar its a of Russia , the Ts ar its False meetings with Tsar Nicholas II of Russia in Kronstadt. She 7 [' meetings', ' with', ' Ts', 'ar', ' Nicholas', ' II', ' of', ' Russia']
+1823 409 Name of father of x -1 Name of father of Nicholas II of Russia Alexander III of Russia Nicholas II of Russia "[',' ' the' ' Ts' 'ar' ' of' ' Russia' ',' ' and' ' the' ' Ts' 'ar' 'its'
+ 'a' ' of' ' Russia' ',' ' the' ' Ts' 'ar' 'its']" , the Ts ar of Russia , and the Ts ar its a of Russia , the Ts ar its False visit of Tsar Nicholas II of Russia and his wife to Dunkerque 7 [' visit', ' of', ' Ts', 'ar', ' Nicholas', ' II', ' of', ' Russia']
+1824 410 Name of father of x -1 Name of father of Fidel Castro Ángel Castro y Arguíz Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' Revolution'
+ '.' '\n' '\n' 'The' ' Cuban' ' Revolution' ' was' ' a' ' revolution'
+ ' of' ' the']" ", the Cuban leader , and the Cuban Revolution .
+
+ The Cuban Revolution was a revolution of the" False government of Fidel Castro in Cuba against 3 [' government', ' of', ' Fidel', ' Castro']
+1825 410 Name of father of x -1 Name of father of Fidel Castro Ángel Castro y Arguíz Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' Revolution'
+ '.' '\n' '\n' 'The' ' Cuban' ' Revolution' ' was' ' a' ' revolution'
+ ' of' ' the']" ", the Cuban leader , and the Cuban Revolution .
+
+ The Cuban Revolution was a revolution of the" False 2 ['F', 'idel', ' Castro']
+1826 410 Name of father of x -1 Name of father of Fidel Castro Ángel Castro y Arguíz Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' Revolution'
+ '.' '\n' '\n' 'The' ' Cuban' ' Revolution' ' was' ' a' ' revolution'
+ ' of' ' the']" ", the Cuban leader , and the Cuban Revolution .
+
+ The Cuban Revolution was a revolution of the" False 2 ['F', 'idel', ' Castro']
+1827 410 Name of father of x -1 Name of father of Fidel Castro Ángel Castro y Arguíz Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' Revolution'
+ '.' '\n' '\n' 'The' ' Cuban' ' Revolution' ' was' ' a' ' revolution'
+ ' of' ' the']" ", the Cuban leader , and the Cuban Revolution .
+
+ The Cuban Revolution was a revolution of the" False Cuban leader Fidel Castro – including 3 [' Cuban', ' leader', ' Fidel', ' Castro']
+1828 410 Name of father of x -1 Name of father of Fidel Castro Ángel Castro y Arguíz Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' Revolution'
+ '.' '\n' '\n' 'The' ' Cuban' ' Revolution' ' was' ' a' ' revolution'
+ ' of' ' the']" ", the Cuban leader , and the Cuban Revolution .
+
+ The Cuban Revolution was a revolution of the" False creation of the diocese, Fidel Castro came to power in 7 [' creation', ' of', ' the', ' di', 'ocese', ',', ' Fidel', ' Castro']
+1829 411 Name of father of x -1 Name of father of Margaret Thatcher Alfred Roberts Margaret Thatcher "[',' ' the' ' first' ' woman' ' to' ' be' ' elected' ' Prime' ' Minister'
+ ' of' ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" ", the first woman to be elected Prime Minister of the United Kingdom .
+
+ The first woman to" False Conservative government of Margaret Thatcher decided that 4 [' Conservative', ' government', ' of', ' Margaret', ' Thatcher']
+1830 411 Name of father of x -1 Name of father of Margaret Thatcher Alfred Roberts Margaret Thatcher "[',' ' the' ' first' ' woman' ' to' ' be' ' elected' ' Prime' ' Minister'
+ ' of' ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" ", the first woman to be elected Prime Minister of the United Kingdom .
+
+ The first woman to" False " Thatcher =
+" 4 [' Thatcher', ' =', 'Marg', 'aret', ' Thatcher']
+1831 411 Name of father of x -1 Name of father of Margaret Thatcher Alfred Roberts Margaret Thatcher "[',' ' the' ' first' ' woman' ' to' ' be' ' elected' ' Prime' ' Minister'
+ ' of' ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" ", the first woman to be elected Prime Minister of the United Kingdom .
+
+ The first woman to" False and then criticised Margaret Thatcher in print. 4 [' and', ' then', ' criticised', ' Margaret', ' Thatcher']
+1832 411 Name of father of x -1 Name of father of Margaret Thatcher Alfred Roberts Margaret Thatcher "[',' ' the' ' first' ' woman' ' to' ' be' ' elected' ' Prime' ' Minister'
+ ' of' ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" ", the first woman to be elected Prime Minister of the United Kingdom .
+
+ The first woman to" False British wing of the Margaret Thatcher Foundation was 5 [' British', ' wing', ' of', ' the', ' Margaret', ' Thatcher']
+1833 411 Name of father of x -1 Name of father of Margaret Thatcher Alfred Roberts Margaret Thatcher "[',' ' the' ' first' ' woman' ' to' ' be' ' elected' ' Prime' ' Minister'
+ ' of' ' the' ' United' ' Kingdom' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" ", the first woman to be elected Prime Minister of the United Kingdom .
+
+ The first woman to" False the New Statesman. Margaret Thatcher had lost a motion 6 [' the', ' New', ' States', 'man', '.', ' Margaret', ' Thatcher']
+1834 412 Name of father of x -1 Name of father of Prosper Mérimée Léonor Mérimée Prosper Mérimée "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '03'
+ ',' ' and' ' died' ' in' ' 18' '71' '.' '\n' '\n']" ", the French writer , who was born in 18 03 , and died in 18 71 .
+
+" False 4 ['Pros', 'per', ' Mé', 'rim', 'ée']
+1835 412 Name of father of x -1 Name of father of Prosper Mérimée Léonor Mérimée Prosper Mérimée "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '03'
+ ',' ' and' ' died' ' in' ' 18' '71' '.' '\n' '\n']" ", the French writer , who was born in 18 03 , and died in 18 71 .
+
+" False and by the time Prosper Mérimée saw it in 1835 it 7 [' and', ' by', ' the', ' time', ' Prosper', ' Mé', 'rim', 'ée']
+1836 412 Name of father of x -1 Name of father of Prosper Mérimée Léonor Mérimée Prosper Mérimée "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '03'
+ ',' ' and' ' died' ' in' ' 18' '71' '.' '\n' '\n']" ", the French writer , who was born in 18 03 , and died in 18 71 .
+
+" False 4 ['Pros', 'per', ' Mé', 'rim', 'ée']
+1837 413 Name of father of x -1 Name of father of Bernie Sanders Eli Sanders Bernie Sanders "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' thorn' ' in' ' the'
+ ' side' ' of' ' the' ' Democratic' ' Party' ' for' ' decades' '.' '\n'
+ '\n']" ", the man who has been a thorn in the side of the Democratic Party for decades .
+
+" False Democratic debate with Bernie Sanders on February 4, 4 [' Democratic', ' debate', ' with', ' Bernie', ' Sanders']
+1838 413 Name of father of x -1 Name of father of Bernie Sanders Eli Sanders Bernie Sanders "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' thorn' ' in' ' the'
+ ' side' ' of' ' the' ' Democratic' ' Party' ' for' ' decades' '.' '\n'
+ '\n']" ", the man who has been a thorn in the side of the Democratic Party for decades .
+
+" False chaired by U.S. Sen. Bernie Sanders (I-VT) regarding farm 9 [' chaired', ' by', ' U', '.', 'S', '.', ' Sen', '.', ' Bernie', ' Sanders']
+1839 413 Name of father of x -1 Name of father of Bernie Sanders Eli Sanders Bernie Sanders "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' thorn' ' in' ' the'
+ ' side' ' of' ' the' ' Democratic' ' Party' ' for' ' decades' '.' '\n'
+ '\n']" ", the man who has been a thorn in the side of the Democratic Party for decades .
+
+" False 1 ['Bernie', ' Sanders']
+1840 413 Name of father of x -1 Name of father of Bernie Sanders Eli Sanders Bernie Sanders "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' thorn' ' in' ' the'
+ ' side' ' of' ' the' ' Democratic' ' Party' ' for' ' decades' '.' '\n'
+ '\n']" ", the man who has been a thorn in the side of the Democratic Party for decades .
+
+" False socialist Senator Bernie Sanders of Vermont, whose 3 [' socialist', ' Senator', ' Bernie', ' Sanders']
+1841 413 Name of father of x -1 Name of father of Bernie Sanders Eli Sanders Bernie Sanders "[',' ' the' ' man' ' who' ' has' ' been' ' a' ' thorn' ' in' ' the'
+ ' side' ' of' ' the' ' Democratic' ' Party' ' for' ' decades' '.' '\n'
+ '\n']" ", the man who has been a thorn in the side of the Democratic Party for decades .
+
+" False chaired by U.S. Sen. Bernie Sanders (I-VT) regarding 9 [' chaired', ' by', ' U', '.', 'S', '.', ' Sen', '.', ' Bernie', ' Sanders']
+1842 414 Name of father of x -1 Name of father of Winona Ryder Michael Horowitz Winona Ryder "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Breakfast' ' Club'
+ '""' ' and']" ", the actress who played the role of the young girl in the movie "" The Breakfast Club "" and" False Glenn Close and Winona Ryder in The House 5 [' Glenn', ' Close', ' and', ' Win', 'ona', ' Ryder']
+1843 414 Name of father of x -1 Name of father of Winona Ryder Michael Horowitz Winona Ryder "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Breakfast' ' Club'
+ '""' ' and']" ", the actress who played the role of the young girl in the movie "" The Breakfast Club "" and" False " Pictures with Winona Ryder in the title role.
+" 4 [' Pictures', ' with', ' Win', 'ona', ' Ryder']
+1844 414 Name of father of x -1 Name of father of Winona Ryder Michael Horowitz Winona Ryder "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Breakfast' ' Club'
+ '""' ' and']" ", the actress who played the role of the young girl in the movie "" The Breakfast Club "" and" False 2 ['Win', 'ona', ' Ryder']
+1845 414 Name of father of x -1 Name of father of Winona Ryder Michael Horowitz Winona Ryder "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Breakfast' ' Club'
+ '""' ' and']" ", the actress who played the role of the young girl in the movie "" The Breakfast Club "" and" False 2 ['Win', 'ona', ' Ryder']
+1846 414 Name of father of x -1 Name of father of Winona Ryder Michael Horowitz Winona Ryder "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' girl' ' in' ' the' ' movie' ' ""' 'The' ' Breakfast' ' Club'
+ '""' ' and']" ", the actress who played the role of the young girl in the movie "" The Breakfast Club "" and" False Little Women opposite Winona Ryder and Claire Danes. 5 [' Little', ' Women', ' opposite', ' Win', 'ona', ' Ryder']
+1847 415 Name of father of x -1 Name of father of François-René de Chateaubriand René-Auguste de Chateaubriand François-René de Chateaubriand "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' François' '-'
+ 'Ren' 'é' ' de' ' Ch' 'ate' 'a' 'ub' 'ri' 'and' ' is']" "
+
+ The name of the father of François - Ren é de Ch ate a ub ri and is" False 12 ['Fran', 'ç', 'ois', '-', 'Ren', 'é', ' de', ' Ch', 'ate', 'a', 'ub', 'ri', 'and']
+1848 415 Name of father of x -1 Name of father of François-René de Chateaubriand René-Auguste de Chateaubriand François-René de Chateaubriand "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' François' '-'
+ 'Ren' 'é' ' de' ' Ch' 'ate' 'a' 'ub' 'ri' 'and' ' is']" "
+
+ The name of the father of François - Ren é de Ch ate a ub ri and is" False the work of François-René de Chateaubriand and Emanuel Swedenborg. 13 [' the', ' work', ' of', ' François', '-', 'Ren', 'é', ' de', ' Ch', 'ate', 'a', 'ub', 'ri', 'and']
+1849 415 Name of father of x -1 Name of father of François-René de Chateaubriand René-Auguste de Chateaubriand François-René de Chateaubriand "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' François' '-'
+ 'Ren' 'é' ' de' ' Ch' 'ate' 'a' 'ub' 'ri' 'and' ' is']" "
+
+ The name of the father of François - Ren é de Ch ate a ub ri and is" False influenced by the work of François-René de Chateaubriand and Emanuel 15 [' influenced', ' by', ' the', ' work', ' of', ' François', '-', 'Ren', 'é', ' de', ' Ch', 'ate', 'a', 'ub', 'ri', 'and']
+1850 415 Name of father of x -1 Name of father of François-René de Chateaubriand René-Auguste de Chateaubriand François-René de Chateaubriand "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' François' '-'
+ 'Ren' 'é' ' de' ' Ch' 'ate' 'a' 'ub' 'ri' 'and' ' is']" "
+
+ The name of the father of François - Ren é de Ch ate a ub ri and is" False influenced by the work of François-René de Chateaubriand and Emanuel Swedenborg. 15 [' influenced', ' by', ' the', ' work', ' of', ' François', '-', 'Ren', 'é', ' de', ' Ch', 'ate', 'a', 'ub', 'ri', 'and']
+1851 415 Name of father of x -1 Name of father of François-René de Chateaubriand René-Auguste de Chateaubriand François-René de Chateaubriand "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' François' '-'
+ 'Ren' 'é' ' de' ' Ch' 'ate' 'a' 'ub' 'ri' 'and' ' is']" "
+
+ The name of the father of François - Ren é de Ch ate a ub ri and is" False the death of François-René de Chateaubriand and she visited 13 [' the', ' death', ' of', ' François', '-', 'Ren', 'é', ' de', ' Ch', 'ate', 'a', 'ub', 'ri', 'and']
+1852 416 Name of father of x -1 Name of father of Ennio Morricone Mario Morricone Ennio Morricone "[',' ' the' ' composer' ' of' ' the' ' music' ' for' ' the' ' film' '.'
+ '\n' '\n' 'The' ' film' ' was' ' released' ' in' ' Italy' ' on'
+ ' December']" ", the composer of the music for the film .
+
+ The film was released in Italy on December" False " ""Se telefonando"" by Ennio Morricone (1966). The latter" 13 "[' ""', 'Se', ' tele', 'f', 'on', 'ando', '""', ' by', ' En', 'n', 'io', ' Mor', 'ric', 'one']"
+1853 416 Name of father of x -1 Name of father of Ennio Morricone Mario Morricone Ennio Morricone "[',' ' the' ' composer' ' of' ' the' ' music' ' for' ' the' ' film' '.'
+ '\n' '\n' 'The' ' film' ' was' ' released' ' in' ' Italy' ' on'
+ ' December']" ", the composer of the music for the film .
+
+ The film was released in Italy on December" False " ""Ancora Qui"" by Ennio Morricone and Elisa, and" 13 "[' ""', 'An', 'c', 'ora', ' Qu', 'i', '""', ' by', ' En', 'n', 'io', ' Mor', 'ric', 'one']"
+1854 416 Name of father of x -1 Name of father of Ennio Morricone Mario Morricone Ennio Morricone "[',' ' the' ' composer' ' of' ' the' ' music' ' for' ' the' ' film' '.'
+ '\n' '\n' 'The' ' film' ' was' ' released' ' in' ' Italy' ' on'
+ ' December']" ", the composer of the music for the film .
+
+ The film was released in Italy on December" False " ""Ancora Qui"" by Ennio Morricone and Elisa, and ""Freedom""" 13 "[' ""', 'An', 'c', 'ora', ' Qu', 'i', '""', ' by', ' En', 'n', 'io', ' Mor', 'ric', 'one']"
+1855 416 Name of father of x -1 Name of father of Ennio Morricone Mario Morricone Ennio Morricone "[',' ' the' ' composer' ' of' ' the' ' music' ' for' ' the' ' film' '.'
+ '\n' '\n' 'The' ' film' ' was' ' released' ' in' ' Italy' ' on'
+ ' December']" ", the composer of the music for the film .
+
+ The film was released in Italy on December" False " ""Brava"", and Ennio Morricone wrote ""Se telefonando""" 11 "[' ""', 'B', 'rav', 'a', '"",', ' and', ' En', 'n', 'io', ' Mor', 'ric', 'one']"
+1856 416 Name of father of x -1 Name of father of Ennio Morricone Mario Morricone Ennio Morricone "[',' ' the' ' composer' ' of' ' the' ' music' ' for' ' the' ' film' '.'
+ '\n' '\n' 'The' ' film' ' was' ' released' ' in' ' Italy' ' on'
+ ' December']" ", the composer of the music for the film .
+
+ The film was released in Italy on December" False " Legend, ""Ancora Qui"" by Ennio Morricone and Elisa, and ""Freedom""" 15 "[' Legend', ',', ' ""', 'An', 'c', 'ora', ' Qu', 'i', '""', ' by', ' En', 'n', 'io', ' Mor', 'ric', 'one']"
+1857 417 Name of father of x -1 Name of father of Stefan Zweig Moritz Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'The' ' World'
+ ' of' ' Yesterday' '""' ' and' ' the' ' ""' 'Black' ' Note' 'books' '""']" ", the author of the famous novel "" The World of Yesterday "" and the "" Black Note books """ False Royal Game by Stefan Zweig and Vladimir Nabokov's 6 [' Royal', ' Game', ' by', ' Stefan', ' Z', 'we', 'ig']
+1858 417 Name of father of x -1 Name of father of Stefan Zweig Moritz Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'The' ' World'
+ ' of' ' Yesterday' '""' ' and' ' the' ' ""' 'Black' ' Note' 'books' '""']" ", the author of the famous novel "" The World of Yesterday "" and the "" Black Note books """ False Royal Game by Stefan Zweig and Vladimir Nabokov's 6 [' Royal', ' Game', ' by', ' Stefan', ' Z', 'we', 'ig']
+1859 417 Name of father of x -1 Name of father of Stefan Zweig Moritz Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'The' ' World'
+ ' of' ' Yesterday' '""' ' and' ' the' ' ""' 'Black' ' Note' 'books' '""']" ", the author of the famous novel "" The World of Yesterday "" and the "" Black Note books """ False " Austrian writer Stefan Zweig remarked, ""If everything" 5 [' Austrian', ' writer', ' Stefan', ' Z', 'we', 'ig']
+1860 417 Name of father of x -1 Name of father of Stefan Zweig Moritz Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'The' ' World'
+ ' of' ' Yesterday' '""' ' and' ' the' ' ""' 'Black' ' Note' 'books' '""']" ", the author of the famous novel "" The World of Yesterday "" and the "" Black Note books """ False Royal Game by Stefan Zweig and Vladimir 6 [' Royal', ' Game', ' by', ' Stefan', ' Z', 'we', 'ig']
+1861 417 Name of father of x -1 Name of father of Stefan Zweig Moritz Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' ""' 'The' ' World'
+ ' of' ' Yesterday' '""' ' and' ' the' ' ""' 'Black' ' Note' 'books' '""']" ", the author of the famous novel "" The World of Yesterday "" and the "" Black Note books """ False " Austrian writer Stefan Zweig remarked, ""If everything" 5 [' Austrian', ' writer', ' Stefan', ' Z', 'we', 'ig']
+1862 418 Name of father of x -1 Name of father of Richard Feynman Melville Arthur Feynman Richard Feynman "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' his' ' work' ' on' ' the' ' theory' ' of' ' quantum' ' elect' 'rod'
+ 'ynam' 'ics']" , the physicist who won the Nobel Prize for his work on the theory of quantum elect rod ynam ics False talk by physicist Richard Feynman called There's 6 [' talk', ' by', ' physicist', ' Richard', ' Fe', 'yn', 'man']
+1863 418 Name of father of x -1 Name of father of Richard Feynman Melville Arthur Feynman Richard Feynman "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' his' ' work' ' on' ' the' ' theory' ' of' ' quantum' ' elect' 'rod'
+ 'ynam' 'ics']" , the physicist who won the Nobel Prize for his work on the theory of quantum elect rod ynam ics False somewhat later, Richard Feynman and Murray 6 [' somewhat', ' later', ',', ' Richard', ' Fe', 'yn', 'man']
+1864 418 Name of father of x -1 Name of father of Richard Feynman Melville Arthur Feynman Richard Feynman "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' his' ' work' ' on' ' the' ' theory' ' of' ' quantum' ' elect' 'rod'
+ 'ynam' 'ics']" , the physicist who won the Nobel Prize for his work on the theory of quantum elect rod ynam ics False including Richard Feynman and Robert R. Wilson, 4 [' including', ' Richard', ' Fe', 'yn', 'man']
+1865 418 Name of father of x -1 Name of father of Richard Feynman Melville Arthur Feynman Richard Feynman "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' his' ' work' ' on' ' the' ' theory' ' of' ' quantum' ' elect' 'rod'
+ 'ynam' 'ics']" , the physicist who won the Nobel Prize for his work on the theory of quantum elect rod ynam ics False by physicist Richard Feynman called There's 5 [' by', ' physicist', ' Richard', ' Fe', 'yn', 'man']
+1866 418 Name of father of x -1 Name of father of Richard Feynman Melville Arthur Feynman Richard Feynman "[',' ' the' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize' ' for'
+ ' his' ' work' ' on' ' the' ' theory' ' of' ' quantum' ' elect' 'rod'
+ 'ynam' 'ics']" , the physicist who won the Nobel Prize for his work on the theory of quantum elect rod ynam ics False hypothesis of Richard Feynman and Murray Gell-Mann. 5 [' hypothesis', ' of', ' Richard', ' Fe', 'yn', 'man']
+1867 420 Name of father of x -1 Name of father of Carl Friedrich Gauss Gebhard Dietrich Gauss Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '55' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 77 , and died in 18 55 .
+
+ The" False system when, in 1832, Carl Friedrich Gauss used it, the 10 [' system', ' when', ',', ' in', ' 18', '32', ',', ' Carl', ' Friedrich', ' Ga', 'uss']
+1868 420 Name of father of x -1 Name of father of Carl Friedrich Gauss Gebhard Dietrich Gauss Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '55' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 77 , and died in 18 55 .
+
+ The" False respectively. In the 1830s Carl Friedrich Gauss laid the foundations 9 [' respectively', '.', ' In', ' the', ' 1830', 's', ' Carl', ' Friedrich', ' Ga', 'uss']
+1869 420 Name of father of x -1 Name of father of Carl Friedrich Gauss Gebhard Dietrich Gauss Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '55' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 77 , and died in 18 55 .
+
+ The" False respectively. In the 1830s Carl Friedrich Gauss laid the foundations 9 [' respectively', '.', ' In', ' the', ' 1830', 's', ' Carl', ' Friedrich', ' Ga', 'uss']
+1870 420 Name of father of x -1 Name of father of Carl Friedrich Gauss Gebhard Dietrich Gauss Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '55' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 77 , and died in 18 55 .
+
+ The" False system when, in 1832, Carl Friedrich Gauss used it, the centimetre 10 [' system', ' when', ',', ' in', ' 18', '32', ',', ' Carl', ' Friedrich', ' Ga', 'uss']
+1871 420 Name of father of x -1 Name of father of Carl Friedrich Gauss Gebhard Dietrich Gauss Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '55' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 77 , and died in 18 55 .
+
+ The" False respectively. In the 1830s Carl Friedrich Gauss laid the foundations 9 [' respectively', '.', ' In', ' the', ' 1830', 's', ' Carl', ' Friedrich', ' Ga', 'uss']
+1872 421 Name of father of x -1 Name of father of Bill Gates Bill Gates Sr. Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Microsoft' ' Windows' ' operating' ' system' '.' '\n' '\n'
+ 'The' ' Microsoft']" ", the founder of Microsoft , and the father of the Microsoft Windows operating system .
+
+ The Microsoft" False Skapinker met Bill Gates at a Microsoft-sponsored 5 [' Sk', 'ap', 'inker', ' met', ' Bill', ' Gates']
+1873 421 Name of father of x -1 Name of father of Bill Gates Bill Gates Sr. Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Microsoft' ' Windows' ' operating' ' system' '.' '\n' '\n'
+ 'The' ' Microsoft']" ", the founder of Microsoft , and the father of the Microsoft Windows operating system .
+
+ The Microsoft" False hands. It is owned by Bill Gates and is displayed once 7 [' hands', '.', ' It', ' is', ' owned', ' by', ' Bill', ' Gates']
+1874 421 Name of father of x -1 Name of father of Bill Gates Bill Gates Sr. Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Microsoft' ' Windows' ' operating' ' system' '.' '\n' '\n'
+ 'The' ' Microsoft']" ", the founder of Microsoft , and the father of the Microsoft Windows operating system .
+
+ The Microsoft" False " been recognized by Bill Gates as a leading institution.
+" 4 [' been', ' recognized', ' by', ' Bill', ' Gates']
+1875 421 Name of father of x -1 Name of father of Bill Gates Bill Gates Sr. Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Microsoft' ' Windows' ' operating' ' system' '.' '\n' '\n'
+ 'The' ' Microsoft']" ", the founder of Microsoft , and the father of the Microsoft Windows operating system .
+
+ The Microsoft" False interview with Bill Gates and that Microsoft 3 [' interview', ' with', ' Bill', ' Gates']
+1876 421 Name of father of x -1 Name of father of Bill Gates Bill Gates Sr. Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Microsoft' ' Windows' ' operating' ' system' '.' '\n' '\n'
+ 'The' ' Microsoft']" ", the founder of Microsoft , and the father of the Microsoft Windows operating system .
+
+ The Microsoft" False the Melinda and Bill Gates Foundation, 5 [' the', ' Mel', 'inda', ' and', ' Bill', ' Gates']
+1877 422 Name of father of x -1 Name of father of Percy Bysshe Shelley Timothy Shelley Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' Mary' ' Shelley' ',' ' the' ' novelist'
+ '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two' ' was' ' born']" ", the poet , and Mary Shelley , the novelist .
+
+ The first of these two was born" False " Poetical Works of Percy Bysshe Shelley (1839)
+" 9 [' Po', 'et', 'ical', ' Works', ' of', ' Percy', ' By', 'ss', 'he', ' Shelley']
+1878 422 Name of father of x -1 Name of father of Percy Bysshe Shelley Timothy Shelley Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' Mary' ' Shelley' ',' ' the' ' novelist'
+ '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two' ' was' ' born']" ", the poet , and Mary Shelley , the novelist .
+
+ The first of these two was born" False " == The Life of Percy Bysshe Shelley ==
+" 8 [' ==', ' The', ' Life', ' of', ' Percy', ' By', 'ss', 'he', ' Shelley']
+1879 422 Name of father of x -1 Name of father of Percy Bysshe Shelley Timothy Shelley Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' Mary' ' Shelley' ',' ' the' ' novelist'
+ '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two' ' was' ' born']" ", the poet , and Mary Shelley , the novelist .
+
+ The first of these two was born" False 5 ['Per', 'cy', ' By', 'ss', 'he', ' Shelley']
+1880 422 Name of father of x -1 Name of father of Percy Bysshe Shelley Timothy Shelley Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' Mary' ' Shelley' ',' ' the' ' novelist'
+ '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two' ' was' ' born']" ", the poet , and Mary Shelley , the novelist .
+
+ The first of these two was born" False time) ran off with Percy Bysshe Shelley to continental 9 [' time', ')', ' ran', ' off', ' with', ' Percy', ' By', 'ss', 'he', ' Shelley']
+1881 422 Name of father of x -1 Name of father of Percy Bysshe Shelley Timothy Shelley Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' Mary' ' Shelley' ',' ' the' ' novelist'
+ '.' '\n' '\n' 'The' ' first' ' of' ' these' ' two' ' was' ' born']" ", the poet , and Mary Shelley , the novelist .
+
+ The first of these two was born" False " Poetical Works of Percy Bysshe Shelley (1839)
+" 9 [' Po', 'et', 'ical', ' Works', ' of', ' Percy', ' By', 'ss', 'he', ' Shelley']
+1882 423 Name of father of x -1 Name of father of Aristophanes Philippus Aristophanes "[',' ' the' ' son' ' of' ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of'
+ ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of' ' Arist' 'oph']" , the son of Arist oph anes , the son of Arist oph anes , the son of Arist oph False playwright Aristophanes wrote the comedy 4 [' play', 'wright', ' Arist', 'oph', 'anes']
+1883 423 Name of father of x -1 Name of father of Aristophanes Philippus Aristophanes "[',' ' the' ' son' ' of' ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of'
+ ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of' ' Arist' 'oph']" , the son of Arist oph anes , the son of Arist oph anes , the son of Arist oph False 3 ['A', 'rist', 'oph', 'anes']
+1884 423 Name of father of x -1 Name of father of Aristophanes Philippus Aristophanes "[',' ' the' ' son' ' of' ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of'
+ ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of' ' Arist' 'oph']" , the son of Arist oph anes , the son of Arist oph anes , the son of Arist oph False comic playwright Aristophanes also used myths, in 5 [' comic', ' play', 'wright', ' Arist', 'oph', 'anes']
+1885 423 Name of father of x -1 Name of father of Aristophanes Philippus Aristophanes "[',' ' the' ' son' ' of' ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of'
+ ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of' ' Arist' 'oph']" , the son of Arist oph anes , the son of Arist oph anes , the son of Arist oph False of Rhodes, Aristophanes of Byzantium 5 [' of', ' Rhodes', ',', ' Arist', 'oph', 'anes']
+1886 423 Name of father of x -1 Name of father of Aristophanes Philippus Aristophanes "[',' ' the' ' son' ' of' ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of'
+ ' Arist' 'oph' 'anes' ',' ' the' ' son' ' of' ' Arist' 'oph']" , the son of Arist oph anes , the son of Arist oph anes , the son of Arist oph False Thratta, used by Aristophanes in The Wasps, 7 [' Thr', 'atta', ',', ' used', ' by', ' Arist', 'oph', 'anes']
+1887 424 Name of father of x -1 Name of father of Norman Mailer Isaac Barnett Mailer Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ' ""' 'The' ' Naked' ' and'
+ ' the' ' Dead' '""' ' and' ' the' ' movie' ' ""' 'The' ' Deer' ' Hunter']" ", the author of the book "" The Naked and the Dead "" and the movie "" The Deer Hunter" False Sex was written by Norman Mailer in response 6 [' Sex', ' was', ' written', ' by', ' Norman', ' Mail', 'er']
+1888 424 Name of father of x -1 Name of father of Norman Mailer Isaac Barnett Mailer Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ' ""' 'The' ' Naked' ' and'
+ ' the' ' Dead' '""' ' and' ' the' ' movie' ' ""' 'The' ' Deer' ' Hunter']" ", the author of the book "" The Naked and the Dead "" and the movie "" The Deer Hunter" False Henry Miller, and Norman Mailer and contrasted their 6 [' Henry', ' Miller', ',', ' and', ' Norman', ' Mail', 'er']
+1889 424 Name of father of x -1 Name of father of Norman Mailer Isaac Barnett Mailer Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ' ""' 'The' ' Naked' ' and'
+ ' the' ' Dead' '""' ' and' ' the' ' movie' ' ""' 'The' ' Deer' ' Hunter']" ", the author of the book "" The Naked and the Dead "" and the movie "" The Deer Hunter" False Lowell and Norman Mailer publicly denounced 4 [' Lowell', ' and', ' Norman', ' Mail', 'er']
+1890 424 Name of father of x -1 Name of father of Norman Mailer Isaac Barnett Mailer Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ' ""' 'The' ' Naked' ' and'
+ ' the' ' Dead' '""' ' and' ' the' ' movie' ' ""' 'The' ' Deer' ' Hunter']" ", the author of the book "" The Naked and the Dead "" and the movie "" The Deer Hunter" False Sex was written by Norman Mailer in response 6 [' Sex', ' was', ' written', ' by', ' Norman', ' Mail', 'er']
+1891 424 Name of father of x -1 Name of father of Norman Mailer Isaac Barnett Mailer Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ' ""' 'The' ' Naked' ' and'
+ ' the' ' Dead' '""' ' and' ' the' ' movie' ' ""' 'The' ' Deer' ' Hunter']" ", the author of the book "" The Naked and the Dead "" and the movie "" The Deer Hunter" False was written by Norman Mailer in response 5 [' was', ' written', ' by', ' Norman', ' Mail', 'er']
+1892 425 Name of father of x -1 Name of father of John Constable Golding Constable John Constable "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' 'John' ' Constable']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ John Constable" False (1775 – 1851) and John Constable (1776 – 1837) sought 9 [' (', '17', '75', ' –', ' 18', '51', ')', ' and', ' John', ' Constable']
+1893 425 Name of father of x -1 Name of father of John Constable Golding Constable John Constable "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' 'John' ' Constable']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ John Constable" False (1775 – 1851) and John Constable (1776 – 1837) 9 [' (', '17', '75', ' –', ' 18', '51', ')', ' and', ' John', ' Constable']
+1894 425 Name of father of x -1 Name of father of John Constable Golding Constable John Constable "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' 'John' ' Constable']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ John Constable" False fellow artist John Constable sarcastically wrote 3 [' fellow', ' artist', ' John', ' Constable']
+1895 425 Name of father of x -1 Name of father of John Constable Golding Constable John Constable "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' 'John' ' Constable']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ John Constable" False Having approached John Constable and other major 3 [' Having', ' approached', ' John', ' Constable']
+1896 425 Name of father of x -1 Name of father of John Constable Golding Constable John Constable "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' a' '\n' '\n' 'John' ' Constable']" ", the painter , and his wife , Mary , who was the daughter of a
+
+ John Constable" False compared with John Constable and described as 3 [' compared', ' with', ' John', ' Constable']
+1897 426 Name of father of x -1 Name of father of Charles Aznavour Misha Aznavourian Charles Aznavour "[',' ' the' ' famous' ' French' ' singer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' B' 'é' 'z' 'iers' ',' ' in' ' the']" , the famous French singer , who was born in the town of B é z iers , in the False " cover of the Charles Aznavour song ""She"", Shania" 6 [' cover', ' of', ' the', ' Charles', ' Az', 'nav', 'our']
+1898 427 Name of father of x -1 Name of father of George Gershwin Morris Gershwine George Gershwin "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' I' 'ra' ' G' 'ers'
+ 'h' 'win' ',' ' who' ' was' ' a' ' pian' 'ist']" , the composer , and his wife , I ra G ers h win , who was a pian ist False he later said to George Gershwin that his rhythmic 8 [' he', ' later', ' said', ' to', ' George', ' G', 'ers', 'h', 'win']
+1899 427 Name of father of x -1 Name of father of George Gershwin Morris Gershwine George Gershwin "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' I' 'ra' ' G' 'ers'
+ 'h' 'win' ',' ' who' ' was' ' a' ' pian' 'ist']" , the composer , and his wife , I ra G ers h win , who was a pian ist False The Red Shoes at the George Gershwin Theater. He replaced 9 [' The', ' Red', ' Shoes', ' at', ' the', ' George', ' G', 'ers', 'h', 'win']
+1900 427 Name of father of x -1 Name of father of George Gershwin Morris Gershwine George Gershwin "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' I' 'ra' ' G' 'ers'
+ 'h' 'win' ',' ' who' ' was' ' a' ' pian' 'ist']" , the composer , and his wife , I ra G ers h win , who was a pian ist False Red Shoes at the George Gershwin Theater. He replaced 8 [' Red', ' Shoes', ' at', ' the', ' George', ' G', 'ers', 'h', 'win']
+1901 427 Name of father of x -1 Name of father of George Gershwin Morris Gershwine George Gershwin "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' I' 'ra' ' G' 'ers'
+ 'h' 'win' ',' ' who' ' was' ' a' ' pian' 'ist']" , the composer , and his wife , I ra G ers h win , who was a pian ist False performance of the George Gershwin composition Rhapsody 7 [' performance', ' of', ' the', ' George', ' G', 'ers', 'h', 'win']
+1902 427 Name of father of x -1 Name of father of George Gershwin Morris Gershwine George Gershwin "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' I' 'ra' ' G' 'ers'
+ 'h' 'win' ',' ' who' ' was' ' a' ' pian' 'ist']" , the composer , and his wife , I ra G ers h win , who was a pian ist False as composer George Gershwin had not been 6 [' as', ' composer', ' George', ' G', 'ers', 'h', 'win']
+1903 428 Name of father of x -1 Name of father of Mariah Carey Alfred Roy Carey Mariah Carey "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Nick' ' Cannon'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Nick Cannon , are expecting their first child together .
+
+" False 2 ['Mar', 'iah', ' Carey']
+1904 428 Name of father of x -1 Name of father of Mariah Carey Alfred Roy Carey Mariah Carey "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Nick' ' Cannon'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Nick Cannon , are expecting their first child together .
+
+" False and songwriter Mariah Carey from her twelfth 5 [' and', ' song', 'writer', ' Mar', 'iah', ' Carey']
+1905 428 Name of father of x -1 Name of father of Mariah Carey Alfred Roy Carey Mariah Carey "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Nick' ' Cannon'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Nick Cannon , are expecting their first child together .
+
+" False ever hitting Mariah Carey levels of skull-rattling 4 [' ever', ' hitting', ' Mar', 'iah', ' Carey']
+1906 428 Name of father of x -1 Name of father of Mariah Carey Alfred Roy Carey Mariah Carey "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Nick' ' Cannon'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Nick Cannon , are expecting their first child together .
+
+" False singer-songwriter Mariah Carey from her sixth 6 [' singer', '-', 'song', 'writer', ' Mar', 'iah', ' Carey']
+1907 428 Name of father of x -1 Name of father of Mariah Carey Alfred Roy Carey Mariah Carey "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Nick' ' Cannon'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Nick Cannon , are expecting their first child together .
+
+" False history; only Mariah Carey (23) has more. 5 [' history', ';', ' only', ' Mar', 'iah', ' Carey']
+1908 429 Name of father of x -1 Name of father of Ludwig Wittgenstein Karl Wittgenstein Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' his' ' son' ',' ' the'
+ ' philosopher' ',' ' and' ' his' ' son' ',' ' the' ' philosopher' ','
+ ' and' ' his']" , the philosopher , and his son , the philosopher , and his son , the philosopher , and his False philosophers, such as Ludwig Wittgenstein and his followers 7 [' philosophers', ',', ' such', ' as', ' Ludwig', ' Witt', 'gen', 'stein']
+1909 429 Name of father of x -1 Name of father of Ludwig Wittgenstein Karl Wittgenstein Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' his' ' son' ',' ' the'
+ ' philosopher' ',' ' and' ' his' ' son' ',' ' the' ' philosopher' ','
+ ' and' ' his']" , the philosopher , and his son , the philosopher , and his son , the philosopher , and his False philosopher Ludwig Wittgenstein finished his Tractatus 4 [' philosopher', ' Ludwig', ' Witt', 'gen', 'stein']
+1910 429 Name of father of x -1 Name of father of Ludwig Wittgenstein Karl Wittgenstein Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' his' ' son' ',' ' the'
+ ' philosopher' ',' ' and' ' his' ' son' ',' ' the' ' philosopher' ','
+ ' and' ' his']" , the philosopher , and his son , the philosopher , and his son , the philosopher , and his False In his early work, Ludwig Wittgenstein attempted to separate 8 [' In', ' his', ' early', ' work', ',', ' Ludwig', ' Witt', 'gen', 'stein']
+1911 429 Name of father of x -1 Name of father of Ludwig Wittgenstein Karl Wittgenstein Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' his' ' son' ',' ' the'
+ ' philosopher' ',' ' and' ' his' ' son' ',' ' the' ' philosopher' ','
+ ' and' ' his']" , the philosopher , and his son , the philosopher , and his son , the philosopher , and his False philosophers, such as Ludwig Wittgenstein and his followers 7 [' philosophers', ',', ' such', ' as', ' Ludwig', ' Witt', 'gen', 'stein']
+1912 429 Name of father of x -1 Name of father of Ludwig Wittgenstein Karl Wittgenstein Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' his' ' son' ',' ' the'
+ ' philosopher' ',' ' and' ' his' ' son' ',' ' the' ' philosopher' ','
+ ' and' ' his']" , the philosopher , and his son , the philosopher , and his son , the philosopher , and his False philosopher Ludwig Wittgenstein finished his 4 [' philosopher', ' Ludwig', ' Witt', 'gen', 'stein']
+1913 430 Name of father of x -1 Name of father of Arthur Wellesley, 1st Duke of Wellington Garret Wesley, 1st Earl of Mornington Arthur Wellesley, 1st Duke of Wellington "['\n' '\n' 'The' ' Duke' ' of' ' Wellington' ' was' ' born' ' in' ' 17'
+ '69' ',' ' the' ' son' ' of' ' Arthur' ' Well' 'esley' ',' ' 1']" "
+
+ The Duke of Wellington was born in 17 69 , the son of Arthur Well esley , 1" False simultaneous support of the Arthur Wellesley, 1st Duke of Wellington (who was also 12 [' simultaneous', ' support', ' of', ' the', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1914 430 Name of father of x -1 Name of father of Arthur Wellesley, 1st Duke of Wellington Garret Wesley, 1st Earl of Mornington Arthur Wellesley, 1st Duke of Wellington "['\n' '\n' 'The' ' Duke' ' of' ' Wellington' ' was' ' born' ' in' ' 17'
+ '69' ',' ' the' ' son' ' of' ' Arthur' ' Well' 'esley' ',' ' 1']" "
+
+ The Duke of Wellington was born in 17 69 , the son of Arthur Well esley , 1" False " Battle record of Arthur Wellesley, 1st Duke of Wellington =
+" 11 [' Battle', ' record', ' of', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1915 430 Name of father of x -1 Name of father of Arthur Wellesley, 1st Duke of Wellington Garret Wesley, 1st Earl of Mornington Arthur Wellesley, 1st Duke of Wellington "['\n' '\n' 'The' ' Duke' ' of' ' Wellington' ' was' ' born' ' in' ' 17'
+ '69' ',' ' the' ' son' ' of' ' Arthur' ' Well' 'esley' ',' ' 1']" "
+
+ The Duke of Wellington was born in 17 69 , the son of Arthur Well esley , 1" False support of the Arthur Wellesley, 1st Duke of Wellington (who was also UK's 11 [' support', ' of', ' the', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1916 430 Name of father of x -1 Name of father of Arthur Wellesley, 1st Duke of Wellington Garret Wesley, 1st Earl of Mornington Arthur Wellesley, 1st Duke of Wellington "['\n' '\n' 'The' ' Duke' ' of' ' Wellington' ' was' ' born' ' in' ' 17'
+ '69' ',' ' the' ' son' ' of' ' Arthur' ' Well' 'esley' ',' ' 1']" "
+
+ The Duke of Wellington was born in 17 69 , the son of Arthur Well esley , 1" False " Battle record of Arthur Wellesley, 1st Duke of Wellington =
+" 11 [' Battle', ' record', ' of', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1917 430 Name of father of x -1 Name of father of Arthur Wellesley, 1st Duke of Wellington Garret Wesley, 1st Earl of Mornington Arthur Wellesley, 1st Duke of Wellington "['\n' '\n' 'The' ' Duke' ' of' ' Wellington' ' was' ' born' ' in' ' 17'
+ '69' ',' ' the' ' son' ' of' ' Arthur' ' Well' 'esley' ',' ' 1']" "
+
+ The Duke of Wellington was born in 17 69 , the son of Arthur Well esley , 1" False " Battle record of Arthur Wellesley, 1st Duke of Wellington =
+" 11 [' Battle', ' record', ' of', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1918 431 Name of father of x -1 Name of father of Justin Bieber Jeremy Bieber Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' fan' ' of' ' the' ' show']" ".
+
+ I am a big fan of the show , and I am a fan of the show" False " conceded to it."" Justin Bieber has embraced the" 5 "[' conceded', ' to', ' it', '.""', ' Justin', ' Bieber']"
+1919 431 Name of father of x -1 Name of father of Justin Bieber Jeremy Bieber Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' fan' ' of' ' the' ' show']" ".
+
+ I am a big fan of the show , and I am a fan of the show" False band to teen singer Justin Bieber in order to 5 [' band', ' to', ' teen', ' singer', ' Justin', ' Bieber']
+1920 431 Name of father of x -1 Name of father of Justin Bieber Jeremy Bieber Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' fan' ' of' ' the' ' show']" ".
+
+ I am a big fan of the show , and I am a fan of the show" False " on the keyboard. Justin Bieber performed ""Look at" 5 [' on', ' the', ' keyboard', '.', ' Justin', ' Bieber']
+1921 431 Name of father of x -1 Name of father of Justin Bieber Jeremy Bieber Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' fan' ' of' ' the' ' show']" ".
+
+ I am a big fan of the show , and I am a fan of the show" False was planning a Justin Bieber tribute episode, 4 [' was', ' planning', ' a', ' Justin', ' Bieber']
+1922 431 Name of father of x -1 Name of father of Justin Bieber Jeremy Bieber Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' fan' ' of' ' the' ' show']" ".
+
+ I am a big fan of the show , and I am a fan of the show" False Legend and Justin Bieber where he made substantially 3 [' Legend', ' and', ' Justin', ' Bieber']
+1923 433 Name of father of x -1 Name of father of John Everett Millais John William Millais John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False by English painter John Everett Millais and Edward VII, Prince 7 [' by', ' English', ' painter', ' John', ' Everett', ' M', 'illa', 'is']
+1924 433 Name of father of x -1 Name of father of John Everett Millais John William Millais John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False " early paintings of John Everett Millais and ""the wonderful" 7 [' early', ' paintings', ' of', ' John', ' Everett', ' M', 'illa', 'is']
+1925 433 Name of father of x -1 Name of father of John Everett Millais John William Millais John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False Browning, Lord Tennyson, John Everett Millais and Henry 11 [' Brown', 'ing', ',', ' Lord', ' Tenn', 'yson', ',', ' John', ' Everett', ' M', 'illa', 'is']
+1926 433 Name of father of x -1 Name of father of John Everett Millais John William Millais John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False Lord Tennyson, John Everett Millais and Henry James. 8 [' Lord', ' Tenn', 'yson', ',', ' John', ' Everett', ' M', 'illa', 'is']
+1927 433 Name of father of x -1 Name of father of John Everett Millais John William Millais John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False by English painter John Everett Millais and Edward VII, 7 [' by', ' English', ' painter', ' John', ' Everett', ' M', 'illa', 'is']
+1928 434 Name of father of x -1 Name of father of Thomas Alva Edison Samuel Ogden Edison Thomas Alva Edison "[',' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' ' electric'
+ ' light' ',' ' the' '\n' 'te' 'legraph' ',' ' the' ' telephone' ','
+ ' the']" ", inventor of the phon ograph , the electric light , the
+ te legraph , the telephone , the" False and passes near the Thomas Alva Edison Memorial Tower 7 [' and', ' passes', ' near', ' the', ' Thomas', ' Al', 'va', ' Edison']
+1929 434 Name of father of x -1 Name of father of Thomas Alva Edison Samuel Ogden Edison Thomas Alva Edison "[',' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' ' electric'
+ ' light' ',' ' the' '\n' 'te' 'legraph' ',' ' the' ' telephone' ','
+ ' the']" ", inventor of the phon ograph , the electric light , the
+ te legraph , the telephone , the" False created by Thomas Alva Edison in 1930. The overhead 5 [' created', ' by', ' Thomas', ' Al', 'va', ' Edison']
+1930 434 Name of father of x -1 Name of father of Thomas Alva Edison Samuel Ogden Edison Thomas Alva Edison "[',' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' ' electric'
+ ' light' ',' ' the' '\n' 'te' 'legraph' ',' ' the' ' telephone' ','
+ ' the']" ", inventor of the phon ograph , the electric light , the
+ te legraph , the telephone , the" False passes near the Thomas Alva Edison Memorial Tower and 6 [' passes', ' near', ' the', ' Thomas', ' Al', 'va', ' Edison']
+1931 434 Name of father of x -1 Name of father of Thomas Alva Edison Samuel Ogden Edison Thomas Alva Edison "[',' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' ' electric'
+ ' light' ',' ' the' '\n' 'te' 'legraph' ',' ' the' ' telephone' ','
+ ' the']" ", inventor of the phon ograph , the electric light , the
+ te legraph , the telephone , the" False created by Thomas Alva Edison in 1930. The 5 [' created', ' by', ' Thomas', ' Al', 'va', ' Edison']
+1932 434 Name of father of x -1 Name of father of Thomas Alva Edison Samuel Ogden Edison Thomas Alva Edison "[',' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' ' electric'
+ ' light' ',' ' the' '\n' 'te' 'legraph' ',' ' the' ' telephone' ','
+ ' the']" ", inventor of the phon ograph , the electric light , the
+ te legraph , the telephone , the" False passes near the Thomas Alva Edison Memorial Tower 6 [' passes', ' near', ' the', ' Thomas', ' Al', 'va', ' Edison']
+1933 435 Name of father of x -1 Name of father of Lauren Bacall Wulf William Perski Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False photographed with actress Lauren Bacall sitting atop 5 [' photographed', ' with', ' actress', ' Lauren', ' Bac', 'all']
+1934 435 Name of father of x -1 Name of father of Lauren Bacall Wulf William Perski Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False prisoner, is invited by Lauren Bacall into her apartment 7 [' prisoner', ',', ' is', ' invited', ' by', ' Lauren', ' Bac', 'all']
+1935 435 Name of father of x -1 Name of father of Lauren Bacall Wulf William Perski Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False movies, such as Lauren Bacall in To Have and Have 6 [' movies', ',', ' such', ' as', ' Lauren', ' Bac', 'all']
+1936 435 Name of father of x -1 Name of father of Lauren Bacall Wulf William Perski Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False off engagements to Lauren Bacall in 1958, and Juliet 5 [' off', ' engagements', ' to', ' Lauren', ' Bac', 'all']
+1937 435 Name of father of x -1 Name of father of Lauren Bacall Wulf William Perski Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False Jonah Jameson. Lauren Bacall and Katharine Hepburn 6 [' Jonah', ' James', 'on', '.', ' Lauren', ' Bac', 'all']
+1938 436 Name of father of x -1 Name of father of Carl Jung Johann Paul Achilles Jung Carl Jung "[',' ' the' ' founder' ' of' ' modern' ' psychology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychiatry' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the founder of modern psychology , and the father of modern psychiatry .
+
+ The first thing that" False was a student of Carl Jung in Zurich in May 5 [' was', ' a', ' student', ' of', ' Carl', ' Jung']
+1939 436 Name of father of x -1 Name of father of Carl Jung Johann Paul Achilles Jung Carl Jung "[',' ' the' ' founder' ' of' ' modern' ' psychology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychiatry' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the founder of modern psychology , and the father of modern psychiatry .
+
+ The first thing that" False influenced both Carl Jung and Joseph Campbell. 3 [' influenced', ' both', ' Carl', ' Jung']
+1940 436 Name of father of x -1 Name of father of Carl Jung Johann Paul Achilles Jung Carl Jung "[',' ' the' ' founder' ' of' ' modern' ' psychology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychiatry' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the founder of modern psychology , and the father of modern psychiatry .
+
+ The first thing that" False who was a student of Carl Jung in Zurich in May 6 [' who', ' was', ' a', ' student', ' of', ' Carl', ' Jung']
+1941 436 Name of father of x -1 Name of father of Carl Jung Johann Paul Achilles Jung Carl Jung "[',' ' the' ' founder' ' of' ' modern' ' psychology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychiatry' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the founder of modern psychology , and the father of modern psychiatry .
+
+ The first thing that" False psychologist Carl Jung took interest 2 [' psychologist', ' Carl', ' Jung']
+1942 436 Name of father of x -1 Name of father of Carl Jung Johann Paul Achilles Jung Carl Jung "[',' ' the' ' founder' ' of' ' modern' ' psychology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychiatry' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the founder of modern psychology , and the father of modern psychiatry .
+
+ The first thing that" False psychologist Carl Jung took interest 2 [' psychologist', ' Carl', ' Jung']
+1943 437 Name of father of x -1 Name of father of Franz Joseph I of Austria Archduke Franz Karl of Austria Franz Joseph I of Austria "[',' ' Emperor' ' of' ' Austria' ',' ' King' ' of' ' Hungary' ',' ' King'
+ ' of' ' Bohem' 'ia' ',' ' King' ' of' ' Croatia' ',' ' King' ' of']" , Emperor of Austria , King of Hungary , King of Bohem ia , King of Croatia , King of False by Emperor Franz Joseph I of Austria for 50,000 franks. 6 [' by', ' Emperor', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1944 437 Name of father of x -1 Name of father of Franz Joseph I of Austria Archduke Franz Karl of Austria Franz Joseph I of Austria "[',' ' Emperor' ' of' ' Austria' ',' ' King' ' of' ' Hungary' ',' ' King'
+ ' of' ' Bohem' 'ia' ',' ' King' ' of' ' Croatia' ',' ' King' ' of']" , Emperor of Austria , King of Hungary , King of Bohem ia , King of Croatia , King of False of France, Franz Joseph I of Austria and Maximilian 7 [' of', ' France', ',', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1945 437 Name of father of x -1 Name of father of Franz Joseph I of Austria Archduke Franz Karl of Austria Franz Joseph I of Austria "[',' ' Emperor' ' of' ' Austria' ',' ' King' ' of' ' Hungary' ',' ' King'
+ ' of' ' Bohem' 'ia' ',' ' King' ' of' ' Croatia' ',' ' King' ' of']" , Emperor of Austria , King of Hungary , King of Bohem ia , King of Croatia , King of False the Emperor Franz Joseph I of Austria in 1914 after the 6 [' the', ' Emperor', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1946 437 Name of father of x -1 Name of father of Franz Joseph I of Austria Archduke Franz Karl of Austria Franz Joseph I of Austria "[',' ' Emperor' ' of' ' Austria' ',' ' King' ' of' ' Hungary' ',' ' King'
+ ' of' ' Bohem' 'ia' ',' ' King' ' of' ' Croatia' ',' ' King' ' of']" , Emperor of Austria , King of Hungary , King of Bohem ia , King of Croatia , King of False " negotiations with Emperor Franz Joseph I of Austria regarding Italy.
+" 7 [' negotiations', ' with', ' Emperor', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1947 437 Name of father of x -1 Name of father of Franz Joseph I of Austria Archduke Franz Karl of Austria Franz Joseph I of Austria "[',' ' Emperor' ' of' ' Austria' ',' ' King' ' of' ' Hungary' ',' ' King'
+ ' of' ' Bohem' 'ia' ',' ' King' ' of' ' Croatia' ',' ' King' ' of']" , Emperor of Austria , King of Hungary , King of Bohem ia , King of Croatia , King of False by Emperor Franz Joseph I of Austria for 50,000 franks. 6 [' by', ' Emperor', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1948 438 Name of father of x -1 Name of father of Catherine II of Russia Christian August, Prince of Anhalt-Zerbst Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' who' ' was'
+ ' born' ' in' ' 17' '29' ',' ' and' ' died' ' in' ' 17' '96' '.']" , the Empress Catherine the Great , who was born in 17 29 , and died in 17 96 . False its completion as Catherine II of Russia had offered him 6 [' its', ' completion', ' as', ' Catherine', ' II', ' of', ' Russia']
+1949 438 Name of father of x -1 Name of father of Catherine II of Russia Christian August, Prince of Anhalt-Zerbst Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' who' ' was'
+ ' born' ' in' ' 17' '29' ',' ' and' ' died' ' in' ' 17' '96' '.']" , the Empress Catherine the Great , who was born in 17 29 , and died in 17 96 . False For example, Catherine II of Russia sent him seeds 6 [' For', ' example', ',', ' Catherine', ' II', ' of', ' Russia']
+1950 438 Name of father of x -1 Name of father of Catherine II of Russia Christian August, Prince of Anhalt-Zerbst Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' who' ' was'
+ ' born' ' in' ' 17' '29' ',' ' and' ' died' ' in' ' 17' '96' '.']" , the Empress Catherine the Great , who was born in 17 29 , and died in 17 96 . False For example, Catherine II of Russia sent him seeds 6 [' For', ' example', ',', ' Catherine', ' II', ' of', ' Russia']
+1951 438 Name of father of x -1 Name of father of Catherine II of Russia Christian August, Prince of Anhalt-Zerbst Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' who' ' was'
+ ' born' ' in' ' 17' '29' ',' ' and' ' died' ' in' ' 17' '96' '.']" , the Empress Catherine the Great , who was born in 17 29 , and died in 17 96 . False of its completion as Catherine II of Russia had offered him a 7 [' of', ' its', ' completion', ' as', ' Catherine', ' II', ' of', ' Russia']
+1952 438 Name of father of x -1 Name of father of Catherine II of Russia Christian August, Prince of Anhalt-Zerbst Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' who' ' was'
+ ' born' ' in' ' 17' '29' ',' ' and' ' died' ' in' ' 17' '96' '.']" , the Empress Catherine the Great , who was born in 17 29 , and died in 17 96 . False 4 ['C', 'atherine', ' II', ' of', ' Russia']
+1953 439 Name of father of x -1 Name of father of Theodore von Kármán Mór Kármán Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' ' in' ' the'
+ ' air' '.' '\n' '\n' 'The' 'odore' ' von' ' K' 'á' 'rm']" ", the first American to fly a plane in the air .
+
+ The odore von K á rm" False km (62 mi), as Theodore von Kármán calculated, 11 [' km', ' (', '62', ' mi', '),', ' as', ' Theodore', ' von', ' K', 'á', 'rm', 'án']
+1954 439 Name of father of x -1 Name of father of Theodore von Kármán Mór Kármán Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' ' in' ' the'
+ ' air' '.' '\n' '\n' 'The' 'odore' ' von' ' K' 'á' 'rm']" ", the first American to fly a plane in the air .
+
+ The odore von K á rm" False doctoral advisor Theodore von Kármán saw more promise 7 [' doctoral', ' advisor', ' Theodore', ' von', ' K', 'á', 'rm', 'án']
+1955 439 Name of father of x -1 Name of father of Theodore von Kármán Mór Kármán Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' ' in' ' the'
+ ' air' '.' '\n' '\n' 'The' 'odore' ' von' ' K' 'á' 'rm']" ", the first American to fly a plane in the air .
+
+ The odore von K á rm" False achievement, that included Theodore von Kármán (b. 1881), George 9 [' achievement', ',', ' that', ' included', ' Theodore', ' von', ' K', 'á', 'rm', 'án']
+1956 439 Name of father of x -1 Name of father of Theodore von Kármán Mór Kármán Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' ' in' ' the'
+ ' air' '.' '\n' '\n' 'The' 'odore' ' von' ' K' 'á' 'rm']" ", the first American to fly a plane in the air .
+
+ The odore von K á rm" False doctoral advisor Theodore von Kármán saw more promise in 7 [' doctoral', ' advisor', ' Theodore', ' von', ' K', 'á', 'rm', 'án']
+1957 439 Name of father of x -1 Name of father of Theodore von Kármán Mór Kármán Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' ' in' ' the'
+ ' air' '.' '\n' '\n' 'The' 'odore' ' von' ' K' 'á' 'rm']" ", the first American to fly a plane in the air .
+
+ The odore von K á rm" False doctoral advisor Theodore von Kármán saw more promise 7 [' doctoral', ' advisor', ' Theodore', ' von', ' K', 'á', 'rm', 'án']
+1958 440 Name of father of x -1 Name of father of Woodrow Wilson Joseph Ruggles Wilson Woodrow Wilson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Wood' 'row' ' Wilson' ',']" ", the president of the United States , and the
+
+ Name of mother of Wood row Wilson ," False U.S. President Woodrow Wilson announced his Fourteen 7 [' U', '.', 'S', '.', ' President', ' Wood', 'row', ' Wilson']
+1959 440 Name of father of x -1 Name of father of Woodrow Wilson Joseph Ruggles Wilson Woodrow Wilson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Wood' 'row' ' Wilson' ',']" ", the president of the United States , and the
+
+ Name of mother of Wood row Wilson ," False President Woodrow Wilson to make a study of 3 [' President', ' Wood', 'row', ' Wilson']
+1960 440 Name of father of x -1 Name of father of Woodrow Wilson Joseph Ruggles Wilson Woodrow Wilson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Wood' 'row' ' Wilson' ',']" ", the president of the United States , and the
+
+ Name of mother of Wood row Wilson ," False U.S. President Woodrow Wilson ordered the 7 [' U', '.', 'S', '.', ' President', ' Wood', 'row', ' Wilson']
+1961 440 Name of father of x -1 Name of father of Woodrow Wilson Joseph Ruggles Wilson Woodrow Wilson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Wood' 'row' ' Wilson' ',']" ", the president of the United States , and the
+
+ Name of mother of Wood row Wilson ," False President Woodrow Wilson as one of 20 preachers 3 [' President', ' Wood', 'row', ' Wilson']
+1962 440 Name of father of x -1 Name of father of Woodrow Wilson Joseph Ruggles Wilson Woodrow Wilson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Wood' 'row' ' Wilson' ',']" ", the president of the United States , and the
+
+ Name of mother of Wood row Wilson ," False present President Woodrow Wilson with it. Many of the 4 [' present', ' President', ' Wood', 'row', ' Wilson']
+1963 441 Name of father of x -1 Name of father of Adam Mickiewicz Mikołaj Mickiewicz Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False heroes (e.g., Kraków's Adam Mickiewicz monument) – were 13 "[' heroes', ' (', 'e', '.', 'g', '.,', ' K', 'rak', 'ó', 'w', ""'s"", ' Adam', ' Mick', 'iewicz']"
+1964 441 Name of father of x -1 Name of father of Adam Mickiewicz Mikołaj Mickiewicz Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False of sociology at Adam Mickiewicz University 5 [' of', ' sociology', ' at', ' Adam', ' Mick', 'iewicz']
+1965 441 Name of father of x -1 Name of father of Adam Mickiewicz Mikołaj Mickiewicz Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False 2 ['Adam', ' Mick', 'iewicz']
+1966 441 Name of father of x -1 Name of father of Adam Mickiewicz Mikołaj Mickiewicz Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False Cultural Studies from Adam Mickiewicz University, 5 [' Cultural', ' Studies', ' from', ' Adam', ' Mick', 'iewicz']
+1967 441 Name of father of x -1 Name of father of Adam Mickiewicz Mikołaj Mickiewicz Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False Warsaw has an Adam Mickiewicz Museum of Literature. 5 [' Warsaw', ' has', ' an', ' Adam', ' Mick', 'iewicz']
+1968 442 Name of father of x -1 Name of father of Pablo Neruda José del Carmen Reyes Morales Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' father' ' of' ' the' ' poet' ',' ' the' ' father' ' of']" , the poet , and the father of the poet , the father of the poet , the father of False by Chilean poet Pablo Neruda highlighted Neruda's 5 [' by', ' Chilean', ' poet', ' Pablo', ' Ner', 'uda']
+1969 442 Name of father of x -1 Name of father of Pablo Neruda José del Carmen Reyes Morales Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' father' ' of' ' the' ' poet' ',' ' the' ' father' ' of']" , the poet , and the father of the poet , the father of the poet , the father of False Chilean poet Pablo Neruda highlighted Neruda's 4 [' Chilean', ' poet', ' Pablo', ' Ner', 'uda']
+1970 442 Name of father of x -1 Name of father of Pablo Neruda José del Carmen Reyes Morales Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' father' ' of' ' the' ' poet' ',' ' the' ' father' ' of']" , the poet , and the father of the poet , the father of the poet , the father of False poet and politician Pablo Neruda organized the 5 [' poet', ' and', ' politician', ' Pablo', ' Ner', 'uda']
+1971 442 Name of father of x -1 Name of father of Pablo Neruda José del Carmen Reyes Morales Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' father' ' of' ' the' ' poet' ',' ' the' ' father' ' of']" , the poet , and the father of the poet , the father of the poet , the father of False poet and politician Pablo Neruda organized 5 [' poet', ' and', ' politician', ' Pablo', ' Ner', 'uda']
+1972 442 Name of father of x -1 Name of father of Pablo Neruda José del Carmen Reyes Morales Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' ' father' ' of' ' the' ' poet' ',' ' the' ' father' ' of']" , the poet , and the father of the poet , the father of the poet , the father of False Guatemala by Chilean poet Pablo Neruda highlighted Neruda's 6 [' Guatemala', ' by', ' Chilean', ' poet', ' Pablo', ' Ner', 'uda']
+1973 443 Name of father of x -1 Name of father of Henry David Thoreau John Thoreau Jr. Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' ""' 'W' 'ald']" ", the author of Wald en , and the author of the book that inspired the movie "" W ald" False for the experiment. Henry David Thoreau questioned 8 [' for', ' the', ' experiment', '.', ' Henry', ' David', ' Th', 'ore', 'au']
+1974 443 Name of father of x -1 Name of father of Henry David Thoreau John Thoreau Jr. Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' ""' 'W' 'ald']" ", the author of Wald en , and the author of the book that inspired the movie "" W ald" False Allan Poe, and Henry David Thoreau established a distinctive 8 [' Allan', ' Poe', ',', ' and', ' Henry', ' David', ' Th', 'ore', 'au']
+1975 443 Name of father of x -1 Name of father of Henry David Thoreau John Thoreau Jr. Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' ""' 'W' 'ald']" ", the author of Wald en , and the author of the book that inspired the movie "" W ald" False 4 ['Henry', ' David', ' Th', 'ore', 'au']
+1976 443 Name of father of x -1 Name of father of Henry David Thoreau John Thoreau Jr. Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' ""' 'W' 'ald']" ", the author of Wald en , and the author of the book that inspired the movie "" W ald" False example arose when Henry David Thoreau the author of Walden 7 [' example', ' arose', ' when', ' Henry', ' David', ' Th', 'ore', 'au']
+1977 443 Name of father of x -1 Name of father of Henry David Thoreau John Thoreau Jr. Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' ""' 'W' 'ald']" ", the author of Wald en , and the author of the book that inspired the movie "" W ald" False " American naturalist Henry David Thoreau wrote:
+" 7 [' American', ' natural', 'ist', ' Henry', ' David', ' Th', 'ore', 'au']
+1978 444 Name of father of x -1 Name of father of Will Smith Willard Carroll Smith Sr. Will Smith "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' father' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' �'
+ '�' 'The']" , the actor who played the role of the father of the main character in the movie � � The False " Boys II (2003) with Will Smith and Martin Lawrence
+" 7 [' Boys', ' II', ' (', '2003', ')', ' with', ' Will', ' Smith']
+1979 444 Name of father of x -1 Name of father of Will Smith Willard Carroll Smith Sr. Will Smith "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' father' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' �'
+ '�' 'The']" , the actor who played the role of the father of the main character in the movie � � The False Disc, actor Will Smith said that there was 4 [' Disc', ',', ' actor', ' Will', ' Smith']
+1980 444 Name of father of x -1 Name of father of Will Smith Willard Carroll Smith Sr. Will Smith "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' father' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' �'
+ '�' 'The']" , the actor who played the role of the father of the main character in the movie � � The False peripheral. Will Smith of The Hawk 3 [' peripheral', '.', ' Will', ' Smith']
+1981 444 Name of father of x -1 Name of father of Will Smith Willard Carroll Smith Sr. Will Smith "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' father' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' �'
+ '�' 'The']" , the actor who played the role of the father of the main character in the movie � � The False first time in his life, Will Smith doesn ’ t 7 [' first', ' time', ' in', ' his', ' life', ',', ' Will', ' Smith']
+1982 444 Name of father of x -1 Name of father of Will Smith Willard Carroll Smith Sr. Will Smith "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' father' ' of' ' the' ' main' ' character' ' in' ' the' ' movie' ' �'
+ '�' 'The']" , the actor who played the role of the father of the main character in the movie � � The False Disc, actor Will Smith said that there 4 [' Disc', ',', ' actor', ' Will', ' Smith']
+1983 445 Name of father of x -1 Name of father of Heracles Zeus Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the True likening Milo unto Heracles in his enormous 5 [' lik', 'ening', ' Milo', ' unto', ' Her', 'acles']
+1984 445 Name of father of x -1 Name of father of Heracles Zeus Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the True takes center-stage. Heracles sacked Pylos, 6 [' takes', ' center', '-', 'stage', '.', ' Her', 'acles']
+1985 445 Name of father of x -1 Name of father of Heracles Zeus Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the True monumental events of Heracles are regarded as the 4 [' monumental', ' events', ' of', ' Her', 'acles']
+1986 445 Name of father of x -1 Name of father of Heracles Zeus Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the True " mythic hero Heracles (see adjacent image).
+" 4 [' myth', 'ic', ' hero', ' Her', 'acles']
+1987 445 Name of father of x -1 Name of father of Heracles Zeus Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the True from the Shield of Heracles, the first 5 [' from', ' the', ' Shield', ' of', ' Her', 'acles']
+1988 446 Name of father of x -1 Name of father of Wilhelm Grimm Philipp Grimm Wilhelm Grimm "[',' ' the' ' famous' ' German' ' author' ' of' ' the' ' fairy' ' tales'
+ '.' '\n' '\n' 'The' ' Grimm' ' Brothers' ' were' ' born' ' in' ' Han'
+ 'au']" ", the famous German author of the fairy tales .
+
+ The Grimm Brothers were born in Han au" False 2 ['Wil', 'helm', ' Grimm']
+1989 447 Name of father of x -1 Name of father of Robert De Niro Robert De Niro Sr. Robert De Niro "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False Jay Roach. Starring Robert De Niro and Ben Stiller, 9 [' Jay', ' Ro', 'ach', '.', ' Star', 'ring', ' Robert', ' De', ' N', 'iro']
+1990 447 Name of father of x -1 Name of father of Robert De Niro Robert De Niro Sr. Robert De Niro "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False " and also considered Robert De Niro as an ""inspiration""" 6 [' and', ' also', ' considered', ' Robert', ' De', ' N', 'iro']
+1991 447 Name of father of x -1 Name of father of Robert De Niro Robert De Niro Sr. Robert De Niro "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False (1995), in which he and Robert De Niro appeared on-screen 10 [' (', '1995', '),', ' in', ' which', ' he', ' and', ' Robert', ' De', ' N', 'iro']
+1992 447 Name of father of x -1 Name of father of Robert De Niro Robert De Niro Sr. Robert De Niro "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False Jay Roach. Starring Robert De Niro and Ben Stiller, 9 [' Jay', ' Ro', 'ach', '.', ' Star', 'ring', ' Robert', ' De', ' N', 'iro']
+1993 447 Name of father of x -1 Name of father of Robert De Niro Robert De Niro Sr. Robert De Niro "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False Pacino and Robert De Niro co-star as 6 [' Pac', 'ino', ' and', ' Robert', ' De', ' N', 'iro']
+1994 448 Name of father of x -1 Name of father of Quentin Tarantino Tony Tarantino Quentin Tarantino "['�' '�' 's' ' �' '�' 'D' 'j' 'ango' ' Unch' 'ained' '�' '�' ' and' ' �'
+ '�' 'Ing' 'l' 'ouri' 'ous' ' B']" � � s � � D j ango Unch ained � � and � � Ing l ouri ous B False 3 ['Qu', 'entin', ' Tarant', 'ino']
+1995 448 Name of father of x -1 Name of father of Quentin Tarantino Tony Tarantino Quentin Tarantino "['�' '�' 's' ' �' '�' 'D' 'j' 'ango' ' Unch' 'ained' '�' '�' ' and' ' �'
+ '�' 'Ing' 'l' 'ouri' 'ous' ' B']" � � s � � D j ango Unch ained � � and � � Ing l ouri ous B False " have known them"". Quentin Tarantino said it was" 6 "[' have', ' known', ' them', '"".', ' Quentin', ' Tarant', 'ino']"
+1996 448 Name of father of x -1 Name of father of Quentin Tarantino Tony Tarantino Quentin Tarantino "['�' '�' 's' ' �' '�' 'D' 'j' 'ango' ' Unch' 'ained' '�' '�' ' and' ' �'
+ '�' 'Ing' 'l' 'ouri' 'ous' ' B']" � � s � � D j ango Unch ained � � and � � Ing l ouri ous B False airings and home video. Quentin Tarantino stated that he 8 [' air', 'ings', ' and', ' home', ' video', '.', ' Quentin', ' Tarant', 'ino']
+1997 448 Name of father of x -1 Name of father of Quentin Tarantino Tony Tarantino Quentin Tarantino "['�' '�' 's' ' �' '�' 'D' 'j' 'ango' ' Unch' 'ained' '�' '�' ' and' ' �'
+ '�' 'Ing' 'l' 'ouri' 'ous' ' B']" � � s � � D j ango Unch ained � � and � � Ing l ouri ous B False screenwriter Quentin Tarantino declared his appreciation 4 [' screen', 'writer', ' Quentin', ' Tarant', 'ino']
+1998 448 Name of father of x -1 Name of father of Quentin Tarantino Tony Tarantino Quentin Tarantino "['�' '�' 's' ' �' '�' 'D' 'j' 'ango' ' Unch' 'ained' '�' '�' ' and' ' �'
+ '�' 'Ing' 'l' 'ouri' 'ous' ' B']" � � s � � D j ango Unch ained � � and � � Ing l ouri ous B False " himself into a Quentin Tarantino character.""
+" 5 [' himself', ' into', ' a', ' Quentin', ' Tarant', 'ino']
+1999 449 Name of father of x -1 Name of father of Giotto Bondone Giotto "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ' of' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the painter of the
+ " False successors such as Giotto favoured presenting 4 [' successors', ' such', ' as', ' Gi', 'otto']
+2000 449 Name of father of x -1 Name of father of Giotto Bondone Giotto "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ' of' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the painter of the
+ " False Space Agency's Giotto probe and the 4 "[' Space', ' Agency', ""'s"", ' Gi', 'otto']"
+2001 449 Name of father of x -1 Name of father of Giotto Bondone Giotto "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ' of' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the painter of the
+ " False Renaissance, Giotto and Masaccio. 3 [' Renaissance', ',', ' Gi', 'otto']
+2002 449 Name of father of x -1 Name of father of Giotto Bondone Giotto "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ' of' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the painter of the
+ " False artists, such as Giotto or Duccio. If you were 5 [' artists', ',', ' such', ' as', ' Gi', 'otto']
+2003 449 Name of father of x -1 Name of father of Giotto Bondone Giotto "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ' of' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the painter of the
+ " False perspective arrived with Giotto (1266 / 7 – 1337), 4 [' perspective', ' arrived', ' with', ' Gi', 'otto']
+2004 450 Name of father of x -1 Name of father of Mary Cassatt Robert S. Cassatt Mary Cassatt "[',' ' the' ' painter' ',' ' and' ' the' ' mother' ' of' ' the' ' painter'
+ ',' ' the' ' painter' ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s""]" , the painter , and the mother of the painter , the painter 's wife , and the painter 's False 1870, Sartain met Mary Cassatt in Philadelphia 8 [' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+2005 450 Name of father of x -1 Name of father of Mary Cassatt Robert S. Cassatt Mary Cassatt "[',' ' the' ' painter' ',' ' and' ' the' ' mother' ' of' ' the' ' painter'
+ ',' ' the' ' painter' ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s""]" , the painter , and the mother of the painter , the painter 's wife , and the painter 's False 1870, Sartain met Mary Cassatt in Philadelphia and 8 [' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+2006 450 Name of father of x -1 Name of father of Mary Cassatt Robert S. Cassatt Mary Cassatt "[',' ' the' ' painter' ',' ' and' ' the' ' mother' ' of' ' the' ' painter'
+ ',' ' the' ' painter' ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s""]" , the painter , and the mother of the painter , the painter 's wife , and the painter 's False 1870, Sartain met Mary Cassatt in Philadelphia 8 [' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+2007 450 Name of father of x -1 Name of father of Mary Cassatt Robert S. Cassatt Mary Cassatt "[',' ' the' ' painter' ',' ' and' ' the' ' mother' ' of' ' the' ' painter'
+ ',' ' the' ' painter' ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s""]" , the painter , and the mother of the painter , the painter 's wife , and the painter 's False 1870, Sartain met Mary Cassatt in Philadelphia 8 [' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+2008 450 Name of father of x -1 Name of father of Mary Cassatt Robert S. Cassatt Mary Cassatt "[',' ' the' ' painter' ',' ' and' ' the' ' mother' ' of' ' the' ' painter'
+ ',' ' the' ' painter' ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s""]" , the painter , and the mother of the painter , the painter 's wife , and the painter 's False 1870, Sartain met Mary Cassatt in Philadelphia 8 [' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+2009 451 Name of father of x -1 Name of father of Kanye West Ray West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' have' ' been' ' married' ' for'
+ ' a']" , the rapper , and his wife Kim Kardashian , the reality TV star , have been married for a False " ""E.T."" featuring rapper Kanye West was released" 8 "[' ""', 'E', '.', 'T', '.""', ' featuring', ' rapper', ' Kanye', ' West']"
+2010 451 Name of father of x -1 Name of father of Kanye West Ray West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' have' ' been' ' married' ' for'
+ ' a']" , the rapper , and his wife Kim Kardashian , the reality TV star , have been married for a False According to Malay, Kanye West helped Ocean during 6 [' According', ' to', ' Mal', 'ay', ',', ' Kanye', ' West']
+2011 451 Name of father of x -1 Name of father of Kanye West Ray West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' have' ' been' ' married' ' for'
+ ' a']" , the rapper , and his wife Kim Kardashian , the reality TV star , have been married for a False " ""Ignorant Shit"". Kanye West mentor No I.D." 7 "[' ""', 'Ign', 'or', 'ant', ' Shit', '"".', ' Kanye', ' West']"
+2012 451 Name of father of x -1 Name of father of Kanye West Ray West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' have' ' been' ' married' ' for'
+ ' a']" , the rapper , and his wife Kim Kardashian , the reality TV star , have been married for a False hop artists Kanye West and Jay-Z, from their 3 [' hop', ' artists', ' Kanye', ' West']
+2013 451 Name of father of x -1 Name of father of Kanye West Ray West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' have' ' been' ' married' ' for'
+ ' a']" , the rapper , and his wife Kim Kardashian , the reality TV star , have been married for a False when you hear Kanye West do it, or include 4 [' when', ' you', ' hear', ' Kanye', ' West']
+2014 452 Name of father of x -1 Name of father of Claude Lévi-Strauss Raymond Levi-strauss Claude Lévi-Strauss "[',' ' the' ' French' ' anthrop' 'ologist' ' and' ' philosopher' '.' '\n'
+ '\n' '**' 'L' 'é' 'vi' '-' 'Stra' 'uss' ',' ' C' '.']" ", the French anthrop ologist and philosopher .
+
+ ** L é vi - Stra uss , C ." False and identicality. Claude Lévi-Strauss argues the former 10 [' and', ' identical', 'ity', '.', ' Claude', ' L', 'é', 'vi', '-', 'Stra', 'uss']
+2015 452 Name of father of x -1 Name of father of Claude Lévi-Strauss Raymond Levi-strauss Claude Lévi-Strauss "[',' ' the' ' French' ' anthrop' 'ologist' ' and' ' philosopher' '.' '\n'
+ '\n' '**' 'L' 'é' 'vi' '-' 'Stra' 'uss' ',' ' C' '.']" ", the French anthrop ologist and philosopher .
+
+ ** L é vi - Stra uss , C ." False social functions. Claude Lévi-Strauss and other structuralists 9 [' social', ' functions', '.', ' Claude', ' L', 'é', 'vi', '-', 'Stra', 'uss']
+2016 452 Name of father of x -1 Name of father of Claude Lévi-Strauss Raymond Levi-strauss Claude Lévi-Strauss "[',' ' the' ' French' ' anthrop' 'ologist' ' and' ' philosopher' '.' '\n'
+ '\n' '**' 'L' 'é' 'vi' '-' 'Stra' 'uss' ',' ' C' '.']" ", the French anthrop ologist and philosopher .
+
+ ** L é vi - Stra uss , C ." False identicality. Claude Lévi-Strauss argues the former 9 [' identical', 'ity', '.', ' Claude', ' L', 'é', 'vi', '-', 'Stra', 'uss']
+2017 452 Name of father of x -1 Name of father of Claude Lévi-Strauss Raymond Levi-strauss Claude Lévi-Strauss "[',' ' the' ' French' ' anthrop' 'ologist' ' and' ' philosopher' '.' '\n'
+ '\n' '**' 'L' 'é' 'vi' '-' 'Stra' 'uss' ',' ' C' '.']" ", the French anthrop ologist and philosopher .
+
+ ** L é vi - Stra uss , C ." False social functions. Claude Lévi-Strauss and other 9 [' social', ' functions', '.', ' Claude', ' L', 'é', 'vi', '-', 'Stra', 'uss']
+2018 452 Name of father of x -1 Name of father of Claude Lévi-Strauss Raymond Levi-strauss Claude Lévi-Strauss "[',' ' the' ' French' ' anthrop' 'ologist' ' and' ' philosopher' '.' '\n'
+ '\n' '**' 'L' 'é' 'vi' '-' 'Stra' 'uss' ',' ' C' '.']" ", the French anthrop ologist and philosopher .
+
+ ** L é vi - Stra uss , C ." False social functions. Claude Lévi-Strauss and other structuralists 9 [' social', ' functions', '.', ' Claude', ' L', 'é', 'vi', '-', 'Stra', 'uss']
+2019 453 Name of father of x -1 Name of father of Jacques Chirac François Chirac Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Marg'
+ 'uer' 'ite' ',' ' who' ' was' ' a' ' former' ' model' '.' '\n']" ", the French president , and his wife , Marg uer ite , who was a former model .
+" False came to an end when Jacques Chirac reinstated the two-round 8 [' came', ' to', ' an', ' end', ' when', ' Jacques', ' Ch', 'ir', 'ac']
+2020 453 Name of father of x -1 Name of father of Jacques Chirac François Chirac Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Marg'
+ 'uer' 'ite' ',' ' who' ' was' ' a' ' former' ' model' '.' '\n']" ", the French president , and his wife , Marg uer ite , who was a former model .
+" False " until French president Jacques Chirac termed it ""unacceptable""" 6 [' until', ' French', ' president', ' Jacques', ' Ch', 'ir', 'ac']
+2021 453 Name of father of x -1 Name of father of Jacques Chirac François Chirac Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Marg'
+ 'uer' 'ite' ',' ' who' ' was' ' a' ' former' ' model' '.' '\n']" ", the French president , and his wife , Marg uer ite , who was a former model .
+" False Then-French President Jacques Chirac became the subject 7 [' Then', '-', 'French', ' President', ' Jacques', ' Ch', 'ir', 'ac']
+2022 453 Name of father of x -1 Name of father of Jacques Chirac François Chirac Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Marg'
+ 'uer' 'ite' ',' ' who' ' was' ' a' ' former' ' model' '.' '\n']" ", the French president , and his wife , Marg uer ite , who was a former model .
+" False Then-French President Jacques Chirac became the subject 7 [' Then', '-', 'French', ' President', ' Jacques', ' Ch', 'ir', 'ac']
+2023 453 Name of father of x -1 Name of father of Jacques Chirac François Chirac Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Marg'
+ 'uer' 'ite' ',' ' who' ' was' ' a' ' former' ' model' '.' '\n']" ", the French president , and his wife , Marg uer ite , who was a former model .
+" False stance reversed after Jacques Chirac was elected president 6 [' stance', ' reversed', ' after', ' Jacques', ' Ch', 'ir', 'ac']
+2024 454 Name of father of x -1 Name of father of Herman Melville Allan Melville Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False novel Moby-Dick by Herman Melville and its main 8 [' novel', ' Mob', 'y', '-', 'Dick', ' by', ' Herman', ' Mel', 'ville']
+2025 454 Name of father of x -1 Name of father of Herman Melville Allan Melville Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False Battle-Pieces publication, Herman Melville penned a poem 8 [' Battle', '-', 'Pie', 'ces', ' publication', ',', ' Herman', ' Mel', 'ville']
+2026 454 Name of father of x -1 Name of father of Herman Melville Allan Melville Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False " Melville ==
+" 6 [' Mel', 'ville', ' ==', 'H', 'erman', ' Mel', 'ville']
+2027 454 Name of father of x -1 Name of father of Herman Melville Allan Melville Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False Battle-Pieces publication, Herman Melville penned a poem 8 [' Battle', '-', 'Pie', 'ces', ' publication', ',', ' Herman', ' Mel', 'ville']
+2028 454 Name of father of x -1 Name of father of Herman Melville Allan Melville Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False known as the Herman Melville House, was the 5 [' known', ' as', ' the', ' Herman', ' Mel', 'ville']
+2029 455 Name of father of x -1 Name of father of Georgia O'Keeffe Francis O'Keeffe Georgia O'Keeffe "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ',' ' a' ' daughter' ',' ' a' ' sister']" "
+
+ I am a mother of two , a wife , a grandmother , a daughter , a sister" False Vincent van Gogh and Georgia O'Keeffe on her paintings 10 "[' Vincent', ' van', ' Go', 'gh', ' and', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+2030 455 Name of father of x -1 Name of father of Georgia O'Keeffe Francis O'Keeffe Georgia O'Keeffe "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ',' ' a' ' daughter' ',' ' a' ' sister']" "
+
+ I am a mother of two , a wife , a grandmother , a daughter , a sister" False van Gogh and Georgia O'Keeffe on her paintings 9 "[' van', ' Go', 'gh', ' and', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+2031 455 Name of father of x -1 Name of father of Georgia O'Keeffe Francis O'Keeffe Georgia O'Keeffe "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ',' ' a' ' daughter' ',' ' a' ' sister']" "
+
+ I am a mother of two , a wife , a grandmother , a daughter , a sister" False bust of her with Georgia O'Keeffe and Susan B. Anthony 9 "[' bust', ' of', ' her', ' with', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+2032 455 Name of father of x -1 Name of father of Georgia O'Keeffe Francis O'Keeffe Georgia O'Keeffe "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ',' ' a' ' daughter' ',' ' a' ' sister']" "
+
+ I am a mother of two , a wife , a grandmother , a daughter , a sister" False including works by Georgia O'Keeffe and George Inness. 8 "[' including', ' works', ' by', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+2033 455 Name of father of x -1 Name of father of Georgia O'Keeffe Francis O'Keeffe Georgia O'Keeffe "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ',' ' a' ' daughter' ',' ' a' ' sister']" "
+
+ I am a mother of two , a wife , a grandmother , a daughter , a sister" False Vincent van Gogh and Georgia O'Keeffe on her paintings and 10 "[' Vincent', ' van', ' Go', 'gh', ' and', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+2034 456 Name of father of x -1 Name of father of Quincy Jones Quincy Delight Jones Quincy Jones "[',' ' the' ' father' ' of' ' the' ' famous' ' jazz' ' musician' ' Quincy'
+ ' Jones' ',' ' and' ' the' ' father' ' of' ' the' ' famous' ' jazz'
+ ' musician' ' Quincy']" , the father of the famous jazz musician Quincy Jones , and the father of the famous jazz musician Quincy False and produced by Quincy Jones (with Jackson 4 [' and', ' produced', ' by', ' Quincy', ' Jones']
+2035 456 Name of father of x -1 Name of father of Quincy Jones Quincy Delight Jones Quincy Jones "[',' ' the' ' father' ' of' ' the' ' famous' ' jazz' ' musician' ' Quincy'
+ ' Jones' ',' ' and' ' the' ' father' ' of' ' the' ' famous' ' jazz'
+ ' musician' ' Quincy']" , the father of the famous jazz musician Quincy Jones , and the father of the famous jazz musician Quincy False 3 ['Qu', 'in', 'cy', ' Jones']
+2036 456 Name of father of x -1 Name of father of Quincy Jones Quincy Delight Jones Quincy Jones "[',' ' the' ' father' ' of' ' the' ' famous' ' jazz' ' musician' ' Quincy'
+ ' Jones' ',' ' and' ' the' ' father' ' of' ' the' ' famous' ' jazz'
+ ' musician' ' Quincy']" , the father of the famous jazz musician Quincy Jones , and the father of the famous jazz musician Quincy False Thriller album. Producer Quincy Jones had wanted to include 6 [' Thr', 'iller', ' album', '.', ' Producer', ' Quincy', ' Jones']
+2037 456 Name of father of x -1 Name of father of Quincy Jones Quincy Delight Jones Quincy Jones "[',' ' the' ' father' ' of' ' the' ' famous' ' jazz' ' musician' ' Quincy'
+ ' Jones' ',' ' and' ' the' ' father' ' of' ' the' ' famous' ' jazz'
+ ' musician' ' Quincy']" , the father of the famous jazz musician Quincy Jones , and the father of the famous jazz musician Quincy False 3 ['Qu', 'in', 'cy', ' Jones']
+2038 456 Name of father of x -1 Name of father of Quincy Jones Quincy Delight Jones Quincy Jones "[',' ' the' ' father' ' of' ' the' ' famous' ' jazz' ' musician' ' Quincy'
+ ' Jones' ',' ' and' ' the' ' father' ' of' ' the' ' famous' ' jazz'
+ ' musician' ' Quincy']" , the father of the famous jazz musician Quincy Jones , and the father of the famous jazz musician Quincy False Greg Phillinganes. Quincy Jones passed on the 6 [' Greg', ' Ph', 'illing', 'anes', '.', ' Quincy', ' Jones']
+2039 457 Name of father of x -1 Name of father of Jean Racine Jean Racine Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at' ' Paris'
+ ',' ' in' ' 16' '39' '.' ' He' ' was' ' a' '\n']" ", the French dram at ist , was born at Paris , in 16 39 . He was a
+" False the Cantique de Jean Racine are in the 6 [' the', ' Cant', 'ique', ' de', ' Jean', ' Rac', 'ine']
+2040 457 Name of father of x -1 Name of father of Jean Racine Jean Racine Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at' ' Paris'
+ ',' ' in' ' 16' '39' '.' ' He' ' was' ' a' '\n']" ", the French dram at ist , was born at Paris , in 16 39 . He was a
+" False the Cantique de Jean Racine are in the tradition 6 [' the', ' Cant', 'ique', ' de', ' Jean', ' Rac', 'ine']
+2041 457 Name of father of x -1 Name of father of Jean Racine Jean Racine Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at' ' Paris'
+ ',' ' in' ' 16' '39' '.' ' He' ' was' ' a' '\n']" ", the French dram at ist , was born at Paris , in 16 39 . He was a
+" False the Cantique de Jean Racine are in the 6 [' the', ' Cant', 'ique', ' de', ' Jean', ' Rac', 'ine']
+2042 457 Name of father of x -1 Name of father of Jean Racine Jean Racine Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at' ' Paris'
+ ',' ' in' ' 16' '39' '.' ' He' ' was' ' a' '\n']" ", the French dram at ist , was born at Paris , in 16 39 . He was a
+" False libretto, based on Jean Racine ’ s Athalie and 8 [' lib', 'rett', 'o', ',', ' based', ' on', ' Jean', ' Rac', 'ine']
+2043 457 Name of father of x -1 Name of father of Jean Racine Jean Racine Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' at' ' Paris'
+ ',' ' in' ' 16' '39' '.' ' He' ' was' ' a' '\n']" ", the French dram at ist , was born at Paris , in 16 39 . He was a
+" False Corneille and Jean Racine and through them 6 [' Cor', 'ne', 'ille', ' and', ' Jean', ' Rac', 'ine']
+2044 458 Name of father of x -1 Name of father of Andrzej Wajda Jakub Wajda Andrzej Wajda "[',' ' Polish' ' film' ' director' '\n' '\n' 'Category' ':' '19' '28'
+ ' births' '\n' 'Category' ':' 'Living' ' people' '\n' 'Category' ':'
+ 'Pol']" ", Polish film director
+
+ Category : 19 28 births
+ Category : Living people
+ Category : Pol" False Oscar-winning filmmaker Andrzej Wajda about the 10 [' Oscar', '-', 'winning', ' filmmaker', ' And', 'r', 'ze', 'j', ' W', 'aj', 'da']
+2045 458 Name of father of x -1 Name of father of Andrzej Wajda Jakub Wajda Andrzej Wajda "[',' ' Polish' ' film' ' director' '\n' '\n' 'Category' ':' '19' '28'
+ ' births' '\n' 'Category' ':' 'Living' ' people' '\n' 'Category' ':'
+ 'Pol']" ", Polish film director
+
+ Category : 19 28 births
+ Category : Living people
+ Category : Pol" False is another Andrzej Wajda film about the Solidarity 8 [' is', ' another', ' And', 'r', 'ze', 'j', ' W', 'aj', 'da']
+2046 458 Name of father of x -1 Name of father of Andrzej Wajda Jakub Wajda Andrzej Wajda "[',' ' Polish' ' film' ' director' '\n' '\n' 'Category' ':' '19' '28'
+ ' births' '\n' 'Category' ':' 'Living' ' people' '\n' 'Category' ':'
+ 'Pol']" ", Polish film director
+
+ Category : 19 28 births
+ Category : Living people
+ Category : Pol" False (1981) is another Andrzej Wajda film about the Solidarity 11 [' (', '1981', ')', ' is', ' another', ' And', 'r', 'ze', 'j', ' W', 'aj', 'da']
+2047 458 Name of father of x -1 Name of father of Andrzej Wajda Jakub Wajda Andrzej Wajda "[',' ' Polish' ' film' ' director' '\n' '\n' 'Category' ':' '19' '28'
+ ' births' '\n' 'Category' ':' 'Living' ' people' '\n' 'Category' ':'
+ 'Pol']" ", Polish film director
+
+ Category : 19 28 births
+ Category : Living people
+ Category : Pol" False including those by Andrzej Wajda (A Generation, 9 [' including', ' those', ' by', ' And', 'r', 'ze', 'j', ' W', 'aj', 'da']
+2048 458 Name of father of x -1 Name of father of Andrzej Wajda Jakub Wajda Andrzej Wajda "[',' ' Polish' ' film' ' director' '\n' '\n' 'Category' ':' '19' '28'
+ ' births' '\n' 'Category' ':' 'Living' ' people' '\n' 'Category' ':'
+ 'Pol']" ", Polish film director
+
+ Category : 19 28 births
+ Category : Living people
+ Category : Pol" False film directors Andrzej Wajda and Roman Polanski, 8 [' film', ' directors', ' And', 'r', 'ze', 'j', ' W', 'aj', 'da']
+2049 459 Name of father of x -1 Name of father of Carl Maria von Weber Franz Anton von Weber Carl Maria von Weber "[',' ' the' ' composer' ',' ' who' ' was' ' born' ' in' ' 18' '25' ','
+ ' and' ' died' ' in' ' 18' '83' '.' '\n' '\n' 'The']" ", the composer , who was born in 18 25 , and died in 18 83 .
+
+ The" False and Romantic eras. Carl Maria von Weber and Felix Mendelssohn 7 [' and', ' Romantic', ' eras', '.', ' Carl', ' Maria', ' von', ' Weber']
+2050 459 Name of father of x -1 Name of father of Carl Maria von Weber Franz Anton von Weber Carl Maria von Weber "[',' ' the' ' composer' ',' ' who' ' was' ' born' ' in' ' 18' '25' ','
+ ' and' ' died' ' in' ' 18' '83' '.' '\n' '\n' 'The']" ", the composer , who was born in 18 25 , and died in 18 83 .
+
+ The" False and Romantic eras. Carl Maria von Weber and Felix Mendelssohn 7 [' and', ' Romantic', ' eras', '.', ' Carl', ' Maria', ' von', ' Weber']
+2051 459 Name of father of x -1 Name of father of Carl Maria von Weber Franz Anton von Weber Carl Maria von Weber "[',' ' the' ' composer' ',' ' who' ' was' ' born' ' in' ' 18' '25' ','
+ ' and' ' died' ' in' ' 18' '83' '.' '\n' '\n' 'The']" ", the composer , who was born in 18 25 , and died in 18 83 .
+
+ The" False Romantic eras. Carl Maria von Weber and Felix Mendelssohn 6 [' Romantic', ' eras', '.', ' Carl', ' Maria', ' von', ' Weber']
+2052 460 Name of father of x -1 Name of father of Che Guevara Ernesto Guevara Lynch Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' was' ' killed'
+ ' in' ' the' ' Sierra' ' Ma' 'est' 'ra' ',' ' the' ' man' ' who' ' was']" , the revolutionary , the man who was killed in the Sierra Ma est ra , the man who was False with a picture of Che Guevara being used to advertise 8 [' with', ' a', ' picture', ' of', ' Che', ' G', 'ue', 'v', 'ara']
+2053 460 Name of father of x -1 Name of father of Che Guevara Ernesto Guevara Lynch Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' was' ' killed'
+ ' in' ' the' ' Sierra' ' Ma' 'est' 'ra' ',' ' the' ' man' ' who' ' was']" , the revolutionary , the man who was killed in the Sierra Ma est ra , the man who was False alongside Bob Dylan and Che Guevara as their favorite 8 [' alongside', ' Bob', ' Dylan', ' and', ' Che', ' G', 'ue', 'v', 'ara']
+2054 460 Name of father of x -1 Name of father of Che Guevara Ernesto Guevara Lynch Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' was' ' killed'
+ ' in' ' the' ' Sierra' ' Ma' 'est' 'ra' ',' ' the' ' man' ' who' ' was']" , the revolutionary , the man who was killed in the Sierra Ma est ra , the man who was False support with Che Guevara stating on 6 [' support', ' with', ' Che', ' G', 'ue', 'v', 'ara']
+2055 460 Name of father of x -1 Name of father of Che Guevara Ernesto Guevara Lynch Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' was' ' killed'
+ ' in' ' the' ' Sierra' ' Ma' 'est' 'ra' ',' ' the' ' man' ' who' ' was']" , the revolutionary , the man who was killed in the Sierra Ma est ra , the man who was False who dreamed of Che Guevara and the Black 7 [' who', ' dreamed', ' of', ' Che', ' G', 'ue', 'v', 'ara']
+2056 460 Name of father of x -1 Name of father of Che Guevara Ernesto Guevara Lynch Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' was' ' killed'
+ ' in' ' the' ' Sierra' ' Ma' 'est' 'ra' ',' ' the' ' man' ' who' ' was']" , the revolutionary , the man who was killed in the Sierra Ma est ra , the man who was False seen with a picture of Che Guevara being used to advertise 9 [' seen', ' with', ' a', ' picture', ' of', ' Che', ' G', 'ue', 'v', 'ara']
+2057 461 Name of father of x -1 Name of father of Demi Moore Charles Harmon Demi Moore "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False " Demi Moore ==
+" 2 [' Dem', 'i', ' Moore']
+2058 461 Name of father of x -1 Name of father of Demi Moore Charles Harmon Demi Moore "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Bruce Willis, Demi Moore and Whoopi 5 [' Bruce', ' Willis', ',', ' Dem', 'i', ' Moore']
+2059 461 Name of father of x -1 Name of father of Demi Moore Charles Harmon Demi Moore "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Bruce Willis, Demi Moore and Whoopi Goldberg 5 [' Bruce', ' Willis', ',', ' Dem', 'i', ' Moore']
+2060 461 Name of father of x -1 Name of father of Demi Moore Charles Harmon Demi Moore "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False photos of Demi Moore also appear 4 [' photos', ' of', ' Dem', 'i', ' Moore']
+2061 461 Name of father of x -1 Name of father of Demi Moore Charles Harmon Demi Moore "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False then-wife Demi Moore campaigned for 5 [' then', '-', 'wife', ' Dem', 'i', ' Moore']
+2062 462 Name of father of x -1 Name of father of Johannes Kepler Heinrich Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '30' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 71 , and died in 16 30 .
+
+ The" False He also says that Johannes Kepler came to his conclusions 5 [' He', ' also', ' says', ' that', ' Johannes', ' Kepler']
+2063 462 Name of father of x -1 Name of father of Johannes Kepler Heinrich Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '30' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 71 , and died in 16 30 .
+
+ The" False 4 ['J', 'oh', 'ann', 'es', ' Kepler']
+2064 462 Name of father of x -1 Name of father of Johannes Kepler Heinrich Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '30' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 71 , and died in 16 30 .
+
+ The" False the theories of Johannes Kepler and William Gilbert. 4 [' the', ' theories', ' of', ' Johannes', ' Kepler']
+2065 462 Name of father of x -1 Name of father of Johannes Kepler Heinrich Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '30' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 71 , and died in 16 30 .
+
+ The" False launched by NASA in 2009, Johannes Kepler ATV (Automated 7 [' launched', ' by', ' NASA', ' in', ' 2009', ',', ' Johannes', ' Kepler']
+2066 462 Name of father of x -1 Name of father of Johannes Kepler Heinrich Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '71' ','
+ ' and' ' died' ' in' ' 16' '30' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 15 71 , and died in 16 30 .
+
+ The" False star, but it was Johannes Kepler who became noted 6 [' star', ',', ' but', ' it', ' was', ' Johannes', ' Kepler']
+2067 465 Name of father of x -1 Name of father of John Calvin Gérard Cauvin John Calvin "[',' ' the' ' great' ' reform' 'er' ' of' ' the' ' Re' 'formation' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Re' 'formed' ' Church' ' of'
+ ' Scotland']" , the great reform er of the Re formation , and the founder of the Re formed Church of Scotland False " John Calvin ===
+" 1 [' John', ' Calvin']
+2068 465 Name of father of x -1 Name of father of John Calvin Gérard Cauvin John Calvin "[',' ' the' ' great' ' reform' 'er' ' of' ' the' ' Re' 'formation' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Re' 'formed' ' Church' ' of'
+ ' Scotland']" , the great reform er of the Re formation , and the founder of the Re formed Church of Scotland False Martin Luther and John Calvin about obedience 4 [' Martin', ' Luther', ' and', ' John', ' Calvin']
+2069 465 Name of father of x -1 Name of father of John Calvin Gérard Cauvin John Calvin "[',' ' the' ' great' ' reform' 'er' ' of' ' the' ' Re' 'formation' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Re' 'formed' ' Church' ' of'
+ ' Scotland']" , the great reform er of the Re formation , and the founder of the Re formed Church of Scotland False Luther and John Calvin about obedience 3 [' Luther', ' and', ' John', ' Calvin']
+2070 465 Name of father of x -1 Name of father of John Calvin Gérard Cauvin John Calvin "[',' ' the' ' great' ' reform' 'er' ' of' ' the' ' Re' 'formation' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Re' 'formed' ' Church' ' of'
+ ' Scotland']" , the great reform er of the Re formation , and the founder of the Re formed Church of Scotland False religious, sphere. John Calvin saw conscience as 5 [' religious', ',', ' sphere', '.', ' John', ' Calvin']
+2071 465 Name of father of x -1 Name of father of John Calvin Gérard Cauvin John Calvin "[',' ' the' ' great' ' reform' 'er' ' of' ' the' ' Re' 'formation' ','
+ ' and' ' the' ' founder' ' of' ' the' ' Re' 'formed' ' Church' ' of'
+ ' Scotland']" , the great reform er of the Re formation , and the founder of the Re formed Church of Scotland False " Going further, John Calvin says that ""it" 4 [' Going', ' further', ',', ' John', ' Calvin']
+2072 466 Name of father of x -1 Name of father of Miles Davis Miles Henry Davis Miles Davis "[',' ' the' ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the'
+ ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the' ' jazz'
+ ' musician']" , the jazz musician , and his wife , the jazz musician , and his wife , the jazz musician False Ryuichi Sakamoto, and Miles Davis as musical influences. 7 [' Ryu', 'ichi', ' Sak', 'amoto', ',', ' and', ' Miles', ' Davis']
+2073 466 Name of father of x -1 Name of father of Miles Davis Miles Henry Davis Miles Davis "[',' ' the' ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the'
+ ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the' ' jazz'
+ ' musician']" , the jazz musician , and his wife , the jazz musician , and his wife , the jazz musician False " articulation of Miles Davis and Kenny Dorham""," 4 [' artic', 'ulation', ' of', ' Miles', ' Davis']
+2074 466 Name of father of x -1 Name of father of Miles Davis Miles Henry Davis Miles Davis "[',' ' the' ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the'
+ ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the' ' jazz'
+ ' musician']" , the jazz musician , and his wife , the jazz musician , and his wife , the jazz musician False on 1979's Hydra. Miles Davis was featured on the 6 "[' on', ' 1979', ""'s"", ' Hydra', '.', ' Miles', ' Davis']"
+2075 466 Name of father of x -1 Name of father of Miles Davis Miles Henry Davis Miles Davis "[',' ' the' ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the'
+ ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the' ' jazz'
+ ' musician']" , the jazz musician , and his wife , the jazz musician , and his wife , the jazz musician False jazz artist Miles Davis included a cover version 3 [' jazz', ' artist', ' Miles', ' Davis']
+2076 466 Name of father of x -1 Name of father of Miles Davis Miles Henry Davis Miles Davis "[',' ' the' ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the'
+ ' jazz' ' musician' ',' ' and' ' his' ' wife' ',' ' the' ' jazz'
+ ' musician']" , the jazz musician , and his wife , the jazz musician , and his wife , the jazz musician False " improvised style of a Miles Davis score."" Tim Lucas" 5 [' improvised', ' style', ' of', ' a', ' Miles', ' Davis']
+2077 467 Name of father of x -1 Name of father of Thomas Lawrence Thomas Lawrence Thomas Lawrence "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Thomas'
+ ' Lawrence' ',' ' the' ' painter' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the father of Thomas Lawrence , the painter , is not known ." True the artist Sir Thomas Lawrence and the proprietor 4 [' the', ' artist', ' Sir', ' Thomas', ' Lawrence']
+2078 467 Name of father of x -1 Name of father of Thomas Lawrence Thomas Lawrence Thomas Lawrence "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Thomas'
+ ' Lawrence' ',' ' the' ' painter' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the father of Thomas Lawrence , the painter , is not known ." True he studied under Thomas Lawrence and trained by copying 4 [' he', ' studied', ' under', ' Thomas', ' Lawrence']
+2079 467 Name of father of x -1 Name of father of Thomas Lawrence Thomas Lawrence Thomas Lawrence "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Thomas'
+ ' Lawrence' ',' ' the' ' painter' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the father of Thomas Lawrence , the painter , is not known ." True " Lawrence ===
+" 3 [' Lawrence', ' ===', 'Thomas', ' Lawrence']
+2080 467 Name of father of x -1 Name of father of Thomas Lawrence Thomas Lawrence Thomas Lawrence "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Thomas'
+ ' Lawrence' ',' ' the' ' painter' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the father of Thomas Lawrence , the painter , is not known ." True Gainsborough and Sir Thomas Lawrence lived and worked 6 [' G', 'ains', 'borough', ' and', ' Sir', ' Thomas', ' Lawrence']
+2081 467 Name of father of x -1 Name of father of Thomas Lawrence Thomas Lawrence Thomas Lawrence "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Thomas'
+ ' Lawrence' ',' ' the' ' painter' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the father of Thomas Lawrence , the painter , is not known ." True deaths of his friend Thomas Lawrence and his housekeeper 5 [' deaths', ' of', ' his', ' friend', ' Thomas', ' Lawrence']
+2082 468 Name of father of x -1 Name of father of Pliny the Elder Gaius Plinius Celer Pliny the Elder "[',' ' and' ' the' ' father' ' of' ' Pl' 'iny' ' the' ' Younger' '.' '\n'
+ '\n' 'References' '\n' '\n' 'Category' ':' 'Roman' ' governors' ' of']" ", and the father of Pl iny the Younger .
+
+ References
+
+ Category : Roman governors of" False islands were made by Pliny the Elder in his Natural History, 7 [' islands', ' were', ' made', ' by', ' Pl', 'iny', ' the', ' Elder']
+2083 468 Name of father of x -1 Name of father of Pliny the Elder Gaius Plinius Celer Pliny the Elder "[',' ' and' ' the' ' father' ' of' ' Pl' 'iny' ' the' ' Younger' '.' '\n'
+ '\n' 'References' '\n' '\n' 'Category' ':' 'Roman' ' governors' ' of']" ", and the father of Pl iny the Younger .
+
+ References
+
+ Category : Roman governors of" False Roman naturalist Pliny the Elder (23 – 79 AD), 6 [' Roman', ' natural', 'ist', ' Pl', 'iny', ' the', ' Elder']
+2084 468 Name of father of x -1 Name of father of Pliny the Elder Gaius Plinius Celer Pliny the Elder "[',' ' and' ' the' ' father' ' of' ' Pl' 'iny' ' the' ' Younger' '.' '\n'
+ '\n' 'References' '\n' '\n' 'Category' ':' 'Roman' ' governors' ' of']" ", and the father of Pl iny the Younger .
+
+ References
+
+ Category : Roman governors of" False and historian, Pliny the Elder (29 – 79 CE) 6 [' and', ' historian', ',', ' Pl', 'iny', ' the', ' Elder']
+2085 468 Name of father of x -1 Name of father of Pliny the Elder Gaius Plinius Celer Pliny the Elder "[',' ' and' ' the' ' father' ' of' ' Pl' 'iny' ' the' ' Younger' '.' '\n'
+ '\n' 'References' '\n' '\n' 'Category' ':' 'Roman' ' governors' ' of']" ", and the father of Pl iny the Younger .
+
+ References
+
+ Category : Roman governors of" False 4 ['P', 'lin', 'y', ' the', ' Elder']
+2086 468 Name of father of x -1 Name of father of Pliny the Elder Gaius Plinius Celer Pliny the Elder "[',' ' and' ' the' ' father' ' of' ' Pl' 'iny' ' the' ' Younger' '.' '\n'
+ '\n' 'References' '\n' '\n' 'Category' ':' 'Roman' ' governors' ' of']" ", and the father of Pl iny the Younger .
+
+ References
+
+ Category : Roman governors of" False Pliny's uncle Pliny the Elder was in command of 7 "[' Pl', 'iny', ""'s"", ' uncle', ' Pl', 'iny', ' the', ' Elder']"
+2087 469 Name of father of x -1 Name of father of Napoleon III Louis Bonaparte Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' of'
+ ' great' ' intelligence' ',' ' was' ' not' '\n' '\n' '[']" ".
+
+ The Emperor , who was not a man of great intelligence , was not
+
+ [" False a poem attacking Napoleon III for his reinstitution 4 [' a', ' poem', ' attacking', ' Napoleon', ' III']
+2088 469 Name of father of x -1 Name of father of Napoleon III Louis Bonaparte Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' of'
+ ' great' ' intelligence' ',' ' was' ' not' '\n' '\n' '[']" ".
+
+ The Emperor , who was not a man of great intelligence , was not
+
+ [" False " by a decree of Napoleon III of France.
+" 5 [' by', ' a', ' decree', ' of', ' Napoleon', ' III']
+2089 469 Name of father of x -1 Name of father of Napoleon III Louis Bonaparte Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' of'
+ ' great' ' intelligence' ',' ' was' ' not' '\n' '\n' '[']" ".
+
+ The Emperor , who was not a man of great intelligence , was not
+
+ [" False French emperor, Napoleon III, demanded territories 4 [' French', ' emperor', ',', ' Napoleon', ' III']
+2090 469 Name of father of x -1 Name of father of Napoleon III Louis Bonaparte Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' of'
+ ' great' ' intelligence' ',' ' was' ' not' '\n' '\n' '[']" ".
+
+ The Emperor , who was not a man of great intelligence , was not
+
+ [" False French Empire of Napoleon III; the emperor and 4 [' French', ' Empire', ' of', ' Napoleon', ' III']
+2091 469 Name of father of x -1 Name of father of Napoleon III Louis Bonaparte Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' of'
+ ' great' ' intelligence' ',' ' was' ' not' '\n' '\n' '[']" ".
+
+ The Emperor , who was not a man of great intelligence , was not
+
+ [" False withdrew, but Emperor Napoleon III of France used the 5 [' withdrew', ',', ' but', ' Emperor', ' Napoleon', ' III']
+2092 470 Name of father of x -1 Name of father of Richard Strauss Franz Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' the' ' composer' ',' ' and' ' the' ' composer' ""'s"" ' mother' ',']" , the composer , and his wife , the mother of the composer , and the composer 's mother , False " became aware of Richard Strauss – of any other""." 4 [' became', ' aware', ' of', ' Richard', ' Strauss']
+2093 470 Name of father of x -1 Name of father of Richard Strauss Franz Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' the' ' composer' ',' ' and' ' the' ' composer' ""'s"" ' mother' ',']" , the composer , and his wife , the mother of the composer , and the composer 's mother , False October 1905 with Richard Strauss conducting the 4 [' October', ' 1905', ' with', ' Richard', ' Strauss']
+2094 470 Name of father of x -1 Name of father of Richard Strauss Franz Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' the' ' composer' ',' ' and' ' the' ' composer' ""'s"" ' mother' ',']" , the composer , and his wife , the mother of the composer , and the composer 's mother , False use of the famous Richard Strauss piece Also sprach 5 [' use', ' of', ' the', ' famous', ' Richard', ' Strauss']
+2095 470 Name of father of x -1 Name of father of Richard Strauss Franz Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' the' ' composer' ',' ' and' ' the' ' composer' ""'s"" ' mother' ',']" , the composer , and his wife , the mother of the composer , and the composer 's mother , False 6 (1906), and Richard Strauss deliberately 7 [' 6', ' (', '19', '06', '),', ' and', ' Richard', ' Strauss']
+2096 470 Name of father of x -1 Name of father of Richard Strauss Franz Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' the' ' composer' ',' ' and' ' the' ' composer' ""'s"" ' mother' ',']" , the composer , and his wife , the mother of the composer , and the composer 's mother , False Vaughan Williams, Elgar, Richard Strauss and Debussy. 7 [' Vaughan', ' Williams', ',', ' El', 'gar', ',', ' Richard', ' Strauss']
+2097 472 Name of father of x -1 Name of father of Edward Burne-Jones Edward Richard Jones Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' his' ' wife' ',' ' the' ' artist' ""'s"" ' mother' ',']" , the painter , and his wife , the artist , and his wife , the artist 's mother , False paintings of Edward Burne-Jones and the Pre-Raphaelite 6 [' paintings', ' of', ' Edward', ' Burn', 'e', '-', 'Jones']
+2098 472 Name of father of x -1 Name of father of Edward Burne-Jones Edward Richard Jones Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' his' ' wife' ',' ' the' ' artist' ""'s"" ' mother' ',']" , the painter , and his wife , the artist , and his wife , the artist 's mother , False the painter Edward Burne-Jones came to Malvern for 6 [' the', ' painter', ' Edward', ' Burn', 'e', '-', 'Jones']
+2099 472 Name of father of x -1 Name of father of Edward Burne-Jones Edward Richard Jones Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' his' ' wife' ',' ' the' ' artist' ""'s"" ' mother' ',']" , the painter , and his wife , the artist , and his wife , the artist 's mother , False subject designed by Edward Burne-Jones (and in the related 7 [' subject', ' designed', ' by', ' Edward', ' Burn', 'e', '-', 'Jones']
+2100 472 Name of father of x -1 Name of father of Edward Burne-Jones Edward Richard Jones Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' his' ' wife' ',' ' the' ' artist' ""'s"" ' mother' ',']" , the painter , and his wife , the artist , and his wife , the artist 's mother , False Pre-Raphaelite artists Edward Burne-Jones and Dante Gabriel 10 [' Pre', '-', 'R', 'aphael', 'ite', ' artists', ' Edward', ' Burn', 'e', '-', 'Jones']
+2101 472 Name of father of x -1 Name of father of Edward Burne-Jones Edward Richard Jones Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' his' ' wife' ',' ' the' ' artist' ""'s"" ' mother' ',']" , the painter , and his wife , the artist , and his wife , the artist 's mother , False " Edward Burne-Jones =
+" 4 [' Edward', ' Burn', 'e', '-', 'Jones']
+2102 473 Name of father of x -1 Name of father of Walt Whitman Walter Whitman, Sr. Walt Whitman "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' Walt'
+ ' Whitman' ',' ' was' ' a' ' man' ' of' ' the' ' people' ',']" , the poet , and the poet 's father , Walt Whitman , was a man of the people , False familiar with Walt Whitman is one of David 3 [' familiar', ' with', ' Walt', ' Whitman']
+2103 473 Name of father of x -1 Name of father of Walt Whitman Walter Whitman, Sr. Walt Whitman "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' Walt'
+ ' Whitman' ',' ' was' ' a' ' man' ' of' ' the' ' people' ',']" , the poet , and the poet 's father , Walt Whitman , was a man of the people , False February 1868, Poems of Walt Whitman was published 8 [' February', ' 18', '68', ',', ' Po', 'ems', ' of', ' Walt', ' Whitman']
+2104 473 Name of father of x -1 Name of father of Walt Whitman Walter Whitman, Sr. Walt Whitman "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' Walt'
+ ' Whitman' ',' ' was' ' a' ' man' ' of' ' the' ' people' ',']" , the poet , and the poet 's father , Walt Whitman , was a man of the people , False American writers Walt Whitman and Henry James, 3 [' American', ' writers', ' Walt', ' Whitman']
+2105 473 Name of father of x -1 Name of father of Walt Whitman Walter Whitman, Sr. Walt Whitman "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' Walt'
+ ' Whitman' ',' ' was' ' a' ' man' ' of' ' the' ' people' ',']" , the poet , and the poet 's father , Walt Whitman , was a man of the people , False Pennsylvania via the Walt Whitman Bridge and Ben 4 [' Pennsylvania', ' via', ' the', ' Walt', ' Whitman']
+2106 473 Name of father of x -1 Name of father of Walt Whitman Walter Whitman, Sr. Walt Whitman "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' Walt'
+ ' Whitman' ',' ' was' ' a' ' man' ' of' ' the' ' people' ',']" , the poet , and the poet 's father , Walt Whitman , was a man of the people , False I-76, I-676, and the Walt Whitman Bridge to and from 11 [' I', '-', '76', ',', ' I', '-', '676', ',', ' and', ' the', ' Walt', ' Whitman']
+2107 474 Name of father of x -1 Name of father of Gregory I Gordianus Gregory I "[',' ' the' ' Great' ',' ' and' ' the' ' first' ' Christian' ' emperor'
+ ' of' ' the' ' East' '.' '\n' '\n' 'The' ' first' ' Christian' ' emperor'
+ ' of']" ", the Great , and the first Christian emperor of the East .
+
+ The first Christian emperor of" False 2 ['Greg', 'ory', ' I']
+2108 474 Name of father of x -1 Name of father of Gregory I Gordianus Gregory I "[',' ' the' ' Great' ',' ' and' ' the' ' first' ' Christian' ' emperor'
+ ' of' ' the' ' East' '.' '\n' '\n' 'The' ' first' ' Christian' ' emperor'
+ ' of']" ", the Great , and the first Christian emperor of the East .
+
+ The first Christian emperor of" False when Pope Gregory I encouraged pagan 3 [' when', ' Pope', ' Gregory', ' I']
+2109 474 Name of father of x -1 Name of father of Gregory I Gordianus Gregory I "[',' ' the' ' Great' ',' ' and' ' the' ' first' ' Christian' ' emperor'
+ ' of' ' the' ' East' '.' '\n' '\n' 'The' ' first' ' Christian' ' emperor'
+ ' of']" ", the Great , and the first Christian emperor of the East .
+
+ The first Christian emperor of" False 2 ['Greg', 'ory', ' I']
+2110 474 Name of father of x -1 Name of father of Gregory I Gordianus Gregory I "[',' ' the' ' Great' ',' ' and' ' the' ' first' ' Christian' ' emperor'
+ ' of' ' the' ' East' '.' '\n' '\n' 'The' ' first' ' Christian' ' emperor'
+ ' of']" ", the Great , and the first Christian emperor of the East .
+
+ The first Christian emperor of" False sent in 601 by Pope Gregory I to Christianize 6 [' sent', ' in', ' 601', ' by', ' Pope', ' Gregory', ' I']
+2111 474 Name of father of x -1 Name of father of Gregory I Gordianus Gregory I "[',' ' the' ' Great' ',' ' and' ' the' ' first' ' Christian' ' emperor'
+ ' of' ' the' ' East' '.' '\n' '\n' 'The' ' first' ' Christian' ' emperor'
+ ' of']" ", the Great , and the first Christian emperor of the East .
+
+ The first Christian emperor of" False In 595, when Pope Gregory I decided to send 7 [' In', ' 5', '95', ',', ' when', ' Pope', ' Gregory', ' I']
+2112 475 Name of father of x -1 Name of father of Augustus Gaius Octavius Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False 27 BC was named Augustus by the Roman Senate, 4 [' 27', ' BC', ' was', ' named', ' Augustus']
+2113 475 Name of father of x -1 Name of father of Augustus Gaius Octavius Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False his father's post as Augustus of the West, with Constantius' 5 "[' his', ' father', ""'s"", ' post', ' as', ' Augustus']"
+2114 475 Name of father of x -1 Name of father of Augustus Gaius Octavius Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False 1 ['August', 'us']
+2115 475 Name of father of x -1 Name of father of Augustus Gaius Octavius Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False When William's son Augustus FitzClarence enquired 4 "[' When', ' William', ""'s"", ' son', ' Augustus']"
+2116 475 Name of father of x -1 Name of father of Augustus Gaius Octavius Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False " ""culinary jewel"" of the Augustus Tower is the" 7 "[' ""', 'cul', 'inary', ' jewel', '""', ' of', ' the', ' Augustus']"
+2117 476 Name of father of x -1 Name of father of W. H. Auden George Augustus Auden W. H. Auden "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' the' ' same' ' year'
+ ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in' ' the']" , the poet , who was born in the same year as the poet , and who died in the False " and, as he wrote to W. H. Auden in 1955, ""I" 11 [' and', ',', ' as', ' he', ' wrote', ' to', ' W', '.', ' H', '.', ' Aud', 'en']
+2118 476 Name of father of x -1 Name of father of W. H. Auden George Augustus Auden W. H. Auden "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' the' ' same' ' year'
+ ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in' ' the']" , the poet , who was born in the same year as the poet , and who died in the False " disliked the film were W. H. Auden and James Agee.
+" 9 [' disliked', ' the', ' film', ' were', ' W', '.', ' H', '.', ' Aud', 'en']
+2119 476 Name of father of x -1 Name of father of W. H. Auden George Augustus Auden W. H. Auden "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' the' ' same' ' year'
+ ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in' ' the']" , the poet , who was born in the same year as the poet , and who died in the False nonetheless W. H. Auden — in his long 6 [' nonetheless', ' W', '.', ' H', '.', ' Aud', 'en']
+2120 476 Name of father of x -1 Name of father of W. H. Auden George Augustus Auden W. H. Auden "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' the' ' same' ' year'
+ ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in' ' the']" , the poet , who was born in the same year as the poet , and who died in the False " aspect of Yeats's work; W. H. Auden called it the ""deplorable" 12 "[' aspect', ' of', ' Ye', 'ats', ""'s"", ' work', ';', ' W', '.', ' H', '.', ' Aud', 'en']"
+2121 476 Name of father of x -1 Name of father of W. H. Auden George Augustus Auden W. H. Auden "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' the' ' same' ' year'
+ ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in' ' the']" , the poet , who was born in the same year as the poet , and who died in the False " Auden =
+" 8 [' Aud', 'en', ' =', 'W', '.', ' H', '.', ' Aud', 'en']
+2122 477 Name of father of x -1 Name of father of Björk Guðmundur Gunnarsson Björk "[',' ' the' ' father' ' of' ' the' ' gods' ',' ' and' ' the' ' father'
+ ' of' ' the' ' gods' ',' ' and' ' the' ' father' ' of' ' the' ' gods']" , the father of the gods , and the father of the gods , and the father of the gods False " but in Vespertine Björk wanted to make a ""microcosmos" 7 [' but', ' in', ' Ves', 'pert', 'ine', ' Bj', 'ör', 'k']
+2123 477 Name of father of x -1 Name of father of Björk Guðmundur Gunnarsson Björk "[',' ' the' ' father' ' of' ' the' ' gods' ',' ' and' ' the' ' father'
+ ' of' ' the' ' gods' ',' ' and' ' the' ' father' ' of' ' the' ' gods']" , the father of the gods , and the father of the gods , and the father of the gods False Brit Awards, Björk won awards for 5 [' Brit', ' Awards', ',', ' Bj', 'ör', 'k']
+2124 477 Name of father of x -1 Name of father of Björk Guðmundur Gunnarsson Björk "[',' ' the' ' father' ' of' ' the' ' gods' ',' ' and' ' the' ' father'
+ ' of' ' the' ' gods' ',' ' and' ' the' ' father' ' of' ' the' ' gods']" , the father of the gods , and the father of the gods , and the father of the gods False followed by Björk singing in 4 [' followed', ' by', ' Bj', 'ör', 'k']
+2125 477 Name of father of x -1 Name of father of Björk Guðmundur Gunnarsson Björk "[',' ' the' ' father' ' of' ' the' ' gods' ',' ' and' ' the' ' father'
+ ' of' ' the' ' gods' ',' ' and' ' the' ' father' ' of' ' the' ' gods']" , the father of the gods , and the father of the gods , and the father of the gods False described the way Björk wanted the music on 5 [' described', ' the', ' way', ' Bj', 'ör', 'k']
+2126 477 Name of father of x -1 Name of father of Björk Guðmundur Gunnarsson Björk "[',' ' the' ' father' ' of' ' the' ' gods' ',' ' and' ' the' ' father'
+ ' of' ' the' ' gods' ',' ' and' ' the' ' father' ' of' ' the' ' gods']" , the father of the gods , and the father of the gods , and the father of the gods False artists such as Björk and Metallica are, 5 [' artists', ' such', ' as', ' Bj', 'ör', 'k']
+2127 478 Name of father of x -1 Name of father of Fred Astaire Fritz Austerlitz Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Ad' 'ele' ' Ast'
+ 'aire' ',' ' who' ' was' ' a' ' dancer' '.' '\n' '\n']" ", the actor , and his wife , Ad ele Ast aire , who was a dancer .
+
+" False the tops! With Fred Astaire dancing and singing 6 [' the', ' tops', '!', ' With', ' Fred', ' Ast', 'aire']
+2128 478 Name of father of x -1 Name of father of Fred Astaire Fritz Austerlitz Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Ad' 'ele' ' Ast'
+ 'aire' ',' ' who' ' was' ' a' ' dancer' '.' '\n' '\n']" ", the actor , and his wife , Ad ele Ast aire , who was a dancer .
+
+" False Band Wagon, when Fred Astaire walks down a railroad 7 [' Band', ' W', 'agon', ',', ' when', ' Fred', ' Ast', 'aire']
+2129 478 Name of father of x -1 Name of father of Fred Astaire Fritz Austerlitz Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Ad' 'ele' ' Ast'
+ 'aire' ',' ' who' ' was' ' a' ' dancer' '.' '\n' '\n']" ", the actor , and his wife , Ad ele Ast aire , who was a dancer .
+
+" False performed by Fred Astaire in the 1936 film, 4 [' performed', ' by', ' Fred', ' Ast', 'aire']
+2130 478 Name of father of x -1 Name of father of Fred Astaire Fritz Austerlitz Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Ad' 'ele' ' Ast'
+ 'aire' ',' ' who' ' was' ' a' ' dancer' '.' '\n' '\n']" ", the actor , and his wife , Ad ele Ast aire , who was a dancer .
+
+" False is the tops! With Fred Astaire dancing and singing 7 [' is', ' the', ' tops', '!', ' With', ' Fred', ' Ast', 'aire']
+2131 478 Name of father of x -1 Name of father of Fred Astaire Fritz Austerlitz Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Ad' 'ele' ' Ast'
+ 'aire' ',' ' who' ' was' ' a' ' dancer' '.' '\n' '\n']" ", the actor , and his wife , Ad ele Ast aire , who was a dancer .
+
+" False originally performed by Fred Astaire in the 1936 film, 5 [' originally', ' performed', ' by', ' Fred', ' Ast', 'aire']
+2132 479 Name of father of x -1 Name of father of Nathaniel Hawthorne Nathaniel Hathorne Nathaniel Hawthorne "[',' ' and' ' the' ' author' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the author of the
+ " False elected to Congress) and Nathaniel Hawthorne (the author), 7 [' elected', ' to', ' Congress', ')', ' and', ' Nathaniel', ' Hawth', 'orne']
+2133 479 Name of father of x -1 Name of father of Nathaniel Hawthorne Nathaniel Hathorne Nathaniel Hawthorne "[',' ' and' ' the' ' author' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the author of the
+ " False " 1835 story by Nathaniel Hawthorne titled ""The" 6 [' 18', '35', ' story', ' by', ' Nathaniel', ' Hawth', 'orne']
+2134 479 Name of father of x -1 Name of father of Nathaniel Hawthorne Nathaniel Hathorne Nathaniel Hawthorne "[',' ' and' ' the' ' author' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the author of the
+ " False 1850, Melville met Nathaniel Hawthorne for the first 7 [' 1850', ',', ' Mel', 'ville', ' met', ' Nathaniel', ' Hawth', 'orne']
+2135 479 Name of father of x -1 Name of father of Nathaniel Hawthorne Nathaniel Hathorne Nathaniel Hawthorne "[',' ' and' ' the' ' author' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the author of the
+ " False when critics like Nathaniel Hawthorne and H. L. Mencken 5 [' when', ' critics', ' like', ' Nathaniel', ' Hawth', 'orne']
+2136 479 Name of father of x -1 Name of father of Nathaniel Hawthorne Nathaniel Hathorne Nathaniel Hawthorne "[',' ' and' ' the' ' author' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the author of the
+ " False when critics like Nathaniel Hawthorne and H. L. Mencken 5 [' when', ' critics', ' like', ' Nathaniel', ' Hawth', 'orne']
+2137 480 Name of father of x -1 Name of father of D. H. Lawrence Arthur John Lawrence D. H. Lawrence "[',' ' the' ' author' ' of' ' ""' 'S' 'ons' ' and' ' Lo' 'vers' '""' ' and'
+ ' ""' 'The' ' Rainbow' '.""' '\n' '\n' 'The' ' following']" ", the author of "" S ons and Lo vers "" and "" The Rainbow .""
+
+ The following" False M. Forster, D. H. Lawrence and Joseph Conrad) 9 [' M', '.', ' For', 'ster', ',', ' D', '.', ' H', '.', ' Lawrence']
+2138 480 Name of father of x -1 Name of father of D. H. Lawrence Arthur John Lawrence D. H. Lawrence "[',' ' the' ' author' ' of' ' ""' 'S' 'ons' ' and' ' Lo' 'vers' '""' ' and'
+ ' ""' 'The' ' Rainbow' '.""' '\n' '\n' 'The' ' following']" ", the author of "" S ons and Lo vers "" and "" The Rainbow .""
+
+ The following" False role was in 1980 as D. H. Lawrence in Priest of 9 [' role', ' was', ' in', ' 1980', ' as', ' D', '.', ' H', '.', ' Lawrence']
+2139 480 Name of father of x -1 Name of father of D. H. Lawrence Arthur John Lawrence D. H. Lawrence "[',' ' the' ' author' ' of' ' ""' 'S' 'ons' ' and' ' Lo' 'vers' '""' ' and'
+ ' ""' 'The' ' Rainbow' '.""' '\n' '\n' 'The' ' following']" ", the author of "" S ons and Lo vers "" and "" The Rainbow .""
+
+ The following" False writers such as D. H. Lawrence and William Butler 7 [' writers', ' such', ' as', ' D', '.', ' H', '.', ' Lawrence']
+2140 480 Name of father of x -1 Name of father of D. H. Lawrence Arthur John Lawrence D. H. Lawrence "[',' ' the' ' author' ' of' ' ""' 'S' 'ons' ' and' ' Lo' 'vers' '""' ' and'
+ ' ""' 'The' ' Rainbow' '.""' '\n' '\n' 'The' ' following']" ", the author of "" S ons and Lo vers "" and "" The Rainbow .""
+
+ The following" False philosophy. Reading D. H. Lawrence led him to ethics 7 [' philosophy', '.', ' Reading', ' D', '.', ' H', '.', ' Lawrence']
+2141 480 Name of father of x -1 Name of father of D. H. Lawrence Arthur John Lawrence D. H. Lawrence "[',' ' the' ' author' ' of' ' ""' 'S' 'ons' ' and' ' Lo' 'vers' '""' ' and'
+ ' ""' 'The' ' Rainbow' '.""' '\n' '\n' 'The' ' following']" ", the author of "" S ons and Lo vers "" and "" The Rainbow .""
+
+ The following" False disparaging, with D. H. Lawrence declaring, in reaction 8 [' dispar', 'aging', ',', ' with', ' D', '.', ' H', '.', ' Lawrence']
+2142 482 Name of father of x -1 Name of father of Sigourney Weaver Sylvester Weaver Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False spaceship. Alien star Sigourney Weaver also expressed 6 [' spaceship', '.', ' Alien', ' star', ' Sig', 'ourney', ' Weaver']
+2143 482 Name of father of x -1 Name of father of Sigourney Weaver Sylvester Weaver Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False canon. Actress Sigourney Weaver agreed to reprise 5 [' canon', '.', ' Actress', ' Sig', 'ourney', ' Weaver']
+2144 482 Name of father of x -1 Name of father of Sigourney Weaver Sylvester Weaver Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False celebrities such as Sigourney Weaver and Bruce 5 [' celebrities', ' such', ' as', ' Sig', 'ourney', ' Weaver']
+2145 482 Name of father of x -1 Name of father of Sigourney Weaver Sylvester Weaver Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False confirmed that Sigourney Weaver and Stephen Lang 4 [' confirmed', ' that', ' Sig', 'ourney', ' Weaver']
+2146 482 Name of father of x -1 Name of father of Sigourney Weaver Sylvester Weaver Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False Christmas party; and Sigourney Weaver in the closing sequence 6 [' Christmas', ' party', ';', ' and', ' Sig', 'ourney', ' Weaver']
+2147 483 Name of father of x -1 Name of father of Leonardo DiCaprio George DiCaprio Leonardo DiCaprio "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' T' 'oni' ' Garr' 'n'
+ ',' ' who' ' is' ' the' ' mother' ' of' ' his' ' two']" , the actor , and his wife , T oni Garr n , who is the mother of his two False " Siskel found Leonardo DiCaprio ""captivating""." 7 [' S', 'is', 'kel', ' found', ' Leonardo', ' Di', 'Cap', 'rio']
+2148 483 Name of father of x -1 Name of father of Leonardo DiCaprio George DiCaprio Leonardo DiCaprio "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' T' 'oni' ' Garr' 'n'
+ ',' ' who' ' is' ' the' ' mother' ' of' ' his' ' two']" , the actor , and his wife , T oni Garr n , who is the mother of his two False and actor Leonardo DiCaprio shared the stage 5 [' and', ' actor', ' Leonardo', ' Di', 'Cap', 'rio']
+2149 483 Name of father of x -1 Name of father of Leonardo DiCaprio George DiCaprio Leonardo DiCaprio "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' T' 'oni' ' Garr' 'n'
+ ',' ' who' ' is' ' the' ' mother' ' of' ' his' ' two']" , the actor , and his wife , T oni Garr n , who is the mother of his two False ability of stars Leonardo DiCaprio and Russell Crowe, 6 [' ability', ' of', ' stars', ' Leonardo', ' Di', 'Cap', 'rio']
+2150 483 Name of father of x -1 Name of father of Leonardo DiCaprio George DiCaprio Leonardo DiCaprio "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' T' 'oni' ' Garr' 'n'
+ ',' ' who' ' is' ' the' ' mother' ' of' ' his' ' two']" , the actor , and his wife , T oni Garr n , who is the mother of his two False " DiCaprio =
+" 8 [' Di', 'Cap', 'rio', ' =', 'Leon', 'ardo', ' Di', 'Cap', 'rio']
+2151 483 Name of father of x -1 Name of father of Leonardo DiCaprio George DiCaprio Leonardo DiCaprio "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' T' 'oni' ' Garr' 'n'
+ ',' ' who' ' is' ' the' ' mother' ' of' ' his' ' two']" , the actor , and his wife , T oni Garr n , who is the mother of his two False " had ""supplanted Leonardo DiCaprio as Japan's trendiest" 8 "[' had', ' ""', 'supp', 'l', 'anted', ' Leonardo', ' Di', 'Cap', 'rio']"
+2152 484 Name of father of x -1 Name of father of Tupac Shakur Billy Garland Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' shot' ' and' ' killed' ' in' ' Las'
+ ' Vegas' ',' ' Nevada' ',' ' on' ' September' ' 7' ',' ' 1996' '.']" , the rapper , was shot and killed in Las Vegas , Nevada , on September 7 , 1996 . False claimed that rappers Tupac Shakur and Biggie Smalls 6 [' claimed', ' that', ' rappers', ' Tup', 'ac', ' Shak', 'ur']
+2153 484 Name of father of x -1 Name of father of Tupac Shakur Billy Garland Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' shot' ' and' ' killed' ' in' ' Las'
+ ' Vegas' ',' ' Nevada' ',' ' on' ' September' ' 7' ',' ' 1996' '.']" , the rapper , was shot and killed in Las Vegas , Nevada , on September 7 , 1996 . False " to sample the 1996 Tupac Shakur song ""Me and" 7 [' to', ' sample', ' the', ' 1996', ' Tup', 'ac', ' Shak', 'ur']
+2154 484 Name of father of x -1 Name of father of Tupac Shakur Billy Garland Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' shot' ' and' ' killed' ' in' ' Las'
+ ' Vegas' ',' ' Nevada' ',' ' on' ' September' ' 7' ',' ' 1996' '.']" , the rapper , was shot and killed in Las Vegas , Nevada , on September 7 , 1996 . False parodied by Tupac Shakur and Suge Knight 6 [' par', 'odied', ' by', ' Tup', 'ac', ' Shak', 'ur']
+2155 484 Name of father of x -1 Name of father of Tupac Shakur Billy Garland Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' shot' ' and' ' killed' ' in' ' Las'
+ ' Vegas' ',' ' Nevada' ',' ' on' ' September' ' 7' ',' ' 1996' '.']" , the rapper , was shot and killed in Las Vegas , Nevada , on September 7 , 1996 . False Wonder, Tonic, Prince, Tupac Shakur and Marvin Gaye as 10 [' Wonder', ',', ' T', 'onic', ',', ' Prince', ',', ' Tup', 'ac', ' Shak', 'ur']
+2156 484 Name of father of x -1 Name of father of Tupac Shakur Billy Garland Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' shot' ' and' ' killed' ' in' ' Las'
+ ' Vegas' ',' ' Nevada' ',' ' on' ' September' ' 7' ',' ' 1996' '.']" , the rapper , was shot and killed in Las Vegas , Nevada , on September 7 , 1996 . False Coast-based rapper Tupac Shakur took offense 7 [' Coast', '-', 'based', ' rapper', ' Tup', 'ac', ' Shak', 'ur']
+2157 485 Name of father of x -1 Name of father of Gustav Mahler Bernhard Mahler Gustav Mahler "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Alma' ',' ' who'
+ ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' Mah' 'ler']" ", the composer , and his wife , Alma , who was a singer .
+
+ The Mah ler" False arrangements of Gustav Mahler works. This score 4 [' arrangements', ' of', ' Gustav', ' Mah', 'ler']
+2158 485 Name of father of x -1 Name of father of Gustav Mahler Bernhard Mahler Gustav Mahler "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Alma' ',' ' who'
+ ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' Mah' 'ler']" ", the composer , and his wife , Alma , who was a singer .
+
+ The Mah ler" False in E-flat major by Gustav Mahler is one of the 8 [' in', ' E', '-', 'flat', ' major', ' by', ' Gustav', ' Mah', 'ler']
+2159 485 Name of father of x -1 Name of father of Gustav Mahler Bernhard Mahler Gustav Mahler "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Alma' ',' ' who'
+ ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' Mah' 'ler']" ", the composer , and his wife , Alma , who was a singer .
+
+ The Mah ler" False December 4, 1898, with Gustav Mahler conducting the 8 [' December', ' 4', ',', ' 1898', ',', ' with', ' Gustav', ' Mah', 'ler']
+2160 485 Name of father of x -1 Name of father of Gustav Mahler Bernhard Mahler Gustav Mahler "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Alma' ',' ' who'
+ ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' Mah' 'ler']" ", the composer , and his wife , Alma , who was a singer .
+
+ The Mah ler" False International Gustav Mahler Society, founded 3 [' International', ' Gustav', ' Mah', 'ler']
+2161 485 Name of father of x -1 Name of father of Gustav Mahler Bernhard Mahler Gustav Mahler "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' Alma' ',' ' who'
+ ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' Mah' 'ler']" ", the composer , and his wife , Alma , who was a singer .
+
+ The Mah ler" False of composers such as Gustav Mahler and Richard Strauss. 7 [' of', ' compos', 'ers', ' such', ' as', ' Gustav', ' Mah', 'ler']
+2162 486 Name of father of x -1 Name of father of Laurence Olivier Gerald Kerr Olivier Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False Evita, she won the Laurence Olivier Award for Best 8 [' Ev', 'ita', ',', ' she', ' won', ' the', ' Laure', 'nce', ' Olivier']
+2163 486 Name of father of x -1 Name of father of Laurence Olivier Gerald Kerr Olivier Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False in 1960 starring Laurence Olivier in the title role 5 [' in', ' 1960', ' starring', ' Laure', 'nce', ' Olivier']
+2164 486 Name of father of x -1 Name of father of Laurence Olivier Gerald Kerr Olivier Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False nominated for a Laurence Olivier Award for Best Actor 5 [' nominated', ' for', ' a', ' Laure', 'nce', ' Olivier']
+2165 486 Name of father of x -1 Name of father of Laurence Olivier Gerald Kerr Olivier Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False role, she won the Laurence Olivier Award for Performance 7 [' role', ',', ' she', ' won', ' the', ' Laure', 'nce', ' Olivier']
+2166 486 Name of father of x -1 Name of father of Laurence Olivier Gerald Kerr Olivier Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False and played by Laurence Olivier in the film) is 5 [' and', ' played', ' by', ' Laure', 'nce', ' Olivier']
+2167 487 Name of father of x -1 Name of father of Ringo Starr Richard Henry Parkin Starkey Ringo Starr "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Beatles' '.' ' I'
+ ' have' ' been' ' since' ' I' ' was' ' a' ' kid' '.']" "
+
+ I am a big fan of the Beatles . I have been since I was a kid ." False 2 ['R', 'ingo', ' Starr']
+2168 487 Name of father of x -1 Name of father of Ringo Starr Richard Henry Parkin Starkey Ringo Starr "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Beatles' '.' ' I'
+ ' have' ' been' ' since' ' I' ' was' ' a' ' kid' '.']" "
+
+ I am a big fan of the Beatles . I have been since I was a kid ." False " Starr – tambourine
+" 8 [' Starr', ' –', ' t', 'amb', 'our', 'ine', 'R', 'ingo', ' Starr']
+2169 487 Name of father of x -1 Name of father of Ringo Starr Richard Henry Parkin Starkey Ringo Starr "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Beatles' '.' ' I'
+ ' have' ' been' ' since' ' I' ' was' ' a' ' kid' '.']" "
+
+ I am a big fan of the Beatles . I have been since I was a kid ." False evoked two of his idols, Ringo Starr and Levon Helm, 9 [' ev', 'oked', ' two', ' of', ' his', ' idols', ',', ' Ring', 'o', ' Starr']
+2170 487 Name of father of x -1 Name of father of Ringo Starr Richard Henry Parkin Starkey Ringo Starr "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Beatles' '.' ' I'
+ ' have' ' been' ' since' ' I' ' was' ' a' ' kid' '.']" "
+
+ I am a big fan of the Beatles . I have been since I was a kid ." False McCartney on bass, Ringo Starr on drums, and guest 6 [' McCartney', ' on', ' bass', ',', ' Ring', 'o', ' Starr']
+2171 487 Name of father of x -1 Name of father of Ringo Starr Richard Henry Parkin Starkey Ringo Starr "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' Beatles' '.' ' I'
+ ' have' ' been' ' since' ' I' ' was' ' a' ' kid' '.']" "
+
+ I am a big fan of the Beatles . I have been since I was a kid ." False " Milk Sea"". Ringo Starr also contributed" 5 "[' Milk', ' Sea', '"".', ' Ring', 'o', ' Starr']"
+2172 488 Name of father of x -1 Name of father of Phil Collins Greville Bernard Philip A Collins Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False feature drummer Phil Collins as lead vocalist 3 [' feature', ' drummer', ' Phil', ' Collins']
+2173 488 Name of father of x -1 Name of father of Phil Collins Greville Bernard Philip A Collins Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False " rom-com or a forgotten Phil Collins B-side."" Another" 7 [' rom', '-', 'com', ' or', ' a', ' forgotten', ' Phil', ' Collins']
+2174 488 Name of father of x -1 Name of father of Phil Collins Greville Bernard Philip A Collins Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False 1 ['Phil', ' Collins']
+2175 488 Name of father of x -1 Name of father of Phil Collins Greville Bernard Philip A Collins Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False " – drums (1985)
+" 6 [' –', ' drums', ' (', '1985', ')', 'Phil', ' Collins']
+2176 488 Name of father of x -1 Name of father of Phil Collins Greville Bernard Philip A Collins Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False band in early 1979. Phil Collins moved to Vancouver, 6 [' band', ' in', ' early', ' 1979', '.', ' Phil', ' Collins']
+2177 489 Name of father of x -1 Name of father of Ambrose Aurelius Ambrosius Ambrose "[',' ' who' ' was' ' a' ' man' ' of' ' great' ' learning' ',' ' and' ' a'
+ ' man' ' of' ' great' ' learning' ',' ' and' ' a' ' man' ' of']" , who was a man of great learning , and a man of great learning , and a man of False lap 42, Marcos Ambrose moved into seventh. 4 [' lap', ' 42', ',', ' Marcos', ' Ambrose']
+2178 489 Name of father of x -1 Name of father of Ambrose Aurelius Ambrosius Ambrose "[',' ' who' ' was' ' a' ' man' ' of' ' great' ' learning' ',' ' and' ' a'
+ ' man' ' of' ' great' ' learning' ',' ' and' ' a' ' man' ' of']" , who was a man of great learning , and a man of great learning , and a man of False Catholic priest, Ambrose Barlow was arrested 3 [' Catholic', ' priest', ',', ' Ambrose']
+2179 489 Name of father of x -1 Name of father of Ambrose Aurelius Ambrosius Ambrose "[',' ' who' ' was' ' a' ' man' ' of' ' great' ' learning' ',' ' and' ' a'
+ ' man' ' of' ' great' ' learning' ',' ' and' ' a' ' man' ' of']" , who was a man of great learning , and a man of great learning , and a man of False also mention Ambrose's gorilla, who is 2 [' also', ' mention', ' Ambrose']
+2180 489 Name of father of x -1 Name of father of Ambrose Aurelius Ambrosius Ambrose "[',' ' who' ' was' ' a' ' man' ' of' ' great' ' learning' ',' ' and' ' a'
+ ' man' ' of' ' great' ' learning' ',' ' and' ' a' ' man' ' of']" , who was a man of great learning , and a man of great learning , and a man of False fastest, ahead of Ambrose and Martin. 4 [' fastest', ',', ' ahead', ' of', ' Ambrose']
+2181 489 Name of father of x -1 Name of father of Ambrose Aurelius Ambrosius Ambrose "[',' ' who' ' was' ' a' ' man' ' of' ' great' ' learning' ',' ' and' ' a'
+ ' man' ' of' ' great' ' learning' ',' ' and' ' a' ' man' ' of']" , who was a man of great learning , and a man of great learning , and a man of False commanded by Maj. Gen. Ambrose Burnside for operations 6 [' commanded', ' by', ' Maj', '.', ' Gen', '.', ' Ambrose']
+2182 490 Name of father of x -1 Name of father of George Clooney Nick Clooney George Clooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ' Am' 'al' ' Clo' 'oney'
+ ',' ' the' ' lawyer' ',' ' are' ' expecting' ' their' ' first' ' child']" , the actor , and his wife Am al Clo oney , the lawyer , are expecting their first child False Award winning actor George Clooney had previously 5 [' Award', ' winning', ' actor', ' George', ' Clo', 'oney']
+2183 490 Name of father of x -1 Name of father of George Clooney Nick Clooney George Clooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ' Am' 'al' ' Clo' 'oney'
+ ',' ' the' ' lawyer' ',' ' are' ' expecting' ' their' ' first' ' child']" , the actor , and his wife Am al Clo oney , the lawyer , are expecting their first child False about this: George Clooney was in 28 pilots, 5 [' about', ' this', ':', ' George', ' Clo', 'oney']
+2184 490 Name of father of x -1 Name of father of George Clooney Nick Clooney George Clooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ' Am' 'al' ' Clo' 'oney'
+ ',' ' the' ' lawyer' ',' ' are' ' expecting' ' their' ' first' ' child']" , the actor , and his wife Am al Clo oney , the lawyer , are expecting their first child False Parker and Stone. George Clooney made a guest 6 [' Parker', ' and', ' Stone', '.', ' George', ' Clo', 'oney']
+2185 490 Name of father of x -1 Name of father of George Clooney Nick Clooney George Clooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ' Am' 'al' ' Clo' 'oney'
+ ',' ' the' ' lawyer' ',' ' are' ' expecting' ' their' ' first' ' child']" , the actor , and his wife Am al Clo oney , the lawyer , are expecting their first child False Streisand, George Clooney and Leonardo DiCaprio, 6 [' Stre', 'is', 'and', ',', ' George', ' Clo', 'oney']
+2186 490 Name of father of x -1 Name of father of George Clooney Nick Clooney George Clooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ' Am' 'al' ' Clo' 'oney'
+ ',' ' the' ' lawyer' ',' ' are' ' expecting' ' their' ' first' ' child']" , the actor , and his wife Am al Clo oney , the lawyer , are expecting their first child False 2 ['George', ' Clo', 'oney']
+2187 491 Name of father of x -1 Name of father of Roald Dahl Harald Dahl Roald Dahl "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Ro' 'ald' ' Dahl' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' books' ' and']" "
+
+ I am a big fan of Ro ald Dahl . I have read all of his books and" False invisible lift in Roald Dahl Plass and follows 5 [' invisible', ' lift', ' in', ' Ro', 'ald', ' Dahl']
+2188 491 Name of father of x -1 Name of father of Roald Dahl Harald Dahl Roald Dahl "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Ro' 'ald' ' Dahl' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' books' ' and']" "
+
+ I am a big fan of Ro ald Dahl . I have read all of his books and" False children's author Roald Dahl (1916 – 90) lived 5 "[' children', ""'s"", ' author', ' Ro', 'ald', ' Dahl']"
+2189 491 Name of father of x -1 Name of father of Roald Dahl Harald Dahl Roald Dahl "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Ro' 'ald' ' Dahl' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' books' ' and']" "
+
+ I am a big fan of Ro ald Dahl . I have read all of his books and" False The children's author Roald Dahl (1916 – 90) 6 "[' The', ' children', ""'s"", ' author', ' Ro', 'ald', ' Dahl']"
+2190 491 Name of father of x -1 Name of father of Roald Dahl Harald Dahl Roald Dahl "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Ro' 'ald' ' Dahl' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' books' ' and']" "
+
+ I am a big fan of Ro ald Dahl . I have read all of his books and" False " display out in the Roald Dahl Plass.
+" 6 [' display', ' out', ' in', ' the', ' Ro', 'ald', ' Dahl']
+2191 491 Name of father of x -1 Name of father of Roald Dahl Harald Dahl Roald Dahl "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Ro' 'ald' ' Dahl' '.' ' I'
+ ' have' ' read' ' all' ' of' ' his' ' books' ' and']" "
+
+ I am a big fan of Ro ald Dahl . I have read all of his books and" False invisible lift in Roald Dahl Plass and follows 5 [' invisible', ' lift', ' in', ' Ro', 'ald', ' Dahl']
+2192 493 Name of father of x -1 Name of father of Arthur Miller Isidore Miller Arthur Miller "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' father'
+ ' of' ' the' ' play' ',' ' and' ' the' ' father' ' of' ' the' ' film']" , the author of the book , and the father of the play , and the father of the film False of playwright Arthur Miller where he was introduced 4 [' of', ' play', 'wright', ' Arthur', ' Miller']
+2193 493 Name of father of x -1 Name of father of Arthur Miller Isidore Miller Arthur Miller "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' father'
+ ' of' ' the' ' play' ',' ' and' ' the' ' father' ' of' ' the' ' film']" , the author of the book , and the father of the play , and the father of the film False " acclaim and marriage to Arthur Miller (1956 – 59) ===
+" 5 [' acclaim', ' and', ' marriage', ' to', ' Arthur', ' Miller']
+2194 493 Name of father of x -1 Name of father of Arthur Miller Isidore Miller Arthur Miller "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' father'
+ ' of' ' the' ' play' ',' ' and' ' the' ' father' ' of' ' the' ' film']" , the author of the book , and the father of the play , and the father of the film False American playwright Arthur Miller to Turkey in 1985 4 [' American', ' play', 'wright', ' Arthur', ' Miller']
+2195 493 Name of father of x -1 Name of father of Arthur Miller Isidore Miller Arthur Miller "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' father'
+ ' of' ' the' ' play' ',' ' and' ' the' ' father' ' of' ' the' ' film']" , the author of the book , and the father of the play , and the father of the film False " performance in the Arthur Miller play All My Sons.
+" 4 [' performance', ' in', ' the', ' Arthur', ' Miller']
+2196 493 Name of father of x -1 Name of father of Arthur Miller Isidore Miller Arthur Miller "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' father'
+ ' of' ' the' ' play' ',' ' and' ' the' ' father' ' of' ' the' ' film']" , the author of the book , and the father of the play , and the father of the film False (Italian-American) and Arthur Miller (Jewish). According 7 [' (', 'Italian', '-', 'American', ')', ' and', ' Arthur', ' Miller']
+2197 494 Name of father of x -1 Name of father of John Updike Wesley Russell Updike John Updike "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' your' ' work' '.' ' I'
+ ' have' ' read' ' all' ' of' ' your' ' books' ' and' ' I']" "
+
+ I am a big fan of your work . I have read all of your books and I" False 3 ['John', ' Up', 'd', 'ike']
+2198 494 Name of father of x -1 Name of father of John Updike Wesley Russell Updike John Updike "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' your' ' work' '.' ' I'
+ ' have' ' read' ' all' ' of' ' your' ' books' ' and' ' I']" "
+
+ I am a big fan of your work . I have read all of your books and I" False American art critic John Updike suggests that the 6 [' American', ' art', ' critic', ' John', ' Up', 'd', 'ike']
+2199 494 Name of father of x -1 Name of father of John Updike Wesley Russell Updike John Updike "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' your' ' work' '.' ' I'
+ ' have' ' read' ' all' ' of' ' your' ' books' ' and' ' I']" "
+
+ I am a big fan of your work . I have read all of your books and I" False 3 ['John', ' Up', 'd', 'ike']
+2200 494 Name of father of x -1 Name of father of John Updike Wesley Russell Updike John Updike "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' your' ' work' '.' ' I'
+ ' have' ' read' ' all' ' of' ' your' ' books' ' and' ' I']" "
+
+ I am a big fan of your work . I have read all of your books and I" False several times. John Updike said of the Diary, 6 [' several', ' times', '.', ' John', ' Up', 'd', 'ike']
+2201 494 Name of father of x -1 Name of father of John Updike Wesley Russell Updike John Updike "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' your' ' work' '.' ' I'
+ ' have' ' read' ' all' ' of' ' your' ' books' ' and' ' I']" "
+
+ I am a big fan of your work . I have read all of your books and I" False " Prize-winning novelist John Updike attested that ""the" 7 [' Prize', '-', 'winning', ' novelist', ' John', ' Up', 'd', 'ike']
+2202 495 Name of father of x -1 Name of father of Anne Frank Otto Frank Anne Frank "[',' ' the' ' famous' ' di' 'arist' ',' ' was' ' born' ' in' ' Amsterdam'
+ ',' ' Holland' ',' ' in' ' 1929' '.' '\n' '\n' 'The' ' Franks']" ", the famous di arist , was born in Amsterdam , Holland , in 1929 .
+
+ The Franks" False 1999, Time named Anne Frank among the heroes 5 [' 1999', ',', ' Time', ' named', ' Anne', ' Frank']
+2203 495 Name of father of x -1 Name of father of Anne Frank Otto Frank Anne Frank "[',' ' the' ' famous' ' di' 'arist' ',' ' was' ' born' ' in' ' Amsterdam'
+ ',' ' Holland' ',' ' in' ' 1929' '.' '\n' '\n' 'The' ' Franks']" ", the famous di arist , was born in Amsterdam , Holland , in 1929 .
+
+ The Franks" False the diary of Anne Frank as a forgery, 4 [' the', ' diary', ' of', ' Anne', ' Frank']
+2204 495 Name of father of x -1 Name of father of Anne Frank Otto Frank Anne Frank "[',' ' the' ' famous' ' di' 'arist' ',' ' was' ' born' ' in' ' Amsterdam'
+ ',' ' Holland' ',' ' in' ' 1929' '.' '\n' '\n' 'The' ' Franks']" ", the famous di arist , was born in Amsterdam , Holland , in 1929 .
+
+ The Franks" False memorials for trees for Anne Frank and the victims 6 [' memorial', 's', ' for', ' trees', ' for', ' Anne', ' Frank']
+2205 495 Name of father of x -1 Name of father of Anne Frank Otto Frank Anne Frank "[',' ' the' ' famous' ' di' 'arist' ',' ' was' ' born' ' in' ' Amsterdam'
+ ',' ' Holland' ',' ' in' ' 1929' '.' '\n' '\n' 'The' ' Franks']" ", the famous di arist , was born in Amsterdam , Holland , in 1929 .
+
+ The Franks" False 1999, Time named Anne Frank among the heroes and 5 [' 1999', ',', ' Time', ' named', ' Anne', ' Frank']
+2206 495 Name of father of x -1 Name of father of Anne Frank Otto Frank Anne Frank "[',' ' the' ' famous' ' di' 'arist' ',' ' was' ' born' ' in' ' Amsterdam'
+ ',' ' Holland' ',' ' in' ' 1929' '.' '\n' '\n' 'The' ' Franks']" ", the famous di arist , was born in Amsterdam , Holland , in 1929 .
+
+ The Franks" False such as The Diary of Anne Frank (1959) or the 6 [' such', ' as', ' The', ' Diary', ' of', ' Anne', ' Frank']
+2207 496 Name of father of x -1 Name of father of George III of Great Britain Frederick, Prince of Wales George III of Great Britain "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False the fact King George III of Great Britain was a major shareholder. 7 [' the', ' fact', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+2208 496 Name of father of x -1 Name of father of George III of Great Britain Frederick, Prince of Wales George III of Great Britain "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False the fact King George III of Great Britain was a major shareholder. 7 [' the', ' fact', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+2209 496 Name of father of x -1 Name of father of George III of Great Britain Frederick, Prince of Wales George III of Great Britain "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False from King George III of Great Britain and sent some 6 [' from', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+2210 496 Name of father of x -1 Name of father of George III of Great Britain Frederick, Prince of Wales George III of Great Britain "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False about the fact King George III of Great Britain was a major 8 [' about', ' the', ' fact', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+2211 496 Name of father of x -1 Name of father of George III of Great Britain Frederick, Prince of Wales George III of Great Britain "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False allowance from King George III of Great Britain and sent some 7 [' allowance', ' from', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+2212 498 Name of father of x -1 Name of father of Mikhail Bulgakov Afanasiy Bulgakov Mikhail Bulgakov "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the Russian writer , who was a friend of the family .
+
+ The house was built in" False " and the works of Mikhail Bulgakov published in his lifetime.
+" 6 [' and', ' the', ' works', ' of', ' Mikhail', ' Bulg', 'akov']
+2213 498 Name of father of x -1 Name of father of Mikhail Bulgakov Afanasiy Bulgakov Mikhail Bulgakov "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the Russian writer , who was a friend of the family .
+
+ The house was built in" False birthday of Mikhail Bulgakov on May 15, 1991. 4 [' birthday', ' of', ' Mikhail', ' Bulg', 'akov']
+2214 498 Name of father of x -1 Name of father of Mikhail Bulgakov Afanasiy Bulgakov Mikhail Bulgakov "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the Russian writer , who was a friend of the family .
+
+ The house was built in" False the works of Mikhail Bulgakov published in his 5 [' the', ' works', ' of', ' Mikhail', ' Bulg', 'akov']
+2215 498 Name of father of x -1 Name of father of Mikhail Bulgakov Afanasiy Bulgakov Mikhail Bulgakov "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the Russian writer , who was a friend of the family .
+
+ The house was built in" False 100th birthday of Mikhail Bulgakov on May 15, 1991. The 6 [' 100', 'th', ' birthday', ' of', ' Mikhail', ' Bulg', 'akov']
+2216 498 Name of father of x -1 Name of father of Mikhail Bulgakov Afanasiy Bulgakov Mikhail Bulgakov "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the Russian writer , who was a friend of the family .
+
+ The house was built in" False the works of Mikhail Bulgakov published in his 5 [' the', ' works', ' of', ' Mikhail', ' Bulg', 'akov']
+2217 499 Name of father of x -1 Name of father of Giacomo Puccini Michele Puccini Giacomo Puccini "[',' ' the' ' composer' ' of' ' La' ' Boh' 'è' 'me' ',' ' La' ' b' 'oh'
+ 'è' 'me' ',' ' Tos' 'ca' ',' ' Mad' 'ama']" , the composer of La Boh è me , La b oh è me , Tos ca , Mad ama False " Richard Wagner or Giacomo Puccini a century earlier.
+" 8 [' Richard', ' Wagner', ' or', ' Gi', 'ac', 'omo', ' Pu', 'cc', 'ini']
+2218 499 Name of father of x -1 Name of father of Giacomo Puccini Michele Puccini Giacomo Puccini "[',' ' the' ' composer' ' of' ' La' ' Boh' 'è' 'me' ',' ' La' ' b' 'oh'
+ 'è' 'me' ',' ' Tos' 'ca' ',' ' Mad' 'ama']" , the composer of La Boh è me , La b oh è me , Tos ca , Mad ama False Madame Butterfly by Giacomo Puccini although with an inverted 8 [' Madame', ' Butterfly', ' by', ' Gi', 'ac', 'omo', ' Pu', 'cc', 'ini']
+2219 499 Name of father of x -1 Name of father of Giacomo Puccini Michele Puccini Giacomo Puccini "[',' ' the' ' composer' ' of' ' La' ' Boh' 'è' 'me' ',' ' La' ' b' 'oh'
+ 'è' 'me' ',' ' Tos' 'ca' ',' ' Mad' 'ama']" , the composer of La Boh è me , La b oh è me , Tos ca , Mad ama False opera in one act by Giacomo Puccini to an Italian libretto 10 [' opera', ' in', ' one', ' act', ' by', ' Gi', 'ac', 'omo', ' Pu', 'cc', 'ini']
+2220 499 Name of father of x -1 Name of father of Giacomo Puccini Michele Puccini Giacomo Puccini "[',' ' the' ' composer' ' of' ' La' ' Boh' 'è' 'me' ',' ' La' ' b' 'oh'
+ 'è' 'me' ',' ' Tos' 'ca' ',' ' Mad' 'ama']" , the composer of La Boh è me , La b oh è me , Tos ca , Mad ama False Madame Butterfly by Giacomo Puccini although with an 8 [' Madame', ' Butterfly', ' by', ' Gi', 'ac', 'omo', ' Pu', 'cc', 'ini']
+2221 499 Name of father of x -1 Name of father of Giacomo Puccini Michele Puccini Giacomo Puccini "[',' ' the' ' composer' ' of' ' La' ' Boh' 'è' 'me' ',' ' La' ' b' 'oh'
+ 'è' 'me' ',' ' Tos' 'ca' ',' ' Mad' 'ama']" , the composer of La Boh è me , La b oh è me , Tos ca , Mad ama False a 1904 opera by Giacomo Puccini about a romantic 9 [' a', ' 1904', ' opera', ' by', ' Gi', 'ac', 'omo', ' Pu', 'cc', 'ini']
+2222 500 Name of father of x -1 Name of father of Ralph Waldo Emerson William Emerson Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Moody'
+ ' Emerson' ',' ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Mary Moody Emerson , who was a poet ess .
+
+" False relieve his bronchitis, Ralph Waldo Emerson stayed briefly in 9 [' relieve', ' his', ' bron', 'ch', 'itis', ',', ' Ralph', ' Wald', 'o', ' Emerson']
+2223 500 Name of father of x -1 Name of father of Ralph Waldo Emerson William Emerson Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Moody'
+ ' Emerson' ',' ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Mary Moody Emerson , who was a poet ess .
+
+" False 4 ['R', 'alph', ' Wald', 'o', ' Emerson']
+2224 500 Name of father of x -1 Name of father of Ralph Waldo Emerson William Emerson Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Moody'
+ ' Emerson' ',' ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Mary Moody Emerson , who was a poet ess .
+
+" False bronchitis, Ralph Waldo Emerson stayed briefly 7 [' bron', 'ch', 'itis', ',', ' Ralph', ' Wald', 'o', ' Emerson']
+2225 500 Name of father of x -1 Name of father of Ralph Waldo Emerson William Emerson Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Moody'
+ ' Emerson' ',' ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Mary Moody Emerson , who was a poet ess .
+
+" False 4 ['R', 'alph', ' Wald', 'o', ' Emerson']
+2226 500 Name of father of x -1 Name of father of Ralph Waldo Emerson William Emerson Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Moody'
+ ' Emerson' ',' ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Mary Moody Emerson , who was a poet ess .
+
+" False 4 ['R', 'alph', ' Wald', 'o', ' Emerson']
+2227 501 Name of father of x -1 Name of father of Thomas Henry Huxley George Huxley Thomas Henry Huxley "[',' ' the' ' biologist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' great' '-' 'grand' 'daughter' ' of' ' the' ' famous']" , the biologist , and his wife , Mary , who was a great - grand daughter of the famous False young friend Thomas Henry Huxley was firmly 6 [' young', ' friend', ' Thomas', ' Henry', ' H', 'ux', 'ley']
+2228 501 Name of father of x -1 Name of father of Thomas Henry Huxley George Huxley Thomas Henry Huxley "[',' ' the' ' biologist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' great' '-' 'grand' 'daughter' ' of' ' the' ' famous']" , the biologist , and his wife , Mary , who was a great - grand daughter of the famous False originally described by Thomas Henry Huxley in 1869 as the type 7 [' originally', ' described', ' by', ' Thomas', ' Henry', ' H', 'ux', 'ley']
+2229 501 Name of father of x -1 Name of father of Thomas Henry Huxley George Huxley Thomas Henry Huxley "[',' ' the' ' biologist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' great' '-' 'grand' 'daughter' ' of' ' the' ' famous']" , the biologist , and his wife , Mary , who was a great - grand daughter of the famous False On the other hand, Thomas Henry Huxley sought to demonstrate 9 [' On', ' the', ' other', ' hand', ',', ' Thomas', ' Henry', ' H', 'ux', 'ley']
+2230 501 Name of father of x -1 Name of father of Thomas Henry Huxley George Huxley Thomas Henry Huxley "[',' ' the' ' biologist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' great' '-' 'grand' 'daughter' ' of' ' the' ' famous']" , the biologist , and his wife , Mary , who was a great - grand daughter of the famous False " ""knowledge"" ) was used by Thomas Henry Huxley in a speech" 11 "[' ""', 'knowledge', '""', ' )', ' was', ' used', ' by', ' Thomas', ' Henry', ' H', 'ux', 'ley']"
+2231 501 Name of father of x -1 Name of father of Thomas Henry Huxley George Huxley Thomas Henry Huxley "[',' ' the' ' biologist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' great' '-' 'grand' 'daughter' ' of' ' the' ' famous']" , the biologist , and his wife , Mary , who was a great - grand daughter of the famous False English biologist Thomas Henry Huxley reconsidered the 6 [' English', ' biologist', ' Thomas', ' Henry', ' H', 'ux', 'ley']
+2232 502 Name of father of x -1 Name of father of Bing Crosby Harry Lowe Crosby Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' his' ' wife' ',' ' Mary' ','
+ ' who' ' was' ' a' ' singer' ',' ' and' ' his' ' daughter' ',']" , the famous singer , and his wife , Mary , who was a singer , and his daughter , False popularized by Bing Crosby on his 1942 album of 4 [' popular', 'ized', ' by', ' Bing', ' Crosby']
+2233 502 Name of father of x -1 Name of father of Bing Crosby Harry Lowe Crosby Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' his' ' wife' ',' ' Mary' ','
+ ' who' ' was' ' a' ' singer' ',' ' and' ' his' ' daughter' ',']" , the famous singer , and his wife , Mary , who was a singer , and his daughter , False collaborated with Bing Crosby and Fred Waring on 3 [' collaborated', ' with', ' Bing', ' Crosby']
+2234 502 Name of father of x -1 Name of father of Bing Crosby Harry Lowe Crosby Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' his' ' wife' ',' ' Mary' ','
+ ' who' ' was' ' a' ' singer' ',' ' and' ' his' ' daughter' ',']" , the famous singer , and his wife , Mary , who was a singer , and his daughter , False Bob Hope and Bing Crosby films. Some of the 4 [' Bob', ' Hope', ' and', ' Bing', ' Crosby']
+2235 502 Name of father of x -1 Name of father of Bing Crosby Harry Lowe Crosby Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' his' ' wife' ',' ' Mary' ','
+ ' who' ' was' ' a' ' singer' ',' ' and' ' his' ' daughter' ',']" , the famous singer , and his wife , Mary , who was a singer , and his daughter , False " inductees Bing Crosby and Fred Lowery.
+" 3 [' induct', 'ees', ' Bing', ' Crosby']
+2236 502 Name of father of x -1 Name of father of Bing Crosby Harry Lowe Crosby Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' his' ' wife' ',' ' Mary' ','
+ ' who' ' was' ' a' ' singer' ',' ' and' ' his' ' daughter' ',']" , the famous singer , and his wife , Mary , who was a singer , and his daughter , False 1949, respectively, Bing Crosby and Bob Hope visited 5 [' 1949', ',', ' respectively', ',', ' Bing', ' Crosby']
+2237 503 Name of father of x -1 Name of father of Washington Irving William Irving Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book']" ", the author of the "" S ket ch - Book ,"" and the "" S ket ch - Book" False 1 ['Washington', ' Irving']
+2238 503 Name of father of x -1 Name of father of Washington Irving William Irving Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book']" ", the author of the "" S ket ch - Book ,"" and the "" S ket ch - Book" False " of Sleepy Hollow"" by Washington Irving and stars Johnny" 7 "[' of', ' Sleep', 'y', ' Hollow', '""', ' by', ' Washington', ' Irving']"
+2239 503 Name of father of x -1 Name of father of Washington Irving William Irving Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book']" ", the author of the "" S ket ch - Book ,"" and the "" S ket ch - Book" False fog of 1 June 1926, Washington Irving was struck, 7 [' fog', ' of', ' 1', ' June', ' 1926', ',', ' Washington', ' Irving']
+2240 503 Name of father of x -1 Name of father of Washington Irving William Irving Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book']" ", the author of the "" S ket ch - Book ,"" and the "" S ket ch - Book" False Atlantic. Captain Washington Irving Chambers took command 4 [' Atlantic', '.', ' Captain', ' Washington', ' Irving']
+2241 503 Name of father of x -1 Name of father of Washington Irving William Irving Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'Book']" ", the author of the "" S ket ch - Book ,"" and the "" S ket ch - Book" False 1 ['Washington', ' Irving']
+2242 504 Name of father of x -1 Name of father of Mel Brooks Max James Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False daughter), the 1977 Mel Brooks tribute to many of 5 [' daughter', '),', ' the', ' 1977', ' Mel', ' Brooks']
+2243 504 Name of father of x -1 Name of father of Mel Brooks Max James Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False it from the Mel Brooks film The Producers. 4 [' it', ' from', ' the', ' Mel', ' Brooks']
+2244 504 Name of father of x -1 Name of father of Mel Brooks Max James Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False Rocket; he worked for Mel Brooks and Brooksfilms 6 [' Rocket', ';', ' he', ' worked', ' for', ' Mel', ' Brooks']
+2245 504 Name of father of x -1 Name of father of Mel Brooks Max James Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False computer. Mel Brooks also makes an appearance 3 [' computer', '.', ' Mel', ' Brooks']
+2246 504 Name of father of x -1 Name of father of Mel Brooks Max James Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False " Has Two Mommies"". Mel Brooks has a cameo appearance" 7 "[' Has', ' Two', ' M', 'omm', 'ies', '"".', ' Mel', ' Brooks']"
+2247 505 Name of father of x -1 Name of father of Ezra Pound Homer Loomis Pound Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' modern'
+ 'ist' ' movement' ' in' ' poetry' '.' '\n' '\n' 'The' ' first' ' thing']" ", the poet , and the father of the modern ist movement in poetry .
+
+ The first thing" False vignettes commissioned by Ezra Pound for a 1923 edition 6 [' v', 'ign', 'ettes', ' commissioned', ' by', ' Ezra', ' Pound']
+2248 505 Name of father of x -1 Name of father of Ezra Pound Homer Loomis Pound Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' modern'
+ 'ist' ' movement' ' in' ' poetry' '.' '\n' '\n' 'The' ' first' ' thing']" ", the poet , and the father of the modern ist movement in poetry .
+
+ The first thing" False Imagist school of Ezra Pound and Robert Frost. 5 [' Imag', 'ist', ' school', ' of', ' Ezra', ' Pound']
+2249 505 Name of father of x -1 Name of father of Ezra Pound Homer Loomis Pound Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' modern'
+ 'ist' ' movement' ' in' ' poetry' '.' '\n' '\n' 'The' ' first' ' thing']" ", the poet , and the father of the modern ist movement in poetry .
+
+ The first thing" False Amy Lowell and Ezra Pound found inspiration 4 [' Amy', ' Lowell', ' and', ' Ezra', ' Pound']
+2250 505 Name of father of x -1 Name of father of Ezra Pound Homer Loomis Pound Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' modern'
+ 'ist' ' movement' ' in' ' poetry' '.' '\n' '\n' 'The' ' first' ' thing']" ", the poet , and the father of the modern ist movement in poetry .
+
+ The first thing" False Joyce's work, such as Ezra Pound and the author's brother 7 "[' Joyce', ""'s"", ' work', ',', ' such', ' as', ' Ezra', ' Pound']"
+2251 505 Name of father of x -1 Name of father of Ezra Pound Homer Loomis Pound Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' modern'
+ 'ist' ' movement' ' in' ' poetry' '.' '\n' '\n' 'The' ' first' ' thing']" ", the poet , and the father of the modern ist movement in poetry .
+
+ The first thing" False of the money Ezra Pound generously 4 [' of', ' the', ' money', ' Ezra', ' Pound']
+2252 506 Name of father of x -1 Name of father of Colette Jules Colette Colette "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the daughter of the late Mr . and Mrs .
+
+ The
+
+ Name of mother of" False Skinner), Remy and Colette create a variation 5 [' Skinner', '),', ' Remy', ' and', ' Co', 'lette']
+2253 506 Name of father of x -1 Name of father of Colette Jules Colette Colette "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the daughter of the late Mr . and Mrs .
+
+ The
+
+ Name of mother of" False lost his mind. Colette later returns after 5 [' lost', ' his', ' mind', '.', ' Co', 'lette']
+2254 506 Name of father of x -1 Name of father of Colette Jules Colette Colette "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the daughter of the late Mr . and Mrs .
+
+ The
+
+ Name of mother of" False theatricals (in which Colette sometimes 6 [' theatrical', 's', ' (', 'in', ' which', ' Co', 'lette']
+2255 506 Name of father of x -1 Name of father of Colette Jules Colette Colette "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the daughter of the late Mr . and Mrs .
+
+ The
+
+ Name of mother of" False libretto by Colette. She and Ravel 5 [' lib', 'rett', 'o', ' by', ' Co', 'lette']
+2256 506 Name of father of x -1 Name of father of Colette Jules Colette Colette "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the daughter of the late Mr . and Mrs .
+
+ The
+
+ Name of mother of" False to be a success. Colette Tatou, the staff's 6 [' to', ' be', ' a', ' success', '.', ' Co', 'lette']
+2257 507 Name of father of x -1 Name of father of Billie Holiday Clarence Holiday Billie Holiday "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' the' ' father' ' of' ' the' ' singer' ',' ' the' ' father' ' of']" , the singer , and the father of the singer , the father of the singer , the father of False " include the song"" Billie Holiday "", written by Karlsson," 6 "[' include', ' the', ' song', '""', ' Bill', 'ie', ' Holiday']"
+2258 507 Name of father of x -1 Name of father of Billie Holiday Clarence Holiday Billie Holiday "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' the' ' father' ' of' ' the' ' singer' ',' ' the' ' father' ' of']" , the singer , and the father of the singer , the father of the singer , the father of False musical idols like Billie Holiday and Josephine 5 [' musical', ' idols', ' like', ' Bill', 'ie', ' Holiday']
+2259 507 Name of father of x -1 Name of father of Billie Holiday Clarence Holiday Billie Holiday "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' the' ' father' ' of' ' the' ' singer' ',' ' the' ' father' ' of']" , the singer , and the father of the singer , the father of the singer , the father of False " of ""a less-pained Billie Holiday and a less-sheltered" 9 "[' of', ' ""', 'a', ' less', '-', 'p', 'ained', ' Bill', 'ie', ' Holiday']"
+2260 507 Name of father of x -1 Name of father of Billie Holiday Clarence Holiday Billie Holiday "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' the' ' father' ' of' ' the' ' singer' ',' ' the' ' father' ' of']" , the singer , and the father of the singer , the father of the singer , the father of False " include the song"" Billie Holiday "", written by Karlsson," 6 "[' include', ' the', ' song', '""', ' Bill', 'ie', ' Holiday']"
+2261 507 Name of father of x -1 Name of father of Billie Holiday Clarence Holiday Billie Holiday "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' the' ' father' ' of' ' the' ' singer' ',' ' the' ' father' ' of']" , the singer , and the father of the singer , the father of the singer , the father of False jazz singers like Billie Holiday and Josephine Baker. 5 [' jazz', ' singers', ' like', ' Bill', 'ie', ' Holiday']
+2262 508 Name of father of x -1 Name of father of Satyajit Ray Sukumar Ray Satyajit Ray "[',' ' the' ' famous' ' Indian' ' film' ' director' ',' ' was' ' born'
+ ' in' ' this' ' city' '.' '\n' '\n' 'References' '\n' '\n' 'Category' ':']" ", the famous Indian film director , was born in this city .
+
+ References
+
+ Category :" False his books, and Satyajit Ray — and has said 8 [' his', ' books', ',', ' and', ' Sat', 'y', 'aj', 'it', ' Ray']
+2263 508 Name of father of x -1 Name of father of Satyajit Ray Sukumar Ray Satyajit Ray "[',' ' the' ' famous' ' Indian' ' film' ' director' ',' ' was' ' born'
+ ' in' ' this' ' city' '.' '\n' '\n' 'References' '\n' '\n' 'Category' ':']" ", the famous Indian film director , was born in this city .
+
+ References
+
+ Category :" False 4 ['Sat', 'y', 'aj', 'it', ' Ray']
+2264 508 Name of father of x -1 Name of father of Satyajit Ray Sukumar Ray Satyajit Ray "[',' ' the' ' famous' ' Indian' ' film' ' director' ',' ' was' ' born'
+ ' in' ' this' ' city' '.' '\n' '\n' 'References' '\n' '\n' 'Category' ':']" ", the famous Indian film director , was born in this city .
+
+ References
+
+ Category :" False Ghatak and Satyajit Ray but pointed that 8 [' G', 'hat', 'ak', ' and', ' Sat', 'y', 'aj', 'it', ' Ray']
+2265 508 Name of father of x -1 Name of father of Satyajit Ray Sukumar Ray Satyajit Ray "[',' ' the' ' famous' ' Indian' ' film' ' director' ',' ' was' ' born'
+ ' in' ' this' ' city' '.' '\n' '\n' 'References' '\n' '\n' 'Category' ':']" ", the famous Indian film director , was born in this city .
+
+ References
+
+ Category :" False about them ... When Satyajit Ray passed on, I was 8 [' about', ' them', '...', ' When', ' Sat', 'y', 'aj', 'it', ' Ray']
+2266 508 Name of father of x -1 Name of father of Satyajit Ray Sukumar Ray Satyajit Ray "[',' ' the' ' famous' ' Indian' ' film' ' director' ',' ' was' ' born'
+ ' in' ' this' ' city' '.' '\n' '\n' 'References' '\n' '\n' 'Category' ':']" ", the famous Indian film director , was born in this city .
+
+ References
+
+ Category :" False established the Satyajit Ray Film and Study collection, 6 [' established', ' the', ' Sat', 'y', 'aj', 'it', ' Ray']
+2267 509 Name of father of x -1 Name of father of Mick Jagger Basil Fashawe Jagger Mick Jagger "[',' ' the' ' Rolling' ' Stones' ',' ' and' ' the' ' Beatles' '.' '\n'
+ '\n' 'The' ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a']" ", the Rolling Stones , and the Beatles .
+
+ The Beatles were the first band to have a" False from 1964, footage of Mick Jagger discussing his 7 [' from', ' 1964', ',', ' footage', ' of', ' Mick', ' J', 'agger']
+2268 509 Name of father of x -1 Name of father of Mick Jagger Basil Fashawe Jagger Mick Jagger "[',' ' the' ' Rolling' ' Stones' ',' ' and' ' the' ' Beatles' '.' '\n'
+ '\n' 'The' ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a']" ", the Rolling Stones , and the Beatles .
+
+ The Beatles were the first band to have a" False mashed. While Mick Jagger was 2,000 light years 5 [' mashed', '.', ' While', ' Mick', ' J', 'agger']
+2269 509 Name of father of x -1 Name of father of Mick Jagger Basil Fashawe Jagger Mick Jagger "[',' ' the' ' Rolling' ' Stones' ',' ' and' ' the' ' Beatles' '.' '\n'
+ '\n' 'The' ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a']" ", the Rolling Stones , and the Beatles .
+
+ The Beatles were the first band to have a" False " Harrison and Mick Jagger ""nagging me" 4 [' Harrison', ' and', ' Mick', ' J', 'agger']
+2270 509 Name of father of x -1 Name of father of Mick Jagger Basil Fashawe Jagger Mick Jagger "[',' ' the' ' Rolling' ' Stones' ',' ' and' ' the' ' Beatles' '.' '\n'
+ '\n' 'The' ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a']" ", the Rolling Stones , and the Beatles .
+
+ The Beatles were the first band to have a" False in England with Mick Jagger and Keith Richards 5 [' in', ' England', ' with', ' Mick', ' J', 'agger']
+2271 509 Name of father of x -1 Name of father of Mick Jagger Basil Fashawe Jagger Mick Jagger "[',' ' the' ' Rolling' ' Stones' ',' ' and' ' the' ' Beatles' '.' '\n'
+ '\n' 'The' ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a']" ", the Rolling Stones , and the Beatles .
+
+ The Beatles were the first band to have a" False " ""State of Shock"" with Mick Jagger for the Jacksons'" 8 "[' ""', 'State', ' of', ' Shock', '""', ' with', ' Mick', ' J', 'agger']"
+2272 511 Name of father of x -1 Name of father of Marcel Proust Adrien Proust Marcel Proust "[',' ' the' ' author' ' of' ' _' 'Rem' 'em' 'brance' ' of' ' Things'
+ ' Past' '_' ',' ' and' ' the' ' author' ' of' ' _' 'Sw' 'ann']" , the author of _ Rem em brance of Things Past _ , and the author of _ Sw ann False 4 ['Mar', 'cel', ' P', 'rou', 'st']
+2273 511 Name of father of x -1 Name of father of Marcel Proust Adrien Proust Marcel Proust "[',' ' the' ' author' ' of' ' _' 'Rem' 'em' 'brance' ' of' ' Things'
+ ' Past' '_' ',' ' and' ' the' ' author' ' of' ' _' 'Sw' 'ann']" , the author of _ Rem em brance of Things Past _ , and the author of _ Sw ann False " French novelist Marcel Proust later termed ""retrospective" 5 [' French', ' novelist', ' Marcel', ' P', 'rou', 'st']
+2274 511 Name of father of x -1 Name of father of Marcel Proust Adrien Proust Marcel Proust "[',' ' the' ' author' ' of' ' _' 'Rem' 'em' 'brance' ' of' ' Things'
+ ' Past' '_' ',' ' and' ' the' ' author' ' of' ' _' 'Sw' 'ann']" , the author of _ Rem em brance of Things Past _ , and the author of _ Sw ann False 't care for it. Marcel Proust never attended 9 "["" '"", 't', ' care', ' for', ' it', '.', ' Marcel', ' P', 'rou', 'st']"
+2275 511 Name of father of x -1 Name of father of Marcel Proust Adrien Proust Marcel Proust "[',' ' the' ' author' ' of' ' _' 'Rem' 'em' 'brance' ' of' ' Things'
+ ' Past' '_' ',' ' and' ' the' ' author' ' of' ' _' 'Sw' 'ann']" , the author of _ Rem em brance of Things Past _ , and the author of _ Sw ann False Sand) have included Marcel Proust and André Gide; 7 [' Sand', ')', ' have', ' included', ' Marcel', ' P', 'rou', 'st']
+2276 511 Name of father of x -1 Name of father of Marcel Proust Adrien Proust Marcel Proust "[',' ' the' ' author' ' of' ' _' 'Rem' 'em' 'brance' ' of' ' Things'
+ ' Past' '_' ',' ' and' ' the' ' author' ' of' ' _' 'Sw' 'ann']" , the author of _ Rem em brance of Things Past _ , and the author of _ Sw ann False 4 ['Mar', 'cel', ' P', 'rou', 'st']
+2277 514 Name of father of x -1 Name of father of Josephus Matthias Josephus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' 'us'
+ ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Joseph']" ", and the
+
+ Name of mother of Joseph us , and the name of the father of Joseph" False according to Josephus, 9 metres 3 [' according', ' to', ' Joseph', 'us']
+2278 514 Name of father of x -1 Name of father of Josephus Matthias Josephus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' 'us'
+ ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Joseph']" ", and the
+
+ Name of mother of Joseph us , and the name of the father of Joseph" False Apollo for safety. Josephus notes that Gaza 5 [' Apollo', ' for', ' safety', '.', ' Joseph', 'us']
+2279 514 Name of father of x -1 Name of father of Josephus Matthias Josephus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' 'us'
+ ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Joseph']" ", and the
+
+ Name of mother of Joseph us , and the name of the father of Joseph" False Accounts of Josephus and Tacitus 3 [' Accounts', ' of', ' Joseph', 'us']
+2280 514 Name of father of x -1 Name of father of Josephus Matthias Josephus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' 'us'
+ ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Joseph']" ", and the
+
+ Name of mother of Joseph us , and the name of the father of Joseph" False dispersed. Josephus claims that 1,100,000 3 [' dispersed', '.', ' Joseph', 'us']
+2281 514 Name of father of x -1 Name of father of Josephus Matthias Josephus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' 'us'
+ ',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Joseph']" ", and the
+
+ Name of mother of Joseph us , and the name of the father of Joseph" False Tiberius's reign, as well as Josephus record Tiberius as 10 "[' T', 'iber', 'ius', ""'s"", ' reign', ',', ' as', ' well', ' as', ' Joseph', 'us']"
+2282 515 Name of father of x -1 Name of father of Sting Ernest Matthew Sumner Sting "['ray' ',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" "ray , the
+
+ I am a mother of two , a wife , a daughter , a sister" False lead vocalist Sting provided backing 3 [' lead', ' vocal', 'ist', ' Sting']
+2283 515 Name of father of x -1 Name of father of Sting Ernest Matthew Sumner Sting "['ray' ',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" "ray , the
+
+ I am a mother of two , a wife , a daughter , a sister" False with the Iron Fists: Sting of the Scorpion but 6 [' with', ' the', ' Iron', ' F', 'ists', ':', ' Sting']
+2284 515 Name of father of x -1 Name of father of Sting Ernest Matthew Sumner Sting "['ray' ',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" "ray , the
+
+ I am a mother of two , a wife , a daughter , a sister" False Christian Cage and Sting against the team 3 [' Christian', ' Cage', ' and', ' Sting']
+2285 515 Name of father of x -1 Name of father of Sting Ernest Matthew Sumner Sting "['ray' ',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" "ray , the
+
+ I am a mother of two , a wife , a daughter , a sister" False hold on Booker T. Sting then entered 5 [' hold', ' on', ' Booker', ' T', '.', ' Sting']
+2286 515 Name of father of x -1 Name of father of Sting Ernest Matthew Sumner Sting "['ray' ',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" "ray , the
+
+ I am a mother of two , a wife , a daughter , a sister" False logo of the wrestler Sting appearing on the 4 [' logo', ' of', ' the', ' wrestler', ' Sting']
+2287 516 Name of father of x -1 Name of father of Maupassant Gustave de Maupassant Maupassant "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' child' ' is' ' a' ' girl' '.']" ", and the
+
+ Name of mother of the child .
+
+ The child is a girl ." False four Guy de Maupassant stories that are 5 [' four', ' Guy', ' de', ' Mau', 'pass', 'ant']
+2288 516 Name of father of x -1 Name of father of Maupassant Gustave de Maupassant Maupassant "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' child' ' is' ' a' ' girl' '.']" ", and the
+
+ Name of mother of the child .
+
+ The child is a girl ." False " the Guy de Maupassant story ""Fear""" 5 [' the', ' Guy', ' de', ' Mau', 'pass', 'ant']
+2289 516 Name of father of x -1 Name of father of Maupassant Gustave de Maupassant Maupassant "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' child' ' is' ' a' ' girl' '.']" ", and the
+
+ Name of mother of the child .
+
+ The child is a girl ." False compares him to Guy de Maupassant for their ability 7 [' compares', ' him', ' to', ' Guy', ' de', ' Mau', 'pass', 'ant']
+2290 516 Name of father of x -1 Name of father of Maupassant Gustave de Maupassant Maupassant "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' child' ' is' ' a' ' girl' '.']" ", and the
+
+ Name of mother of the child .
+
+ The child is a girl ." False " the Guy de Maupassant story ""Fear"" and" 5 [' the', ' Guy', ' de', ' Mau', 'pass', 'ant']
+2291 516 Name of father of x -1 Name of father of Maupassant Gustave de Maupassant Maupassant "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' child' ' is' ' a' ' girl' '.']" ", and the
+
+ Name of mother of the child .
+
+ The child is a girl ." False " by the Guy de Maupassant story ""Fear"" and" 6 [' by', ' the', ' Guy', ' de', ' Mau', 'pass', 'ant']
+2292 517 Name of father of x -1 Name of father of Elizabeth I of England Henry VIII of England Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' son' ' of' ' a' ' king'
+ ',' ' and' ' I' ' am' ' the' ' son' ' of' ' a']" ", and the
+
+ I am the son of a king , and I am the son of a" False and Queen Elizabeth I of England came up for discussion; 5 [' and', ' Queen', ' Elizabeth', ' I', ' of', ' England']
+2293 517 Name of father of x -1 Name of father of Elizabeth I of England Henry VIII of England Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' son' ' of' ' a' ' king'
+ ',' ' and' ' I' ' am' ' the' ' son' ' of' ' a']" ", and the
+
+ I am the son of a king , and I am the son of a" False news. Because Elizabeth I of England was never queen 6 [' news', '.', ' Because', ' Elizabeth', ' I', ' of', ' England']
+2294 517 Name of father of x -1 Name of father of Elizabeth I of England Henry VIII of England Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' son' ' of' ' a' ' king'
+ ',' ' and' ' I' ' am' ' the' ' son' ' of' ' a']" ", and the
+
+ I am the son of a king , and I am the son of a" False with Spain. Queen Elizabeth I of England chose to support 7 [' with', ' Spain', '.', ' Queen', ' Elizabeth', ' I', ' of', ' England']
+2295 517 Name of father of x -1 Name of father of Elizabeth I of England Henry VIII of England Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' son' ' of' ' a' ' king'
+ ',' ' and' ' I' ' am' ' the' ' son' ' of' ' a']" ", and the
+
+ I am the son of a king , and I am the son of a" False " Elizabeth I of England =
+" 3 [' Elizabeth', ' I', ' of', ' England']
+2296 517 Name of father of x -1 Name of father of Elizabeth I of England Henry VIII of England Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' son' ' of' ' a' ' king'
+ ',' ' and' ' I' ' am' ' the' ' son' ' of' ' a']" ", and the
+
+ I am the son of a king , and I am the son of a" False Queen Mary. In 1559 Elizabeth I of England granted the site 9 [' Queen', ' Mary', '.', ' In', ' 15', '59', ' Elizabeth', ' I', ' of', ' England']
+2297 518 Name of father of x -1 Name of father of Harriet Beecher Stowe Lyman Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False written by Harriet Beecher Stowe about her winters 7 [' written', ' by', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']
+2298 518 Name of father of x -1 Name of father of Harriet Beecher Stowe Lyman Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False By the time Harriet Beecher Stowe (1811 – 1896) moved 8 [' By', ' the', ' time', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']
+2299 518 Name of father of x -1 Name of father of Harriet Beecher Stowe Lyman Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False previously included by Harriet Beecher Stowe in Uncle Tom's 8 [' previously', ' included', ' by', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']
+2300 518 Name of father of x -1 Name of father of Harriet Beecher Stowe Lyman Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False Thulesius, author of Harriet Beecher Stowe in Florida, 11 [' Th', 'ules', 'ius', ',', ' author', ' of', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']
+2301 518 Name of father of x -1 Name of father of Harriet Beecher Stowe Lyman Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False previously included by Harriet Beecher Stowe in Uncle Tom's 8 [' previously', ' included', ' by', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']
+2302 519 Name of father of x -1 Name of father of Grace Kelly John B. Kelly Sr. Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' '\n' '\n' 'The'
+ ' daughter']" ", the daughter of a wealthy family , and the daughter of a wealthy family .
+
+ The daughter" False High Noon (1952) with Grace Kelly for United Artists. 8 [' High', ' Noon', ' (', '19', '52', ')', ' with', ' Grace', ' Kelly']
+2303 519 Name of father of x -1 Name of father of Grace Kelly John B. Kelly Sr. Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' '\n' '\n' 'The'
+ ' daughter']" ", the daughter of a wealthy family , and the daughter of a wealthy family .
+
+ The daughter" False company. He dated Grace Kelly and Joan Bennett 5 [' company', '.', ' He', ' dated', ' Grace', ' Kelly']
+2304 519 Name of father of x -1 Name of father of Grace Kelly John B. Kelly Sr. Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' '\n' '\n' 'The'
+ ' daughter']" ", the daughter of a wealthy family , and the daughter of a wealthy family .
+
+ The daughter" False gave actress Grace Kelly in the 1954 film 3 [' gave', ' actress', ' Grace', ' Kelly']
+2305 519 Name of father of x -1 Name of father of Grace Kelly John B. Kelly Sr. Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' '\n' '\n' 'The'
+ ' daughter']" ", the daughter of a wealthy family , and the daughter of a wealthy family .
+
+ The daughter" False alongside Bing Crosby and Grace Kelly in High Society 5 [' alongside', ' Bing', ' Crosby', ' and', ' Grace', ' Kelly']
+2306 519 Name of father of x -1 Name of father of Grace Kelly John B. Kelly Sr. Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' '\n' '\n' 'The'
+ ' daughter']" ", the daughter of a wealthy family , and the daughter of a wealthy family .
+
+ The daughter" False declared her' The New Grace Kelly '. At the 6 "[' declared', ' her', ""'"", ' The', ' New', ' Grace', ' Kelly']"
+2307 520 Name of father of x -1 Name of father of Dennis Hopper Jay Millard Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' famous' ' Ho' 'pper' ' family' '.'
+ '\n' '\n' 'The' ' Ho' 'pper' ' family' ' is' ' a' ' family' ' of']" ", the father of the famous Ho pper family .
+
+ The Ho pper family is a family of" False comeback of Dennis Hopper after a significant 4 [' comeback', ' of', ' Dennis', ' Ho', 'pper']
+2308 520 Name of father of x -1 Name of father of Dennis Hopper Jay Millard Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' famous' ' Ho' 'pper' ' family' '.'
+ '\n' '\n' 'The' ' Ho' 'pper' ' family' ' is' ' a' ' family' ' of']" ", the father of the famous Ho pper family .
+
+ The Ho pper family is a family of" False Sinatra, 2003), Dennis Hopper (The Night We 7 [' Sin', 'atra', ',', ' 2003', '),', ' Dennis', ' Ho', 'pper']
+2309 520 Name of father of x -1 Name of father of Dennis Hopper Jay Millard Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' famous' ' Ho' 'pper' ' family' '.'
+ '\n' '\n' 'The' ' Ho' 'pper' ' family' ' is' ' a' ' family' ' of']" ", the father of the famous Ho pper family .
+
+ The Ho pper family is a family of" False starring Peter Fonda, Dennis Hopper (Director) and Jack 7 [' starring', ' Peter', ' F', 'onda', ',', ' Dennis', ' Ho', 'pper']
+2310 520 Name of father of x -1 Name of father of Dennis Hopper Jay Millard Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' famous' ' Ho' 'pper' ' family' '.'
+ '\n' '\n' 'The' ' Ho' 'pper' ' family' ' is' ' a' ' family' ' of']" ", the father of the famous Ho pper family .
+
+ The Ho pper family is a family of" False Roberto Rossellini. Dennis Hopper was the biggest 7 [' Roberto', ' Ros', 'sell', 'ini', '.', ' Dennis', ' Ho', 'pper']
+2311 520 Name of father of x -1 Name of father of Dennis Hopper Jay Millard Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' famous' ' Ho' 'pper' ' family' '.'
+ '\n' '\n' 'The' ' Ho' 'pper' ' family' ' is' ' a' ' family' ' of']" ", the father of the famous Ho pper family .
+
+ The Ho pper family is a family of" False mid 1996, with Dennis Hopper attached to 6 [' mid', ' 1996', ',', ' with', ' Dennis', ' Ho', 'pper']
+2312 521 Name of father of x -1 Name of father of Geoffrey Chaucer John Chaucer Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False Isabel of Bavaria. Geoffrey Chaucer supervised preparations 7 [' Isabel', ' of', ' Bav', 'aria', '.', ' Geoffrey', ' Chau', 'cer']
+2313 521 Name of father of x -1 Name of father of Geoffrey Chaucer John Chaucer Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False of Bavaria. Geoffrey Chaucer supervised preparations 6 [' of', ' Bav', 'aria', '.', ' Geoffrey', ' Chau', 'cer']
+2314 521 Name of father of x -1 Name of father of Geoffrey Chaucer John Chaucer Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False " The Miller's Tale, Geoffrey Chaucer writes ""And prively" 7 "[' The', ' Miller', ""'s"", ' Tale', ',', ' Geoffrey', ' Chau', 'cer']"
+2315 521 Name of father of x -1 Name of father of Geoffrey Chaucer John Chaucer Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False exquisite. The poet Geoffrey Chaucer uses the windows 6 [' exquisite', '.', ' The', ' poet', ' Geoffrey', ' Chau', 'cer']
+2316 521 Name of father of x -1 Name of father of Geoffrey Chaucer John Chaucer Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False of Bavaria. Geoffrey Chaucer supervised preparations 6 [' of', ' Bav', 'aria', '.', ' Geoffrey', ' Chau', 'cer']
+2317 522 Name of father of x -1 Name of father of John Steinbeck John Steinbeck John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Prize-winner John Steinbeck wrote about Hurricane 5 [' Prize', '-', 'winner', ' John', ' Stein', 'beck']
+2318 522 Name of father of x -1 Name of father of John Steinbeck John Steinbeck John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Saroyan and John Steinbeck were known 6 [' Sar', 'oy', 'an', ' and', ' John', ' Stein', 'beck']
+2319 522 Name of father of x -1 Name of father of John Steinbeck John Steinbeck John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Tolkien, Robert Frost, John Steinbeck and E.M. Forster. 7 [' Tolkien', ',', ' Robert', ' Frost', ',', ' John', ' Stein', 'beck']
+2320 522 Name of father of x -1 Name of father of John Steinbeck John Steinbeck John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Jonathan Miller and John Steinbeck from a young 5 [' Jonathan', ' Miller', ' and', ' John', ' Stein', 'beck']
+2321 522 Name of father of x -1 Name of father of John Steinbeck John Steinbeck John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False American author John Steinbeck and published 4 [' American', ' author', ' John', ' Stein', 'beck']
+2322 523 Name of father of x -1 Name of father of Muhammad Abdullah ibn Abdul-Muttalib Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False historian Ibn Ishaq, Muhammad was involved with 6 [' historian', ' Ibn', ' Is', 'ha', 'q', ',', ' Muhammad']
+2323 523 Name of father of x -1 Name of father of Muhammad Abdullah ibn Abdul-Muttalib Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False arrived with Muhammad Ali, bringing the 2 [' arrived', ' with', ' Muhammad']
+2324 523 Name of father of x -1 Name of father of Muhammad Abdullah ibn Abdul-Muttalib Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False and veneration of Muhammad have been expressed 4 [' and', ' vener', 'ation', ' of', ' Muhammad']
+2325 523 Name of father of x -1 Name of father of Muhammad Abdullah ibn Abdul-Muttalib Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False societies, insulting Muhammad is considered 3 [' societies', ',', ' insulting', ' Muhammad']
+2326 523 Name of father of x -1 Name of father of Muhammad Abdullah ibn Abdul-Muttalib Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False government purge when Muhammad Shah came to power. 3 [' government', ' purge', ' when', ' Muhammad']
+2327 524 Name of father of x -1 Name of father of Marie Antoinette Francis I, Holy Roman Emperor Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' France' ',' ' the' ' Queen' ' of' ' France' ',' ' the']" ", the Queen of France , and the
+
+ Queen of France , the Queen of France , the" False staircase. Near this was the Marie Antoinette parlor, which 9 [' staircase', '.', ' Near', ' this', ' was', ' the', ' Marie', ' Ant', 'oin', 'ette']
+2328 524 Name of father of x -1 Name of father of Marie Antoinette Francis I, Holy Roman Emperor Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' France' ',' ' the' ' Queen' ' of' ' France' ',' ' the']" ", the Queen of France , and the
+
+ Queen of France , the Queen of France , the" False Burke's tears for Marie Antoinette and the monarchy of 7 "[' Burke', ""'s"", ' tears', ' for', ' Marie', ' Ant', 'oin', 'ette']"
+2329 524 Name of father of x -1 Name of father of Marie Antoinette Francis I, Holy Roman Emperor Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' France' ',' ' the' ' Queen' ' of' ' France' ',' ' the']" ", the Queen of France , and the
+
+ Queen of France , the Queen of France , the" False " Suite ====
+" 6 [' Suite', ' =', '===', 'Marie', ' Ant', 'oin', 'ette']
+2330 524 Name of father of x -1 Name of father of Marie Antoinette Francis I, Holy Roman Emperor Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' France' ',' ' the' ' Queen' ' of' ' France' ',' ' the']" ", the Queen of France , and the
+
+ Queen of France , the Queen of France , the" False his Queen, Marie Antoinette (who came tumbling 6 [' his', ' Queen', ',', ' Marie', ' Ant', 'oin', 'ette']
+2331 524 Name of father of x -1 Name of father of Marie Antoinette Francis I, Holy Roman Emperor Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' France' ',' ' the' ' Queen' ' of' ' France' ',' ' the']" ", the Queen of France , and the
+
+ Queen of France , the Queen of France , the" False French Baroque – Marie Antoinette Revival, Military 8 [' French', ' Bar', 'o', 'que', ' –', ' Marie', ' Ant', 'oin', 'ette']
+2332 525 Name of father of x -1 Name of father of James VI and I Henry Stuart, Lord Darnley James VI and I "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False granddaughter of James VI and I through his daughter 5 [' granddaughter', ' of', ' James', ' VI', ' and', ' I']
+2333 525 Name of father of x -1 Name of father of James VI and I Henry Stuart, Lord Darnley James VI and I "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " VI and I =
+" 7 [' VI', ' and', ' I', ' =', 'James', ' VI', ' and', ' I']
+2334 525 Name of father of x -1 Name of father of James VI and I Henry Stuart, Lord Darnley James VI and I "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " James VI and I =
+" 3 [' James', ' VI', ' and', ' I']
+2335 525 Name of father of x -1 Name of father of James VI and I Henry Stuart, Lord Darnley James VI and I "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False the auspices of King James VI and I the Authorised 8 [' the', ' ausp', 'ices', ' of', ' King', ' James', ' VI', ' and', ' I']
+2336 525 Name of father of x -1 Name of father of James VI and I Henry Stuart, Lord Darnley James VI and I "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " James VI and I =
+" 3 [' James', ' VI', ' and', ' I']
+2337 526 Name of father of x -1 Name of father of Friedrich Engels Friedrich Engels Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' ',' ' and' ' the'
+ ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of the proletariat , and the father of the proletariat .
+
+ The first thing to" False Hegel, Karl Marx and Friedrich Engels rather than later 6 [' Hegel', ',', ' Karl', ' Marx', ' and', ' Friedrich', ' Engels']
+2338 526 Name of father of x -1 Name of father of Friedrich Engels Friedrich Engels Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' ',' ' and' ' the'
+ ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of the proletariat , and the father of the proletariat .
+
+ The first thing to" False " referring to Friedrich Engels as ""General""," 3 [' referring', ' to', ' Friedrich', ' Engels']
+2339 526 Name of father of x -1 Name of father of Friedrich Engels Friedrich Engels Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' ',' ' and' ' the'
+ ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of the proletariat , and the father of the proletariat .
+
+ The first thing to" False safety protection. Friedrich Engels in his The Condition 4 [' safety', ' protection', '.', ' Friedrich', ' Engels']
+2340 526 Name of father of x -1 Name of father of Friedrich Engels Friedrich Engels Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' ',' ' and' ' the'
+ ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of the proletariat , and the father of the proletariat .
+
+ The first thing to" False " persuasions. Marxist Friedrich Engels wrote: ""I have learned" 5 [' persu', 'asions', '.', ' Marxist', ' Friedrich', ' Engels']
+2341 526 Name of father of x -1 Name of father of Friedrich Engels Friedrich Engels Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' ',' ' and' ' the'
+ ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of the proletariat , and the father of the proletariat .
+
+ The first thing to" False International, about whom Friedrich Engels once had written, 5 [' International', ',', ' about', ' whom', ' Friedrich', ' Engels']
+2342 528 Name of father of x -1 Name of father of Samuel Taylor Coleridge John Coleridge Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False Romantic poet Samuel Taylor Coleridge and the former 6 [' Romantic', ' poet', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2343 528 Name of father of x -1 Name of father of Samuel Taylor Coleridge John Coleridge Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False walks taken by poet Samuel Taylor Coleridge to Lynmouth, starting 8 [' walks', ' taken', ' by', ' poet', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2344 528 Name of father of x -1 Name of father of Samuel Taylor Coleridge John Coleridge Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False tribute to the Samuel Taylor Coleridge poem The Rime of the 7 [' tribute', ' to', ' the', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2345 528 Name of father of x -1 Name of father of Samuel Taylor Coleridge John Coleridge Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False saw the marriages of Samuel Taylor Coleridge to Sarah Fricker and 8 [' saw', ' the', ' marriages', ' of', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2346 528 Name of father of x -1 Name of father of Samuel Taylor Coleridge John Coleridge Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False Wordsworth and Samuel Taylor Coleridge that was both Wordsworth's 7 [' Word', 'sworth', ' and', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2347 529 Name of father of x -1 Name of father of Tennessee Williams Cornelius Coffin Williams Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False great-grandnephew of Tennessee Williams on his father's 8 [' great', '-', 'grand', 'n', 'ep', 'hew', ' of', ' Tennessee', ' Williams']
+2348 529 Name of father of x -1 Name of father of Tennessee Williams Cornelius Coffin Williams Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False over playwright Tennessee Williams and their experience 4 [' over', ' play', 'wright', ' Tennessee', ' Williams']
+2349 529 Name of father of x -1 Name of father of Tennessee Williams Cornelius Coffin Williams Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False " destroys me), the Tennessee Williams quote ""A prayer" 5 [' destroys', ' me', '),', ' the', ' Tennessee', ' Williams']
+2350 529 Name of father of x -1 Name of father of Tennessee Williams Cornelius Coffin Williams Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False by writers such as Tennessee Williams and William 5 [' by', ' writers', ' such', ' as', ' Tennessee', ' Williams']
+2351 529 Name of father of x -1 Name of father of Tennessee Williams Cornelius Coffin Williams Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False Best Actress. Tennessee Williams commented that Leigh 4 [' Best', ' Actress', '.', ' Tennessee', ' Williams']
+2352 530 Name of father of x -1 Name of father of George V Edward VII George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' premiere' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the premiere ." False Maori. Rodney, King George V and the destroyers 7 [' Ma', 'ori', '.', ' Rodney', ',', ' King', ' George', ' V']
+2353 530 Name of father of x -1 Name of father of George V Edward VII George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' premiere' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the premiere ." False " patent from King George V dated 5 June 1925.
+" 4 [' patent', ' from', ' King', ' George', ' V']
+2354 530 Name of father of x -1 Name of father of George V Edward VII George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' premiere' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the premiere ." False refuse has died: George V believed he could 5 [' refuse', ' has', ' died', ':', ' George', ' V']
+2355 530 Name of father of x -1 Name of father of George V Edward VII George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' premiere' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the premiere ." False that of the King George V class with a 5 [' that', ' of', ' the', ' King', ' George', ' V']
+2356 530 Name of father of x -1 Name of father of George V Edward VII George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' premiere' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the premiere ." False after the death of George V in 1936. She published 5 [' after', ' the', ' death', ' of', ' George', ' V']
+2357 531 Name of father of x -1 Name of father of Constantine the Great Constantius Chlorus Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' great'
+ ' em' 'perors' ' of' ' the' ' Roman' ' Empire' ',' ' who' ' was']" ", and the
+
+ The first of the three great em perors of the Roman Empire , who was" False Byzantium by Constantine the Great (r. 306 – 337) 6 [' Byz', 'ant', 'ium', ' by', ' Constantine', ' the', ' Great']
+2358 531 Name of father of x -1 Name of father of Constantine the Great Constantius Chlorus Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' great'
+ ' em' 'perors' ' of' ' the' ' Roman' ' Empire' ',' ' who' ' was']" ", and the
+
+ The first of the three great em perors of the Roman Empire , who was" False 313. Under Emperor Constantine the Great (306 – 337), 6 [' 313', '.', ' Under', ' Emperor', ' Constantine', ' the', ' Great']
+2359 531 Name of father of x -1 Name of father of Constantine the Great Constantius Chlorus Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' great'
+ ' em' 'perors' ' of' ' the' ' Roman' ' Empire' ',' ' who' ' was']" ", and the
+
+ The first of the three great em perors of the Roman Empire , who was" False under Emperor Constantine the Great (r. 306 – 337) to the 4 [' under', ' Emperor', ' Constantine', ' the', ' Great']
+2360 531 Name of father of x -1 Name of father of Constantine the Great Constantius Chlorus Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' great'
+ ' em' 'perors' ' of' ' the' ' Roman' ' Empire' ',' ' who' ' was']" ", and the
+
+ The first of the three great em perors of the Roman Empire , who was" False of Byzantium by Constantine the Great (r. 306 – 337) during 7 [' of', ' Byz', 'ant', 'ium', ' by', ' Constantine', ' the', ' Great']
+2361 531 Name of father of x -1 Name of father of Constantine the Great Constantius Chlorus Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' great'
+ ' em' 'perors' ' of' ' the' ' Roman' ' Empire' ',' ' who' ' was']" ", and the
+
+ The first of the three great em perors of the Roman Empire , who was" False Constantine I known as Constantine the Great — another 6 [' Constantine', ' I', ' known', ' as', ' Constantine', ' the', ' Great']
+2362 532 Name of father of x -1 Name of father of Tenzin Gyatso Choekyong Tsering Tenzin Gyatso "[',' ' the' ' Dalai' ' Lama' ',' ' the' ' 14' 'th' ' Dalai' ' Lama' ','
+ ' the' ' spiritual' ' leader' ' of' ' Tibet' ',' ' has' ' been' ' in']" , the Dalai Lama , the 14 th Dalai Lama , the spiritual leader of Tibet , has been in False 14th Dalai Lama, Tenzin Gyatso in Einsiedeln, 9 [' 14', 'th', ' Dalai', ' Lama', ',', ' Ten', 'zin', ' Gy', 'at', 'so']
+2363 532 Name of father of x -1 Name of father of Tenzin Gyatso Choekyong Tsering Tenzin Gyatso "[',' ' the' ' Dalai' ' Lama' ',' ' the' ' 14' 'th' ' Dalai' ' Lama' ','
+ ' the' ' spiritual' ' leader' ' of' ' Tibet' ',' ' has' ' been' ' in']" , the Dalai Lama , the 14 th Dalai Lama , the spiritual leader of Tibet , has been in False 14th Dalai Lama, Tenzin Gyatso in Einsiedeln, 9 [' 14', 'th', ' Dalai', ' Lama', ',', ' Ten', 'zin', ' Gy', 'at', 'so']
+2364 533 Name of father of x -1 Name of father of Francisco Franco Nicolás Franco Salgado-Araújo Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False from General Francisco Franco to rebels in Morocco. 3 [' from', ' General', ' Francisco', ' Franco']
+2365 533 Name of father of x -1 Name of father of Francisco Franco Nicolás Franco Salgado-Araújo Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False Gil-Robles and General Francisco Franco had approached him 7 [' Gil', '-', 'Rob', 'les', ' and', ' General', ' Francisco', ' Franco']
+2366 533 Name of father of x -1 Name of father of Francisco Franco Nicolás Franco Salgado-Araújo Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False especially of Francisco Franco (1939 – 1975), all 3 [' especially', ' of', ' Francisco', ' Franco']
+2367 533 Name of father of x -1 Name of father of Francisco Franco Nicolás Franco Salgado-Araújo Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False convinced dictator Francisco Franco that his town needed 3 [' convinced', ' dictator', ' Francisco', ' Franco']
+2368 533 Name of father of x -1 Name of father of Francisco Franco Nicolás Franco Salgado-Araújo Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False of state, such as Francisco Franco and Pope Pius XII. 6 [' of', ' state', ',', ' such', ' as', ' Francisco', ' Franco']
+2369 534 Name of father of x -1 Name of father of Tomáš Garrigue Masaryk Jozef Maszárik Tomáš Garrigue Masaryk "[',' ' the' ' Czech' 'oslov' 'ak' 'ian' ' states' 'man' ' and'
+ ' philosopher' ',' ' who' ' was' ' the' ' first' ' president' ' of'
+ ' Czech' 'oslov' 'akia']" , the Czech oslov ak ian states man and philosopher , who was the first president of Czech oslov akia False Gandhi of India, Tomáš Garrigue Masaryk of Czechoslovakia, 11 [' Gandhi', ' of', ' India', ',', ' Tom', 'á', 'š', ' Garr', 'igue', ' Mas', 'ary', 'k']
+2370 534 Name of father of x -1 Name of father of Tomáš Garrigue Masaryk Jozef Maszárik Tomáš Garrigue Masaryk "[',' ' the' ' Czech' 'oslov' 'ak' 'ian' ' states' 'man' ' and'
+ ' philosopher' ',' ' who' ' was' ' the' ' first' ' president' ' of'
+ ' Czech' 'oslov' 'akia']" , the Czech oslov ak ian states man and philosopher , who was the first president of Czech oslov akia False of India, Tomáš Garrigue Masaryk of Czechoslovakia, 10 [' of', ' India', ',', ' Tom', 'á', 'š', ' Garr', 'igue', ' Mas', 'ary', 'k']
+2371 535 Name of father of x -1 Name of father of Georges Bizet Adolphe Armand Bizet Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' '\n' '\n' 'The' ' first']" ", the French composer , was born in Paris in 18 38 . He was a
+
+ The first" False winners were Georges Bizet and Charles 6 [' winners', ' were', ' Georg', 'es', ' B', 'iz', 'et']
+2372 535 Name of father of x -1 Name of father of Georges Bizet Adolphe Armand Bizet Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' '\n' '\n' 'The' ' first']" ", the French composer , was born in Paris in 18 38 . He was a
+
+ The first" False 4 ['Georg', 'es', ' B', 'iz', 'et']
+2373 535 Name of father of x -1 Name of father of Georges Bizet Adolphe Armand Bizet Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' '\n' '\n' 'The' ' first']" ", the French composer , was born in Paris in 18 38 . He was a
+
+ The first" False work, the composer Georges Bizet was disappointed 8 [' work', ',', ' the', ' composer', ' Georg', 'es', ' B', 'iz', 'et']
+2374 535 Name of father of x -1 Name of father of Georges Bizet Adolphe Armand Bizet Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' '\n' '\n' 'The' ' first']" ", the French composer , was born in Paris in 18 38 . He was a
+
+ The first" False established a fund for a Georges Bizet prize, to be 9 [' established', ' a', ' fund', ' for', ' a', ' Georg', 'es', ' B', 'iz', 'et']
+2375 535 Name of father of x -1 Name of father of Georges Bizet Adolphe Armand Bizet Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' '\n' '\n' 'The' ' first']" ", the French composer , was born in Paris in 18 38 . He was a
+
+ The first" False " ""Habanera"" from the Georges Bizet opera Carmen," 11 "[' ""', 'H', 'aban', 'era', '""', ' from', ' the', ' Georg', 'es', ' B', 'iz', 'et']"
+2376 538 Name of father of x -1 Name of father of Pius XII Filippo Pacelli Pius XII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" '\n' '\n' 'The' ' Pope'
+ ""'s"" ' name' ' is' ' not' ' P' 'ius' ' XII' '.']" ", the Pope , and the Pope 's
+
+ The Pope 's name is not P ius XII ." False In the 1950s, Pope Pius XII told the most senior 8 [' In', ' the', ' 1950', 's', ',', ' Pope', ' P', 'ius', ' XII']
+2377 538 Name of father of x -1 Name of father of Pius XII Filippo Pacelli Pius XII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" '\n' '\n' 'The' ' Pope'
+ ""'s"" ' name' ' is' ' not' ' P' 'ius' ' XII' '.']" ", the Pope , and the Pope 's
+
+ The Pope 's name is not P ius XII ." False and writings of Pius XII were intended 5 [' and', ' writings', ' of', ' P', 'ius', ' XII']
+2378 538 Name of father of x -1 Name of father of Pius XII Filippo Pacelli Pius XII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" '\n' '\n' 'The' ' Pope'
+ ""'s"" ' name' ' is' ' not' ' P' 'ius' ' XII' '.']" ", the Pope , and the Pope 's
+
+ The Pope 's name is not P ius XII ." False intervened with Pope Pius XII to put an end to 5 [' intervened', ' with', ' Pope', ' P', 'ius', ' XII']
+2379 538 Name of father of x -1 Name of father of Pius XII Filippo Pacelli Pius XII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" '\n' '\n' 'The' ' Pope'
+ ""'s"" ' name' ' is' ' not' ' P' 'ius' ' XII' '.']" ", the Pope , and the Pope 's
+
+ The Pope 's name is not P ius XII ." False Pope: How Pope Pius XII Rescued Jews from the 6 [' Pope', ':', ' How', ' Pope', ' P', 'ius', ' XII']
+2380 538 Name of father of x -1 Name of father of Pius XII Filippo Pacelli Pius XII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" '\n' '\n' 'The' ' Pope'
+ ""'s"" ' name' ' is' ' not' ' P' 'ius' ' XII' '.']" ", the Pope , and the Pope 's
+
+ The Pope 's name is not P ius XII ." False foremost advisor to Pius XII during the writing 5 [' foremost', ' advisor', ' to', ' P', 'ius', ' XII']
+2381 539 Name of father of x -1 Name of father of Franz Marc Wilhelm Marc Franz Marc "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' son' ',' ' and'
+ ' the' ' painter' ""'s"" ' son' ""'s"" ' son' ',' ' and' ' the']" , the painter , and the painter 's son , and the painter 's son 's son , and the False Max Pechstein, Franz Marc and Erich Waske were 6 [' Max', ' Pe', 'ch', 'stein', ',', ' Franz', ' Marc']
+2382 540 Name of father of x -1 Name of father of Allen Ginsberg Louis Ginsberg Allen Ginsberg "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ""'s"" ' mother' ',' ' the']" , the poet , and his wife , the poet ess , and the poet ess 's mother , the False ideological fallacies. Allen Ginsberg states that, in 6 [' ideological', ' fall', 'acies', '.', ' Allen', ' Gins', 'berg']
+2383 540 Name of father of x -1 Name of father of Allen Ginsberg Louis Ginsberg Allen Ginsberg "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ""'s"" ' mother' ',' ' the']" , the poet , and his wife , the poet ess , and the poet ess 's mother , the False Burroughs, Jack Kerouac and Allen Ginsberg wrote about 11 [' Bur', 'rough', 's', ',', ' Jack', ' Ker', 'ou', 'ac', ' and', ' Allen', ' Gins', 'berg']
+2384 540 Name of father of x -1 Name of father of Allen Ginsberg Louis Ginsberg Allen Ginsberg "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ""'s"" ' mother' ',' ' the']" , the poet , and his wife , the poet ess , and the poet ess 's mother , the False Village resident Allen Ginsberg lived on Christopher 4 [' Village', ' resident', ' Allen', ' Gins', 'berg']
+2385 540 Name of father of x -1 Name of father of Allen Ginsberg Louis Ginsberg Allen Ginsberg "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ""'s"" ' mother' ',' ' the']" , the poet , and his wife , the poet ess , and the poet ess 's mother , the False Hoffman and Allen Ginsberg in Chicago 10 (2007). 4 [' Hoffman', ' and', ' Allen', ' Gins', 'berg']
+2386 540 Name of father of x -1 Name of father of Allen Ginsberg Louis Ginsberg Allen Ginsberg "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ""'s"" ' mother' ',' ' the']" , the poet , and his wife , the poet ess , and the poet ess 's mother , the False mindedness. Of them, Allen Ginsberg and William S. Burroughs 8 [' minded', 'ness', '.', ' Of', ' them', ',', ' Allen', ' Gins', 'berg']
+2387 541 Name of father of x -1 Name of father of John Locke John Locke John Locke "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' Locke' ',' ' who' ' was'
+ ' a' ' great' ' man' ',' ' and' ' a' ' great' ' man' ',']" , the son of the late John Locke , who was a great man , and a great man , True survivors led by John Locke (Terry O'Quinn). Intercut 4 [' survivors', ' led', ' by', ' John', ' Locke']
+2388 541 Name of father of x -1 Name of father of John Locke John Locke John Locke "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' Locke' ',' ' who' ' was'
+ ' a' ' great' ' man' ',' ' and' ' a' ' great' ' man' ',']" , the son of the late John Locke , who was a great man , and a great man , True " Locke (1632 – 1704)
+" 9 [' Locke', ' (', '16', '32', ' –', ' 17', '04', ')', 'John', ' Locke']
+2389 541 Name of father of x -1 Name of father of John Locke John Locke John Locke "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' Locke' ',' ' who' ' was'
+ ' a' ' great' ' man' ',' ' and' ' a' ' great' ' man' ',']" , the son of the late John Locke , who was a great man , and a great man , True Shephard (Matthew Fox) and John Locke (Terry O 'Quinn) 9 [' She', 'ph', 'ard', ' (', 'Matthew', ' Fox', ')', ' and', ' John', ' Locke']
+2390 541 Name of father of x -1 Name of father of John Locke John Locke John Locke "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' Locke' ',' ' who' ' was'
+ ' a' ' great' ' man' ',' ' and' ' a' ' great' ' man' ',']" , the son of the late John Locke , who was a great man , and a great man , True Kate meets up with John Locke (Terry O'Quinn) 5 [' Kate', ' meets', ' up', ' with', ' John', ' Locke']
+2391 541 Name of father of x -1 Name of father of John Locke John Locke John Locke "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' Locke' ',' ' who' ' was'
+ ' a' ' great' ' man' ',' ' and' ' a' ' great' ' man' ',']" , the son of the late John Locke , who was a great man , and a great man , True missing. Jack and John Locke (Terry O 'Quinn) 5 [' missing', '.', ' Jack', ' and', ' John', ' Locke']
+2392 542 Name of father of x -1 Name of father of Henry Fuseli Johann Caspar Füssli Henry Fuseli "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Henry' ' Fu' 'sel' 'i' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Henry Fu sel i , the painter ," False Priestley and the artist Henry Fuseli — two relationships 8 [' Priest', 'ley', ' and', ' the', ' artist', ' Henry', ' Fu', 'sel', 'i']
+2393 542 Name of father of x -1 Name of father of Henry Fuseli Johann Caspar Füssli Henry Fuseli "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Henry' ' Fu' 'sel' 'i' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Henry Fu sel i , the painter ," False with Candle. Though Henry Fuseli would later also 7 [' with', ' Candle', '.', ' Though', ' Henry', ' Fu', 'sel', 'i']
+2394 542 Name of father of x -1 Name of father of Henry Fuseli Johann Caspar Füssli Henry Fuseli "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Henry' ' Fu' 'sel' 'i' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Henry Fu sel i , the painter ," False with Candle. Though Henry Fuseli would later also 7 [' with', ' Candle', '.', ' Though', ' Henry', ' Fu', 'sel', 'i']
+2395 542 Name of father of x -1 Name of father of Henry Fuseli Johann Caspar Füssli Henry Fuseli "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Henry' ' Fu' 'sel' 'i' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Henry Fu sel i , the painter ," False Candle. Though Henry Fuseli would later 6 [' Candle', '.', ' Though', ' Henry', ' Fu', 'sel', 'i']
+2396 542 Name of father of x -1 Name of father of Henry Fuseli Johann Caspar Füssli Henry Fuseli "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Henry' ' Fu' 'sel' 'i' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Henry Fu sel i , the painter ," False nevertheless, both Henry Fuseli and Mary Wollstonecraft 6 [' nevertheless', ',', ' both', ' Henry', ' Fu', 'sel', 'i']
+2397 543 Name of father of x -1 Name of father of Benjamin Netanyahu Benzion Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False interpreted as a signal Benjamin Netanyahu was sending that Israel 5 [' interpreted', ' as', ' a', ' signal', ' Benjamin', ' Netanyahu']
+2398 543 Name of father of x -1 Name of father of Benjamin Netanyahu Benzion Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False Prime Minister Benjamin Netanyahu called the 3 [' Prime', ' Minister', ' Benjamin', ' Netanyahu']
+2399 543 Name of father of x -1 Name of father of Benjamin Netanyahu Benzion Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False 2 ['Ben', 'jamin', ' Netanyahu']
+2400 543 Name of father of x -1 Name of father of Benjamin Netanyahu Benzion Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False Prime Minister Benjamin Netanyahu who described them 3 [' Prime', ' Minister', ' Benjamin', ' Netanyahu']
+2401 543 Name of father of x -1 Name of father of Benjamin Netanyahu Benzion Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False as a signal Benjamin Netanyahu was sending 4 [' as', ' a', ' signal', ' Benjamin', ' Netanyahu']
+2402 544 Name of father of x -1 Name of father of Hieronymus Bosch Anthonis van Aken Hieronymus Bosch "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' father' ','
+ ' Hier' 'onym' 'us' ' Bos' 'ch' ',' ' the' ' painter' ',' ' and']" , the painter , and the painter 's father , Hier onym us Bos ch , the painter , and False (1395 – 1441) and Hieronymus Bosch (1450 – 1516) and the 12 [' (', '13', '95', ' –', ' 14', '41', ')', ' and', ' Hier', 'onym', 'us', ' Bos', 'ch']
+2403 544 Name of father of x -1 Name of father of Hieronymus Bosch Anthonis van Aken Hieronymus Bosch "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' father' ','
+ ' Hier' 'onym' 'us' ' Bos' 'ch' ',' ' the' ' painter' ',' ' and']" , the painter , and the painter 's father , Hier onym us Bos ch , the painter , and False the 1490s Hieronymus Bosch painted at least 8 [' the', ' 14', '90', 's', ' Hier', 'onym', 'us', ' Bos', 'ch']
+2404 544 Name of father of x -1 Name of father of Hieronymus Bosch Anthonis van Aken Hieronymus Bosch "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' father' ','
+ ' Hier' 'onym' 'us' ' Bos' 'ch' ',' ' the' ' painter' ',' ' and']" , the painter , and the painter 's father , Hier onym us Bos ch , the painter , and False the 1490s Hieronymus Bosch painted at least 8 [' the', ' 14', '90', 's', ' Hier', 'onym', 'us', ' Bos', 'ch']
+2405 544 Name of father of x -1 Name of father of Hieronymus Bosch Anthonis van Aken Hieronymus Bosch "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' father' ','
+ ' Hier' 'onym' 'us' ' Bos' 'ch' ',' ' the' ' painter' ',' ' and']" , the painter , and the painter 's father , Hier onym us Bos ch , the painter , and False strange work of Hieronymus Bosch to the everyday life 7 [' strange', ' work', ' of', ' Hier', 'onym', 'us', ' Bos', 'ch']
+2406 544 Name of father of x -1 Name of father of Hieronymus Bosch Anthonis van Aken Hieronymus Bosch "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' father' ','
+ ' Hier' 'onym' 'us' ' Bos' 'ch' ',' ' the' ' painter' ',' ' and']" , the painter , and the painter 's father , Hier onym us Bos ch , the painter , and False (1395 – 1441) and Hieronymus Bosch (1450 – 1516) 12 [' (', '13', '95', ' –', ' 14', '41', ')', ' and', ' Hier', 'onym', 'us', ' Bos', 'ch']
+2407 545 Name of father of x -1 Name of father of Alfred Tennyson George Clayton Tennyson Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False one to attend. Alfred Tennyson contributed a poem 6 [' one', ' to', ' attend', '.', ' Alfred', ' Tenn', 'yson']
+2408 545 Name of father of x -1 Name of father of Alfred Tennyson George Clayton Tennyson Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False Poet Laureate Alfred Tennyson appealed to the 6 [' Po', 'et', ' Laure', 'ate', ' Alfred', ' Tenn', 'yson']
+2409 545 Name of father of x -1 Name of father of Alfred Tennyson George Clayton Tennyson Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False Poet Laureate Alfred Tennyson appealed to the 6 [' Po', 'et', ' Laure', 'ate', ' Alfred', ' Tenn', 'yson']
+2410 545 Name of father of x -1 Name of father of Alfred Tennyson George Clayton Tennyson Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False Gladstone as well as Alfred Tennyson and Francis 7 [' Glad', 'stone', ' as', ' well', ' as', ' Alfred', ' Tenn', 'yson']
+2411 545 Name of father of x -1 Name of father of Alfred Tennyson George Clayton Tennyson Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False " Gladstone as well as Alfred Tennyson and Francis Parkman.
+" 7 [' Glad', 'stone', ' as', ' well', ' as', ' Alfred', ' Tenn', 'yson']
+2412 546 Name of father of x -1 Name of father of William Makepeace Thackeray Richmond Thackeray William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' ' father'
+ ' of' ' the' ' modern' ' novel' '.' '\n' '\n' 'The' ' first' ' edition']" ", the author of Vanity Fair , and the father of the modern novel .
+
+ The first edition" False " the writer William Makepeace Thackeray noted Nelson ""upon" 7 [' the', ' writer', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2413 546 Name of father of x -1 Name of father of William Makepeace Thackeray Richmond Thackeray William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' ' father'
+ ' of' ' the' ' modern' ' novel' '.' '\n' '\n' 'The' ' first' ' edition']" ", the author of Vanity Fair , and the father of the modern novel .
+
+ The first edition" False " unpopular Albert; William Makepeace Thackeray wrote in 1845: ""Think" 8 [' unpopular', ' Albert', ';', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2414 546 Name of father of x -1 Name of father of William Makepeace Thackeray Richmond Thackeray William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' ' father'
+ ' of' ' the' ' modern' ' novel' '.' '\n' '\n' 'The' ' first' ' edition']" ", the author of Vanity Fair , and the father of the modern novel .
+
+ The first edition" False Pretender James. William Makepeace Thackeray indicates such 9 [' Pret', 'ender', ' James', '.', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2415 546 Name of father of x -1 Name of father of William Makepeace Thackeray Richmond Thackeray William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' ' father'
+ ' of' ' the' ' modern' ' novel' '.' '\n' '\n' 'The' ' first' ' edition']" ", the author of Vanity Fair , and the father of the modern novel .
+
+ The first edition" False Irving's style. William Makepeace Thackeray was the first 9 "[' Irving', ""'s"", ' style', '.', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']"
+2416 546 Name of father of x -1 Name of father of William Makepeace Thackeray Richmond Thackeray William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' ' father'
+ ' of' ' the' ' modern' ' novel' '.' '\n' '\n' 'The' ' first' ' edition']" ", the author of Vanity Fair , and the father of the modern novel .
+
+ The first edition" False Pretender James. William Makepeace Thackeray indicates such 9 [' Pret', 'ender', ' James', '.', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2417 547 Name of father of x -1 Name of father of Deborah Kerr Arthur Charles Kerr Trimmer Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False co-starring with Deborah Kerr and Diana Rigg; Give 6 [' co', '-', 'star', 'ring', ' with', ' Deborah', ' Kerr']
+2418 547 Name of father of x -1 Name of father of Deborah Kerr Arthur Charles Kerr Trimmer Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Deborah Kerr =
+" 1 [' Deborah', ' Kerr']
+2419 547 Name of father of x -1 Name of father of Deborah Kerr Arthur Charles Kerr Trimmer Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Deborah Kerr =
+" 1 [' Deborah', ' Kerr']
+2420 547 Name of father of x -1 Name of father of Deborah Kerr Arthur Charles Kerr Trimmer Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 2 ['Deb', 'orah', ' Kerr']
+2421 547 Name of father of x -1 Name of father of Deborah Kerr Arthur Charles Kerr Trimmer Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False starring opposite Deborah Kerr in the romantic 3 [' starring', ' opposite', ' Deborah', ' Kerr']
+2422 548 Name of father of x -1 Name of father of Romy Schneider Wolf Albach-Retty Romy Schneider "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who played the role of the mother of the main character in the film .
+
+" False awarded France's Romy Schneider and Jean Gabin 5 "[' awarded', ' France', ""'s"", ' Rom', 'y', ' Schneider']"
+2423 548 Name of father of x -1 Name of father of Romy Schneider Wolf Albach-Retty Romy Schneider "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who played the role of the mother of the main character in the film .
+
+" False were awarded France's Romy Schneider and Jean Gabin Prizes 6 "[' were', ' awarded', ' France', ""'s"", ' Rom', 'y', ' Schneider']"
+2424 550 Name of father of x -1 Name of father of Robert Graves Alfred Perceval Graves Robert Graves "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' I' ' am' ' a' ' man' ' of' ' the' ' people']" ", the
+
+ I am a man of the people , and I am a man of the people" False 1 ['Robert', ' Graves']
+2425 550 Name of father of x -1 Name of father of Robert Graves Alfred Perceval Graves Robert Graves "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' I' ' am' ' a' ' man' ' of' ' the' ' people']" ", the
+
+ I am a man of the people , and I am a man of the people" False residents led by poet Robert Graves campaigned 5 [' residents', ' led', ' by', ' poet', ' Robert', ' Graves']
+2426 550 Name of father of x -1 Name of father of Robert Graves Alfred Perceval Graves Robert Graves "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' I' ' am' ' a' ' man' ' of' ' the' ' people']" ", the
+
+ I am a man of the people , and I am a man of the people" False the poets Robert Graves and Ted Hughes. 3 [' the', ' poets', ' Robert', ' Graves']
+2427 550 Name of father of x -1 Name of father of Robert Graves Alfred Perceval Graves Robert Graves "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' I' ' am' ' a' ' man' ' of' ' the' ' people']" ", the
+
+ I am a man of the people , and I am a man of the people" False English classicist Robert Graves and Italian ethnobotanist 4 [' English', ' classic', 'ist', ' Robert', ' Graves']
+2428 550 Name of father of x -1 Name of father of Robert Graves Alfred Perceval Graves Robert Graves "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' I' ' am' ' a' ' man' ' of' ' the' ' people']" ", the
+
+ I am a man of the people , and I am a man of the people" False Khayyam's Rubaiyat, by Robert Graves and Shah's older 12 "[' Kh', 'ay', 'y', 'am', ""'s"", ' Rub', 'ai', 'y', 'at', ',', ' by', ' Robert', ' Graves']"
+2429 551 Name of father of x -1 Name of father of Katharine Hepburn Thomas Norval Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' daughter' ' of' ' the'
+ ' famous' ' actress' ' Kath' 'arine' ' Hep' 'burn' '.' '\n' '\n' 'The'
+ ' daughter']" ", the actress , and the daughter of the famous actress Kath arine Hep burn .
+
+ The daughter" False 1989. Skelton and Katharine Hepburn were honored with 9 [' 1989', '.', ' S', 'kel', 'ton', ' and', ' Kath', 'arine', ' Hep', 'burn']
+2430 551 Name of father of x -1 Name of father of Katharine Hepburn Thomas Norval Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' daughter' ' of' ' the'
+ ' famous' ' actress' ' Kath' 'arine' ' Hep' 'burn' '.' '\n' '\n' 'The'
+ ' daughter']" ", the actress , and the daughter of the famous actress Kath arine Hep burn .
+
+ The daughter" False " said, ""Picture Katharine Hepburn in every movie she" 7 "[' said', ',', ' ""', 'Picture', ' Kath', 'arine', ' Hep', 'burn']"
+2431 551 Name of father of x -1 Name of father of Katharine Hepburn Thomas Norval Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' daughter' ' of' ' the'
+ ' famous' ' actress' ' Kath' 'arine' ' Hep' 'burn' '.' '\n' '\n' 'The'
+ ' daughter']" ", the actress , and the daughter of the famous actress Kath arine Hep burn .
+
+ The daughter" False Christopher Strong, Katharine Hepburn plays an aviator 6 [' Christopher', ' Strong', ',', ' Kath', 'arine', ' Hep', 'burn']
+2432 551 Name of father of x -1 Name of father of Katharine Hepburn Thomas Norval Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' daughter' ' of' ' the'
+ ' famous' ' actress' ' Kath' 'arine' ' Hep' 'burn' '.' '\n' '\n' 'The'
+ ' daughter']" ", the actress , and the daughter of the famous actress Kath arine Hep burn .
+
+ The daughter" False starred opposite Katharine Hepburn in the screwball 5 [' starred', ' opposite', ' Kath', 'arine', ' Hep', 'burn']
+2433 551 Name of father of x -1 Name of father of Katharine Hepburn Thomas Norval Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' daughter' ' of' ' the'
+ ' famous' ' actress' ' Kath' 'arine' ' Hep' 'burn' '.' '\n' '\n' 'The'
+ ' daughter']" ", the actress , and the daughter of the famous actress Kath arine Hep burn .
+
+ The daughter" False Rosalind Russell, Katharine Hepburn and Jean Arthur. When 8 [' Ros', 'al', 'ind', ' Russell', ',', ' Kath', 'arine', ' Hep', 'burn']
+2434 552 Name of father of x -1 Name of father of Benjamin West John West Benjamin West "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False American 3 [' America', 'Ben', 'jamin', ' West']
+2435 552 Name of father of x -1 Name of father of Benjamin West John West Benjamin West "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False appeared in paintings by Benjamin West and Arthur William 5 [' appeared', ' in', ' paintings', ' by', ' Benjamin', ' West']
+2436 552 Name of father of x -1 Name of father of Benjamin West John West Benjamin West "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False third portrait is by Benjamin West and was painted 5 [' third', ' portrait', ' is', ' by', ' Benjamin', ' West']
+2437 552 Name of father of x -1 Name of father of Benjamin West John West Benjamin West "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False became a protégé of Benjamin West with whom he studied 8 [' became', ' a', ' prot', 'é', 'g', 'é', ' of', ' Benjamin', ' West']
+2438 552 Name of father of x -1 Name of father of Benjamin West John West Benjamin West "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False " contributed, such as Benjamin West and Henry Fuseli.
+" 5 [' contributed', ',', ' such', ' as', ' Benjamin', ' West']
+2439 553 Name of father of x -1 Name of father of Philipp Melanchthon Georg Schwarzerdt Philipp Melanchthon "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False 1545 and signed by Philipp Melanchthon and others, much 9 [' 15', '45', ' and', ' signed', ' by', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2440 553 Name of father of x -1 Name of father of Philipp Melanchthon Georg Schwarzerdt Philipp Melanchthon "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False and signed by Philipp Melanchthon and others, 7 [' and', ' signed', ' by', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2441 553 Name of father of x -1 Name of father of Philipp Melanchthon Georg Schwarzerdt Philipp Melanchthon "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False and signed by Philipp Melanchthon and others, much 7 [' and', ' signed', ' by', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2442 553 Name of father of x -1 Name of father of Philipp Melanchthon Georg Schwarzerdt Philipp Melanchthon "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False with Luther and Philipp Melanchthon arriving shortly 7 [' with', ' Luther', ' and', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2443 553 Name of father of x -1 Name of father of Philipp Melanchthon Georg Schwarzerdt Philipp Melanchthon "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False Continental reformer Philipp Melanchthon was aware that 7 [' Continental', ' reform', 'er', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2444 554 Name of father of x -1 Name of father of David Hume Alex Pazos Bellon David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' a' ' man' ' of'
+ ' great' ' learning']" , the father of modern philosophy , and the father of modern science , was a man of great learning False " period. In 1741, David Hume claimed: ""The Elegance" 7 [' period', '.', ' In', ' 17', '41', ',', ' David', ' Hume']
+2445 554 Name of father of x -1 Name of father of David Hume Alex Pazos Bellon David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' a' ' man' ' of'
+ ' great' ' learning']" , the father of modern philosophy , and the father of modern science , was a man of great learning False " David Hume ===
+" 1 [' David', ' Hume']
+2446 554 Name of father of x -1 Name of father of David Hume Alex Pazos Bellon David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' a' ' man' ' of'
+ ' great' ' learning']" , the father of modern philosophy , and the father of modern science , was a man of great learning False perspective to the issue. David Hume denied that suicide 6 [' perspective', ' to', ' the', ' issue', '.', ' David', ' Hume']
+2447 554 Name of father of x -1 Name of father of David Hume Alex Pazos Bellon David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' a' ' man' ' of'
+ ' great' ' learning']" , the father of modern philosophy , and the father of modern science , was a man of great learning False Scottish empiricist David Hume made a similar 4 [' Scottish', ' empir', 'icist', ' David', ' Hume']
+2448 554 Name of father of x -1 Name of father of David Hume Alex Pazos Bellon David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' a' ' man' ' of'
+ ' great' ' learning']" , the father of modern philosophy , and the father of modern science , was a man of great learning False era in Scotland was David Hume (1711 – 76) whose Treatise 5 [' era', ' in', ' Scotland', ' was', ' David', ' Hume']
+2449 555 Name of father of x -1 Name of father of Stanisław Lem Samuel Lem Stanisław Lem "[',' ' the' ' Polish' ' writer' ',' ' who' ' was' ' born' ' in' ' Warsaw'
+ ' in' ' 1920' '.' '\n' '\n' 'The' ' author' ' of' ' the' ' book']" ", the Polish writer , who was born in Warsaw in 1920 .
+
+ The author of the book" False 4 ['Stan', 'is', 'ł', 'aw', ' Lem']
+2450 555 Name of father of x -1 Name of father of Stanisław Lem Samuel Lem Stanisław Lem "[',' ' the' ' Polish' ' writer' ',' ' who' ' was' ' born' ' in' ' Warsaw'
+ ' in' ' 1920' '.' '\n' '\n' 'The' ' author' ' of' ' the' ' book']" ", the Polish writer , who was born in Warsaw in 1920 .
+
+ The author of the book" False " Stanisław Lem =
+" 4 [' Stan', 'is', 'ł', 'aw', ' Lem']
+2451 555 Name of father of x -1 Name of father of Stanisław Lem Samuel Lem Stanisław Lem "[',' ' the' ' Polish' ' writer' ',' ' who' ' was' ' born' ' in' ' Warsaw'
+ ' in' ' 1920' '.' '\n' '\n' 'The' ' author' ' of' ' the' ' book']" ", the Polish writer , who was born in Warsaw in 1920 .
+
+ The author of the book" False 4 ['Stan', 'is', 'ł', 'aw', ' Lem']
+2452 555 Name of father of x -1 Name of father of Stanisław Lem Samuel Lem Stanisław Lem "[',' ' the' ' Polish' ' writer' ',' ' who' ' was' ' born' ' in' ' Warsaw'
+ ' in' ' 1920' '.' '\n' '\n' 'The' ' author' ' of' ' the' ' book']" ", the Polish writer , who was born in Warsaw in 1920 .
+
+ The author of the book" False 4 ['Stan', 'is', 'ł', 'aw', ' Lem']
+2453 555 Name of father of x -1 Name of father of Stanisław Lem Samuel Lem Stanisław Lem "[',' ' the' ' Polish' ' writer' ',' ' who' ' was' ' born' ' in' ' Warsaw'
+ ' in' ' 1920' '.' '\n' '\n' 'The' ' author' ' of' ' the' ' book']" ", the Polish writer , who was born in Warsaw in 1920 .
+
+ The author of the book" False " Stanisław Lem =
+" 4 [' Stan', 'is', 'ł', 'aw', ' Lem']
+2454 556 Name of father of x -1 Name of father of Al Gore Albert Arnold Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' T' 'ipper' ',' ' were' ' in'
+ ' the']" , the former vice president of the United States , and his wife , T ipper , were in the False Vice President Al Gore to attend 3 [' Vice', ' President', ' Al', ' Gore']
+2455 556 Name of father of x -1 Name of father of Al Gore Albert Arnold Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' T' 'ipper' ',' ' were' ' in'
+ ' the']" , the former vice president of the United States , and his wife , T ipper , were in the False presidential nominee Al Gore was going to 3 [' presidential', ' nominee', ' Al', ' Gore']
+2456 556 Name of father of x -1 Name of father of Al Gore Albert Arnold Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' T' 'ipper' ',' ' were' ' in'
+ ' the']" , the former vice president of the United States , and his wife , T ipper , were in the False Vice President Al Gore was hosted at the 3 [' Vice', ' President', ' Al', ' Gore']
+2457 556 Name of father of x -1 Name of father of Al Gore Albert Arnold Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' T' 'ipper' ',' ' were' ' in'
+ ' the']" , the former vice president of the United States , and his wife , T ipper , were in the False US vice president Al Gore said Baird's plan 4 [' US', ' vice', ' president', ' Al', ' Gore']
+2458 556 Name of father of x -1 Name of father of Al Gore Albert Arnold Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' his' ' wife' ',' ' T' 'ipper' ',' ' were' ' in'
+ ' the']" , the former vice president of the United States , and his wife , T ipper , were in the False vice president Al Gore said Baird's plan 3 [' vice', ' president', ' Al', ' Gore']
+2459 557 Name of father of x -1 Name of father of Martin Scorsese Charles Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False including Martin Scorsese, Steven Spielberg, 4 [' including', ' Martin', ' Sc', 'ors', 'ese']
+2460 557 Name of father of x -1 Name of father of Martin Scorsese Charles Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False Last Waltz, the Martin Scorsese documentary film 8 [' Last', ' Walt', 'z', ',', ' the', ' Martin', ' Sc', 'ors', 'ese']
+2461 557 Name of father of x -1 Name of father of Martin Scorsese Charles Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False Coppola rather than Martin Scorsese (though Robert 8 [' Co', 'pp', 'ola', ' rather', ' than', ' Martin', ' Sc', 'ors', 'ese']
+2462 557 Name of father of x -1 Name of father of Martin Scorsese Charles Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False " Beatle"". Filmmaker Martin Scorsese has written of the" 9 "[' Beat', 'le', '"".', ' Fil', 'mm', 'aker', ' Martin', ' Sc', 'ors', 'ese']"
+2463 557 Name of father of x -1 Name of father of Martin Scorsese Charles Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False (re-released on Martin Scorsese Presents the 8 [' (', 're', '-', 'released', ' on', ' Martin', ' Sc', 'ors', 'ese']
+2464 558 Name of father of x -1 Name of father of Gotthold Ephraim Lessing Johann Gottfried Lessing Gotthold Ephraim Lessing "[',' ' the' ' German' ' philosopher' ',' ' was' ' born' ' in' ' 17' '29'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the German philosopher , was born in 17 29 .
+
+ The name of the father of the" False discovered by Gotthold Ephraim Lessing in a Greek manuscript 10 [' discovered', ' by', ' Got', 'th', 'old', ' E', 'ph', 'ra', 'im', ' Less', 'ing']
+2465 558 Name of father of x -1 Name of father of Gotthold Ephraim Lessing Johann Gottfried Lessing Gotthold Ephraim Lessing "[',' ' the' ' German' ' philosopher' ',' ' was' ' born' ' in' ' 17' '29'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the German philosopher , was born in 17 29 .
+
+ The name of the father of the" False Friedrich Schiller, Gotthold Ephraim Lessing and Theodor Fontane. 12 [' Friedrich', ' Sch', 'iller', ',', ' Got', 'th', 'old', ' E', 'ph', 'ra', 'im', ' Less', 'ing']
+2466 558 Name of father of x -1 Name of father of Gotthold Ephraim Lessing Johann Gottfried Lessing Gotthold Ephraim Lessing "[',' ' the' ' German' ' philosopher' ',' ' was' ' born' ' in' ' 17' '29'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the German philosopher , was born in 17 29 .
+
+ The name of the father of the" False work was discovered by Gotthold Ephraim Lessing in a Greek manuscript 12 [' work', ' was', ' discovered', ' by', ' Got', 'th', 'old', ' E', 'ph', 'ra', 'im', ' Less', 'ing']
+2467 558 Name of father of x -1 Name of father of Gotthold Ephraim Lessing Johann Gottfried Lessing Gotthold Ephraim Lessing "[',' ' the' ' German' ' philosopher' ',' ' was' ' born' ' in' ' 17' '29'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the German philosopher , was born in 17 29 .
+
+ The name of the father of the" False discovered by Gotthold Ephraim Lessing in a Greek manuscript 10 [' discovered', ' by', ' Got', 'th', 'old', ' E', 'ph', 'ra', 'im', ' Less', 'ing']
+2468 558 Name of father of x -1 Name of father of Gotthold Ephraim Lessing Johann Gottfried Lessing Gotthold Ephraim Lessing "[',' ' the' ' German' ' philosopher' ',' ' was' ' born' ' in' ' 17' '29'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the German philosopher , was born in 17 29 .
+
+ The name of the father of the" False was discovered by Gotthold Ephraim Lessing in a Greek manuscript 11 [' was', ' discovered', ' by', ' Got', 'th', 'old', ' E', 'ph', 'ra', 'im', ' Less', 'ing']
+2469 559 Name of father of x -1 Name of father of Giacomo Leopardi Monaldo Leopardi Giacomo Leopardi "[',' ' the' ' father' ' of' ' the' ' modern' ' Italian' ' language' '.'
+ '\n' '\n' 'The' ' first' ' edition' ' of' ' the' ' book' ' was'
+ ' published' ' in']" ", the father of the modern Italian language .
+
+ The first edition of the book was published in" False Rabindranath Tagore, Giacomo Leopardi and pursued academic 11 [' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore', ',', ' Gi', 'ac', 'omo', ' Leopard', 'i']
+2470 559 Name of father of x -1 Name of father of Giacomo Leopardi Monaldo Leopardi Giacomo Leopardi "[',' ' the' ' father' ' of' ' the' ' modern' ' Italian' ' language' '.'
+ '\n' '\n' 'The' ' first' ' edition' ' of' ' the' ' book' ' was'
+ ' published' ' in']" ", the father of the modern Italian language .
+
+ The first edition of the book was published in" False Rabindranath Tagore, Giacomo Leopardi and pursued academic 11 [' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore', ',', ' Gi', 'ac', 'omo', ' Leopard', 'i']
+2471 559 Name of father of x -1 Name of father of Giacomo Leopardi Monaldo Leopardi Giacomo Leopardi "[',' ' the' ' father' ' of' ' the' ' modern' ' Italian' ' language' '.'
+ '\n' '\n' 'The' ' first' ' edition' ' of' ' the' ' book' ' was'
+ ' published' ' in']" ", the father of the modern Italian language .
+
+ The first edition of the book was published in" False – 28 when the poet Giacomo Leopardi had visited and met 9 [' –', ' 28', ' when', ' the', ' poet', ' Gi', 'ac', 'omo', ' Leopard', 'i']
+2472 560 Name of father of x -1 Name of father of John Cale William Arthur George Cale John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was']" "o , the father of the bride , and the father of the groom .
+
+ The wedding was" False " contributions from John Cale on two songs: ""Northern" 4 [' contributions', ' from', ' John', ' C', 'ale']
+2473 560 Name of father of x -1 Name of father of John Cale William Arthur George Cale John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was']" "o , the father of the bride , and the father of the groom .
+
+ The wedding was" False Wade, Baha Men, and John Cale (covering Leonard 9 [' Wade', ',', ' B', 'aha', ' Men', ',', ' and', ' John', ' C', 'ale']
+2474 560 Name of father of x -1 Name of father of John Cale William Arthur George Cale John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was']" "o , the father of the bride , and the father of the groom .
+
+ The wedding was" False contributions from John Cale on two songs: 4 [' contributions', ' from', ' John', ' C', 'ale']
+2475 560 Name of father of x -1 Name of father of John Cale William Arthur George Cale John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was']" "o , the father of the bride , and the father of the groom .
+
+ The wedding was" False " studio, musician John Cale reported, ""Morrison" 5 [' studio', ',', ' musician', ' John', ' C', 'ale']
+2476 560 Name of father of x -1 Name of father of John Cale William Arthur George Cale John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the'
+ ' father' ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was']" "o , the father of the bride , and the father of the groom .
+
+ The wedding was" False produced by John Cale for the major 4 [' produced', ' by', ' John', ' C', 'ale']
+2477 561 Name of father of x -1 Name of father of Elon Musk Errol Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False CEO and inventor Elon Musk complimented West 4 [' CEO', ' and', ' inventor', ' Elon', ' Musk']
+2478 561 Name of father of x -1 Name of father of Elon Musk Errol Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False Tesla Motors CEO Elon Musk and Oracle Corporation 4 [' Tesla', ' Motors', ' CEO', ' Elon', ' Musk']
+2479 561 Name of father of x -1 Name of father of Elon Musk Errol Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False 2012, CEO Elon Musk announced SpaceX's 4 [' 2012', ',', ' CEO', ' Elon', ' Musk']
+2480 561 Name of father of x -1 Name of father of Elon Musk Errol Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False thrust fluctuations. Elon Musk reported that this 4 [' thrust', ' fluctuations', '.', ' Elon', ' Musk']
+2481 561 Name of father of x -1 Name of father of Elon Musk Errol Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False counter gravity. Elon Musk indicated this 4 [' counter', ' gravity', '.', ' Elon', ' Musk']
+2482 562 Name of father of x -1 Name of father of Brad Pitt William Alvin Pitt Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False starred opposite Brad Pitt as a bored married 3 [' starred', ' opposite', ' Brad', ' Pitt']
+2483 562 Name of father of x -1 Name of father of Brad Pitt William Alvin Pitt Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False replica of Brad Pitt at Madame Tussauds 3 [' replica', ' of', ' Brad', ' Pitt']
+2484 562 Name of father of x -1 Name of father of Brad Pitt William Alvin Pitt Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False in a replica of Brad Pitt at Madame Tussauds 5 [' in', ' a', ' replica', ' of', ' Brad', ' Pitt']
+2485 562 Name of father of x -1 Name of father of Brad Pitt William Alvin Pitt Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False Fassbender, Cameron Diaz, Brad Pitt and husband Javier 8 [' F', 'ass', 'bender', ',', ' Cameron', ' Diaz', ',', ' Brad', ' Pitt']
+2486 562 Name of father of x -1 Name of father of Brad Pitt William Alvin Pitt Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False until Inception. Both Brad Pitt and Will Smith 6 [' until', ' In', 'ception', '.', ' Both', ' Brad', ' Pitt']
+2487 563 Name of father of x -1 Name of father of Fridtjof Nansen Baldur Fridtjof Nansen Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False Norwegian explorer Fridtjof Nansen aboard the Fram, 8 [' Norwegian', ' explorer', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2488 563 Name of father of x -1 Name of father of Fridtjof Nansen Baldur Fridtjof Nansen Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False Norwegian explorer Fridtjof Nansen aboard the Fram, which 8 [' Norwegian', ' explorer', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2489 563 Name of father of x -1 Name of father of Fridtjof Nansen Baldur Fridtjof Nansen Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False poet Valery Bryusov, Fridtjof Nansen (1925), English 13 [' poet', ' Val', 'ery', ' Bry', 'us', 'ov', ',', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2490 563 Name of father of x -1 Name of father of Fridtjof Nansen Baldur Fridtjof Nansen Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False are missing. Also, Fridtjof Nansen and his companion 11 [' are', ' missing', '.', ' Also', ',', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2491 563 Name of father of x -1 Name of father of Fridtjof Nansen Baldur Fridtjof Nansen Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False Norwegian explorer Fridtjof Nansen to reach the 8 [' Norwegian', ' explorer', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2492 564 Name of father of x -1 Name of father of Johnny Depp John Christopher Depp Johnny Depp "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False X. In June 1995, Johnny Depp was cast into 8 [' X', '.', ' In', ' June', ' 1995', ',', ' Johnny', ' De', 'pp']
+2493 564 Name of father of x -1 Name of father of Johnny Depp John Christopher Depp Johnny Depp "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False relationship with Johnny Depp in the early 1990s 4 [' relationship', ' with', ' Johnny', ' De', 'pp']
+2494 564 Name of father of x -1 Name of father of Johnny Depp John Christopher Depp Johnny Depp "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False performances from Johnny Depp and Helena Bonham Carter. 4 [' performances', ' from', ' Johnny', ' De', 'pp']
+2495 564 Name of father of x -1 Name of father of Johnny Depp John Christopher Depp Johnny Depp "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False " ""Jake"" Grimm: Johnny Depp was Gilliam's first" 7 "[' ""', 'Jake', '""', ' Grimm', ':', ' Johnny', ' De', 'pp']"
+2496 564 Name of father of x -1 Name of father of Johnny Depp John Christopher Depp Johnny Depp "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False In June 1995, Johnny Depp was cast into the 6 [' In', ' June', ' 1995', ',', ' Johnny', ' De', 'pp']
+2497 565 Name of father of x -1 Name of father of Max Weber Max Weber Max Weber "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' that' ' strikes' ' you' ' about' ' the' ' book' ' is']" ", the father of the modern world .
+
+ The first thing that strikes you about the book is" False cultural superstructure). Max Weber critiqued Marxist 5 [' cultural', ' superst', 'ructure', ').', ' Max', ' Weber']
+2498 565 Name of father of x -1 Name of father of Max Weber Max Weber Max Weber "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' that' ' strikes' ' you' ' about' ' the' ' book' ' is']" ", the father of the modern world .
+
+ The first thing that strikes you about the book is" False Habermas (7), Max Weber (8), and Bruno 7 [' Hab', 'er', 'mas', ' (', '7', '),', ' Max', ' Weber']
+2499 565 Name of father of x -1 Name of father of Max Weber Max Weber Max Weber "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' that' ' strikes' ' you' ' about' ' the' ' book' ' is']" ", the father of the modern world .
+
+ The first thing that strikes you about the book is" False of his home. Max Weber contracted the 5 [' of', ' his', ' home', '.', ' Max', ' Weber']
+2500 565 Name of father of x -1 Name of father of Max Weber Max Weber Max Weber "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' that' ' strikes' ' you' ' about' ' the' ' book' ' is']" ", the father of the modern world .
+
+ The first thing that strikes you about the book is" False Habermas (7), Max Weber (8), and Bruno 7 [' Hab', 'er', 'mas', ' (', '7', '),', ' Max', ' Weber']
+2501 565 Name of father of x -1 Name of father of Max Weber Max Weber Max Weber "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' that' ' strikes' ' you' ' about' ' the' ' book' ' is']" ", the father of the modern world .
+
+ The first thing that strikes you about the book is" False " = Max Weber =
+" 2 [' =', ' Max', ' Weber']
+2502 567 Name of father of x -1 Name of father of Socrates Sophroniscus Socrates "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Socrates' ','
+ ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Socrates' ',']" ", and the
+
+ Name of mother of Socrates , and the name of the father of Socrates ," False 1 ['S', 'ocrates']
+2503 567 Name of father of x -1 Name of father of Socrates Sophroniscus Socrates "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Socrates' ','
+ ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Socrates' ',']" ", and the
+
+ Name of mother of Socrates , and the name of the father of Socrates ," False timid lawyer, Socrates Poole (Christian 3 [' timid', ' lawyer', ',', ' Socrates']
+2504 567 Name of father of x -1 Name of father of Socrates Sophroniscus Socrates "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Socrates' ','
+ ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Socrates' ',']" ", and the
+
+ Name of mother of Socrates , and the name of the father of Socrates ," False 1 ['S', 'ocrates']
+2505 567 Name of father of x -1 Name of father of Socrates Sophroniscus Socrates "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Socrates' ','
+ ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Socrates' ',']" ", and the
+
+ Name of mother of Socrates , and the name of the father of Socrates ," False Pericles, and quotes Socrates as claiming 5 [' Per', 'icles', ',', ' and', ' quotes', ' Socrates']
+2506 567 Name of father of x -1 Name of father of Socrates Sophroniscus Socrates "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Socrates' ','
+ ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' Socrates' ',']" ", and the
+
+ Name of mother of Socrates , and the name of the father of Socrates ," False dialogue between Socrates and Euthyphro. 2 [' dialogue', ' between', ' Socrates']
+2507 568 Name of father of x -1 Name of father of M. C. Escher George Arnold Escher M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of father of M . C . Esc her , the" False graphic artist M. C. Escher and authors Raymond 7 [' graphic', ' artist', ' M', '.', ' C', '.', ' Esc', 'her']
+2508 568 Name of father of x -1 Name of father of M. C. Escher George Arnold Escher M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of father of M . C . Esc her , the" False mathematics, such as M. C. Escher (inspired by 9 [' mathematics', ',', ' such', ' as', ' M', '.', ' C', '.', ' Esc', 'her']
+2509 568 Name of father of x -1 Name of father of M. C. Escher George Arnold Escher M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of father of M . C . Esc her , the" False century, the work of M. C. Escher often made 10 [' century', ',', ' the', ' work', ' of', ' M', '.', ' C', '.', ' Esc', 'her']
+2510 568 Name of father of x -1 Name of father of M. C. Escher George Arnold Escher M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of father of M . C . Esc her , the" False century, the work of M. C. Escher often made use 10 [' century', ',', ' the', ' work', ' of', ' M', '.', ' C', '.', ' Esc', 'her']
+2511 568 Name of father of x -1 Name of father of M. C. Escher George Arnold Escher M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of father of M . C . Esc her , the" False graphic artist M. C. Escher and authors Raymond 7 [' graphic', ' artist', ' M', '.', ' C', '.', ' Esc', 'her']
+2512 569 Name of father of x -1 Name of father of Richard Nixon Francis A. Nixon Richard Nixon "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' member' ' of' ' the'
+ ' Communist' ' Party' ',' ' and' ' who' ' had' ' been' ' a' ' member'
+ ' of']" , the man who had been a member of the Communist Party , and who had been a member of False American president Richard Nixon and Italian dictator 3 [' American', ' president', ' Richard', ' Nixon']
+2513 569 Name of father of x -1 Name of father of Richard Nixon Francis A. Nixon Richard Nixon "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' member' ' of' ' the'
+ ' Communist' ' Party' ',' ' and' ' who' ' had' ' been' ' a' ' member'
+ ' of']" , the man who had been a member of the Communist Party , and who had been a member of False In 1969, President Richard Nixon appointed Koubek 5 [' In', ' 1969', ',', ' President', ' Richard', ' Nixon']
+2514 569 Name of father of x -1 Name of father of Richard Nixon Francis A. Nixon Richard Nixon "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' member' ' of' ' the'
+ ' Communist' ' Party' ',' ' and' ' who' ' had' ' been' ' a' ' member'
+ ' of']" , the man who had been a member of the Communist Party , and who had been a member of False 1973, President Richard Nixon signed legislation 4 [' 1973', ',', ' President', ' Richard', ' Nixon']
+2515 569 Name of father of x -1 Name of father of Richard Nixon Francis A. Nixon Richard Nixon "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' member' ' of' ' the'
+ ' Communist' ' Party' ',' ' and' ' who' ' had' ' been' ' a' ' member'
+ ' of']" , the man who had been a member of the Communist Party , and who had been a member of False was carried by Richard Nixon and the Republican 4 [' was', ' carried', ' by', ' Richard', ' Nixon']
+2516 569 Name of father of x -1 Name of father of Richard Nixon Francis A. Nixon Richard Nixon "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' member' ' of' ' the'
+ ' Communist' ' Party' ',' ' and' ' who' ' had' ' been' ' a' ' member'
+ ' of']" , the man who had been a member of the Communist Party , and who had been a member of False publicly supported Richard Nixon during the 1960 presidential 3 [' publicly', ' supported', ' Richard', ' Nixon']
+2517 570 Name of father of x -1 Name of father of Seamus Heaney Patrick Heaney Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Ma' 'ire' 'ad' ','
+ ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Ma ire ad , who was a poet ess .
+
+" False from Blake to Seamus Heaney he takes words back 6 [' from', ' Blake', ' to', ' Se', 'amus', ' He', 'aney']
+2518 570 Name of father of x -1 Name of father of Seamus Heaney Patrick Heaney Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Ma' 'ire' 'ad' ','
+ ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Ma ire ad , who was a poet ess .
+
+" False New York. When Seamus Heaney gave an Oxford 7 [' New', ' York', '.', ' When', ' Se', 'amus', ' He', 'aney']
+2519 570 Name of father of x -1 Name of father of Seamus Heaney Patrick Heaney Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Ma' 'ire' 'ad' ','
+ ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Ma ire ad , who was a poet ess .
+
+" False laureate poet Seamus Heaney in late 2010, 5 [' laureate', ' poet', ' Se', 'amus', ' He', 'aney']
+2520 570 Name of father of x -1 Name of father of Seamus Heaney Patrick Heaney Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Ma' 'ire' 'ad' ','
+ ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Ma ire ad , who was a poet ess .
+
+" False laureate poet Seamus Heaney in late 2010, but 5 [' laureate', ' poet', ' Se', 'amus', ' He', 'aney']
+2521 570 Name of father of x -1 Name of father of Seamus Heaney Patrick Heaney Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Ma' 'ire' 'ad' ','
+ ' who' ' was' ' a' ' poet' 'ess' '.' '\n' '\n']" ", the poet , and his wife , Ma ire ad , who was a poet ess .
+
+" False from Blake to Seamus Heaney he takes words 6 [' from', ' Blake', ' to', ' Se', 'amus', ' He', 'aney']
+2522 571 Name of father of x -1 Name of father of Justin Timberlake Randall Timberlake Justin Timberlake "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' member']" , who was a member of the band , and the band 's bass ist , who was a member False the chemistry between Justin Timberlake and Mila Kunis 5 [' the', ' chemistry', ' between', ' Justin', ' Timber', 'lake']
+2523 571 Name of father of x -1 Name of father of Justin Timberlake Randall Timberlake Justin Timberlake "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' member']" , who was a member of the band , and the band 's bass ist , who was a member False " pop singer Justin Timberlake ""could not go back" 4 [' pop', ' singer', ' Justin', ' Timber', 'lake']
+2524 571 Name of father of x -1 Name of father of Justin Timberlake Randall Timberlake Justin Timberlake "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' member']" , who was a member of the band , and the band 's bass ist , who was a member False Ciara's steamy romp with Justin Timberlake in their new clip for' 10 "[' Ci', 'ara', ""'s"", ' steam', 'y', ' r', 'omp', ' with', ' Justin', ' Timber', 'lake']"
+2525 571 Name of father of x -1 Name of father of Justin Timberlake Randall Timberlake Justin Timberlake "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' member']" , who was a member of the band , and the band 's bass ist , who was a member False New York, where Justin Timberlake and Beyoncé Knowles 6 [' New', ' York', ',', ' where', ' Justin', ' Timber', 'lake']
+2526 571 Name of father of x -1 Name of father of Justin Timberlake Randall Timberlake Justin Timberlake "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' member']" , who was a member of the band , and the band 's bass ist , who was a member False Together he and Justin Timberlake were confirmed to have 5 [' Together', ' he', ' and', ' Justin', ' Timber', 'lake']
+2527 572 Name of father of x -1 Name of father of Pius II Silvio Piccolomini Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' son' ',' ' the' '\n'
+ '\n' 'Pope' ""'s"" ' son' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's son , the
+
+ Pope 's son , the Pope 's" False town which Pope Pius II lamentably termed 5 [' town', ' which', ' Pope', ' P', 'ius', ' II']
+2528 572 Name of father of x -1 Name of father of Pius II Silvio Piccolomini Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' son' ',' ' the' '\n'
+ '\n' 'Pope' ""'s"" ' son' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's son , the
+
+ Pope 's son , the Pope 's" False engineers sent by Pope Pius II in 1463, at 6 [' engineers', ' sent', ' by', ' Pope', ' P', 'ius', ' II']
+2529 572 Name of father of x -1 Name of father of Pius II Silvio Piccolomini Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' son' ',' ' the' '\n'
+ '\n' 'Pope' ""'s"" ' son' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's son , the
+
+ Pope 's son , the Pope 's" False planned by Pope Pius II with Skanderbeg 5 [' planned', ' by', ' Pope', ' P', 'ius', ' II']
+2530 572 Name of father of x -1 Name of father of Pius II Silvio Piccolomini Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' son' ',' ' the' '\n'
+ '\n' 'Pope' ""'s"" ' son' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's son , the
+
+ Pope 's son , the Pope 's" False banned by Pope Pius II in a conflict over 5 [' banned', ' by', ' Pope', ' P', 'ius', ' II']
+2531 572 Name of father of x -1 Name of father of Pius II Silvio Piccolomini Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' son' ',' ' the' '\n'
+ '\n' 'Pope' ""'s"" ' son' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's son , the
+
+ Pope 's son , the Pope 's" False sent by Pope Pius II in 1463, at 5 [' sent', ' by', ' Pope', ' P', 'ius', ' II']
+2532 573 Name of father of x -1 Name of father of William Faulkner Murry Faulkner William Faulkner "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Sound'
+ ' and' ' the' ' Fury' '_' ',' ' and' ' the' ' author' ' of' ' _']" , the author of the famous novel _ The Sound and the Fury _ , and the author of _ False adaptation of two William Faulkner novels based 6 [' adaptation', ' of', ' two', ' William', ' Faul', 'k', 'ner']
+2533 573 Name of father of x -1 Name of father of William Faulkner Murry Faulkner William Faulkner "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Sound'
+ ' and' ' the' ' Fury' '_' ',' ' and' ' the' ' author' ' of' ' _']" , the author of the famous novel _ The Sound and the Fury _ , and the author of _ False Ernest Hemingway and William Faulkner in the development 8 [' Ernest', ' Hem', 'ing', 'way', ' and', ' William', ' Faul', 'k', 'ner']
+2534 573 Name of father of x -1 Name of father of William Faulkner Murry Faulkner William Faulkner "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Sound'
+ ' and' ' the' ' Fury' '_' ',' ' and' ' the' ' author' ' of' ' _']" , the author of the famous novel _ The Sound and the Fury _ , and the author of _ False " Faulkner ===
+" 7 [' Faul', 'k', 'ner', ' ===', 'William', ' Faul', 'k', 'ner']
+2535 573 Name of father of x -1 Name of father of William Faulkner Murry Faulkner William Faulkner "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Sound'
+ ' and' ' the' ' Fury' '_' ',' ' and' ' the' ' author' ' of' ' _']" , the author of the famous novel _ The Sound and the Fury _ , and the author of _ False Ernest Hemingway and William Faulkner in the development 8 [' Ernest', ' Hem', 'ing', 'way', ' and', ' William', ' Faul', 'k', 'ner']
+2536 573 Name of father of x -1 Name of father of William Faulkner Murry Faulkner William Faulkner "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' 'The' ' Sound'
+ ' and' ' the' ' Fury' '_' ',' ' and' ' the' ' author' ' of' ' _']" , the author of the famous novel _ The Sound and the Fury _ , and the author of _ False based on the William Faulkner novel Sanctuary. 6 [' based', ' on', ' the', ' William', ' Faul', 'k', 'ner']
+2537 574 Name of father of x -1 Name of father of Asia Argento Dario Argento Asia Argento "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Italian' ' parliament'
+ ',' ' and' ' a' ' member' ' of' ' the' ' Italian' ' Communist' ' Party'
+ '.' '\n']" ", who was a member of the Italian parliament , and a member of the Italian Communist Party .
+" False and over was cited by Asia Argento (Nicolodi's 7 [' and', ' over', ' was', ' cited', ' by', ' Asia', ' Argent', 'o']
+2538 574 Name of father of x -1 Name of father of Asia Argento Dario Argento Asia Argento "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Italian' ' parliament'
+ ',' ' and' ' a' ' member' ' of' ' the' ' Italian' ' Communist' ' Party'
+ '.' '\n']" ", who was a member of the Italian parliament , and a member of the Italian Communist Party .
+" False was cited by Asia Argento (Nicolodi's 5 [' was', ' cited', ' by', ' Asia', ' Argent', 'o']
+2539 577 Name of father of x -1 Name of father of Valéry Giscard d'Estaing Edmond Giscard d'Estaing Valéry Giscard d'Estaing "[',' ' President' ' of' ' the' ' French' ' Republic' '\n' '\n' 'The'
+ ' French' ' Republic' '\n' '\n' 'The' ' French' ' Republic' ' is' ' a'
+ ' constitutional' ' republic']" ", President of the French Republic
+
+ The French Republic
+
+ The French Republic is a constitutional republic" False President Valéry Giscard d'Estaing was at first reluctant 11 "[' President', ' Val', 'é', 'ry', ' G', 'isc', 'ard', ' d', ""'"", 'Est', 'a', 'ing']"
+2540 578 Name of father of x -1 Name of father of Jean de La Fontaine Charles de la fontaine Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False classical antiquity. Jean de La Fontaine began his collection 7 [' classical', ' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2541 578 Name of father of x -1 Name of father of Jean de La Fontaine Charles de la fontaine Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False antiquity. Jean de La Fontaine began his collection 6 [' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2542 578 Name of father of x -1 Name of father of Jean de La Fontaine Charles de la fontaine Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False classical antiquity. Jean de La Fontaine began his collection 7 [' classical', ' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2543 578 Name of father of x -1 Name of father of Jean de La Fontaine Charles de la fontaine Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False classical antiquity. Jean de La Fontaine began his collection 7 [' classical', ' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2544 579 Name of father of x -1 Name of father of Jacques Offenbach Isaac Offenbach Jacques Offenbach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'T' 'ales' ' of'
+ ' Hoff' 'mann' '""' ' and' ' ""' 'The' ' Tales' ' of' ' Hoff' 'mann']" ", the composer of the famous "" T ales of Hoff mann "" and "" The Tales of Hoff mann" False Works by or about Jacques Offenbach at Internet 7 [' Works', ' by', ' or', ' about', ' Jacques', ' Off', 'en', 'bach']
+2545 579 Name of father of x -1 Name of father of Jacques Offenbach Isaac Offenbach Jacques Offenbach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'T' 'ales' ' of'
+ ' Hoff' 'mann' '""' ' and' ' ""' 'The' ' Tales' ' of' ' Hoff' 'mann']" ", the composer of the famous "" T ales of Hoff mann "" and "" The Tales of Hoff mann" False Free scores by Jacques Offenbach in the Choral 6 [' Free', ' scores', ' by', ' Jacques', ' Off', 'en', 'bach']
+2546 579 Name of father of x -1 Name of father of Jacques Offenbach Isaac Offenbach Jacques Offenbach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'T' 'ales' ' of'
+ ' Hoff' 'mann' '""' ' and' ' ""' 'The' ' Tales' ' of' ' Hoff' 'mann']" ", the composer of the famous "" T ales of Hoff mann "" and "" The Tales of Hoff mann" False opera competition which Jacques Offenbach had organised 6 [' opera', ' competition', ' which', ' Jacques', ' Off', 'en', 'bach']
+2547 579 Name of father of x -1 Name of father of Jacques Offenbach Isaac Offenbach Jacques Offenbach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'T' 'ales' ' of'
+ ' Hoff' 'mann' '""' ' and' ' ""' 'The' ' Tales' ' of' ' Hoff' 'mann']" ", the composer of the famous "" T ales of Hoff mann "" and "" The Tales of Hoff mann" False " Offenbach =
+" 8 [' Off', 'en', 'bach', ' =', 'Jac', 'ques', ' Off', 'en', 'bach']
+2548 579 Name of father of x -1 Name of father of Jacques Offenbach Isaac Offenbach Jacques Offenbach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'T' 'ales' ' of'
+ ' Hoff' 'mann' '""' ' and' ' ""' 'The' ' Tales' ' of' ' Hoff' 'mann']" ", the composer of the famous "" T ales of Hoff mann "" and "" The Tales of Hoff mann" False " Offenbach =
+" 8 [' Off', 'en', 'bach', ' =', 'Jac', 'ques', ' Off', 'en', 'bach']
+2549 580 Name of father of x -1 Name of father of Apollo Zeus Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False control teams for Apollo 8 consisted of 3 [' control', ' teams', ' for', ' Apollo']
+2550 580 Name of father of x -1 Name of father of Apollo Zeus Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False flight crew during Apollo 15. Each track is 3 [' flight', ' crew', ' during', ' Apollo']
+2551 580 Name of father of x -1 Name of father of Apollo Zeus Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False 1 ['Ap', 'ollo']
+2552 580 Name of father of x -1 Name of father of Apollo Zeus Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False was in 1970, after Apollo 13, where 5 [' was', ' in', ' 1970', ',', ' after', ' Apollo']
+2553 580 Name of father of x -1 Name of father of Apollo Zeus Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False an exhibit on the Apollo Theater) on July 10, 4 [' an', ' exhibit', ' on', ' the', ' Apollo']
+2554 581 Name of father of x -1 Name of father of John Stuart Mill James Rodríguez John Stuart Mill "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' thing' ' that'
+ ' strikes']" ", the great philosopher , and the father of the modern world .
+
+ The first thing that strikes" False " situation"". He demeaned John Stuart Mill as a ""wooly man of" 8 "[' situation', '"".', ' He', ' dem', 'ean', 'ed', ' John', ' Stuart', ' Mill']"
+2555 581 Name of father of x -1 Name of father of John Stuart Mill James Rodríguez John Stuart Mill "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' thing' ' that'
+ ' strikes']" ", the great philosopher , and the father of the modern world .
+
+ The first thing that strikes" False and, arguably, John Stuart Mill – Mill's focus 6 [' and', ',', ' arguably', ',', ' John', ' Stuart', ' Mill']
+2556 581 Name of father of x -1 Name of father of John Stuart Mill James Rodríguez John Stuart Mill "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' thing' ' that'
+ ' strikes']" ", the great philosopher , and the father of the modern world .
+
+ The first thing that strikes" False empiricism outlined by John Stuart Mill in his An Examination 6 [' empir', 'icism', ' outlined', ' by', ' John', ' Stuart', ' Mill']
+2557 581 Name of father of x -1 Name of father of John Stuart Mill James Rodríguez John Stuart Mill "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' thing' ' that'
+ ' strikes']" ", the great philosopher , and the father of the modern world .
+
+ The first thing that strikes" False century writings of John Stuart Mill are also considered 5 [' century', ' writings', ' of', ' John', ' Stuart', ' Mill']
+2558 581 Name of father of x -1 Name of father of John Stuart Mill James Rodríguez John Stuart Mill "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' ' father' ' of'
+ ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first' ' thing' ' that'
+ ' strikes']" ", the great philosopher , and the father of the modern world .
+
+ The first thing that strikes" False continued through John Stuart Mill and Bertrand 4 [' continued', ' through', ' John', ' Stuart', ' Mill']
+2559 582 Name of father of x -1 Name of father of Gerard ter Borch Gerard ter Borch the Elder Gerard ter Borch "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Dutch' ' parliament' ','
+ ' and' ' a' ' member' ' of' ' the' ' Dutch' ' parliament' ',' ' and' ' a']" , who was a member of the Dutch parliament , and a member of the Dutch parliament , and a False 1936 in the style of Gerard ter Borch (unsold, on display 8 [' 1936', ' in', ' the', ' style', ' of', ' Gerard', ' ter', ' Bor', 'ch']
+2560 582 Name of father of x -1 Name of father of Gerard ter Borch Gerard ter Borch the Elder Gerard ter Borch "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Dutch' ' parliament' ','
+ ' and' ' a' ' member' ' of' ' the' ' Dutch' ' parliament' ',' ' and' ' a']" , who was a member of the Dutch parliament , and a member of the Dutch parliament , and a False the style of Gerard ter Borch (unsold, on display 6 [' the', ' style', ' of', ' Gerard', ' ter', ' Bor', 'ch']
+2561 582 Name of father of x -1 Name of father of Gerard ter Borch Gerard ter Borch the Elder Gerard ter Borch "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Dutch' ' parliament' ','
+ ' and' ' a' ' member' ' of' ' the' ' Dutch' ' parliament' ',' ' and' ' a']" , who was a member of the Dutch parliament , and a member of the Dutch parliament , and a False 1936 in the style of Gerard ter Borch (unsold, on display 8 [' 1936', ' in', ' the', ' style', ' of', ' Gerard', ' ter', ' Bor', 'ch']
+2562 583 Name of father of x -1 Name of father of Maurice Ravel Pierre-Joseph Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' daughter' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's daughter , the pian False " piece composed by Maurice Ravel in 1928 and is ""built" 5 [' piece', ' composed', ' by', ' Maurice', ' Ra', 'vel']
+2563 583 Name of father of x -1 Name of father of Maurice Ravel Pierre-Joseph Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' daughter' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's daughter , the pian False the maison-musée de Maurice Ravel remains open 9 [' the', ' ma', 'ison', '-', 'mus', 'ée', ' de', ' Maurice', ' Ra', 'vel']
+2564 583 Name of father of x -1 Name of father of Maurice Ravel Pierre-Joseph Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' daughter' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's daughter , the pian False Works by or about Maurice Ravel in libraries (WorldCat 6 [' Works', ' by', ' or', ' about', ' Maurice', ' Ra', 'vel']
+2565 583 Name of father of x -1 Name of father of Maurice Ravel Pierre-Joseph Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' daughter' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's daughter , the pian False language of Maurice Ravel and American 4 [' language', ' of', ' Maurice', ' Ra', 'vel']
+2566 583 Name of father of x -1 Name of father of Maurice Ravel Pierre-Joseph Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' daughter' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's daughter , the pian False 4 ['M', 'aur', 'ice', ' Ra', 'vel']
+2567 584 Name of father of x -1 Name of father of Paul Claudel Louis Prosper Claudel Paul Claudel "[',' ' the' ' great' ' poet' ',' ' the' ' great' ' poet' ',' ' the'
+ ' great' ' poet' ',' ' the' ' great' ' poet' ',' ' the' ' great' ' poet']" , the great poet , the great poet , the great poet , the great poet , the great poet False 2 ['Paul', ' Claud', 'el']
+2568 584 Name of father of x -1 Name of father of Paul Claudel Louis Prosper Claudel Paul Claudel "[',' ' the' ' great' ' poet' ',' ' the' ' great' ' poet' ',' ' the'
+ ' great' ' poet' ',' ' the' ' great' ' poet' ',' ' the' ' great' ' poet']" , the great poet , the great poet , the great poet , the great poet , the great poet False 2 ['Paul', ' Claud', 'el']
+2569 586 Name of father of x -1 Name of father of Emma Roberts Eric Roberts Emma Roberts "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' Roberts' ','
+ ' who' ' was' ' a' '\n' '\n' 'member' ' of' ' the' ' firm' ' of']" ", the daughter of the late Mr . Roberts , who was a
+
+ member of the firm of" False Pictures, with Emma Roberts as Nancy Drew, 4 [' Pictures', ',', ' with', ' Emma', ' Roberts']
+2570 586 Name of father of x -1 Name of father of Emma Roberts Eric Roberts Emma Roberts "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' Roberts' ','
+ ' who' ' was' ' a' '\n' '\n' 'member' ' of' ' the' ' firm' ' of']" ", the daughter of the late Mr . Roberts , who was a
+
+ member of the firm of" False May 2013 that Emma Roberts had been added to 4 [' May', ' 2013', ' that', ' Emma', ' Roberts']
+2571 586 Name of father of x -1 Name of father of Emma Roberts Eric Roberts Emma Roberts "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' Roberts' ','
+ ' who' ' was' ' a' '\n' '\n' 'member' ' of' ' the' ' firm' ' of']" ", the daughter of the late Mr . Roberts , who was a
+
+ member of the firm of" False in May 2013 that Emma Roberts had been added 5 [' in', ' May', ' 2013', ' that', ' Emma', ' Roberts']
+2572 586 Name of father of x -1 Name of father of Emma Roberts Eric Roberts Emma Roberts "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' Roberts' ','
+ ' who' ' was' ' a' '\n' '\n' 'member' ' of' ' the' ' firm' ' of']" ", the daughter of the late Mr . Roberts , who was a
+
+ member of the firm of" False Bros. Pictures, with Emma Roberts as Nancy Drew, 6 [' Bros', '.', ' Pictures', ',', ' with', ' Emma', ' Roberts']
+2573 586 Name of father of x -1 Name of father of Emma Roberts Eric Roberts Emma Roberts "[',' ' the' ' daughter' ' of' ' the' ' late' ' Mr' '.' ' Roberts' ','
+ ' who' ' was' ' a' '\n' '\n' 'member' ' of' ' the' ' firm' ' of']" ", the daughter of the late Mr . Roberts , who was a
+
+ member of the firm of" False The author Emma Roberts (c. 1794 – 1840) 3 [' The', ' author', ' Emma', ' Roberts']
+2574 587 Name of father of x -1 Name of father of Pythagoras Mnesarchus Pythagoras "[',' ' the' ' son' ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son'
+ ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son' ' of']" , the son of Pyth ag oras , and the son of Pyth ag oras , and the son of False philosophers from Pythagoras onwards was τὸ πᾶν 4 [' philosophers', ' from', ' Pyth', 'ag', 'oras']
+2575 587 Name of father of x -1 Name of father of Pythagoras Mnesarchus Pythagoras "[',' ' the' ' son' ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son'
+ ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son' ' of']" , the son of Pyth ag oras , and the son of Pyth ag oras , and the son of False tried to flee, but Pythagoras rallied the soldiers 7 [' tried', ' to', ' flee', ',', ' but', ' Pyth', 'ag', 'oras']
+2576 587 Name of father of x -1 Name of father of Pythagoras Mnesarchus Pythagoras "[',' ' the' ' son' ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son'
+ ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son' ' of']" , the son of Pyth ag oras , and the son of Pyth ag oras , and the son of False 3 ['P', 'yth', 'ag', 'oras']
+2577 587 Name of father of x -1 Name of father of Pythagoras Mnesarchus Pythagoras "[',' ' the' ' son' ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son'
+ ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son' ' of']" , the son of Pyth ag oras , and the son of Pyth ag oras , and the son of False of love, and Pythagoras and Heraclitus 6 [' of', ' love', ',', ' and', ' Pyth', 'ag', 'oras']
+2578 587 Name of father of x -1 Name of father of Pythagoras Mnesarchus Pythagoras "[',' ' the' ' son' ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son'
+ ' of' ' Pyth' 'ag' 'oras' ',' ' and' ' the' ' son' ' of']" , the son of Pyth ag oras , and the son of Pyth ag oras , and the son of False years before Pythagoras was born. (Yale's 4 [' years', ' before', ' Pyth', 'ag', 'oras']
+2579 589 Name of father of x -1 Name of father of Naomi Watts Peter Watts Naomi Watts "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' since' ' 2006' ' and'
+ ' have']" ", who is a former model and actress .
+
+ The couple have been married since 2006 and have" False co-starring Rachel Weisz, Naomi Watts and Marton 10 [' co', '-', 'star', 'ring', ' Rachel', ' We', 'is', 'z', ',', ' Naomi', ' Watts']
+2580 589 Name of father of x -1 Name of father of Naomi Watts Peter Watts Naomi Watts "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' since' ' 2006' ' and'
+ ' have']" ", who is a former model and actress .
+
+ The couple have been married since 2006 and have" False film co-starring Naomi Watts and Ewan McGregor. 6 [' film', ' co', '-', 'star', 'ring', ' Naomi', ' Watts']
+2581 589 Name of father of x -1 Name of father of Naomi Watts Peter Watts Naomi Watts "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' since' ' 2006' ' and'
+ ' have']" ", who is a former model and actress .
+
+ The couple have been married since 2006 and have" False July, actress Naomi Watts was cast opposite 4 [' July', ',', ' actress', ' Naomi', ' Watts']
+2582 589 Name of father of x -1 Name of father of Naomi Watts Peter Watts Naomi Watts "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' since' ' 2006' ' and'
+ ' have']" ", who is a former model and actress .
+
+ The couple have been married since 2006 and have" False " Eleanor Whitman
+" 4 [' Eleanor', ' Whitman', 'Na', 'omi', ' Watts']
+2583 589 Name of father of x -1 Name of father of Naomi Watts Peter Watts Naomi Watts "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' since' ' 2006' ' and'
+ ' have']" ", who is a former model and actress .
+
+ The couple have been married since 2006 and have" False Rachel Weisz, Naomi Watts and Marton Csokas. 6 [' Rachel', ' We', 'is', 'z', ',', ' Naomi', ' Watts']
+2584 590 Name of father of x -1 Name of father of Rita Ora Besnik Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False February 2012, Rita Ora covered the song 5 [' February', ' 2012', ',', ' Rita', ' O', 'ra']
+2585 590 Name of father of x -1 Name of father of Rita Ora Besnik Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False " and uneven"". Rita Ora covered ""Drunk in" 5 "[' and', ' uneven', '"".', ' Rita', ' O', 'ra']"
+2586 590 Name of father of x -1 Name of father of Rita Ora Besnik Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Girls mashed with Rita Ora and Chanel 5 [' Girls', ' mashed', ' with', ' Rita', ' O', 'ra']
+2587 590 Name of father of x -1 Name of father of Rita Ora Besnik Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False song was covered by Rita Ora at Radio 1's Big 6 [' song', ' was', ' covered', ' by', ' Rita', ' O', 'ra']
+2588 590 Name of father of x -1 Name of father of Rita Ora Besnik Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False " and uneven"". Rita Ora covered ""Drunk in" 5 "[' and', ' uneven', '"".', ' Rita', ' O', 'ra']"
+2589 591 Name of father of x -1 Name of father of Romain Rolland Émile Rolland Romain Rolland "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about']" ", the French writer , who was a friend of the
+
+ The first thing that strikes you about" False " restrictions."" Romain Rolland and Francis" 5 "[' restrictions', '.""', ' Rom', 'ain', ' Roll', 'and']"
+2590 591 Name of father of x -1 Name of father of Romain Rolland Émile Rolland Romain Rolland "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about']" ", the French writer , who was a friend of the
+
+ The first thing that strikes you about" False " restrictions."" Romain Rolland and Francis Poulenc" 5 "[' restrictions', '.""', ' Rom', 'ain', ' Roll', 'and']"
+2591 592 Name of father of x -1 Name of father of Andrew Lloyd Webber William Lloyd Webber Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False It was written by Andrew Lloyd Webber and Tim Rice, for 7 [' It', ' was', ' written', ' by', ' Andrew', ' Lloyd', ' Web', 'ber']
+2592 592 Name of father of x -1 Name of father of Andrew Lloyd Webber William Lloyd Webber Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False Tim Rice and Andrew Lloyd Webber in 1976, with 6 [' Tim', ' Rice', ' and', ' Andrew', ' Lloyd', ' Web', 'ber']
+2593 592 Name of father of x -1 Name of father of Andrew Lloyd Webber William Lloyd Webber Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False " collection of Andrew Lloyd Webber and the ""Narcissus" 5 [' collection', ' of', ' Andrew', ' Lloyd', ' Web', 'ber']
+2594 592 Name of father of x -1 Name of father of Andrew Lloyd Webber William Lloyd Webber Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False in the 2006 Andrew Lloyd Webber and David Ian stage 6 [' in', ' the', ' 2006', ' Andrew', ' Lloyd', ' Web', 'ber']
+2595 592 Name of father of x -1 Name of father of Andrew Lloyd Webber William Lloyd Webber Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False James Whitbourn and Andrew Lloyd Webber have all been involved 8 [' James', ' Whit', 'b', 'ourn', ' and', ' Andrew', ' Lloyd', ' Web', 'ber']
+2596 594 Name of father of x -1 Name of father of Frank Lloyd Wright William Carey Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False " Employment with Frank Lloyd Wright ===
+" 4 [' Employment', ' with', ' Frank', ' Lloyd', ' Wright']
+2597 594 Name of father of x -1 Name of father of Frank Lloyd Wright William Carey Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False Long, who was a Frank Lloyd Wright protege, according 7 [' Long', ',', ' who', ' was', ' a', ' Frank', ' Lloyd', ' Wright']
+2598 594 Name of father of x -1 Name of father of Frank Lloyd Wright William Carey Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False with nine other Frank Lloyd Wright properties to a tentative 5 [' with', ' nine', ' other', ' Frank', ' Lloyd', ' Wright']
+2599 594 Name of father of x -1 Name of father of Frank Lloyd Wright William Carey Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False Long, who was a Frank Lloyd Wright protege, according 7 [' Long', ',', ' who', ' was', ' a', ' Frank', ' Lloyd', ' Wright']
+2600 594 Name of father of x -1 Name of father of Frank Lloyd Wright William Carey Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False " Employment with Frank Lloyd Wright ===
+" 4 [' Employment', ' with', ' Frank', ' Lloyd', ' Wright']
+2601 595 Name of father of x -1 Name of father of Alexander III of Russia Alexander II of Russia Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' '.' '\n' '\n' 'The'
+ ' Russian' ' Empire' ' was' ' a' ' vast' ' country' ',' ' and' ' the']" ", and the
+
+ Russian Empire .
+
+ The Russian Empire was a vast country , and the" False ceremony was attended by Alexander III of Russia and then-Tsesarevich 7 [' ceremony', ' was', ' attended', ' by', ' Alexander', ' III', ' of', ' Russia']
+2602 595 Name of father of x -1 Name of father of Alexander III of Russia Alexander II of Russia Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' '.' '\n' '\n' 'The'
+ ' Russian' ' Empire' ' was' ' a' ' vast' ' country' ',' ' and' ' the']" ", and the
+
+ Russian Empire .
+
+ The Russian Empire was a vast country , and the" False by Emperor Alexander III of Russia for the Hermitage 5 [' by', ' Emperor', ' Alexander', ' III', ' of', ' Russia']
+2603 595 Name of father of x -1 Name of father of Alexander III of Russia Alexander II of Russia Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' '.' '\n' '\n' 'The'
+ ' Russian' ' Empire' ' was' ' a' ' vast' ' country' ',' ' and' ' the']" ", and the
+
+ Russian Empire .
+
+ The Russian Empire was a vast country , and the" False 1 November, Czar Alexander III of Russia died; Wilhelm II 8 [' 1', ' November', ',', ' C', 'zar', ' Alexander', ' III', ' of', ' Russia']
+2604 595 Name of father of x -1 Name of father of Alexander III of Russia Alexander II of Russia Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' '.' '\n' '\n' 'The'
+ ' Russian' ' Empire' ' was' ' a' ' vast' ' country' ',' ' and' ' the']" ", and the
+
+ Russian Empire .
+
+ The Russian Empire was a vast country , and the" False was attended by Alexander III of Russia and then-Tsesarevich 6 [' was', ' attended', ' by', ' Alexander', ' III', ' of', ' Russia']
+2605 595 Name of father of x -1 Name of father of Alexander III of Russia Alexander II of Russia Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' '.' '\n' '\n' 'The'
+ ' Russian' ' Empire' ' was' ' a' ' vast' ' country' ',' ' and' ' the']" ", and the
+
+ Russian Empire .
+
+ The Russian Empire was a vast country , and the" False purchased by Emperor Alexander III of Russia for the Hermitage Museum 6 [' purchased', ' by', ' Emperor', ' Alexander', ' III', ' of', ' Russia']
+2606 596 Name of father of x -1 Name of father of Tim Krabbé Maarten Krabbé Tim Krabbé "[',' ' the' ' father' ' of' ' the' ' modern' '-' 'day' ' ""' 'K' 'r' 'abb'
+ 'é' '""' '\n' '\n' 'The' ' name' ' Kr' 'abb']" ", the father of the modern - day "" K r abb é ""
+
+ The name Kr abb" False 3 ['Tim', ' Kr', 'abb', 'é']
+2607 597 Name of father of x -1 Name of father of Alphonse Daudet Vincent Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False Gustave Flaubert, Alphonse Daudet and Émile 10 [' Gust', 'ave', ' Fl', 'au', 'bert', ',', ' Alph', 'onse', ' D', 'aud', 'et']
+2608 597 Name of father of x -1 Name of father of Alphonse Daudet Vincent Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False by French dramatist Alphonse Daudet that Puccini thought 9 [' by', ' French', ' dram', 'at', 'ist', ' Alph', 'onse', ' D', 'aud', 'et']
+2609 597 Name of father of x -1 Name of father of Alphonse Daudet Vincent Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False dramatist Alphonse Daudet that Puccini 7 [' dram', 'at', 'ist', ' Alph', 'onse', ' D', 'aud', 'et']
+2610 597 Name of father of x -1 Name of father of Alphonse Daudet Vincent Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False French dramatist Alphonse Daudet that Puccini 8 [' French', ' dram', 'at', 'ist', ' Alph', 'onse', ' D', 'aud', 'et']
+2611 597 Name of father of x -1 Name of father of Alphonse Daudet Vincent Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False " Gustave Flaubert, Alphonse Daudet and Émile Zola.
+" 10 [' Gust', 'ave', ' Fl', 'au', 'bert', ',', ' Alph', 'onse', ' D', 'aud', 'et']
+2612 598 Name of father of x -1 Name of father of Roman Polanski Ryszard Polański Roman Polanski "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False of directors Roman Polanski and Alfred 4 [' of', ' directors', ' Roman', ' Pol', 'anski']
+2613 598 Name of father of x -1 Name of father of Roman Polanski Ryszard Polański Roman Polanski "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Prix in the Roman Polanski produced film 5 [' Prix', ' in', ' the', ' Roman', ' Pol', 'anski']
+2614 598 Name of father of x -1 Name of father of Roman Polanski Ryszard Polański Roman Polanski "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Jeremy Irons, Roman Polanski and many European, 6 [' Jeremy', ' Ir', 'ons', ',', ' Roman', ' Pol', 'anski']
+2615 598 Name of father of x -1 Name of father of Roman Polanski Ryszard Polański Roman Polanski "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False were inspired by Roman Polanski and Alfred Hitchcock 5 [' were', ' inspired', ' by', ' Roman', ' Pol', 'anski']
+2616 598 Name of father of x -1 Name of father of Roman Polanski Ryszard Polański Roman Polanski "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Sarandon, Jeremy Irons, Roman Polanski and many European, 9 [' Sar', 'andon', ',', ' Jeremy', ' Ir', 'ons', ',', ' Roman', ' Pol', 'anski']
+2617 599 Name of father of x -1 Name of father of Benjamin Britten Robert Victor Britten Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False about Britten: Benjamin Britten & his Festival 6 [' about', ' Br', 'itten', ':', ' Benjamin', ' Br', 'itten']
+2618 599 Name of father of x -1 Name of father of Benjamin Britten Robert Victor Britten Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False also signed Benjamin Britten and Aaron Copland. 4 [' also', ' signed', ' Benjamin', ' Br', 'itten']
+2619 599 Name of father of x -1 Name of father of Benjamin Britten Robert Victor Britten Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False " Benjamin Britten =
+" 2 [' Benjamin', ' Br', 'itten']
+2620 599 Name of father of x -1 Name of father of Benjamin Britten Robert Victor Britten Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False when the young Benjamin Britten found the performance 5 [' when', ' the', ' young', ' Benjamin', ' Br', 'itten']
+2621 599 Name of father of x -1 Name of father of Benjamin Britten Robert Victor Britten Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False been aware that Benjamin Britten had written 5 [' been', ' aware', ' that', ' Benjamin', ' Br', 'itten']
+2622 600 Name of father of x -1 Name of father of Doris Day Frederick Wilhelm Kappelhoff Doris Day "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' and' ' singer' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' and' ' singer' ' Dor'
+ 'is']" , the daughter of the late actor and singer , and the mother of the actress and singer Dor is False originally recorded by Doris Day for Alfred Hitchcock's 5 [' originally', ' recorded', ' by', ' Dor', 'is', ' Day']
+2623 600 Name of father of x -1 Name of father of Doris Day Frederick Wilhelm Kappelhoff Doris Day "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' and' ' singer' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' and' ' singer' ' Dor'
+ 'is']" , the daughter of the late actor and singer , and the mother of the actress and singer Dor is False in 1996, the 1995 Doris Day Music Award at the 7 [' in', ' 1996', ',', ' the', ' 1995', ' Dor', 'is', ' Day']
+2624 600 Name of father of x -1 Name of father of Doris Day Frederick Wilhelm Kappelhoff Doris Day "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' and' ' singer' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' and' ' singer' ' Dor'
+ 'is']" , the daughter of the late actor and singer , and the mother of the actress and singer Dor is False " released ""an album of Doris Day covers on [her]" 7 "[' released', ' ""', 'an', ' album', ' of', ' Dor', 'is', ' Day']"
+2625 600 Name of father of x -1 Name of father of Doris Day Frederick Wilhelm Kappelhoff Doris Day "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' and' ' singer' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' and' ' singer' ' Dor'
+ 'is']" , the daughter of the late actor and singer , and the mother of the actress and singer Dor is False originally recorded by Doris Day for Alfred Hitchcock's 5 [' originally', ' recorded', ' by', ' Dor', 'is', ' Day']
+2626 600 Name of father of x -1 Name of father of Doris Day Frederick Wilhelm Kappelhoff Doris Day "[',' ' the' ' daughter' ' of' ' the' ' late' ' actor' ' and' ' singer' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' and' ' singer' ' Dor'
+ 'is']" , the daughter of the late actor and singer , and the mother of the actress and singer Dor is False his album, Memories. Doris Day recorded the 7 [' his', ' album', ',', ' Memories', '.', ' Dor', 'is', ' Day']
+2627 601 Name of father of x -1 Name of father of Gabriel Fauré Toussaint Fauré Gabriel Fauré "[',' ' the' ' French' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of'
+ ' father' ' of' ' Gabriel' ' F' 'aur' 'é' ',' ' the' ' French']" ", the French composer , and the
+
+ Name of father of Gabriel F aur é , the French" False 4 ['Gab', 'riel', ' F', 'aur', 'é']
+2628 601 Name of father of x -1 Name of father of Gabriel Fauré Toussaint Fauré Gabriel Fauré "[',' ' the' ' French' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of'
+ ' father' ' of' ' Gabriel' ' F' 'aur' 'é' ',' ' the' ' French']" ", the French composer , and the
+
+ Name of father of Gabriel F aur é , the French" False Variations on the Name Gabriel Fauré (1949) for harp 8 [' Vari', 'ations', ' on', ' the', ' Name', ' Gabriel', ' F', 'aur', 'é']
+2629 601 Name of father of x -1 Name of father of Gabriel Fauré Toussaint Fauré Gabriel Fauré "[',' ' the' ' French' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of'
+ ' father' ' of' ' Gabriel' ' F' 'aur' 'é' ',' ' the' ' French']" ", the French composer , and the
+
+ Name of father of Gabriel F aur é , the French" False with Eugène Gigout, Gabriel Fauré and (after leaving 10 [' with', ' Eug', 'è', 'ne', ' Gig', 'out', ',', ' Gabriel', ' F', 'aur', 'é']
+2630 601 Name of father of x -1 Name of father of Gabriel Fauré Toussaint Fauré Gabriel Fauré "[',' ' the' ' French' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of'
+ ' father' ' of' ' Gabriel' ' F' 'aur' 'é' ',' ' the' ' French']" ", the French composer , and the
+
+ Name of father of Gabriel F aur é , the French" False The French composer Gabriel Fauré (1845 – 1924) wrote 6 [' The', ' French', ' composer', ' Gabriel', ' F', 'aur', 'é']
+2631 601 Name of father of x -1 Name of father of Gabriel Fauré Toussaint Fauré Gabriel Fauré "[',' ' the' ' French' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of'
+ ' father' ' of' ' Gabriel' ' F' 'aur' 'é' ',' ' the' ' French']" ", the French composer , and the
+
+ Name of father of Gabriel F aur é , the French" False Eugène Gigout, Gabriel Fauré and (after leaving 9 [' Eug', 'è', 'ne', ' Gig', 'out', ',', ' Gabriel', ' F', 'aur', 'é']
+2632 602 Name of father of x -1 Name of father of Sergei Eisenstein Mikhail Eisenstein Sergei Eisenstein "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' Sergei' ' Eisen' 'stein' ',' ' was' ' a']" , the Russian director of the film , and the film 's producer , Sergei Eisen stein , was a False most famously by Sergei Eisenstein in his 1925 silent 5 [' most', ' famously', ' by', ' Sergei', ' Eisen', 'stein']
+2633 602 Name of father of x -1 Name of father of Sergei Eisenstein Mikhail Eisenstein Sergei Eisenstein "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' Sergei' ' Eisen' 'stein' ',' ' was' ' a']" , the Russian director of the film , and the film 's producer , Sergei Eisen stein , was a False filmmakers such as Sergei Eisenstein and Charlie 5 [' filmmakers', ' such', ' as', ' Sergei', ' Eisen', 'stein']
+2634 602 Name of father of x -1 Name of father of Sergei Eisenstein Mikhail Eisenstein Sergei Eisenstein "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' Sergei' ' Eisen' 'stein' ',' ' was' ' a']" , the Russian director of the film , and the film 's producer , Sergei Eisen stein , was a False memorialized most famously by Sergei Eisenstein in his 1925 7 [' memorial', 'ized', ' most', ' famously', ' by', ' Sergei', ' Eisen', 'stein']
+2635 602 Name of father of x -1 Name of father of Sergei Eisenstein Mikhail Eisenstein Sergei Eisenstein "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' Sergei' ' Eisen' 'stein' ',' ' was' ' a']" , the Russian director of the film , and the film 's producer , Sergei Eisen stein , was a False Russian director Sergei Eisenstein especially disliked 4 [' Russian', ' director', ' Sergei', ' Eisen', 'stein']
+2636 602 Name of father of x -1 Name of father of Sergei Eisenstein Mikhail Eisenstein Sergei Eisenstein "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' Sergei' ' Eisen' 'stein' ',' ' was' ' a']" , the Russian director of the film , and the film 's producer , Sergei Eisen stein , was a False most famously by Sergei Eisenstein in his 1925 silent 5 [' most', ' famously', ' by', ' Sergei', ' Eisen', 'stein']
+2637 603 Name of father of x -1 Name of father of Frida Kahlo Guillermo Kahlo Frida Kahlo "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ',' ' Fr'
+ 'ida' ' Kah' 'lo' ',' ' was' ' a' ' great' ' admire' 'r']" , the painter , and the painter 's wife , Fr ida Kah lo , was a great admire r False Mexican painter Frida Kahlo was a notable wearer 5 [' Mexican', ' painter', ' Fr', 'ida', ' Kah', 'lo']
+2638 603 Name of father of x -1 Name of father of Frida Kahlo Guillermo Kahlo Frida Kahlo "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ',' ' Fr'
+ 'ida' ' Kah' 'lo' ',' ' was' ' a' ' great' ' admire' 'r']" , the painter , and the painter 's wife , Fr ida Kah lo , was a great admire r False as Remedios Varo, Frida Kahlo and Leonora Carrington. 10 [' as', ' Rem', 'ed', 'ios', ' V', 'aro', ',', ' Fr', 'ida', ' Kah', 'lo']
+2639 603 Name of father of x -1 Name of father of Frida Kahlo Guillermo Kahlo Frida Kahlo "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ',' ' Fr'
+ 'ida' ' Kah' 'lo' ',' ' was' ' a' ' great' ' admire' 'r']" , the painter , and the painter 's wife , Fr ida Kah lo , was a great admire r False and his wife Frida Kahlo lived and worked at 6 [' and', ' his', ' wife', ' Fr', 'ida', ' Kah', 'lo']
+2640 603 Name of father of x -1 Name of father of Frida Kahlo Guillermo Kahlo Frida Kahlo "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ',' ' Fr'
+ 'ida' ' Kah' 'lo' ',' ' was' ' a' ' great' ' admire' 'r']" , the painter , and the painter 's wife , Fr ida Kah lo , was a great admire r False The Mexican painter Frida Kahlo was a notable wearer 6 [' The', ' Mexican', ' painter', ' Fr', 'ida', ' Kah', 'lo']
+2641 603 Name of father of x -1 Name of father of Frida Kahlo Guillermo Kahlo Frida Kahlo "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ',' ' Fr'
+ 'ida' ' Kah' 'lo' ',' ' was' ' a' ' great' ' admire' 'r']" , the painter , and the painter 's wife , Fr ida Kah lo , was a great admire r False Remedios Varo, Frida Kahlo and Leonora Carrington. 9 [' Rem', 'ed', 'ios', ' V', 'aro', ',', ' Fr', 'ida', ' Kah', 'lo']
+2642 604 Name of father of x -1 Name of father of Charlotte Rampling Godfrey Rampling Charlotte Rampling "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Alida Valli and Charlotte Rampling ""in the [Miss" 7 [' Al', 'ida', ' V', 'alli', ' and', ' Charlotte', ' Ram', 'pling']
+2643 604 Name of father of x -1 Name of father of Charlotte Rampling Godfrey Rampling Charlotte Rampling "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Signoret, Alida Valli and Charlotte Rampling ""in the [Miss Blandish]" 10 [' Sign', 'oret', ',', ' Al', 'ida', ' V', 'alli', ' and', ' Charlotte', ' Ram', 'pling']
+2644 604 Name of father of x -1 Name of father of Charlotte Rampling Godfrey Rampling Charlotte Rampling "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Rampling as Miss Emily
+" 8 [' Ram', 'pling', ' as', ' Miss', ' Emily', 'Charl', 'otte', ' Ram', 'pling']
+2645 604 Name of father of x -1 Name of father of Charlotte Rampling Godfrey Rampling Charlotte Rampling "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " as Miss Emily
+" 6 [' as', ' Miss', ' Emily', 'Charl', 'otte', ' Ram', 'pling']
+2646 604 Name of father of x -1 Name of father of Charlotte Rampling Godfrey Rampling Charlotte Rampling "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False his second wife Charlotte Rampling at a dinner party 5 [' his', ' second', ' wife', ' Charlotte', ' Ram', 'pling']
+2647 606 Name of father of x -1 Name of father of Arthur Rimbaud Frédéric Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' the' ' poet' ' of' ' the' '\n' '\n' 'po' 'et'
+ ' of' ' the' ' poet' ',' ' the' ' poet' ' of' ' the']" ", the poet , the poet of the
+
+ po et of the poet , the poet of the" False appreciation of Arthur Rimbaud. In August 5 [' appreciation', ' of', ' Arthur', ' R', 'imb', 'aud']
+2648 606 Name of father of x -1 Name of father of Arthur Rimbaud Frédéric Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' the' ' poet' ' of' ' the' '\n' '\n' 'po' 'et'
+ ' of' ' the' ' poet' ',' ' the' ' poet' ' of' ' the']" ", the poet , the poet of the
+
+ po et of the poet , the poet of the" False diverse as French poet Arthur Rimbaud and Italian filmmaker 7 [' diverse', ' as', ' French', ' poet', ' Arthur', ' R', 'imb', 'aud']
+2649 606 Name of father of x -1 Name of father of Arthur Rimbaud Frédéric Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' the' ' poet' ' of' ' the' '\n' '\n' 'po' 'et'
+ ' of' ' the' ' poet' ',' ' the' ' poet' ' of' ' the']" ", the poet , the poet of the
+
+ po et of the poet , the poet of the" False Charles Baudelaire, Arthur Rimbaud and Stéphane Mallarmé. 9 [' Charles', ' B', 'aud', 'el', 'aire', ',', ' Arthur', ' R', 'imb', 'aud']
+2650 606 Name of father of x -1 Name of father of Arthur Rimbaud Frédéric Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' the' ' poet' ' of' ' the' '\n' '\n' 'po' 'et'
+ ' of' ' the' ' poet' ',' ' the' ' poet' ' of' ' the']" ", the poet , the poet of the
+
+ po et of the poet , the poet of the" False include James Joyce, Arthur Rimbaud and D. H. Lawrence. 7 [' include', ' James', ' Joyce', ',', ' Arthur', ' R', 'imb', 'aud']
+2651 606 Name of father of x -1 Name of father of Arthur Rimbaud Frédéric Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' the' ' poet' ' of' ' the' '\n' '\n' 'po' 'et'
+ ' of' ' the' ' poet' ',' ' the' ' poet' ' of' ' the']" ", the poet , the poet of the
+
+ po et of the poet , the poet of the" False Baudelaire, Arthur Rimbaud and Stéphane Mallarmé. 8 [' B', 'aud', 'el', 'aire', ',', ' Arthur', ' R', 'imb', 'aud']
+2652 607 Name of father of x -1 Name of father of Boris Yeltsin Nikolai Ignatievich Yeltsin Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Yel' 'ts' 'in' ""'s"" ' daughter' ','
+ ' Yel' 'ts']" , the Russian president , and the Russian president 's wife , Yel ts in 's daughter , Yel ts False of Belarus, met with Boris Yeltsin of Russia and Leonid 8 [' of', ' Belarus', ',', ' met', ' with', ' Boris', ' Yel', 'ts', 'in']
+2653 607 Name of father of x -1 Name of father of Boris Yeltsin Nikolai Ignatievich Yeltsin Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Yel' 'ts' 'in' ""'s"" ' daughter' ','
+ ' Yel' 'ts']" , the Russian president , and the Russian president 's wife , Yel ts in 's daughter , Yel ts False the electorate. Boris Yeltsin beat him decisively 6 [' the', ' electorate', '.', ' Boris', ' Yel', 'ts', 'in']
+2654 607 Name of father of x -1 Name of father of Boris Yeltsin Nikolai Ignatievich Yeltsin Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Yel' 'ts' 'in' ""'s"" ' daughter' ','
+ ' Yel' 'ts']" , the Russian president , and the Russian president 's wife , Yel ts in 's daughter , Yel ts False Coup of 1991 Boris Yeltsin and the Ministry 6 [' Coup', ' of', ' 1991', ' Boris', ' Yel', 'ts', 'in']
+2655 607 Name of father of x -1 Name of father of Boris Yeltsin Nikolai Ignatievich Yeltsin Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Yel' 'ts' 'in' ""'s"" ' daughter' ','
+ ' Yel' 'ts']" , the Russian president , and the Russian president 's wife , Yel ts in 's daughter , Yel ts False Union broke up, Boris Yeltsin ushered in 7 [' Union', ' broke', ' up', ',', ' Boris', ' Yel', 'ts', 'in']
+2656 607 Name of father of x -1 Name of father of Boris Yeltsin Nikolai Ignatievich Yeltsin Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ""'s"" ' wife' ',' ' Yel' 'ts' 'in' ""'s"" ' daughter' ','
+ ' Yel' 'ts']" , the Russian president , and the Russian president 's wife , Yel ts in 's daughter , Yel ts False " the funeral of Boris Yeltsin in 2007.
+" 6 [' the', ' funeral', ' of', ' Boris', ' Yel', 'ts', 'in']
+2657 608 Name of father of x -1 Name of father of Philip II of Spain Charles V Philip II of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False passed to his son Philip II of Spain on his abdication 7 [' passed', ' to', ' his', ' son', ' Philip', ' II', ' of', ' Spain']
+2658 608 Name of father of x -1 Name of father of Philip II of Spain Charles V Philip II of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False Portuguese Crown, Philip II of Spain briefly moved his 6 [' Portuguese', ' Crown', ',', ' Philip', ' II', ' of', ' Spain']
+2659 608 Name of father of x -1 Name of father of Philip II of Spain Charles V Philip II of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False the Portuguese Crown, Philip II of Spain briefly moved 7 [' the', ' Portuguese', ' Crown', ',', ' Philip', ' II', ' of', ' Spain']
+2660 608 Name of father of x -1 Name of father of Philip II of Spain Charles V Philip II of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False construction and Philip II of Spain was experiencing 5 [' construction', ' and', ' Philip', ' II', ' of', ' Spain']
+2661 608 Name of father of x -1 Name of father of Philip II of Spain Charles V Philip II of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False 4 ['Phil', 'ip', ' II', ' of', ' Spain']
+2662 609 Name of father of x -1 Name of father of Karen Blixen Wilhelm Dinesen Karen Blixen "[',' ' the' ' Danish' ' author' ' of' ' the' ' book' ' ""' 'Out' ' of'
+ ' Africa' '""' ' and' ' the' ' movie' ' ""' 'G' 'one' ' with' ' the']" ", the Danish author of the book "" Out of Africa "" and the movie "" G one with the" False the Danish writer Karen Blixen opposite Robert Redford's 6 [' the', ' Danish', ' writer', ' Karen', ' Bl', 'ix', 'en']
+2663 610 Name of father of x -1 Name of father of Louisa May Alcott Amos Bronson Alcott Louisa May Alcott "[',' ' the' ' author' ' of' ' Little' ' Women' ',' ' and' ' the' ' author'
+ ' of' ' Little' ' Women' ',' ' and' ' the' ' author' ' of' ' Little'
+ ' Women']" , the author of Little Women , and the author of Little Women , and the author of Little Women False version described by Louisa May Alcott in her 1879 story, 7 [' version', ' described', ' by', ' Louis', 'a', ' May', ' Al', 'cott']
+2664 610 Name of father of x -1 Name of father of Louisa May Alcott Amos Bronson Alcott Louisa May Alcott "[',' ' the' ' author' ' of' ' Little' ' Women' ',' ' and' ' the' ' author'
+ ' of' ' Little' ' Women' ',' ' and' ' the' ' author' ' of' ' Little'
+ ' Women']" , the author of Little Women , and the author of Little Women , and the author of Little Women False version described by Louisa May Alcott in her 1879 7 [' version', ' described', ' by', ' Louis', 'a', ' May', ' Al', 'cott']
+2665 610 Name of father of x -1 Name of father of Louisa May Alcott Amos Bronson Alcott Louisa May Alcott "[',' ' the' ' author' ' of' ' Little' ' Women' ',' ' and' ' the' ' author'
+ ' of' ' Little' ' Women' ',' ' and' ' the' ' author' ' of' ' Little'
+ ' Women']" , the author of Little Women , and the author of Little Women , and the author of Little Women False Little Women by Louisa May Alcott and dedicated 7 [' Little', ' Women', ' by', ' Louis', 'a', ' May', ' Al', 'cott']
+2666 610 Name of father of x -1 Name of father of Louisa May Alcott Amos Bronson Alcott Louisa May Alcott "[',' ' the' ' author' ' of' ' Little' ' Women' ',' ' and' ' the' ' author'
+ ' of' ' Little' ' Women' ',' ' and' ' the' ' author' ' of' ' Little'
+ ' Women']" , the author of Little Women , and the author of Little Women , and the author of Little Women False of Little Women by Louisa May Alcott and dedicated her performance 8 [' of', ' Little', ' Women', ' by', ' Louis', 'a', ' May', ' Al', 'cott']
+2667 611 Name of father of x -1 Name of father of Francis Ford Coppola Carmine Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' director' ' of' ' The'
+ ' God' 'father' ' Part' ' II' ',' ' The' ' God' 'father' ' Part' ' III']" 's The God father , and the director of The God father Part II , The God father Part III False including Martin Scorsese, Francis Ford Coppola and Steven Spielberg. 10 [' including', ' Martin', ' Sc', 'ors', 'ese', ',', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2668 611 Name of father of x -1 Name of father of Francis Ford Coppola Carmine Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' director' ' of' ' The'
+ ' God' 'father' ' Part' ' II' ',' ' The' ' God' 'father' ' Part' ' III']" 's The God father , and the director of The God father Part II , The God father Part III False worked with producer Francis Ford Coppola on The Godfather. 7 [' worked', ' with', ' producer', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2669 611 Name of father of x -1 Name of father of Francis Ford Coppola Carmine Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' director' ' of' ' The'
+ ' God' 'father' ' Part' ' II' ',' ' The' ' God' 'father' ' Part' ' III']" 's The God father , and the director of The God father Part II , The God father Part III False George Lucas and Francis Ford Coppola on the 17-minute 7 [' George', ' Lucas', ' and', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2670 611 Name of father of x -1 Name of father of Francis Ford Coppola Carmine Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' director' ' of' ' The'
+ ' God' 'father' ' Part' ' II' ',' ' The' ' God' 'father' ' Part' ' III']" 's The God father , and the director of The God father Part II , The God father Part III False Tarantino and Francis Ford Coppola — and combines 7 [' Tarant', 'ino', ' and', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2671 611 Name of father of x -1 Name of father of Francis Ford Coppola Carmine Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' director' ' of' ' The'
+ ' God' 'father' ' Part' ' II' ',' ' The' ' God' 'father' ' Part' ' III']" 's The God father , and the director of The God father Part II , The God father Part III False Martin Scorsese, Francis Ford Coppola and Steven 9 [' Martin', ' Sc', 'ors', 'ese', ',', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2672 612 Name of father of x -1 Name of father of Bruce Lee Lee Hoi-chuen Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False personality to that of Bruce Lee and Noel Gallagher. 5 [' personality', ' to', ' that', ' of', ' Bruce', ' Lee']
+2673 612 Name of father of x -1 Name of father of Bruce Lee Lee Hoi-chuen Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False " ""At the time, Bruce Lee was knocking out" 6 "[' ""', 'At', ' the', ' time', ',', ' Bruce', ' Lee']"
+2674 612 Name of father of x -1 Name of father of Bruce Lee Lee Hoi-chuen Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False introduces a young Bruce Lee prior to becoming 4 [' introduces', ' a', ' young', ' Bruce', ' Lee']
+2675 612 Name of father of x -1 Name of father of Bruce Lee Lee Hoi-chuen Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False 1 ['Bruce', ' Lee']
+2676 612 Name of father of x -1 Name of father of Bruce Lee Lee Hoi-chuen Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False " Dragon"" after Bruce Lee and Jackie Chan. After" 4 "[' Dragon', '""', ' after', ' Bruce', ' Lee']"
+2677 613 Name of father of x -1 Name of father of Benjamin Disraeli Isaac D'Israeli Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' ' great' ' poet' ','
+ ' the' ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first']" ", the great states man , the great poet , the great
+
+ 1 .
+
+ The first" False distinguished history; Benjamin Disraeli had been a pupil 6 [' distinguished', ' history', ';', ' Benjamin', ' Dis', 'rael', 'i']
+2678 613 Name of father of x -1 Name of father of Benjamin Disraeli Isaac D'Israeli Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' ' great' ' poet' ','
+ ' the' ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first']" ", the great states man , the great poet , the great
+
+ 1 .
+
+ The first" False replaced Hall, and Benjamin Disraeli was appointed 7 [' replaced', ' Hall', ',', ' and', ' Benjamin', ' Dis', 'rael', 'i']
+2679 613 Name of father of x -1 Name of father of Benjamin Disraeli Isaac D'Israeli Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' ' great' ' poet' ','
+ ' the' ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first']" ", the great states man , the great poet , the great
+
+ 1 .
+
+ The first" False distinguished history; Benjamin Disraeli had been a 6 [' distinguished', ' history', ';', ' Benjamin', ' Dis', 'rael', 'i']
+2680 613 Name of father of x -1 Name of father of Benjamin Disraeli Isaac D'Israeli Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' ' great' ' poet' ','
+ ' the' ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first']" ", the great states man , the great poet , the great
+
+ 1 .
+
+ The first" False " Disraeli =
+" 8 [' Dis', 'rael', 'i', ' =', 'Ben', 'jamin', ' Dis', 'rael', 'i']
+2681 613 Name of father of x -1 Name of father of Benjamin Disraeli Isaac D'Israeli Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' ' great' ' poet' ','
+ ' the' ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first']" ", the great states man , the great poet , the great
+
+ 1 .
+
+ The first" False " Disraeli =
+" 8 [' Dis', 'rael', 'i', ' =', 'Ben', 'jamin', ' Dis', 'rael', 'i']
+2682 614 Name of father of x -1 Name of father of Silvio Berlusconi Luigi Berlusconi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False Carfagna and Silvio Berlusconi engaged in a telephone 9 [' Car', 'f', 'agna', ' and', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2683 614 Name of father of x -1 Name of father of Silvio Berlusconi Luigi Berlusconi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False from 2008, Silvio Berlusconi adopted a similar 8 [' from', ' 2008', ',', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2684 614 Name of father of x -1 Name of father of Silvio Berlusconi Luigi Berlusconi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False President and owner Silvio Berlusconi was equally delighted 8 [' President', ' and', ' owner', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2685 614 Name of father of x -1 Name of father of Silvio Berlusconi Luigi Berlusconi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False then-Prime Minister Silvio Berlusconi inaugurated the 9 [' then', '-', 'Prime', ' Minister', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2686 614 Name of father of x -1 Name of father of Silvio Berlusconi Luigi Berlusconi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False again from 2008, Silvio Berlusconi adopted a similar 9 [' again', ' from', ' 2008', ',', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2687 615 Name of father of x -1 Name of father of Charles Gounod François-Louis Gounod Charles Gounod "[',' ' the' ' composer' ' of' ' the' ' opera' ' _' 'Fa' 'ust' '_' ','
+ ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the composer of the opera _ Fa ust _ , and the
+ " False composers such as Charles Gounod and Félicien 7 [' compos', 'ers', ' such', ' as', ' Charles', ' G', 'oun', 'od']
+2688 615 Name of father of x -1 Name of father of Charles Gounod François-Louis Gounod Charles Gounod "[',' ' the' ' composer' ' of' ' the' ' opera' ' _' 'Fa' 'ust' '_' ','
+ ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the composer of the opera _ Fa ust _ , and the
+ " False the French composer Charles Gounod in Bennett's music 6 [' the', ' French', ' composer', ' Charles', ' G', 'oun', 'od']
+2689 615 Name of father of x -1 Name of father of Charles Gounod François-Louis Gounod Charles Gounod "[',' ' the' ' composer' ' of' ' the' ' opera' ' _' 'Fa' 'ust' '_' ','
+ ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the composer of the opera _ Fa ust _ , and the
+ " False composers such as Charles Gounod and Félicien 7 [' compos', 'ers', ' such', ' as', ' Charles', ' G', 'oun', 'od']
+2690 615 Name of father of x -1 Name of father of Charles Gounod François-Louis Gounod Charles Gounod "[',' ' the' ' composer' ' of' ' the' ' opera' ' _' 'Fa' 'ust' '_' ','
+ ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the composer of the opera _ Fa ust _ , and the
+ " False French composers such as Charles Gounod and Félicien David. 8 [' French', ' compos', 'ers', ' such', ' as', ' Charles', ' G', 'oun', 'od']
+2691 615 Name of father of x -1 Name of father of Charles Gounod François-Louis Gounod Charles Gounod "[',' ' the' ' composer' ' of' ' the' ' opera' ' _' 'Fa' 'ust' '_' ','
+ ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the composer of the opera _ Fa ust _ , and the
+ " False French composer Charles Gounod in Bennett's music 5 [' French', ' composer', ' Charles', ' G', 'oun', 'od']
+2692 616 Name of father of x -1 Name of father of Johannes Vermeer Reijnier Janszoon Vermeer Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the']" , the painter , and his wife , the painter 's wife , and the painter 's mother , the False yellows preferred by Johannes Vermeer and other Dutch 7 [' yell', 'ows', ' preferred', ' by', ' Johannes', ' Ver', 'me', 'er']
+2693 616 Name of father of x -1 Name of father of Johannes Vermeer Reijnier Janszoon Vermeer Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the']" , the painter , and his wife , the painter 's wife , and the painter 's mother , the False Andalee. After a famous Johannes Vermeer painting is stolen 10 [' And', 'ale', 'e', '.', ' After', ' a', ' famous', ' Johannes', ' Ver', 'me', 'er']
+2694 616 Name of father of x -1 Name of father of Johannes Vermeer Reijnier Janszoon Vermeer Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the']" , the painter , and his wife , the painter 's wife , and the painter 's mother , the False been inspired by Johannes Vermeer and showed the 6 [' been', ' inspired', ' by', ' Johannes', ' Ver', 'me', 'er']
+2695 616 Name of father of x -1 Name of father of Johannes Vermeer Reijnier Janszoon Vermeer Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the']" , the painter , and his wife , the painter 's wife , and the painter 's mother , the False Dutch painter Johannes Vermeer (played by 5 [' Dutch', ' painter', ' Johannes', ' Ver', 'me', 'er']
+2696 616 Name of father of x -1 Name of father of Johannes Vermeer Reijnier Janszoon Vermeer Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the']" , the painter , and his wife , the painter 's wife , and the painter 's mother , the False discover the artist Johannes Vermeer and his paintings, 6 [' discover', ' the', ' artist', ' Johannes', ' Ver', 'me', 'er']
+2697 618 Name of father of x -1 Name of father of Hugo Grotius Jan Cornets de Groot Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' friend' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Dutch jur ist , who was a friend of the
+ " False " 17th-century Dutch jurist Hugo Grotius that ""the" 10 [' 17', 'th', '-', 'century', ' Dutch', ' jur', 'ist', ' Hugo', ' Gro', 't', 'ius']
+2698 618 Name of father of x -1 Name of father of Hugo Grotius Jan Cornets de Groot Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' friend' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Dutch jur ist , who was a friend of the
+ " False 4 ['Hug', 'o', ' Gro', 't', 'ius']
+2699 618 Name of father of x -1 Name of father of Hugo Grotius Jan Cornets de Groot Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' friend' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Dutch jur ist , who was a friend of the
+ " False " 17th-century Dutch jurist Hugo Grotius that ""the purpose" 10 [' 17', 'th', '-', 'century', ' Dutch', ' jur', 'ist', ' Hugo', ' Gro', 't', 'ius']
+2700 618 Name of father of x -1 Name of father of Hugo Grotius Jan Cornets de Groot Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' friend' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Dutch jur ist , who was a friend of the
+ " False 4 ['Hug', 'o', ' Gro', 't', 'ius']
+2701 618 Name of father of x -1 Name of father of Hugo Grotius Jan Cornets de Groot Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' friend' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great Dutch jur ist , who was a friend of the
+ " False 4 ['Hug', 'o', ' Gro', 't', 'ius']
+2702 619 Name of father of x -1 Name of father of Olivier Messiaen Pierre Messiaen Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s"" ' wife' ',']" , the composer , and his wife , the painter , and his daughter , the painter 's wife , False habitat in Olivier Messiaen ’ s Catalogue 5 [' habitat', ' in', ' Olivier', ' Mess', 'ia', 'en']
+2703 619 Name of father of x -1 Name of father of Olivier Messiaen Pierre Messiaen Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s"" ' wife' ',']" , the composer , and his wife , the painter , and his daughter , the painter 's wife , False feature on Olivier Messiaen by Radio France 5 [' feature', ' on', ' Olivier', ' Mess', 'ia', 'en']
+2704 619 Name of father of x -1 Name of father of Olivier Messiaen Pierre Messiaen Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s"" ' wife' ',']" , the composer , and his wife , the painter , and his daughter , the painter 's wife , False blackcap. The composer Olivier Messiaen used the song 8 [' black', 'cap', '.', ' The', ' composer', ' Olivier', ' Mess', 'ia', 'en']
+2705 619 Name of father of x -1 Name of father of Olivier Messiaen Pierre Messiaen Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s"" ' wife' ',']" , the composer , and his wife , the painter , and his daughter , the painter 's wife , False of The Rite. For Olivier Messiaen The Rite was 8 [' of', ' The', ' Rite', '.', ' For', ' Olivier', ' Mess', 'ia', 'en']
+2706 619 Name of father of x -1 Name of father of Olivier Messiaen Pierre Messiaen Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' his' ' daughter' ',' ' the' ' painter' ""'s"" ' wife' ',']" , the composer , and his wife , the painter , and his daughter , the painter 's wife , False In-depth feature on Olivier Messiaen by Radio France 8 [' In', '-', 'depth', ' feature', ' on', ' Olivier', ' Mess', 'ia', 'en']
+2707 620 Name of father of x -1 Name of father of Helmut Kohl Hans Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False chancellor Helmut Kohl took full responsibility 4 [' chancellor', ' Hel', 'mut', ' Koh', 'l']
+2708 620 Name of father of x -1 Name of father of Helmut Kohl Hans Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False time, chancellor Helmut Kohl had accepted an invitation 6 [' time', ',', ' chancellor', ' Hel', 'mut', ' Koh', 'l']
+2709 620 Name of father of x -1 Name of father of Helmut Kohl Hans Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False CDU chairman Helmut Kohl included him in his 6 [' CD', 'U', ' chairman', ' Hel', 'mut', ' Koh', 'l']
+2710 620 Name of father of x -1 Name of father of Helmut Kohl Hans Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False German Chancellor Helmut Kohl said he would 5 [' German', ' Chancellor', ' Hel', 'mut', ' Koh', 'l']
+2711 620 Name of father of x -1 Name of father of Helmut Kohl Hans Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False time, chancellor Helmut Kohl had accepted 6 [' time', ',', ' chancellor', ' Hel', 'mut', ' Koh', 'l']
+2712 622 Name of father of x -1 Name of father of Anders Zorn Leonhard Zorn Anders Zorn "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' parliament' ','
+ ' and' ' a' ' member' ' of' ' the' ' German' ' parliament' ',' ' and'
+ ' a']" , who was a member of the German parliament , and a member of the German parliament , and a False Carl Larsson, Anders Zorn and Bruno Liljefors. 6 [' Carl', ' Lar', 'sson', ',', ' Anders', ' Z', 'orn']
+2713 623 Name of father of x -1 Name of father of Harry S. Truman John Anderson Truman Harry S. Truman "[',' ' the' ' only' ' son' ' of' ' a' ' poor' ' family' ',' ' and' ' a'
+ ' poor' ' family' ' at' ' that' '.' '\n' '\n' 'The' ' family']" ", the only son of a poor family , and a poor family at that .
+
+ The family" False " Eleanor Roosevelt and Harry S. Truman"". Harry S." 6 [' Eleanor', ' Roosevelt', ' and', ' Harry', ' S', '.', ' Truman']
+2714 623 Name of father of x -1 Name of father of Harry S. Truman John Anderson Truman Harry S. Truman "[',' ' the' ' only' ' son' ' of' ' a' ' poor' ' family' ',' ' and' ' a'
+ ' poor' ' family' ' at' ' that' '.' '\n' '\n' 'The' ' family']" ", the only son of a poor family , and a poor family at that .
+
+ The family" False 3 ['Harry', ' S', '.', ' Truman']
+2715 623 Name of father of x -1 Name of father of Harry S. Truman John Anderson Truman Harry S. Truman "[',' ' the' ' only' ' son' ' of' ' a' ' poor' ' family' ',' ' and' ' a'
+ ' poor' ' family' ' at' ' that' '.' '\n' '\n' 'The' ' family']" ", the only son of a poor family , and a poor family at that .
+
+ The family" False 3 ['Harry', ' S', '.', ' Truman']
+2716 623 Name of father of x -1 Name of father of Harry S. Truman John Anderson Truman Harry S. Truman "[',' ' the' ' only' ' son' ' of' ' a' ' poor' ' family' ',' ' and' ' a'
+ ' poor' ' family' ' at' ' that' '.' '\n' '\n' 'The' ' family']" ", the only son of a poor family , and a poor family at that .
+
+ The family" False destroyed, and US President Harry S. Truman ordered General 8 [' destroyed', ',', ' and', ' US', ' President', ' Harry', ' S', '.', ' Truman']
+2717 623 Name of father of x -1 Name of father of Harry S. Truman John Anderson Truman Harry S. Truman "[',' ' the' ' only' ' son' ' of' ' a' ' poor' ' family' ',' ' and' ' a'
+ ' poor' ' family' ' at' ' that' '.' '\n' '\n' 'The' ' family']" ", the only son of a poor family , and a poor family at that .
+
+ The family" False States President Harry S. Truman subsequently ordered 5 [' States', ' President', ' Harry', ' S', '.', ' Truman']
+2718 624 Name of father of x -1 Name of father of Nicholas I of Russia Paul I of Russia Nicholas I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False portrait of Tsar Nicholas I of Russia and another of the 7 [' portrait', ' of', ' Ts', 'ar', ' Nicholas', ' I', ' of', ' Russia']
+2719 624 Name of father of x -1 Name of father of Nicholas I of Russia Paul I of Russia Nicholas I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False royalty, such as Nicholas I of Russia (1837), Prince Charles 7 [' royalty', ',', ' such', ' as', ' Nicholas', ' I', ' of', ' Russia']
+2720 624 Name of father of x -1 Name of father of Nicholas I of Russia Paul I of Russia Nicholas I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False and royalty, such as Nicholas I of Russia (1837), Prince 8 [' and', ' royalty', ',', ' such', ' as', ' Nicholas', ' I', ' of', ' Russia']
+2721 625 Name of father of x -1 Name of father of Mary, Queen of Scots James V of Scotland Mary, Queen of Scots "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False 4 ['Mary', ',', ' Queen', ' of', ' Scots']
+2722 625 Name of father of x -1 Name of father of Mary, Queen of Scots James V of Scotland Mary, Queen of Scots "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False the imprisoned Mary, Queen of Scots and her 140 retainers, 6 [' the', ' imprisoned', ' Mary', ',', ' Queen', ' of', ' Scots']
+2723 625 Name of father of x -1 Name of father of Mary, Queen of Scots James V of Scotland Mary, Queen of Scots "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False 4 ['Mary', ',', ' Queen', ' of', ' Scots']
+2724 625 Name of father of x -1 Name of father of Mary, Queen of Scots James V of Scotland Mary, Queen of Scots "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False keeping the imprisoned Mary, Queen of Scots and her 140 retainers, 7 [' keeping', ' the', ' imprisoned', ' Mary', ',', ' Queen', ' of', ' Scots']
+2725 625 Name of father of x -1 Name of father of Mary, Queen of Scots James V of Scotland Mary, Queen of Scots "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False 1558 marriage of Mary, Queen of Scots and Francis, 8 [' 15', '58', ' marriage', ' of', ' Mary', ',', ' Queen', ' of', ' Scots']
+2726 627 Name of father of x -1 Name of father of Caspar David Friedrich Adolph Gottlieb Friedrich Caspar David Friedrich "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 17' '74' ' in' ' the'
+ ' town' ' of' ' K' 'ön' 'igs' 'berg' ',' ' Pr' 'ussia']" , the painter , was born in 17 74 in the town of K ön igs berg , Pr ussia False the Baroque, Caspar David Friedrich and Carl Spitzweg of 8 [' the', ' Bar', 'o', 'que', ',', ' Cas', 'par', ' David', ' Friedrich']
+2727 627 Name of father of x -1 Name of father of Caspar David Friedrich Adolph Gottlieb Friedrich Caspar David Friedrich "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 17' '74' ' in' ' the'
+ ' town' ' of' ' K' 'ön' 'igs' 'berg' ',' ' Pr' 'ussia']" , the painter , was born in 17 74 in the town of K ön igs berg , Pr ussia False 3 ['Cas', 'par', ' David', ' Friedrich']
+2728 627 Name of father of x -1 Name of father of Caspar David Friedrich Adolph Gottlieb Friedrich Caspar David Friedrich "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 17' '74' ' in' ' the'
+ ' town' ' of' ' K' 'ön' 'igs' 'berg' ',' ' Pr' 'ussia']" , the painter , was born in 17 74 in the town of K ön igs berg , Pr ussia False of the Baroque, Caspar David Friedrich and Carl Spitzweg 9 [' of', ' the', ' Bar', 'o', 'que', ',', ' Cas', 'par', ' David', ' Friedrich']
+2729 627 Name of father of x -1 Name of father of Caspar David Friedrich Adolph Gottlieb Friedrich Caspar David Friedrich "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 17' '74' ' in' ' the'
+ ' town' ' of' ' K' 'ön' 'igs' 'berg' ',' ' Pr' 'ussia']" , the painter , was born in 17 74 in the town of K ön igs berg , Pr ussia False " Friedrich =
+" 5 [' Friedrich', ' =', 'Cas', 'par', ' David', ' Friedrich']
+2730 627 Name of father of x -1 Name of father of Caspar David Friedrich Adolph Gottlieb Friedrich Caspar David Friedrich "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 17' '74' ' in' ' the'
+ ' town' ' of' ' K' 'ön' 'igs' 'berg' ',' ' Pr' 'ussia']" , the painter , was born in 17 74 in the town of K ön igs berg , Pr ussia False 3 ['Cas', 'par', ' David', ' Friedrich']
+2731 630 Name of father of x -1 Name of father of Friedrich Dürrenmatt Reinhold Dürrenmatt Friedrich Dürrenmatt "[""'s"" ' _' 'The' ' Visit' '_' ',' ' and' ' the' ' _' 'B' 'ild' 'ung' 's'
+ 'roman' '_' ' of' ' the' ' _' 'B' 'ild']" 's _ The Visit _ , and the _ B ild ung s roman _ of the _ B ild False Frisch (1911 – 91) and Friedrich Dürrenmatt (1921 – 90), 14 [' Fr', 'isch', ' (', '19', '11', ' –', ' 91', ')', ' and', ' Friedrich', ' D', 'ür', 'ren', 'm', 'att']
+2732 630 Name of father of x -1 Name of father of Friedrich Dürrenmatt Reinhold Dürrenmatt Friedrich Dürrenmatt "[""'s"" ' _' 'The' ' Visit' '_' ',' ' and' ' the' ' _' 'B' 'ild' 'ung' 's'
+ 'roman' '_' ' of' ' the' ' _' 'B' 'ild']" 's _ The Visit _ , and the _ B ild ung s roman _ of the _ B ild False (1911 – 91) and Friedrich Dürrenmatt (1921 – 90), 12 [' (', '19', '11', ' –', ' 91', ')', ' and', ' Friedrich', ' D', 'ür', 'ren', 'm', 'att']
+2733 630 Name of father of x -1 Name of father of Friedrich Dürrenmatt Reinhold Dürrenmatt Friedrich Dürrenmatt "[""'s"" ' _' 'The' ' Visit' '_' ',' ' and' ' the' ' _' 'B' 'ild' 'ung' 's'
+ 'roman' '_' ' of' ' the' ' _' 'B' 'ild']" 's _ The Visit _ , and the _ B ild ung s roman _ of the _ B ild False (1911 – 91) and Friedrich Dürrenmatt (1921 – 90), 12 [' (', '19', '11', ' –', ' 91', ')', ' and', ' Friedrich', ' D', 'ür', 'ren', 'm', 'att']
+2734 631 Name of father of x -1 Name of father of Ignatius of Loyola Beltrán II Ibáñez de Loyola Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jesuit' ' Order' '.' '\n' '\n' 'The'
+ ' Jes' 'uits' ' are' ' the' ' most' ' powerful' ' and' ' influential'
+ ' religious']" ", the founder of the Jesuit Order .
+
+ The Jes uits are the most powerful and influential religious" False Exercises of St. Ignatius of Loyola (1491 – 1556), the 12 [' Ex', 'erc', 'ises', ' of', ' St', '.', ' Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2735 631 Name of father of x -1 Name of father of Ignatius of Loyola Beltrán II Ibáñez de Loyola Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jesuit' ' Order' '.' '\n' '\n' 'The'
+ ' Jes' 'uits' ' are' ' the' ' most' ' powerful' ' and' ' influential'
+ ' religious']" ", the founder of the Jesuit Order .
+
+ The Jes uits are the most powerful and influential religious" False Exercises of Ignatius of Loyola use meditative mental 10 [' Ex', 'erc', 'ises', ' of', ' Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2736 631 Name of father of x -1 Name of father of Ignatius of Loyola Beltrán II Ibáñez de Loyola Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jesuit' ' Order' '.' '\n' '\n' 'The'
+ ' Jes' 'uits' ' are' ' the' ' most' ' powerful' ' and' ' influential'
+ ' religious']" ", the founder of the Jesuit Order .
+
+ The Jes uits are the most powerful and influential religious" False Exercises of St. Ignatius of Loyola (1491 – 1556), the 12 [' Ex', 'erc', 'ises', ' of', ' St', '.', ' Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2737 631 Name of father of x -1 Name of father of Ignatius of Loyola Beltrán II Ibáñez de Loyola Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jesuit' ' Order' '.' '\n' '\n' 'The'
+ ' Jes' 'uits' ' are' ' the' ' most' ' powerful' ' and' ' influential'
+ ' religious']" ", the founder of the Jesuit Order .
+
+ The Jes uits are the most powerful and influential religious" False Spiritual Exercises of Ignatius of Loyola use meditative mental 11 [' Spiritual', ' Ex', 'erc', 'ises', ' of', ' Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2738 631 Name of father of x -1 Name of father of Ignatius of Loyola Beltrán II Ibáñez de Loyola Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jesuit' ' Order' '.' '\n' '\n' 'The'
+ ' Jes' 'uits' ' are' ' the' ' most' ' powerful' ' and' ' influential'
+ ' religious']" ", the founder of the Jesuit Order .
+
+ The Jes uits are the most powerful and influential religious" False Exercises of St. Ignatius of Loyola (1491 – 1556), 12 [' Ex', 'erc', 'ises', ' of', ' St', '.', ' Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2739 632 Name of father of x -1 Name of father of Alexander I of Russia Paul I of Russia Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False author of On War, and Alexander I of Russia noting that the 9 [' author', ' of', ' On', ' War', ',', ' and', ' Alexander', ' I', ' of', ' Russia']
+2740 632 Name of father of x -1 Name of father of Alexander I of Russia Paul I of Russia Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False by the decree of Alexander I of Russia in 1817, and 7 [' by', ' the', ' decree', ' of', ' Alexander', ' I', ' of', ' Russia']
+2741 632 Name of father of x -1 Name of father of Alexander I of Russia Paul I of Russia Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False to collect Tsar Alexander I of Russia and King Frederick 7 [' to', ' collect', ' Ts', 'ar', ' Alexander', ' I', ' of', ' Russia']
+2742 632 Name of father of x -1 Name of father of Alexander I of Russia Paul I of Russia Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False established by the decree of Alexander I of Russia in 1817, and by 1850, 8 [' established', ' by', ' the', ' decree', ' of', ' Alexander', ' I', ' of', ' Russia']
+2743 632 Name of father of x -1 Name of father of Alexander I of Russia Paul I of Russia Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False boarding house for Tsar Alexander I of Russia during his 8 [' boarding', ' house', ' for', ' Ts', 'ar', ' Alexander', ' I', ' of', ' Russia']
+2744 633 Name of father of x -1 Name of father of Godfrey Kneller Zacharias Kniller Godfrey Kneller "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' God' 'frey' ' Kn'
+ 'eller' ',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", the
+
+ Name of father of God frey Kn eller , the
+
+ Name of father of" False Dyck, Peter Lely, Godfrey Kneller and William 10 [' Dy', 'ck', ',', ' Peter', ' Le', 'ly', ',', ' God', 'frey', ' Kn', 'eller']
+2745 633 Name of father of x -1 Name of father of Godfrey Kneller Zacharias Kniller Godfrey Kneller "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' God' 'frey' ' Kn'
+ 'eller' ',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", the
+
+ Name of father of God frey Kn eller , the
+
+ Name of father of" False Dyck, Peter Lely, Godfrey Kneller and William 10 [' Dy', 'ck', ',', ' Peter', ' Le', 'ly', ',', ' God', 'frey', ' Kn', 'eller']
+2746 633 Name of father of x -1 Name of father of Godfrey Kneller Zacharias Kniller Godfrey Kneller "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' God' 'frey' ' Kn'
+ 'eller' ',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", the
+
+ Name of father of God frey Kn eller , the
+
+ Name of father of" False 1696. Portraits by Sir Godfrey Kneller are at the National 10 [' 16', '96', '.', ' Port', 'raits', ' by', ' Sir', ' God', 'frey', ' Kn', 'eller']
+2747 633 Name of father of x -1 Name of father of Godfrey Kneller Zacharias Kniller Godfrey Kneller "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' God' 'frey' ' Kn'
+ 'eller' ',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", the
+
+ Name of father of God frey Kn eller , the
+
+ Name of father of" False Portraits by Sir Godfrey Kneller are at the National 7 [' Port', 'raits', ' by', ' Sir', ' God', 'frey', ' Kn', 'eller']
+2748 633 Name of father of x -1 Name of father of Godfrey Kneller Zacharias Kniller Godfrey Kneller "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' God' 'frey' ' Kn'
+ 'eller' ',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", the
+
+ Name of father of God frey Kn eller , the
+
+ Name of father of" False Portraits by Sir Godfrey Kneller are at the National 7 [' Port', 'raits', ' by', ' Sir', ' God', 'frey', ' Kn', 'eller']
+2749 635 Name of father of x -1 Name of father of Peter Ustinov Jona von Ustinov Peter Ustinov "[',' ' the' ' actor' ',' ' who' ' was' ' a' ' friend' ' of' ' the'
+ ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large' ',']" ", the actor , who was a friend of the family .
+
+ The house was a large ," False " and Emily are by Peter Ustinov and Emily Osborne.
+" 8 [' and', ' Emily', ' are', ' by', ' Peter', ' U', 'st', 'in', 'ov']
+2750 635 Name of father of x -1 Name of father of Peter Ustinov Jona von Ustinov Peter Ustinov "[',' ' the' ' actor' ',' ' who' ' was' ' a' ' friend' ' of' ' the'
+ ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large' ',']" ", the actor , who was a friend of the family .
+
+ The house was a large ," False non-Chinese actors, Peter Ustinov and Angie Dickinson, 9 [' non', '-', 'Chinese', ' actors', ',', ' Peter', ' U', 'st', 'in', 'ov']
+2751 635 Name of father of x -1 Name of father of Peter Ustinov Jona von Ustinov Peter Ustinov "[',' ' the' ' actor' ',' ' who' ' was' ' a' ' friend' ' of' ' the'
+ ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large' ',']" ", the actor , who was a friend of the family .
+
+ The house was a large ," False " a sense of duty."" Peter Ustinov described her" 9 "[' a', ' sense', ' of', ' duty', '.""', ' Peter', ' U', 'st', 'in', 'ov']"
+2752 635 Name of father of x -1 Name of father of Peter Ustinov Jona von Ustinov Peter Ustinov "[',' ' the' ' actor' ',' ' who' ' was' ' a' ' friend' ' of' ' the'
+ ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large' ',']" ", the actor , who was a friend of the family .
+
+ The house was a large ," False (Lyric, 1972), a Peter Ustinov comedy, Overheard 11 [' (', 'Ly', 'ric', ',', ' 1972', '),', ' a', ' Peter', ' U', 'st', 'in', 'ov']
+2753 635 Name of father of x -1 Name of father of Peter Ustinov Jona von Ustinov Peter Ustinov "[',' ' the' ' actor' ',' ' who' ' was' ' a' ' friend' ' of' ' the'
+ ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large' ',']" ", the actor , who was a friend of the family .
+
+ The house was a large ," False (Lyric, 1972), a Peter Ustinov comedy, Overheard 11 [' (', 'Ly', 'ric', ',', ' 1972', '),', ' a', ' Peter', ' U', 'st', 'in', 'ov']
+2754 636 Name of father of x -1 Name of father of Shinzō Abe Shintarō Abe Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'iko' ',' ' who' ' is' ' a' ' former' ' Olympic' ' gold' ' medal']" , the Japanese prime minister , and his wife , Ak iko , who is a former Olympic gold medal False Prime Minister Shinzō Abe said Japan wanted 5 [' Prime', ' Minister', ' Shin', 'z', 'ō', ' Abe']
+2755 636 Name of father of x -1 Name of father of Shinzō Abe Shintarō Abe Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'iko' ',' ' who' ' is' ' a' ' former' ' Olympic' ' gold' ' medal']" , the Japanese prime minister , and his wife , Ak iko , who is a former Olympic gold medal False general election, Shinzō Abe replaced Yoshihiko 6 [' general', ' election', ',', ' Shin', 'z', 'ō', ' Abe']
+2756 636 Name of father of x -1 Name of father of Shinzō Abe Shintarō Abe Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'iko' ',' ' who' ' is' ' a' ' former' ' Olympic' ' gold' ' medal']" , the Japanese prime minister , and his wife , Ak iko , who is a former Olympic gold medal False Prime Minister Shinzō Abe that the sea be called 5 [' Prime', ' Minister', ' Shin', 'z', 'ō', ' Abe']
+2757 636 Name of father of x -1 Name of father of Shinzō Abe Shintarō Abe Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'iko' ',' ' who' ' is' ' a' ' former' ' Olympic' ' gold' ' medal']" , the Japanese prime minister , and his wife , Ak iko , who is a former Olympic gold medal False general election, Shinzō Abe replaced Yoshihiko 6 [' general', ' election', ',', ' Shin', 'z', 'ō', ' Abe']
+2758 636 Name of father of x -1 Name of father of Shinzō Abe Shintarō Abe Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'iko' ',' ' who' ' is' ' a' ' former' ' Olympic' ' gold' ' medal']" , the Japanese prime minister , and his wife , Ak iko , who is a former Olympic gold medal False Japanese Prime Minister Shinzō Abe that the sea be 6 [' Japanese', ' Prime', ' Minister', ' Shin', 'z', 'ō', ' Abe']
+2759 637 Name of father of x -1 Name of father of Louis Aragon Louis Andrieux Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' ' revolution'
+ ',' ' and' ' the' ' poet' ' of' ' the' ' revolution' ',' ' and' ' the']" , the poet , and the poet of the revolution , and the poet of the revolution , and the False Communist poet Louis Aragon in 1928. Berni continued 4 [' Communist', ' poet', ' Louis', ' Ar', 'agon']
+2760 637 Name of father of x -1 Name of father of Louis Aragon Louis Andrieux Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' ' revolution'
+ ',' ' and' ' the' ' poet' ' of' ' the' ' revolution' ',' ' and' ' the']" , the poet , and the poet of the revolution , and the poet of the revolution , and the False was Picasso and Louis Aragon who told Koppelman 6 [' was', ' Pic', 'asso', ' and', ' Louis', ' Ar', 'agon']
+2761 637 Name of father of x -1 Name of father of Louis Aragon Louis Andrieux Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' ' revolution'
+ ',' ' and' ' the' ' poet' ' of' ' the' ' revolution' ',' ' and' ' the']" , the poet , and the poet of the revolution , and the poet of the revolution , and the False Deux poèmes de Louis Aragon (1943), titled 8 [' De', 'ux', ' po', 'è', 'mes', ' de', ' Louis', ' Ar', 'agon']
+2762 637 Name of father of x -1 Name of father of Louis Aragon Louis Andrieux Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' ' revolution'
+ ',' ' and' ' the' ' poet' ' of' ' the' ' revolution' ',' ' and' ' the']" , the poet , and the poet of the revolution , and the poet of the revolution , and the False Max Jacob, Louis Aragon and Jean Cocteau along 5 [' Max', ' Jacob', ',', ' Louis', ' Ar', 'agon']
+2763 637 Name of father of x -1 Name of father of Louis Aragon Louis Andrieux Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' ' revolution'
+ ',' ' and' ' the' ' poet' ' of' ' the' ' revolution' ',' ' and' ' the']" , the poet , and the poet of the revolution , and the poet of the revolution , and the False the Deux poèmes de Louis Aragon (1943), titled 9 [' the', ' De', 'ux', ' po', 'è', 'mes', ' de', ' Louis', ' Ar', 'agon']
+2764 638 Name of father of x -1 Name of father of Basil of Caesarea Basil the Elder Basil of Caesarea "[',' ' and' ' the' ' other' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the other , the
+ " False examples of Basil of Caesarea (who is the Greek 6 [' examples', ' of', ' Basil', ' of', ' Ca', 'es', 'area']
+2765 638 Name of father of x -1 Name of father of Basil of Caesarea Basil the Elder Basil of Caesarea "[',' ' and' ' the' ' other' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the other , the
+ " False be identified as Basil of Caesarea and John Chrysostom. 7 [' be', ' identified', ' as', ' Basil', ' of', ' Ca', 'es', 'area']
+2766 638 Name of father of x -1 Name of father of Basil of Caesarea Basil the Elder Basil of Caesarea "[',' ' and' ' the' ' other' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the other , the
+ " False examples of Basil of Caesarea (who is the Greek 6 [' examples', ' of', ' Basil', ' of', ' Ca', 'es', 'area']
+2767 638 Name of father of x -1 Name of father of Basil of Caesarea Basil the Elder Basil of Caesarea "[',' ' and' ' the' ' other' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", and the other , the
+ " False identified as Basil of Caesarea and John Chrysostom. 6 [' identified', ' as', ' Basil', ' of', ' Ca', 'es', 'area']
+2768 639 Name of father of x -1 Name of father of Fiona Apple Brandon Maggart Fiona Apple "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Fiona' ' Apple' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the singer , and the father of the singer Fiona Apple .
+
+ I 'm not sure if" False collaborated with) singer Fiona Apple for several years 5 [' collaborated', ' with', ')', ' singer', ' Fiona', ' Apple']
+2769 639 Name of father of x -1 Name of father of Fiona Apple Brandon Maggart Fiona Apple "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Fiona' ' Apple' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the singer , and the father of the singer Fiona Apple .
+
+ I 'm not sure if" False including PJ Harvey, Fiona Apple and The Distillers. 5 [' including', ' PJ', ' Harvey', ',', ' Fiona', ' Apple']
+2770 639 Name of father of x -1 Name of father of Fiona Apple Brandon Maggart Fiona Apple "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Fiona' ' Apple' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the singer , and the father of the singer Fiona Apple .
+
+ I 'm not sure if" False with) singer Fiona Apple for several years 4 [' with', ')', ' singer', ' Fiona', ' Apple']
+2771 639 Name of father of x -1 Name of father of Fiona Apple Brandon Maggart Fiona Apple "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Fiona' ' Apple' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the singer , and the father of the singer Fiona Apple .
+
+ I 'm not sure if" False " singer-songwriter Fiona Apple with the line, ""Well," 5 [' singer', '-', 'song', 'writer', ' Fiona', ' Apple']
+2772 639 Name of father of x -1 Name of father of Fiona Apple Brandon Maggart Fiona Apple "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Fiona' ' Apple' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the singer , and the father of the singer Fiona Apple .
+
+ I 'm not sure if" False " for ""Criminal"" by Fiona Apple and summarizing" 7 "[' for', ' ""', 'C', 'riminal', '""', ' by', ' Fiona', ' Apple']"
+2773 640 Name of father of x -1 Name of father of Kurt Vonnegut Kurt Vonnegut, Sr. Kurt Vonnegut "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Kurt' ' Von'
+ 'neg' 'ut' ',' ' Jr' '.' ' and' ' his']" ", Jr .
+
+ I am a big fan of Kurt Von neg ut , Jr . and his" False are taught by Kurt Vonnegut at the University 6 [' are', ' taught', ' by', ' Kurt', ' Von', 'neg', 'ut']
+2774 640 Name of father of x -1 Name of father of Kurt Vonnegut Kurt Vonnegut, Sr. Kurt Vonnegut "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Kurt' ' Von'
+ 'neg' 'ut' ',' ' Jr' '.' ' and' ' his']" ", Jr .
+
+ I am a big fan of Kurt Von neg ut , Jr . and his" False 4 ['K', 'urt', ' Von', 'neg', 'ut']
+2775 640 Name of father of x -1 Name of father of Kurt Vonnegut Kurt Vonnegut, Sr. Kurt Vonnegut "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Kurt' ' Von'
+ 'neg' 'ut' ',' ' Jr' '.' ' and' ' his']" ", Jr .
+
+ I am a big fan of Kurt Von neg ut , Jr . and his" False " and satirist Kurt Vonnegut are notable alumni.
+" 6 [' and', ' satir', 'ist', ' Kurt', ' Von', 'neg', 'ut']
+2776 640 Name of father of x -1 Name of father of Kurt Vonnegut Kurt Vonnegut, Sr. Kurt Vonnegut "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Kurt' ' Von'
+ 'neg' 'ut' ',' ' Jr' '.' ' and' ' his']" ", Jr .
+
+ I am a big fan of Kurt Von neg ut , Jr . and his" False libraries, the Kurt Vonnegut Memorial Library 6 [' libraries', ',', ' the', ' Kurt', ' Von', 'neg', 'ut']
+2777 640 Name of father of x -1 Name of father of Kurt Vonnegut Kurt Vonnegut, Sr. Kurt Vonnegut "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Kurt' ' Von'
+ 'neg' 'ut' ',' ' Jr' '.' ' and' ' his']" ", Jr .
+
+ I am a big fan of Kurt Von neg ut , Jr . and his" False 4 ['K', 'urt', ' Von', 'neg', 'ut']
+2778 641 Name of father of x -1 Name of father of Tony Blair Leo Blair Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False " a ""flagship Tony Blair pledge to allow" 5 "[' a', ' ""', 'flags', 'hip', ' Tony', ' Blair']"
+2779 641 Name of father of x -1 Name of father of Tony Blair Leo Blair Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False a speech made by Tony Blair at the Labour Party 5 [' a', ' speech', ' made', ' by', ' Tony', ' Blair']
+2780 641 Name of father of x -1 Name of father of Tony Blair Leo Blair Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False leadership after Tony Blair became favourite, 3 [' leadership', ' after', ' Tony', ' Blair']
+2781 641 Name of father of x -1 Name of father of Tony Blair Leo Blair Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False 1 ['Tony', ' Blair']
+2782 641 Name of father of x -1 Name of father of Tony Blair Leo Blair Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False UK Prime Minister Tony Blair met with Gaddafi 4 [' UK', ' Prime', ' Minister', ' Tony', ' Blair']
+2783 642 Name of father of x -1 Name of father of John XXIII Giovanni Battista Roncalli John XXIII "[',' ' the' ' pope' ' who' ' was' ' elected' ' in' ' 1958' ',' ' and'
+ ' who' ' died' ' in' ' 1963' '.' '\n' '\n' 'The' ' pope' ' who']" ", the pope who was elected in 1958 , and who died in 1963 .
+
+ The pope who" False explained to Pope John XXIII that the shield 5 [' explained', ' to', ' Pope', ' John', ' XX', 'III']
+2784 642 Name of father of x -1 Name of father of John XXIII Giovanni Battista Roncalli John XXIII "[',' ' the' ' pope' ' who' ' was' ' elected' ' in' ' 1958' ',' ' and'
+ ' who' ' died' ' in' ' 1963' '.' '\n' '\n' 'The' ' pope' ' who']" ", the pope who was elected in 1958 , and who died in 1963 .
+
+ The pope who" False methods, prompted John XXIII to form a committee 5 [' methods', ',', ' prompted', ' John', ' XX', 'III']
+2785 642 Name of father of x -1 Name of father of John XXIII Giovanni Battista Roncalli John XXIII "[',' ' the' ' pope' ' who' ' was' ' elected' ' in' ' 1958' ',' ' and'
+ ' who' ' died' ' in' ' 1963' '.' '\n' '\n' 'The' ' pope' ' who']" ", the pope who was elected in 1958 , and who died in 1963 .
+
+ The pope who" False succeeded Alexander V as John XXIII in 1410. John 6 [' succeeded', ' Alexander', ' V', ' as', ' John', ' XX', 'III']
+2786 642 Name of father of x -1 Name of father of John XXIII Giovanni Battista Roncalli John XXIII "[',' ' the' ' pope' ' who' ' was' ' elected' ' in' ' 1958' ',' ' and'
+ ' who' ' died' ' in' ' 1963' '.' '\n' '\n' 'The' ' pope' ' who']" ", the pope who was elected in 1958 , and who died in 1963 .
+
+ The pope who" False telegrams from Pope John XXIII and Queen Elizabeth 7 [' te', 'leg', 'rams', ' from', ' Pope', ' John', ' XX', 'III']
+2787 642 Name of father of x -1 Name of father of John XXIII Giovanni Battista Roncalli John XXIII "[',' ' the' ' pope' ' who' ' was' ' elected' ' in' ' 1958' ',' ' and'
+ ' who' ' died' ' in' ' 1963' '.' '\n' '\n' 'The' ' pope' ' who']" ", the pope who was elected in 1958 , and who died in 1963 .
+
+ The pope who" False methods, prompted John XXIII to form a committee 5 [' methods', ',', ' prompted', ' John', ' XX', 'III']
+2788 643 Name of father of x -1 Name of father of Czesław Miłosz Aleksander Miłosz Czesław Miłosz "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' C' 'zes' 'ł' 'aw' ' Mi' 'ł'
+ 'os' 'z' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of C zes ł aw Mi ł os z
+
+ ! Name of" False was considered by Czesław Miłosz to be the great 10 [' was', ' considered', ' by', ' C', 'zes', 'ł', 'aw', ' Mi', 'ł', 'os', 'z']
+2789 643 Name of father of x -1 Name of father of Czesław Miłosz Aleksander Miłosz Czesław Miłosz "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' C' 'zes' 'ł' 'aw' ' Mi' 'ł'
+ 'os' 'z' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of C zes ł aw Mi ł os z
+
+ ! Name of" False Prus canon, and Czesław Miłosz has called Children 12 [' Pr', 'us', ' canon', ',', ' and', ' C', 'zes', 'ł', 'aw', ' Mi', 'ł', 'os', 'z']
+2790 643 Name of father of x -1 Name of father of Czesław Miłosz Aleksander Miłosz Czesław Miłosz "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' C' 'zes' 'ł' 'aw' ' Mi' 'ł'
+ 'os' 'z' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of C zes ł aw Mi ł os z
+
+ ! Name of" False canon, and Czesław Miłosz has called Children 10 [' canon', ',', ' and', ' C', 'zes', 'ł', 'aw', ' Mi', 'ł', 'os', 'z']
+2791 643 Name of father of x -1 Name of father of Czesław Miłosz Aleksander Miłosz Czesław Miłosz "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' C' 'zes' 'ł' 'aw' ' Mi' 'ł'
+ 'os' 'z' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of C zes ł aw Mi ł os z
+
+ ! Name of" False considered by Czesław Miłosz to be the great 9 [' considered', ' by', ' C', 'zes', 'ł', 'aw', ' Mi', 'ł', 'os', 'z']
+2792 643 Name of father of x -1 Name of father of Czesław Miłosz Aleksander Miłosz Czesław Miłosz "['\n' '\n' '!' 'Name' ' of' ' mother' ' of' ' C' 'zes' 'ł' 'aw' ' Mi' 'ł'
+ 'os' 'z' '\n' '\n' '!' 'Name' ' of']" "
+
+ ! Name of mother of C zes ł aw Mi ł os z
+
+ ! Name of" False was considered by Czesław Miłosz to be the great 10 [' was', ' considered', ' by', ' C', 'zes', 'ł', 'aw', ' Mi', 'ł', 'os', 'z']
+2793 645 Name of father of x -1 Name of father of David Lloyd George William George David Lloyd George "[',' ' the' ' Prime' ' Minister' ' of' ' the' ' United' ' Kingdom' ','
+ ' and' ' his' ' wife' ',' ' Mary' ',' ' who' ' was' ' the' ' daughter'
+ ' of']" , the Prime Minister of the United Kingdom , and his wife , Mary , who was the daughter of False British Prime Minister David Lloyd George had been bribed 5 [' British', ' Prime', ' Minister', ' David', ' Lloyd', ' George']
+2794 645 Name of father of x -1 Name of father of David Lloyd George William George David Lloyd George "[',' ' the' ' Prime' ' Minister' ' of' ' the' ' United' ' Kingdom' ','
+ ' and' ' his' ' wife' ',' ' Mary' ',' ' who' ' was' ' the' ' daughter'
+ ' of']" , the Prime Minister of the United Kingdom , and his wife , Mary , who was the daughter of False resistance to the EFF and David Lloyd George was still a 7 [' resistance', ' to', ' the', ' EFF', ' and', ' David', ' Lloyd', ' George']
+2795 645 Name of father of x -1 Name of father of David Lloyd George William George David Lloyd George "[',' ' the' ' Prime' ' Minister' ' of' ' the' ' United' ' Kingdom' ','
+ ' and' ' his' ' wife' ',' ' Mary' ',' ' who' ' was' ' the' ' daughter'
+ ' of']" , the Prime Minister of the United Kingdom , and his wife , Mary , who was the daughter of False of the Exchequer David Lloyd George in 1912. As 7 [' of', ' the', ' Ex', 'che', 'quer', ' David', ' Lloyd', ' George']
+2796 645 Name of father of x -1 Name of father of David Lloyd George William George David Lloyd George "[',' ' the' ' Prime' ' Minister' ' of' ' the' ' United' ' Kingdom' ','
+ ' and' ' his' ' wife' ',' ' Mary' ',' ' who' ' was' ' the' ' daughter'
+ ' of']" , the Prime Minister of the United Kingdom , and his wife , Mary , who was the daughter of False British Prime Minister David Lloyd George had been bribed 5 [' British', ' Prime', ' Minister', ' David', ' Lloyd', ' George']
+2797 645 Name of father of x -1 Name of father of David Lloyd George William George David Lloyd George "[',' ' the' ' Prime' ' Minister' ' of' ' the' ' United' ' Kingdom' ','
+ ' and' ' his' ' wife' ',' ' Mary' ',' ' who' ' was' ' the' ' daughter'
+ ' of']" , the Prime Minister of the United Kingdom , and his wife , Mary , who was the daughter of False particular. David Lloyd George held a number 4 [' particular', '.', ' David', ' Lloyd', ' George']
+2798 646 Name of father of x -1 Name of father of Winslow Homer Charles Savage Homer Winslow Homer "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' son' ',' ' Homer' ','
+ ' the' ' poet' '.' '\n' '\n' 'The' ' house' ' is' ' a']" ", the painter , and of his son , Homer , the poet .
+
+ The house is a" False American artist Winslow Homer (1836 – 1910), replicates 4 [' American', ' artist', ' Wins', 'low', ' Homer']
+2799 646 Name of father of x -1 Name of father of Winslow Homer Charles Savage Homer Winslow Homer "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' son' ',' ' Homer' ','
+ ' the' ' poet' '.' '\n' '\n' 'The' ' house' ' is' ' a']" ", the painter , and of his son , Homer , the poet .
+
+ The house is a" False United States. Winslow Homer (1836 – 1910) depicted 5 [' United', ' States', '.', ' Wins', 'low', ' Homer']
+2800 646 Name of father of x -1 Name of father of Winslow Homer Charles Savage Homer Winslow Homer "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' son' ',' ' Homer' ','
+ ' the' ' poet' '.' '\n' '\n' 'The' ' house' ' is' ' a']" ", the painter , and of his son , Homer , the poet .
+
+ The house is a" False American artist Winslow Homer (1836 – 1910), replicates 4 [' American', ' artist', ' Wins', 'low', ' Homer']
+2801 647 Name of father of x -1 Name of father of Sarah Jessica Parker Stephen Parker Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' plays' ' Carrie' ' Brad' 'shaw' ' in'
+ ' the' ' TV' ' series' ' Sex' ' and' ' the' ' City' '.' '\n' '\n' 'The']" ", the actress who plays Carrie Brad shaw in the TV series Sex and the City .
+
+ The" False Cartman also claims Sarah Jessica Parker has also been 6 [' Cart', 'man', ' also', ' claims', ' Sarah', ' Jessica', ' Parker']
+2802 647 Name of father of x -1 Name of father of Sarah Jessica Parker Stephen Parker Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' plays' ' Carrie' ' Brad' 'shaw' ' in'
+ ' the' ' TV' ' series' ' Sex' ' and' ' the' ' City' '.' '\n' '\n' 'The']" ", the actress who plays Carrie Brad shaw in the TV series Sex and the City .
+
+ The" False 2 ['Sarah', ' Jessica', ' Parker']
+2803 647 Name of father of x -1 Name of father of Sarah Jessica Parker Stephen Parker Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' plays' ' Carrie' ' Brad' 'shaw' ' in'
+ ' the' ' TV' ' series' ' Sex' ' and' ' the' ' City' '.' '\n' '\n' 'The']" ", the actress who plays Carrie Brad shaw in the TV series Sex and the City .
+
+ The" False 2 ['Sarah', ' Jessica', ' Parker']
+2804 647 Name of father of x -1 Name of father of Sarah Jessica Parker Stephen Parker Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' plays' ' Carrie' ' Brad' 'shaw' ' in'
+ ' the' ' TV' ' series' ' Sex' ' and' ' the' ' City' '.' '\n' '\n' 'The']" ", the actress who plays Carrie Brad shaw in the TV series Sex and the City .
+
+ The" False starred opposite Sarah Jessica Parker in the Marc 4 [' starred', ' opposite', ' Sarah', ' Jessica', ' Parker']
+2805 647 Name of father of x -1 Name of father of Sarah Jessica Parker Stephen Parker Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' plays' ' Carrie' ' Brad' 'shaw' ' in'
+ ' the' ' TV' ' series' ' Sex' ' and' ' the' ' City' '.' '\n' '\n' 'The']" ", the actress who plays Carrie Brad shaw in the TV series Sex and the City .
+
+ The" False 2 ['Sarah', ' Jessica', ' Parker']
+2806 649 Name of father of x -1 Name of father of Sven Hedin Abraham Ludvig Hedin Sven Hedin "[',' ' the' ' explorer' ',' ' who' ' was' ' born' ' in' ' 18' '66' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the explorer , who was born in 18 66 .
+
+ The name of the father of the" False to Swedish explorer Sven Hedin in 1902; this may well 5 [' to', ' Swedish', ' explorer', ' Sven', ' H', 'edin']
+2807 649 Name of father of x -1 Name of father of Sven Hedin Abraham Ludvig Hedin Sven Hedin "[',' ' the' ' explorer' ',' ' who' ' was' ' born' ' in' ' 18' '66' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the explorer , who was born in 18 66 .
+
+ The name of the father of the" False Swedish explorer Sven Hedin in 1902; this may 4 [' Swedish', ' explorer', ' Sven', ' H', 'edin']
+2808 649 Name of father of x -1 Name of father of Sven Hedin Abraham Ludvig Hedin Sven Hedin "[',' ' the' ' explorer' ',' ' who' ' was' ' born' ' in' ' 18' '66' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the explorer , who was born in 18 66 .
+
+ The name of the father of the" False Swedish explorer Sven Hedin in 1902; this may 4 [' Swedish', ' explorer', ' Sven', ' H', 'edin']
+2809 650 Name of father of x -1 Name of father of Irving Berlin Moses Beilin Irving Berlin "[',' ' the' ' composer' ' of' ' ""' 'White' ' Christmas' ',""' ' ""' 'An'
+ 'ch' 'ors' ' A' 'we' 'igh' ',""' ' ""' 'White' ' Christmas' ',""']" ", the composer of "" White Christmas ,"" "" An ch ors A we igh ,"" "" White Christmas ,""" False deface as many Irving Berlin songs as you like, 5 [' def', 'ace', ' as', ' many', ' Irving', ' Berlin']
+2810 650 Name of father of x -1 Name of father of Irving Berlin Moses Beilin Irving Berlin "[',' ' the' ' composer' ' of' ' ""' 'White' ' Christmas' ',""' ' ""' 'An'
+ 'ch' 'ors' ' A' 'we' 'igh' ',""' ' ""' 'White' ' Christmas' ',""']" ", the composer of "" White Christmas ,"" "" An ch ors A we igh ,"" "" White Christmas ,""" False song's composer Irving Berlin to have the song, 4 "[' song', ""'s"", ' composer', ' Irving', ' Berlin']"
+2811 650 Name of father of x -1 Name of father of Irving Berlin Moses Beilin Irving Berlin "[',' ' the' ' composer' ' of' ' ""' 'White' ' Christmas' ',""' ' ""' 'An'
+ 'ch' 'ors' ' A' 'we' 'igh' ',""' ' ""' 'White' ' Christmas' ',""']" ", the composer of "" White Christmas ,"" "" An ch ors A we igh ,"" "" White Christmas ,""" False " show that the Irving Berlin song ""Puttin 'On" 4 [' show', ' that', ' the', ' Irving', ' Berlin']
+2812 650 Name of father of x -1 Name of father of Irving Berlin Moses Beilin Irving Berlin "[',' ' the' ' composer' ' of' ' ""' 'White' ' Christmas' ',""' ' ""' 'An'
+ 'ch' 'ors' ' A' 'we' 'igh' ',""' ' ""' 'White' ' Christmas' ',""']" ", the composer of "" White Christmas ,"" "" An ch ors A we igh ,"" "" White Christmas ,""" False song written by Irving Berlin in 1918 that 4 [' song', ' written', ' by', ' Irving', ' Berlin']
+2813 650 Name of father of x -1 Name of father of Irving Berlin Moses Beilin Irving Berlin "[',' ' the' ' composer' ' of' ' ""' 'White' ' Christmas' ',""' ' ""' 'An'
+ 'ch' 'ors' ' A' 'we' 'igh' ',""' ' ""' 'White' ' Christmas' ',""']" ", the composer of "" White Christmas ,"" "" An ch ors A we igh ,"" "" White Christmas ,""" False Vincent Youmans, Irving Berlin and Jerome Kern. 5 [' Vincent', ' You', 'mans', ',', ' Irving', ' Berlin']
+2814 651 Name of father of x -1 Name of father of Noël Coward Arthur Sabin Coward Noël Coward "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' No' 'ë' 'l' ' Cow'
+ 'ard' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of No ë l Cow ard is not known .
+
+ The" False Irving Berlin, Noël Coward and more. Wainwright 7 [' Irving', ' Berlin', ',', ' No', 'ë', 'l', ' Cow', 'ard']
+2815 651 Name of father of x -1 Name of father of Noël Coward Arthur Sabin Coward Noël Coward "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' No' 'ë' 'l' ' Cow'
+ 'ard' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of No ë l Cow ard is not known .
+
+ The" False governor opposite Noël Coward in the 1969 Peter 6 [' governor', ' opposite', ' No', 'ë', 'l', ' Cow', 'ard']
+2816 651 Name of father of x -1 Name of father of Noël Coward Arthur Sabin Coward Noël Coward "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' No' 'ë' 'l' ' Cow'
+ 'ard' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of No ë l Cow ard is not known .
+
+ The" False Greece against Turkey. Noël Coward was a notable 8 [' Greece', ' against', ' Turkey', '.', ' No', 'ë', 'l', ' Cow', 'ard']
+2817 651 Name of father of x -1 Name of father of Noël Coward Arthur Sabin Coward Noël Coward "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' No' 'ë' 'l' ' Cow'
+ 'ard' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of No ë l Cow ard is not known .
+
+ The" False " quite mad"". Noël Coward expressed surprise" 7 "[' quite', ' mad', '"".', ' No', 'ë', 'l', ' Cow', 'ard']"
+2818 651 Name of father of x -1 Name of father of Noël Coward Arthur Sabin Coward Noël Coward "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' No' 'ë' 'l' ' Cow'
+ 'ard' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of No ë l Cow ard is not known .
+
+ The" False hundred songs. The Noël Coward Society's 8 [' hundred', ' songs', '.', ' The', ' No', 'ë', 'l', ' Cow', 'ard']
+2819 654 Name of father of x -1 Name of father of Madeleine Albright Josef Korbel Madeleine Albright "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Secretary of State Madeleine Albright tried to force Yugoslavia 7 [' Secretary', ' of', ' State', ' Made', 'le', 'ine', ' Al', 'bright']
+2820 654 Name of father of x -1 Name of father of Madeleine Albright Josef Korbel Madeleine Albright "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False president and Madeleine Albright were angered 6 [' president', ' and', ' Made', 'le', 'ine', ' Al', 'bright']
+2821 654 Name of father of x -1 Name of father of Madeleine Albright Josef Korbel Madeleine Albright "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Secretary of State Madeleine Albright tried to force 7 [' Secretary', ' of', ' State', ' Made', 'le', 'ine', ' Al', 'bright']
+2822 654 Name of father of x -1 Name of father of Madeleine Albright Josef Korbel Madeleine Albright "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False mistakes a photo of Madeleine Albright for Leslie's grandmother, 8 [' mistakes', ' a', ' photo', ' of', ' Made', 'le', 'ine', ' Al', 'bright']
+2823 654 Name of father of x -1 Name of father of Madeleine Albright Josef Korbel Madeleine Albright "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' grandmother' ' of' ' two' ',' ' a' ' great' ' grandmother']" "
+
+ I am a mother of two , a wife , a grandmother of two , a great grandmother" False Secretary of State Madeleine Albright and former Nigerian 7 [' Secretary', ' of', ' State', ' Made', 'le', 'ine', ' Al', 'bright']
+2824 655 Name of father of x -1 Name of father of Robert Burns William Burnes Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ',' ' and'
+ ' was' ' educated' ' at' ' the' ' University' ' of' ' Glasgow' '.']" , the poet , was born in A yr shire , and was educated at the University of Glasgow . False 1 ['Robert', ' Burns']
+2825 655 Name of father of x -1 Name of father of Robert Burns William Burnes Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ',' ' and'
+ ' was' ' educated' ' at' ' the' ' University' ' of' ' Glasgow' '.']" , the poet , was born in A yr shire , and was educated at the University of Glasgow . False Greek god Pan. Robert Burns refers to 5 [' Greek', ' god', ' Pan', '.', ' Robert', ' Burns']
+2826 655 Name of father of x -1 Name of father of Robert Burns William Burnes Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ',' ' and'
+ ' was' ' educated' ' at' ' the' ' University' ' of' ' Glasgow' '.']" , the poet , was born in A yr shire , and was educated at the University of Glasgow . False as such by Robert Burns in the poem Holy 4 [' as', ' such', ' by', ' Robert', ' Burns']
+2827 655 Name of father of x -1 Name of father of Robert Burns William Burnes Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ',' ' and'
+ ' was' ' educated' ' at' ' the' ' University' ' of' ' Glasgow' '.']" , the poet , was born in A yr shire , and was educated at the University of Glasgow . False later be used by Robert Burns as a poetic 5 [' later', ' be', ' used', ' by', ' Robert', ' Burns']
+2828 655 Name of father of x -1 Name of father of Robert Burns William Burnes Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ',' ' and'
+ ' was' ' educated' ' at' ' the' ' University' ' of' ' Glasgow' '.']" , the poet , was born in A yr shire , and was educated at the University of Glasgow . False 1 ['Robert', ' Burns']
+2829 657 Name of father of x -1 Name of father of Nikolai Rimsky-Korsakov Andrey Rimsky-Korsakov Nikolai Rimsky-Korsakov "['\n' '\n' 'Nik' 'ol' 'ai' ' R' 'ims' 'ky' '-' 'K' 'ors' 'akov' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ Nik ol ai R ims ky - K ors akov ( ; ; ; ; ; ; ;" False Alexander Glazunov, Nikolai Rimsky-Korsakov and Nikolai Tcherepnin, 14 [' Alexander', ' Gl', 'az', 'un', 'ov', ',', ' Nikol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2830 657 Name of father of x -1 Name of father of Nikolai Rimsky-Korsakov Andrey Rimsky-Korsakov Nikolai Rimsky-Korsakov "['\n' '\n' 'Nik' 'ol' 'ai' ' R' 'ims' 'ky' '-' 'K' 'ors' 'akov' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ Nik ol ai R ims ky - K ors akov ( ; ; ; ; ; ; ;" False " Rimsky-Korsakov =
+" 17 [' R', 'ims', 'ky', '-', 'K', 'ors', 'akov', ' =', 'Nik', 'ol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2831 657 Name of father of x -1 Name of father of Nikolai Rimsky-Korsakov Andrey Rimsky-Korsakov Nikolai Rimsky-Korsakov "['\n' '\n' 'Nik' 'ol' 'ai' ' R' 'ims' 'ky' '-' 'K' 'ors' 'akov' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ Nik ol ai R ims ky - K ors akov ( ; ; ; ; ; ; ;" False Modest Mussorgsky, Nikolai Rimsky-Korsakov and Alexander Borodin 13 [' Modest', ' Muss', 'org', 'sky', ',', ' Nikol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2832 657 Name of father of x -1 Name of father of Nikolai Rimsky-Korsakov Andrey Rimsky-Korsakov Nikolai Rimsky-Korsakov "['\n' '\n' 'Nik' 'ol' 'ai' ' R' 'ims' 'ky' '-' 'K' 'ors' 'akov' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ Nik ol ai R ims ky - K ors akov ( ; ; ; ; ; ; ;" False " Rimsky-Korsakov =
+" 17 [' R', 'ims', 'ky', '-', 'K', 'ors', 'akov', ' =', 'Nik', 'ol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2833 657 Name of father of x -1 Name of father of Nikolai Rimsky-Korsakov Andrey Rimsky-Korsakov Nikolai Rimsky-Korsakov "['\n' '\n' 'Nik' 'ol' 'ai' ' R' 'ims' 'ky' '-' 'K' 'ors' 'akov' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ Nik ol ai R ims ky - K ors akov ( ; ; ; ; ; ; ;" False Modest Mussorgsky, Nikolai Rimsky-Korsakov and Alexander 13 [' Modest', ' Muss', 'org', 'sky', ',', ' Nikol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2834 658 Name of father of x -1 Name of father of Emmanuel Macron Jean-Michel Macron Emmanuel Macron "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ' Brig'
+ 'itte' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have' ' been'
+ ' married']" ", the French president , and his wife Brig itte .
+
+ The couple , who have been married" False " years with Emmanuel Macron from France
+" 3 [' years', ' with', ' Emmanuel', ' Macron']
+2835 659 Name of father of x -1 Name of father of Paul Verlaine Nicolas Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' the' ' poet'
+ 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the poet , and of his wife , the poet ess .
+
+ The name of the mother" False with lyrics by Paul Verlaine was aborted by the 6 [' with', ' lyrics', ' by', ' Paul', ' Ver', 'l', 'aine']
+2836 659 Name of father of x -1 Name of father of Paul Verlaine Nicolas Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' the' ' poet'
+ 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the poet , and of his wife , the poet ess .
+
+ The name of the mother" False Mallarmé and Paul Verlaine worshipped Wagner. 7 [' Mall', 'arm', 'é', ' and', ' Paul', ' Ver', 'l', 'aine']
+2837 659 Name of father of x -1 Name of father of Paul Verlaine Nicolas Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' the' ' poet'
+ 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the poet , and of his wife , the poet ess .
+
+ The name of the mother" False " l'espinette"" to words by Paul Verlaine and Clément Marot," 12 "[' l', ""'"", 'esp', 'in', 'ette', '""', ' to', ' words', ' by', ' Paul', ' Ver', 'l', 'aine']"
+2838 659 Name of father of x -1 Name of father of Paul Verlaine Nicolas Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' the' ' poet'
+ 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the poet , and of his wife , the poet ess .
+
+ The name of the mother" False Stéphane Mallarmé and Paul Verlaine worshipped Wagner. 11 [' St', 'é', 'ph', 'ane', ' Mall', 'arm', 'é', ' and', ' Paul', ' Ver', 'l', 'aine']
+2839 659 Name of father of x -1 Name of father of Paul Verlaine Nicolas Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' the' ' poet'
+ 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' mother']" ", the poet , and of his wife , the poet ess .
+
+ The name of the mother" False lyrics by Paul Verlaine was aborted 5 [' lyrics', ' by', ' Paul', ' Ver', 'l', 'aine']
+2840 660 Name of father of x -1 Name of father of Muammar Gaddafi Mohammed Abdus-Salam Muammar Gaddafi "[',' ' the' ' late' ' Libyan' ' leader' ',' ' and' ' his' ' son' ' Sa'
+ 'if' ' al' '-' 'Islam' ' Gaddafi' ',' ' who' ' is' ' currently' ' in']" , the late Libyan leader , and his son Sa if al - Islam Gaddafi , who is currently in False government under Colonel Muammar Gaddafi escalated, and 6 [' government', ' under', ' Colonel', ' Mu', 'am', 'mar', ' Gaddafi']
+2841 660 Name of father of x -1 Name of father of Muammar Gaddafi Mohammed Abdus-Salam Muammar Gaddafi "[',' ' the' ' late' ' Libyan' ' leader' ',' ' and' ' his' ' son' ' Sa'
+ 'if' ' al' '-' 'Islam' ' Gaddafi' ',' ' who' ' is' ' currently' ' in']" , the late Libyan leader , and his son Sa if al - Islam Gaddafi , who is currently in False 3 ['Mu', 'am', 'mar', ' Gaddafi']
+2842 660 Name of father of x -1 Name of father of Muammar Gaddafi Mohammed Abdus-Salam Muammar Gaddafi "[',' ' the' ' late' ' Libyan' ' leader' ',' ' and' ' his' ' son' ' Sa'
+ 'if' ' al' '-' 'Islam' ' Gaddafi' ',' ' who' ' is' ' currently' ' in']" , the late Libyan leader , and his son Sa if al - Islam Gaddafi , who is currently in False another, he echoed Muammar Gaddafi by accusing protesters 7 [' another', ',', ' he', ' echoed', ' Mu', 'am', 'mar', ' Gaddafi']
+2843 660 Name of father of x -1 Name of father of Muammar Gaddafi Mohammed Abdus-Salam Muammar Gaddafi "[',' ' the' ' late' ' Libyan' ' leader' ',' ' and' ' his' ' son' ' Sa'
+ 'if' ' al' '-' 'Islam' ' Gaddafi' ',' ' who' ' is' ' currently' ' in']" , the late Libyan leader , and his son Sa if al - Islam Gaddafi , who is currently in False 3 ['Mu', 'am', 'mar', ' Gaddafi']
+2844 660 Name of father of x -1 Name of father of Muammar Gaddafi Mohammed Abdus-Salam Muammar Gaddafi "[',' ' the' ' late' ' Libyan' ' leader' ',' ' and' ' his' ' son' ' Sa'
+ 'if' ' al' '-' 'Islam' ' Gaddafi' ',' ' who' ' is' ' currently' ' in']" , the late Libyan leader , and his son Sa if al - Islam Gaddafi , who is currently in False meeting with leader Muammar Gaddafi in Tripoli. In May 6 [' meeting', ' with', ' leader', ' Mu', 'am', 'mar', ' Gaddafi']
+2845 661 Name of father of x -1 Name of father of George IV of the United Kingdom George III of Great Britain George IV of the United Kingdom "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' George' ' IV'
+ ' of' ' the' ' United' ' Kingdom' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of George IV of the United Kingdom is not known .
+
+" False Londonderry) and King George IV of the United Kingdom in Hanover in October. 12 [' L', 'ond', 'ond', 'erry', ')', ' and', ' King', ' George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2846 661 Name of father of x -1 Name of father of George IV of the United Kingdom George III of Great Britain George IV of the United Kingdom "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' George' ' IV'
+ ' of' ' the' ' United' ' Kingdom' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of George IV of the United Kingdom is not known .
+
+" False " United Kingdom =
+" 8 [' United', ' Kingdom', ' =', 'George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2847 661 Name of father of x -1 Name of father of George IV of the United Kingdom George III of Great Britain George IV of the United Kingdom "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' George' ' IV'
+ ' of' ' the' ' United' ' Kingdom' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of George IV of the United Kingdom is not known .
+
+" False " United Kingdom =
+" 8 [' United', ' Kingdom', ' =', 'George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2848 661 Name of father of x -1 Name of father of George IV of the United Kingdom George III of Great Britain George IV of the United Kingdom "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' George' ' IV'
+ ' of' ' the' ' United' ' Kingdom' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of George IV of the United Kingdom is not known .
+
+" False " Kingdom =
+" 7 [' Kingdom', ' =', 'George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2849 662 Name of father of x -1 Name of father of Paul Dirac Charles Dirac Paul Dirac "[',' ' the' ' physicist' ' who' ' discovered' ' the' ' Dir' 'ac'
+ ' equation' ',' ' and' ' the' ' Dir' 'ac' ' equation' ' is' ' the'
+ ' equation' ' that' ' describes']" , the physicist who discovered the Dir ac equation , and the Dir ac equation is the equation that describes False surprised to discover that Paul Dirac had been thinking 6 [' surprised', ' to', ' discover', ' that', ' Paul', ' Dir', 'ac']
+2850 662 Name of father of x -1 Name of father of Paul Dirac Charles Dirac Paul Dirac "[',' ' the' ' physicist' ' who' ' discovered' ' the' ' Dir' 'ac'
+ ' equation' ',' ' and' ' the' ' Dir' 'ac' ' equation' ' is' ' the'
+ ' equation' ' that' ' describes']" , the physicist who discovered the Dir ac equation , and the Dir ac equation is the equation that describes False to discover that Paul Dirac had been thinking 5 [' to', ' discover', ' that', ' Paul', ' Dir', 'ac']
+2851 662 Name of father of x -1 Name of father of Paul Dirac Charles Dirac Paul Dirac "[',' ' the' ' physicist' ' who' ' discovered' ' the' ' Dir' 'ac'
+ ' equation' ',' ' and' ' the' ' Dir' 'ac' ' equation' ' is' ' the'
+ ' equation' ' that' ' describes']" , the physicist who discovered the Dir ac equation , and the Dir ac equation is the equation that describes False positron, after a paper by Paul Dirac proposed that electrons 9 [' posit', 'ron', ',', ' after', ' a', ' paper', ' by', ' Paul', ' Dir', 'ac']
+2852 662 Name of father of x -1 Name of father of Paul Dirac Charles Dirac Paul Dirac "[',' ' the' ' physicist' ' who' ' discovered' ' the' ' Dir' 'ac'
+ ' equation' ',' ' and' ' the' ' Dir' 'ac' ' equation' ' is' ' the'
+ ' equation' ' that' ' describes']" , the physicist who discovered the Dir ac equation , and the Dir ac equation is the equation that describes False surprised to discover that Paul Dirac had been thinking 6 [' surprised', ' to', ' discover', ' that', ' Paul', ' Dir', 'ac']
+2853 662 Name of father of x -1 Name of father of Paul Dirac Charles Dirac Paul Dirac "[',' ' the' ' physicist' ' who' ' discovered' ' the' ' Dir' 'ac'
+ ' equation' ',' ' and' ' the' ' Dir' 'ac' ' equation' ' is' ' the'
+ ' equation' ' that' ' describes']" , the physicist who discovered the Dir ac equation , and the Dir ac equation is the equation that describes False In fact, in 1926, Paul Dirac derived the 8 [' In', ' fact', ',', ' in', ' 1926', ',', ' Paul', ' Dir', 'ac']
+2854 665 Name of father of x -1 Name of father of François Guizot André Guizot François Guizot "[',' ' the' ' French' ' states' 'man' ',' ' who' ' was' ' born' ' in'
+ ' 17' '87' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the French states man , who was born in 17 87 .
+
+ The name of the father" False " Saint-Simonians such as François Guizot and Augustin Thierry.
+" 9 [' Saint', '-', 'Simon', 'ians', ' such', ' as', ' François', ' Gu', 'iz', 'ot']
+2855 665 Name of father of x -1 Name of father of François Guizot André Guizot François Guizot "[',' ' the' ' French' ' states' 'man' ',' ' who' ' was' ' born' ' in'
+ ' 17' '87' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the French states man , who was born in 17 87 .
+
+ The name of the father" False " Saint-Simonians such as François Guizot and Augustin Thierry.
+" 9 [' Saint', '-', 'Simon', 'ians', ' such', ' as', ' François', ' Gu', 'iz', 'ot']
+2856 665 Name of father of x -1 Name of father of François Guizot André Guizot François Guizot "[',' ' the' ' French' ' states' 'man' ',' ' who' ' was' ' born' ' in'
+ ' 17' '87' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the French states man , who was born in 17 87 .
+
+ The name of the father" False " Saint-Simonians such as François Guizot and Augustin Thierry.
+" 9 [' Saint', '-', 'Simon', 'ians', ' such', ' as', ' François', ' Gu', 'iz', 'ot']
+2857 665 Name of father of x -1 Name of father of François Guizot André Guizot François Guizot "[',' ' the' ' French' ' states' 'man' ',' ' who' ' was' ' born' ' in'
+ ' 17' '87' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the French states man , who was born in 17 87 .
+
+ The name of the father" False in years from François Guizot over the Swiss Civil 6 [' in', ' years', ' from', ' François', ' Gu', 'iz', 'ot']
+2858 665 Name of father of x -1 Name of father of François Guizot André Guizot François Guizot "[',' ' the' ' French' ' states' 'man' ',' ' who' ' was' ' born' ' in'
+ ' 17' '87' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father']" ", the French states man , who was born in 17 87 .
+
+ The name of the father" False time in years from François Guizot over the Swiss 7 [' time', ' in', ' years', ' from', ' François', ' Gu', 'iz', 'ot']
+2859 666 Name of father of x -1 Name of father of Freddie Mercury Bomi Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' band' ' Queen'
+ '.' ' I' ' have' ' been' ' since' ' I' ' was' ' a']" ".
+
+ I am a big fan of the band Queen . I have been since I was a" False 2 ['Fred', 'die', ' Mercury']
+2860 666 Name of father of x -1 Name of father of Freddie Mercury Bomi Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' band' ' Queen'
+ '.' ' I' ' have' ' been' ' since' ' I' ' was' ' a']" ".
+
+ I am a big fan of the band Queen . I have been since I was a" False 2 ['Fred', 'die', ' Mercury']
+2861 666 Name of father of x -1 Name of father of Freddie Mercury Bomi Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' band' ' Queen'
+ '.' ' I' ' have' ' been' ' since' ' I' ' was' ' a']" ".
+
+ I am a big fan of the band Queen . I have been since I was a" False appeared at The Freddie Mercury Tribute Concert 4 [' appeared', ' at', ' The', ' Freddie', ' Mercury']
+2862 666 Name of father of x -1 Name of father of Freddie Mercury Bomi Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' band' ' Queen'
+ '.' ' I' ' have' ' been' ' since' ' I' ' was' ' a']" ".
+
+ I am a big fan of the band Queen . I have been since I was a" False " Mercury =
+" 4 [' Mercury', ' =', 'Fred', 'die', ' Mercury']
+2863 666 Name of father of x -1 Name of father of Freddie Mercury Bomi Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' band' ' Queen'
+ '.' ' I' ' have' ' been' ' since' ' I' ' was' ' a']" ".
+
+ I am a big fan of the band Queen . I have been since I was a" False knelt on stage at The Freddie Mercury Tribute Concert and 7 [' kn', 'elt', ' on', ' stage', ' at', ' The', ' Freddie', ' Mercury']
+2864 667 Name of father of x -1 Name of father of Buckminster Fuller Richard Buckminster Fuller Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' ge' 'odes' 'ic' ' dome' '.' '\n' '\n' 'The']" ", the father of modern architecture , and the father of the ge odes ic dome .
+
+ The" False Pallot's project Buckminster Fuller EP in February 2009. 6 "[' Pall', 'ot', ""'s"", ' project', ' Buck', 'minster', ' Fuller']"
+2865 667 Name of father of x -1 Name of father of Buckminster Fuller Richard Buckminster Fuller Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' ge' 'odes' 'ic' ' dome' '.' '\n' '\n' 'The']" ", the father of modern architecture , and the father of the ge odes ic dome .
+
+ The" False even invited R. Buckminster Fuller to design the structure; 6 [' even', ' invited', ' R', '.', ' Buck', 'minster', ' Fuller']
+2866 667 Name of father of x -1 Name of father of Buckminster Fuller Richard Buckminster Fuller Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' ge' 'odes' 'ic' ' dome' '.' '\n' '\n' 'The']" ", the father of modern architecture , and the father of the ge odes ic dome .
+
+ The" False architect Richard Buckminster Fuller is famous for designing 4 [' architect', ' Richard', ' Buck', 'minster', ' Fuller']
+2867 667 Name of father of x -1 Name of father of Buckminster Fuller Richard Buckminster Fuller Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' ge' 'odes' 'ic' ' dome' '.' '\n' '\n' 'The']" ", the father of modern architecture , and the father of the ge odes ic dome .
+
+ The" False soccer ball, Buckminster Fuller geodesic dome, 5 [' soccer', ' ball', ',', ' Buck', 'minster', ' Fuller']
+2868 667 Name of father of x -1 Name of father of Buckminster Fuller Richard Buckminster Fuller Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' ge' 'odes' 'ic' ' dome' '.' '\n' '\n' 'The']" ", the father of modern architecture , and the father of the ge odes ic dome .
+
+ The" False Airport, while Richard Buckminster Fuller pioneered the 6 [' Airport', ',', ' while', ' Richard', ' Buck', 'minster', ' Fuller']
+2869 668 Name of father of x -1 Name of father of John McCain John S. McCain, Jr. John McCain "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' to' ' say' ' that'
+ ' the' ' United' ' States' ' should' ' not' ' be' ' involved' ' in'
+ ' the' ' Vietnam']" , the man who was the first to say that the United States should not be involved in the Vietnam False spokeswoman for the John McCain campaign argued 4 [' spokeswoman', ' for', ' the', ' John', ' McCain']
+2870 668 Name of father of x -1 Name of father of John McCain John S. McCain, Jr. John McCain "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' to' ' say' ' that'
+ ' the' ' United' ' States' ' should' ' not' ' be' ' involved' ' in'
+ ' the' ' Vietnam']" , the man who was the first to say that the United States should not be involved in the Vietnam False Darrell Hammond as John McCain and Will Ferrell 4 [' Darrell', ' Hammond', ' as', ' John', ' McCain']
+2871 668 Name of father of x -1 Name of father of John McCain John S. McCain, Jr. John McCain "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' to' ' say' ' that'
+ ' the' ' United' ' States' ' should' ' not' ' be' ' involved' ' in'
+ ' the' ' Vietnam']" , the man who was the first to say that the United States should not be involved in the Vietnam False percent, and Senator John McCain with 23.6 percent. 5 [' percent', ',', ' and', ' Senator', ' John', ' McCain']
+2872 668 Name of father of x -1 Name of father of John McCain John S. McCain, Jr. John McCain "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' to' ' say' ' that'
+ ' the' ' United' ' States' ' should' ' not' ' be' ' involved' ' in'
+ ' the' ' Vietnam']" , the man who was the first to say that the United States should not be involved in the Vietnam False " Sarah Palin and John McCain had the word ""nope""." 4 [' Sarah', ' Palin', ' and', ' John', ' McCain']
+2873 668 Name of father of x -1 Name of father of John McCain John S. McCain, Jr. John McCain "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' to' ' say' ' that'
+ ' the' ' United' ' States' ' should' ' not' ' be' ' involved' ' in'
+ ' the' ' Vietnam']" , the man who was the first to say that the United States should not be involved in the Vietnam False presidential candidate John McCain calling him 3 [' presidential', ' candidate', ' John', ' McCain']
+2874 669 Name of father of x -1 Name of father of Mustafa Kemal Atatürk Ali Rıza Efendi Mustafa Kemal Atatürk "['\n' '\n' 'Must' 'afa' ' Kem' 'al' ' At' 'at' 'ür' 'k' ' was' ' born'
+ ' on' ' November' ' 10' ',' ' 18' '81' ',' ' in']" "
+
+ Must afa Kem al At at ür k was born on November 10 , 18 81 , in" False carried the remains of Mustafa Kemal Atatürk from Istanbul 11 [' carried', ' the', ' remains', ' of', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2875 669 Name of father of x -1 Name of father of Mustafa Kemal Atatürk Ali Rıza Efendi Mustafa Kemal Atatürk "['\n' '\n' 'Must' 'afa' ' Kem' 'al' ' At' 'at' 'ür' 'k' ' was' ' born'
+ ' on' ' November' ' 10' ',' ' 18' '81' ',' ' in']" "
+
+ Must afa Kem al At at ür k was born on November 10 , 18 81 , in" False the remains of Mustafa Kemal Atatürk from Istanbul to 10 [' the', ' remains', ' of', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2876 669 Name of father of x -1 Name of father of Mustafa Kemal Atatürk Ali Rıza Efendi Mustafa Kemal Atatürk "['\n' '\n' 'Must' 'afa' ' Kem' 'al' ' At' 'at' 'ür' 'k' ' was' ' born'
+ ' on' ' November' ' 10' ',' ' 18' '81' ',' ' in']" "
+
+ Must afa Kem al At at ür k was born on November 10 , 18 81 , in" False the remains of Mustafa Kemal Atatürk from Istanbul to 10 [' the', ' remains', ' of', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2877 669 Name of father of x -1 Name of father of Mustafa Kemal Atatürk Ali Rıza Efendi Mustafa Kemal Atatürk "['\n' '\n' 'Must' 'afa' ' Kem' 'al' ' At' 'at' 'ür' 'k' ' was' ' born'
+ ' on' ' November' ' 10' ',' ' 18' '81' ',' ' in']" "
+
+ Must afa Kem al At at ür k was born on November 10 , 18 81 , in" False the remains of Mustafa Kemal Atatürk from Istanbul 10 [' the', ' remains', ' of', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2878 669 Name of father of x -1 Name of father of Mustafa Kemal Atatürk Ali Rıza Efendi Mustafa Kemal Atatürk "['\n' '\n' 'Must' 'afa' ' Kem' 'al' ' At' 'at' 'ür' 'k' ' was' ' born'
+ ' on' ' November' ' 10' ',' ' 18' '81' ',' ' in']" "
+
+ Must afa Kem al At at ür k was born on November 10 , 18 81 , in" False deemed insulting to Mustafa Kemal Atatürk and some material 10 [' deemed', ' insulting', ' to', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2879 670 Name of father of x -1 Name of father of Max Born Gustav Jacob Born Max Born "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' born' ' in'
+ ' Germany' ' in' ' 18' '82' '.' '\n' '\n' 'The' ' first' ' of' ' the']" ", the famous physicist , who was born in Germany in 18 82 .
+
+ The first of the" False colleagues, including Max Born and Richard 4 [' colleagues', ',', ' including', ' Max', ' Born']
+2880 670 Name of father of x -1 Name of father of Max Born Gustav Jacob Born Max Born "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' born' ' in'
+ ' Germany' ' in' ' 18' '82' '.' '\n' '\n' 'The' ' first' ' of' ' the']" ", the famous physicist , who was born in Germany in 18 82 .
+
+ The first of the" False as early as 1909, Max Born had given a 6 [' as', ' early', ' as', ' 1909', ',', ' Max', ' Born']
+2881 670 Name of father of x -1 Name of father of Max Born Gustav Jacob Born Max Born "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' born' ' in'
+ ' Germany' ' in' ' 18' '82' '.' '\n' '\n' 'The' ' first' ' of' ' the']" ", the famous physicist , who was born in Germany in 18 82 .
+
+ The first of the" False 1 ['Max', ' Born']
+2882 670 Name of father of x -1 Name of father of Max Born Gustav Jacob Born Max Born "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' born' ' in'
+ ' Germany' ' in' ' 18' '82' '.' '\n' '\n' 'The' ' first' ' of' ' the']" ", the famous physicist , who was born in Germany in 18 82 .
+
+ The first of the" False showed his results to Max Born in Göttingen, 5 [' showed', ' his', ' results', ' to', ' Max', ' Born']
+2883 670 Name of father of x -1 Name of father of Max Born Gustav Jacob Born Max Born "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' born' ' in'
+ ' Germany' ' in' ' 18' '82' '.' '\n' '\n' 'The' ' first' ' of' ' the']" ", the famous physicist , who was born in Germany in 18 82 .
+
+ The first of the" False research into physics. Max Born became a naturalised 5 [' research', ' into', ' physics', '.', ' Max', ' Born']
+2884 671 Name of father of x -1 Name of father of John Adams John Adams, Sr. John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' John']" ", the first president of the United States .
+
+ The first president of the United States , John" False commentators such as John Adams observed that 4 [' commentators', ' such', ' as', ' John', ' Adams']
+2885 671 Name of father of x -1 Name of father of John Adams John Adams, Sr. John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' John']" ", the first president of the United States .
+
+ The first president of the United States , John" False 1962 he enrolled at John Adams High School, but 5 [' 1962', ' he', ' enrolled', ' at', ' John', ' Adams']
+2886 671 Name of father of x -1 Name of father of John Adams John Adams, Sr. John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' John']" ", the first president of the United States .
+
+ The first president of the United States , John" False by Walton Music, John Adams (as the chorus 5 [' by', ' Walton', ' Music', ',', ' John', ' Adams']
+2887 671 Name of father of x -1 Name of father of John Adams John Adams, Sr. John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' John']" ", the first president of the United States .
+
+ The first president of the United States , John" False Jefferson, in a letter to John Adams, would describe 7 [' Jefferson', ',', ' in', ' a', ' letter', ' to', ' John', ' Adams']
+2888 671 Name of father of x -1 Name of father of John Adams John Adams, Sr. John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' John']" ", the first president of the United States .
+
+ The first president of the United States , John" False Longford (2006) and in John Adams (2008), a seven-part 8 [' Long', 'ford', ' (', '2006', ')', ' and', ' in', ' John', ' Adams']
+2889 672 Name of father of x -1 Name of father of Teresa of Ávila Alonso Sánchez de Cepeda Teresa of Ávila "['\n' '\n' 'The' ' name' ' Teresa' ' of' ' �' '�' 'vil' 'a' ' is' ' a'
+ ' feminine' ' given' ' name' '.' ' It' ' is' ' a' ' variant']" "
+
+ The name Teresa of � � vil a is a feminine given name . It is a variant" False Luis de León, Teresa of Ávila and John of the Cross, 10 [' Luis', ' de', ' Le', 'ón', ',', ' Teresa', ' of', ' �', '�', 'vil', 'a']
+2890 672 Name of father of x -1 Name of father of Teresa of Ávila Alonso Sánchez de Cepeda Teresa of Ávila "['\n' '\n' 'The' ' name' ' Teresa' ' of' ' �' '�' 'vil' 'a' ' is' ' a'
+ ' feminine' ' given' ' name' '.' ' It' ' is' ' a' ' variant']" "
+
+ The name Teresa of � � vil a is a feminine given name . It is a variant" False fray Luis de León, Teresa of Ávila and John of 11 [' fray', ' Luis', ' de', ' Le', 'ón', ',', ' Teresa', ' of', ' �', '�', 'vil', 'a']
+2891 672 Name of father of x -1 Name of father of Teresa of Ávila Alonso Sánchez de Cepeda Teresa of Ávila "['\n' '\n' 'The' ' name' ' Teresa' ' of' ' �' '�' 'vil' 'a' ' is' ' a'
+ ' feminine' ' given' ' name' '.' ' It' ' is' ' a' ' variant']" "
+
+ The name Teresa of � � vil a is a feminine given name . It is a variant" False Luis de León, Teresa of Ávila and John of 10 [' Luis', ' de', ' Le', 'ón', ',', ' Teresa', ' of', ' �', '�', 'vil', 'a']
+2892 672 Name of father of x -1 Name of father of Teresa of Ávila Alonso Sánchez de Cepeda Teresa of Ávila "['\n' '\n' 'The' ' name' ' Teresa' ' of' ' �' '�' 'vil' 'a' ' is' ' a'
+ ' feminine' ' given' ' name' '.' ' It' ' is' ' a' ' variant']" "
+
+ The name Teresa of � � vil a is a feminine given name . It is a variant" False other of Saint Teresa of Ávila with a quill 8 [' other', ' of', ' Saint', ' Teresa', ' of', ' �', '�', 'vil', 'a']
+2893 672 Name of father of x -1 Name of father of Teresa of Ávila Alonso Sánchez de Cepeda Teresa of Ávila "['\n' '\n' 'The' ' name' ' Teresa' ' of' ' �' '�' 'vil' 'a' ' is' ' a'
+ ' feminine' ' given' ' name' '.' ' It' ' is' ' a' ' variant']" "
+
+ The name Teresa of � � vil a is a feminine given name . It is a variant" False 6 ['Te', 'resa', ' of', ' �', '�', 'vil', 'a']
+2894 673 Name of father of x -1 Name of father of William Wordsworth John Wordsworth William Wordsworth "[',' ' the' ' poet' ',' ' was' ' born' ' at' ' Cock' 'erm' 'outh' ','
+ ' in' ' Cumber' 'land' ',' ' in' ' 17' '70' '.' ' He']" , the poet , was born at Cock erm outh , in Cumber land , in 17 70 . He False exhibited in 1832, William Wordsworth (1833), Jules Verne 7 [' exhibited', ' in', ' 18', '32', ',', ' William', ' Word', 'sworth']
+2895 673 Name of father of x -1 Name of father of William Wordsworth John Wordsworth William Wordsworth "[',' ' the' ' poet' ',' ' was' ' born' ' at' ' Cock' 'erm' 'outh' ','
+ ' in' ' Cumber' 'land' ',' ' in' ' 17' '70' '.' ' He']" , the poet , was born at Cock erm outh , in Cumber land , in 17 70 . He False " inspiring, for example, William Wordsworth to write ""The Egyptian" 7 [' inspiring', ',', ' for', ' example', ',', ' William', ' Word', 'sworth']
+2896 673 Name of father of x -1 Name of father of William Wordsworth John Wordsworth William Wordsworth "[',' ' the' ' poet' ',' ' was' ' born' ' at' ' Cock' 'erm' 'outh' ','
+ ' in' ' Cumber' 'land' ',' ' in' ' 17' '70' '.' ' He']" , the poet , was born at Cock erm outh , in Cumber land , in 17 70 . He False writers such as William Wordsworth and Samuel Taylor 5 [' writers', ' such', ' as', ' William', ' Word', 'sworth']
+2897 673 Name of father of x -1 Name of father of William Wordsworth John Wordsworth William Wordsworth "[',' ' the' ' poet' ',' ' was' ' born' ' at' ' Cock' 'erm' 'outh' ','
+ ' in' ' Cumber' 'land' ',' ' in' ' 17' '70' '.' ' He']" , the poet , was born at Cock erm outh , in Cumber land , in 17 70 . He False poem The Prelude, William Wordsworth recounted exploring 6 [' poem', ' The', ' Prelude', ',', ' William', ' Word', 'sworth']
+2898 673 Name of father of x -1 Name of father of William Wordsworth John Wordsworth William Wordsworth "[',' ' the' ' poet' ',' ' was' ' born' ' at' ' Cock' 'erm' 'outh' ','
+ ' in' ' Cumber' 'land' ',' ' in' ' 17' '70' '.' ' He']" , the poet , was born at Cock erm outh , in Cumber land , in 17 70 . He False Harp' [...] to ' To William Wordsworth ] [...] and 9 "[' Har', 'p', ""'"", ' [...]', ' to', "" '"", ' To', ' William', ' Word', 'sworth']"
+2899 674 Name of father of x -1 Name of father of Karlheinz Stockhausen Simon Stockhausen Karlheinz Stockhausen "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl' 'he' 'in' 'z' ' Stock'
+ 'haus' 'en' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl']" "
+
+ Name of mother of Karl he in z Stock haus en
+
+ Name of mother of Karl" False early modern eras. Karlheinz Stockhausen and Hans Zimmer are 10 [' early', ' modern', ' eras', '.', ' Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2900 674 Name of father of x -1 Name of father of Karlheinz Stockhausen Simon Stockhausen Karlheinz Stockhausen "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl' 'he' 'in' 'z' ' Stock'
+ 'haus' 'en' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl']" "
+
+ Name of mother of Karl he in z Stock haus en
+
+ Name of mother of Karl" False Thelonious Monk and Karlheinz Stockhausen each performed 11 [' The', 'lon', 'ious', ' Monk', ' and', ' Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2901 674 Name of father of x -1 Name of father of Karlheinz Stockhausen Simon Stockhausen Karlheinz Stockhausen "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl' 'he' 'in' 'z' ' Stock'
+ 'haus' 'en' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl']" "
+
+ Name of mother of Karl he in z Stock haus en
+
+ Name of mother of Karl" False 6 ['Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2902 674 Name of father of x -1 Name of father of Karlheinz Stockhausen Simon Stockhausen Karlheinz Stockhausen "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl' 'he' 'in' 'z' ' Stock'
+ 'haus' 'en' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl']" "
+
+ Name of mother of Karl he in z Stock haus en
+
+ Name of mother of Karl" False " German composer Karlheinz Stockhausen in Cologne.
+" 8 [' German', ' composer', ' Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2903 674 Name of father of x -1 Name of father of Karlheinz Stockhausen Simon Stockhausen Karlheinz Stockhausen "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl' 'he' 'in' 'z' ' Stock'
+ 'haus' 'en' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Karl']" "
+
+ Name of mother of Karl he in z Stock haus en
+
+ Name of mother of Karl" False " Stockhausen =
+" 10 [' Stock', 'haus', 'en', ' =', 'Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2904 675 Name of father of x -1 Name of father of Pelé Dondinho Pelé "[',' ' the' ' Brazilian' ' soccer' ' player' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' Pel' 'é' ',' ' the' ' Brazilian' ' soccer'
+ ' player']" ", the Brazilian soccer player
+
+ The name of the father of Pel é , the Brazilian soccer player" False " footballing icon Pelé tweeted ""I always" 4 [' football', 'ing', ' icon', ' Pel', 'é']
+2905 675 Name of father of x -1 Name of father of Pelé Dondinho Pelé "[',' ' the' ' Brazilian' ' soccer' ' player' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' Pel' 'é' ',' ' the' ' Brazilian' ' soccer'
+ ' player']" ", the Brazilian soccer player
+
+ The name of the father of Pel é , the Brazilian soccer player" False he was named by Pelé as one of the world's 5 [' he', ' was', ' named', ' by', ' Pel', 'é']
+2906 675 Name of father of x -1 Name of father of Pelé Dondinho Pelé "[',' ' the' ' Brazilian' ' soccer' ' player' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' Pel' 'é' ',' ' the' ' Brazilian' ' soccer'
+ ' player']" ", the Brazilian soccer player
+
+ The name of the father of Pel é , the Brazilian soccer player" False Soccer Federation, Pelé's bicycle kick 4 [' Soccer', ' Federation', ',', ' Pel', 'é']
+2907 675 Name of father of x -1 Name of father of Pelé Dondinho Pelé "[',' ' the' ' Brazilian' ' soccer' ' player' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' Pel' 'é' ',' ' the' ' Brazilian' ' soccer'
+ ' player']" ", the Brazilian soccer player
+
+ The name of the father of Pel é , the Brazilian soccer player" False 2 ['P', 'el', 'é']
+2908 675 Name of father of x -1 Name of father of Pelé Dondinho Pelé "[',' ' the' ' Brazilian' ' soccer' ' player' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' Pel' 'é' ',' ' the' ' Brazilian' ' soccer'
+ ' player']" ", the Brazilian soccer player
+
+ The name of the father of Pel é , the Brazilian soccer player" False Stallone and Pelé in the 1981 prisoner 4 [' Stall', 'one', ' and', ' Pel', 'é']
+2909 676 Name of father of x -1 Name of father of Elizabeth Barrett Browning Edward Moulton-Barrett Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' poet' 'ess' ' is' ' not' ' known' '.']" ", the poet ess .
+
+ The name of the father of the poet ess is not known ." False works of Shakespeare, Elizabeth Barrett Browning and her husband 7 [' works', ' of', ' Shakespeare', ',', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2910 676 Name of father of x -1 Name of father of Elizabeth Barrett Browning Edward Moulton-Barrett Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' poet' 'ess' ' is' ' not' ' known' '.']" ", the poet ess .
+
+ The name of the father of the poet ess is not known ." False he considered Elizabeth Barrett Browning the best contemporary 5 [' he', ' considered', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2911 676 Name of father of x -1 Name of father of Elizabeth Barrett Browning Edward Moulton-Barrett Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' poet' 'ess' ' is' ' not' ' known' '.']" ", the poet ess .
+
+ The name of the father of the poet ess is not known ." False literature. Elizabeth Barrett Browning was also a strong 5 [' literature', '.', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2912 676 Name of father of x -1 Name of father of Elizabeth Barrett Browning Edward Moulton-Barrett Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' poet' 'ess' ' is' ' not' ' known' '.']" ", the poet ess .
+
+ The name of the father of the poet ess is not known ." False he considered Elizabeth Barrett Browning the best contemporary 5 [' he', ' considered', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2913 676 Name of father of x -1 Name of father of Elizabeth Barrett Browning Edward Moulton-Barrett Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' father' ' of' ' the' ' poet' 'ess' ' is' ' not' ' known' '.']" ", the poet ess .
+
+ The name of the father of the poet ess is not known ." False literature. Elizabeth Barrett Browning was also a strong 5 [' literature', '.', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2914 677 Name of father of x -1 Name of father of Joan of Arc Jacques d'Arc Joan of Arc "[',' ' the' ' Maid' ' of' ' Orleans' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Maid of Orleans , and the
+ " False alleged that Joan of Arc and Gilles de Rais 4 [' alleged', ' that', ' Joan', ' of', ' Arc']
+2915 677 Name of father of x -1 Name of father of Joan of Arc Jacques d'Arc Joan of Arc "[',' ' the' ' Maid' ' of' ' Orleans' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Maid of Orleans , and the
+ " False 3 ['Jo', 'an', ' of', ' Arc']
+2916 677 Name of father of x -1 Name of father of Joan of Arc Jacques d'Arc Joan of Arc "[',' ' the' ' Maid' ' of' ' Orleans' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Maid of Orleans , and the
+ " False window depicting Joan of Arc was donated 4 [' window', ' depicting', ' Joan', ' of', ' Arc']
+2917 677 Name of father of x -1 Name of father of Joan of Arc Jacques d'Arc Joan of Arc "[',' ' the' ' Maid' ' of' ' Orleans' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Maid of Orleans , and the
+ " False " abuse"" and ""like Joan of Arc coming back" 7 "[' abuse', '""', ' and', ' ""', 'like', ' Joan', ' of', ' Arc']"
+2918 677 Name of father of x -1 Name of father of Joan of Arc Jacques d'Arc Joan of Arc "[',' ' the' ' Maid' ' of' ' Orleans' ',' ' and' ' the' '\n' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Maid of Orleans , and the
+ " False The figure of Joan of Arc represents both 5 [' The', ' figure', ' of', ' Joan', ' of', ' Arc']
+2919 678 Name of father of x -1 Name of father of Sylvester Stallone Frank Stallone Sylvester Stallone "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False appeared alongside Sylvester Stallone and Antonio 6 [' appeared', ' alongside', ' Sy', 'lves', 'ter', ' Stall', 'one']
+2920 678 Name of father of x -1 Name of father of Sylvester Stallone Frank Stallone Sylvester Stallone "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False 5 ['S', 'yl', 'ves', 'ter', ' Stall', 'one']
+2921 678 Name of father of x -1 Name of father of Sylvester Stallone Frank Stallone Sylvester Stallone "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False 1930s Cleveland, Sylvester Stallone leads a local labor 8 [' 1930', 's', ' Cleveland', ',', ' Sy', 'lves', 'ter', ' Stall', 'one']
+2922 678 Name of father of x -1 Name of father of Sylvester Stallone Frank Stallone Sylvester Stallone "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False style of actor Sylvester Stallone as the character 7 [' style', ' of', ' actor', ' Sy', 'lves', 'ter', ' Stall', 'one']
+2923 678 Name of father of x -1 Name of father of Sylvester Stallone Frank Stallone Sylvester Stallone "['\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' this' ' is' ' the' ' right'
+ ' place' ' to' ' post' ' this' ',' ' but' ' I' ""'m"" ' not']" "
+
+ I 'm not sure if this is the right place to post this , but I 'm not" False American actor Sylvester Stallone was interested 6 [' American', ' actor', ' Sy', 'lves', 'ter', ' Stall', 'one']
+2924 679 Name of father of x -1 Name of father of Jules Massenet Alexis Massenet Jules Massenet "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' J' 'ules' ' Mass' 'en' 'et' '.']" ", the composer .
+
+ The name of the father of the composer J ules Mass en et ." False " Jules Massenet =
+" 4 [' J', 'ules', ' Mass', 'en', 'et']
+2925 679 Name of father of x -1 Name of father of Jules Massenet Alexis Massenet Jules Massenet "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' J' 'ules' ' Mass' 'en' 'et' '.']" ", the composer .
+
+ The name of the father of the composer J ules Mass en et ." False " Franck and Jules Massenet among its founder-members.
+" 7 [' Fran', 'ck', ' and', ' J', 'ules', ' Mass', 'en', 'et']
+2926 679 Name of father of x -1 Name of father of Jules Massenet Alexis Massenet Jules Massenet "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' J' 'ules' ' Mass' 'en' 'et' '.']" ", the composer .
+
+ The name of the father of the composer J ules Mass en et ." False seventeen-year-old Jules Massenet was the orchestra's 9 [' seventeen', '-', 'year', '-', 'old', ' J', 'ules', ' Mass', 'en', 'et']
+2927 679 Name of father of x -1 Name of father of Jules Massenet Alexis Massenet Jules Massenet "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' J' 'ules' ' Mass' 'en' 'et' '.']" ", the composer .
+
+ The name of the father of the composer J ules Mass en et ." False " Jules Massenet =
+" 4 [' J', 'ules', ' Mass', 'en', 'et']
+2928 679 Name of father of x -1 Name of father of Jules Massenet Alexis Massenet Jules Massenet "[',' ' the' ' composer' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' composer' ' J' 'ules' ' Mass' 'en' 'et' '.']" ", the composer .
+
+ The name of the father of the composer J ules Mass en et ." False by Franz Lachner, Jules Massenet and Joachim Raff 10 [' by', ' Franz', ' L', 'ach', 'ner', ',', ' J', 'ules', ' Mass', 'en', 'et']
+2929 680 Name of father of x -1 Name of father of Kofi Annan Henry Reginald Annan Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' who' ' was'
+ ' born' ' in' ' Ghana' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the former UN Secretary - General , who was born in Ghana .
+
+ The name of the" False secretary general, Kofi Annan met separately 6 [' secretary', ' general', ',', ' K', 'ofi', ' Ann', 'an']
+2930 680 Name of father of x -1 Name of father of Kofi Annan Henry Reginald Annan Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' who' ' was'
+ ' born' ' in' ' Ghana' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the former UN Secretary - General , who was born in Ghana .
+
+ The name of the" False Secretary General Kofi Annan (2004), Late Night 5 [' Secretary', ' General', ' K', 'ofi', ' Ann', 'an']
+2931 680 Name of father of x -1 Name of father of Kofi Annan Henry Reginald Annan Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' who' ' was'
+ ' born' ' in' ' Ghana' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the former UN Secretary - General , who was born in Ghana .
+
+ The name of the" False Secretary-General Kofi Annan later recognised 6 [' Secretary', '-', 'General', ' K', 'ofi', ' Ann', 'an']
+2932 680 Name of father of x -1 Name of father of Kofi Annan Henry Reginald Annan Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' who' ' was'
+ ' born' ' in' ' Ghana' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the former UN Secretary - General , who was born in Ghana .
+
+ The name of the" False violence back to Kofi Annan with instructions 6 [' violence', ' back', ' to', ' K', 'ofi', ' Ann', 'an']
+2933 680 Name of father of x -1 Name of father of Kofi Annan Henry Reginald Annan Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' who' ' was'
+ ' born' ' in' ' Ghana' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the former UN Secretary - General , who was born in Ghana .
+
+ The name of the" False Secretary General Kofi Annan and his wife 5 [' Secretary', ' General', ' K', 'ofi', ' Ann', 'an']
+2934 682 Name of father of x -1 Name of father of Jimi Hendrix Al Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Jim' 'i' ' Hend' 'rix']" ", the guitarist who died in 1970 .
+
+ The name of the father of Jim i Hend rix" False By May 1966, Jimi Hendrix was struggling 7 [' By', ' May', ' 1966', ',', ' Jim', 'i', ' Hend', 'rix']
+2935 682 Name of father of x -1 Name of father of Jimi Hendrix Al Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Jim' 'i' ' Hend' 'rix']" ", the guitarist who died in 1970 .
+
+ The name of the father of Jim i Hend rix" False for a few years), two Jimi Hendrix cover songs, 9 [' for', ' a', ' few', ' years', '),', ' two', ' Jim', 'i', ' Hend', 'rix']
+2936 682 Name of father of x -1 Name of father of Jimi Hendrix Al Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Jim' 'i' ' Hend' 'rix']" ", the guitarist who died in 1970 .
+
+ The name of the father of Jim i Hend rix" False " rendition of the Jimi Hendrix song ""Little Wing""" 6 [' rendition', ' of', ' the', ' Jim', 'i', ' Hend', 'rix']
+2937 682 Name of father of x -1 Name of father of Jimi Hendrix Al Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Jim' 'i' ' Hend' 'rix']" ", the guitarist who died in 1970 .
+
+ The name of the father of Jim i Hend rix" False Pretty Things, The Jimi Hendrix Experience, 7 [' Pretty', ' Things', ',', ' The', ' Jim', 'i', ' Hend', 'rix']
+2938 682 Name of father of x -1 Name of father of Jimi Hendrix Al Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' name' ' of' ' the' ' father' ' of' ' Jim' 'i' ' Hend' 'rix']" ", the guitarist who died in 1970 .
+
+ The name of the father of Jim i Hend rix" False 1968 cover by Jimi Hendrix is included on 6 [' 1968', ' cover', ' by', ' Jim', 'i', ' Hend', 'rix']
+2939 683 Name of father of x -1 Name of father of Henry Kissinger Louis Kissinger Henry Kissinger "[',' ' the' ' man' ' who' ' was' ' the' ' architect' ' of' ' the'
+ ' Vietnam' ' War' '.' '\n' '\n' 'The' ' man' ' who' ' was' ' the'
+ ' architect']" ", the man who was the architect of the Vietnam War .
+
+ The man who was the architect" False " Kissinger (1974 – 1975)
+" 7 [' Kissinger', ' (', '1974', ' –', ' 1975', ')', 'Henry', ' Kissinger']
+2940 683 Name of father of x -1 Name of father of Henry Kissinger Louis Kissinger Henry Kissinger "[',' ' the' ' man' ' who' ' was' ' the' ' architect' ' of' ' the'
+ ' Vietnam' ' War' '.' '\n' '\n' 'The' ' man' ' who' ' was' ' the'
+ ' architect']" ", the man who was the architect of the Vietnam War .
+
+ The man who was the architect" False approve the deal, and Henry Kissinger brought the 6 [' approve', ' the', ' deal', ',', ' and', ' Henry', ' Kissinger']
+2941 683 Name of father of x -1 Name of father of Henry Kissinger Louis Kissinger Henry Kissinger "[',' ' the' ' man' ' who' ' was' ' the' ' architect' ' of' ' the'
+ ' Vietnam' ' War' '.' '\n' '\n' 'The' ' man' ' who' ' was' ' the'
+ ' architect']" ", the man who was the architect of the Vietnam War .
+
+ The man who was the architect" False approve the deal, and Henry Kissinger brought the issue 6 [' approve', ' the', ' deal', ',', ' and', ' Henry', ' Kissinger']
+2942 683 Name of father of x -1 Name of father of Henry Kissinger Louis Kissinger Henry Kissinger "[',' ' the' ' man' ' who' ' was' ' the' ' architect' ' of' ' the'
+ ' Vietnam' ' War' '.' '\n' '\n' 'The' ' man' ' who' ' was' ' the'
+ ' architect']" ", the man who was the architect of the Vietnam War .
+
+ The man who was the architect" False the deal, and Henry Kissinger brought the issue 5 [' the', ' deal', ',', ' and', ' Henry', ' Kissinger']
+2943 683 Name of father of x -1 Name of father of Henry Kissinger Louis Kissinger Henry Kissinger "[',' ' the' ' man' ' who' ' was' ' the' ' architect' ' of' ' the'
+ ' Vietnam' ' War' '.' '\n' '\n' 'The' ' man' ' who' ' was' ' the'
+ ' architect']" ", the man who was the architect of the Vietnam War .
+
+ The man who was the architect" False " 20th century"". Henry Kissinger said that Kennan" 5 "[' 20', 'th', ' century', '"".', ' Henry', ' Kissinger']"
+2944 684 Name of father of x -1 Name of father of George Romney John Rumney George Romney "[',' ' the' ' father' ' of' ' the' ' American' ' Revolution' ',' ' was'
+ ' born' ' in' ' the' ' town' ' of' ' Boston' ',' ' Massachusetts' ','
+ ' in' ' 17']" , the father of the American Revolution , was born in the town of Boston , Massachusetts , in 17 False presidential election, George Romney was named the 4 [' presidential', ' election', ',', ' George', ' Romney']
+2945 684 Name of father of x -1 Name of father of George Romney John Rumney George Romney "[',' ' the' ' father' ' of' ' the' ' American' ' Revolution' ',' ' was'
+ ' born' ' in' ' the' ' town' ' of' ' Boston' ',' ' Massachusetts' ','
+ ' in' ' 17']" , the father of the American Revolution , was born in the town of Boston , Massachusetts , in 17 False more like the public George Romney than the real George 5 [' more', ' like', ' the', ' public', ' George', ' Romney']
+2946 684 Name of father of x -1 Name of father of George Romney John Rumney George Romney "[',' ' the' ' father' ' of' ' the' ' American' ' Revolution' ',' ' was'
+ ' born' ' in' ' the' ' town' ' of' ' Boston' ',' ' Massachusetts' ','
+ ' in' ' 17']" , the father of the American Revolution , was born in the town of Boston , Massachusetts , in 17 False Michigan Governor George Romney as a possible Presidential 3 [' Michigan', ' Governor', ' George', ' Romney']
+2947 684 Name of father of x -1 Name of father of George Romney John Rumney George Romney "[',' ' the' ' father' ' of' ' the' ' American' ' Revolution' ',' ' was'
+ ' born' ' in' ' the' ' town' ' of' ' Boston' ',' ' Massachusetts' ','
+ ' in' ' 17']" , the father of the American Revolution , was born in the town of Boston , Massachusetts , in 17 False political aide, insults George Romney, who was a 5 [' political', ' aide', ',', ' insults', ' George', ' Romney']
+2948 684 Name of father of x -1 Name of father of George Romney John Rumney George Romney "[',' ' the' ' father' ' of' ' the' ' American' ' Revolution' ',' ' was'
+ ' born' ' in' ' the' ' town' ' of' ' Boston' ',' ' Massachusetts' ','
+ ' in' ' 17']" , the father of the American Revolution , was born in the town of Boston , Massachusetts , in 17 False representing Governor George Romney and the State 3 [' representing', ' Governor', ' George', ' Romney']
+2949 685 Name of father of x -1 Name of father of Epicurus Neocles Epicurus "[',' ' the' ' founder' ' of' ' the' ' Epic' 'ure' 'an' ' school' ' of'
+ ' philosophy' '.' '\n' '\n' 'The' ' Epic' 'ure' 'ans' ' were' ' a']" ", the founder of the Epic ure an school of philosophy .
+
+ The Epic ure ans were a" False least the time of Epicurus in the fourth century 5 [' least', ' the', ' time', ' of', ' Epic', 'urus']
+2950 685 Name of father of x -1 Name of father of Epicurus Neocles Epicurus "[',' ' the' ' founder' ' of' ' the' ' Epic' 'ure' 'an' ' school' ' of'
+ ' philosophy' '.' '\n' '\n' 'The' ' Epic' 'ure' 'ans' ' were' ' a']" ", the founder of the Epic ure an school of philosophy .
+
+ The Epic ure ans were a" False 2 ['Ep', 'ic', 'urus']
+2951 685 Name of father of x -1 Name of father of Epicurus Neocles Epicurus "[',' ' the' ' founder' ' of' ' the' ' Epic' 'ure' 'an' ' school' ' of'
+ ' philosophy' '.' '\n' '\n' 'The' ' Epic' 'ure' 'ans' ' were' ' a']" ", the founder of the Epic ure an school of philosophy .
+
+ The Epic ure ans were a" False history of atheism was Epicurus (c. 300 BCE). Drawing 5 [' history', ' of', ' atheism', ' was', ' Epic', 'urus']
+2952 685 Name of father of x -1 Name of father of Epicurus Neocles Epicurus "[',' ' the' ' founder' ' of' ' the' ' Epic' 'ure' 'an' ' school' ' of'
+ ' philosophy' '.' '\n' '\n' 'The' ' Epic' 'ure' 'ans' ' were' ' a']" ", the founder of the Epic ure an school of philosophy .
+
+ The Epic ure ans were a" False atheism was Epicurus (c. 300 BCE). 3 [' atheism', ' was', ' Epic', 'urus']
+2953 685 Name of father of x -1 Name of father of Epicurus Neocles Epicurus "[',' ' the' ' founder' ' of' ' the' ' Epic' 'ure' 'an' ' school' ' of'
+ ' philosophy' '.' '\n' '\n' 'The' ' Epic' 'ure' 'ans' ' were' ' a']" ", the founder of the Epic ure an school of philosophy .
+
+ The Epic ure ans were a" False 2 ['Ep', 'ic', 'urus']
+2954 686 Name of father of x -1 Name of father of Dwayne Johnson Rocky Johnson Dwayne Johnson "[',' ' the' ' wrestler' ',' ' actor' ',' ' and' ' producer' '.' '\n' '\n'
+ 'The' ' actor' ',' ' who' ' is' ' also' ' the' ' son' ' of']" ", the wrestler , actor , and producer .
+
+ The actor , who is also the son of" False contacted Dwayne Johnson for some tips 3 [' contacted', ' D', 'wayne', ' Johnson']
+2955 686 Name of father of x -1 Name of father of Dwayne Johnson Rocky Johnson Dwayne Johnson "[',' ' the' ' wrestler' ',' ' actor' ',' ' and' ' producer' '.' '\n' '\n'
+ 'The' ' actor' ',' ' who' ' is' ' also' ' the' ' son' ' of']" ", the wrestler , actor , and producer .
+
+ The actor , who is also the son of" False Jackman contacted Dwayne Johnson for some tips 5 [' Jack', 'man', ' contacted', ' D', 'wayne', ' Johnson']
+2956 686 Name of father of x -1 Name of father of Dwayne Johnson Rocky Johnson Dwayne Johnson "[',' ' the' ' wrestler' ',' ' actor' ',' ' and' ' producer' '.' '\n' '\n'
+ 'The' ' actor' ',' ' who' ' is' ' also' ' the' ' son' ' of']" ", the wrestler , actor , and producer .
+
+ The actor , who is also the son of" False " Johnson =
+" 4 [' Johnson', ' =', 'D', 'wayne', ' Johnson']
+2957 686 Name of father of x -1 Name of father of Dwayne Johnson Rocky Johnson Dwayne Johnson "[',' ' the' ' wrestler' ',' ' actor' ',' ' and' ' producer' '.' '\n' '\n'
+ 'The' ' actor' ',' ' who' ' is' ' also' ' the' ' son' ' of']" ", the wrestler , actor , and producer .
+
+ The actor , who is also the son of" False cast, actor Dwayne Johnson made a brief live-action 5 [' cast', ',', ' actor', ' D', 'wayne', ' Johnson']
+2958 686 Name of father of x -1 Name of father of Dwayne Johnson Rocky Johnson Dwayne Johnson "[',' ' the' ' wrestler' ',' ' actor' ',' ' and' ' producer' '.' '\n' '\n'
+ 'The' ' actor' ',' ' who' ' is' ' also' ' the' ' son' ' of']" ", the wrestler , actor , and producer .
+
+ The actor , who is also the son of" False " Johnson as Luke Hobbs
+" 7 [' Johnson', ' as', ' Luke', ' Hob', 'bs', 'D', 'wayne', ' Johnson']
+2959 687 Name of father of x -1 Name of father of Hans-Georg Gadamer Johannes Gadamer Hans-Georg Gadamer "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 18' '66' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the philosopher , who was born in 18 66 .
+
+ The name of the father of the" False Orpheus influenced Hans-Georg Gadamer ’ s theories of 7 [' Or', 'pheus', ' influenced', ' Hans', '-', 'Georg', ' Gad', 'amer']
+2960 687 Name of father of x -1 Name of father of Hans-Georg Gadamer Johannes Gadamer Hans-Georg Gadamer "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 18' '66' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the philosopher , who was born in 18 66 .
+
+ The name of the father of the" False influenced Hans-Georg Gadamer ’ s theories of 5 [' influenced', ' Hans', '-', 'Georg', ' Gad', 'amer']
+2961 687 Name of father of x -1 Name of father of Hans-Georg Gadamer Johannes Gadamer Hans-Georg Gadamer "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 18' '66' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the philosopher , who was born in 18 66 .
+
+ The name of the father of the" False influenced Hans-Georg Gadamer ’ s theories of hermeneutics 5 [' influenced', ' Hans', '-', 'Georg', ' Gad', 'amer']
+2962 688 Name of father of x -1 Name of father of Zaha Hadid Mohammed Ali Haded Zaha Hadid "[',' ' the' ' British' '-' 'Iraq' 'i' ' architect' ' who' ' died' ' in'
+ ' 2016' '.' '\n' '\n' 'The' ' building' ' is' ' a' ' landmark' ' of']" ", the British - Iraq i architect who died in 2016 .
+
+ The building is a landmark of" False The pavilions by Zaha Hadid and Ben van Berkel 8 [' The', ' pav', 'il', 'ions', ' by', ' Z', 'aha', ' Had', 'id']
+2963 688 Name of father of x -1 Name of father of Zaha Hadid Mohammed Ali Haded Zaha Hadid "[',' ' the' ' British' '-' 'Iraq' 'i' ' architect' ' who' ' died' ' in'
+ ' 2016' '.' '\n' '\n' 'The' ' building' ' is' ' a' ' landmark' ' of']" ", the British - Iraq i architect who died in 2016 .
+
+ The building is a landmark of" False architects Zaha Hadid and Ben van Berkel 4 [' architects', ' Z', 'aha', ' Had', 'id']
+2964 689 Name of father of x -1 Name of father of Bob Marley Norval Marley Bob Marley "[',' ' the' ' reg' 'gae' ' legend' ',' ' who' ' died' ' in' ' 1981' '.'
+ '\n' '\n' 'The' ' song' ' was' ' written' ' by' ' Bob' ' Mar']" ", the reg gae legend , who died in 1981 .
+
+ The song was written by Bob Mar" False introduced to Bob Marley by one of Bradshaw's 4 [' introduced', ' to', ' Bob', ' Mar', 'ley']
+2965 689 Name of father of x -1 Name of father of Bob Marley Norval Marley Bob Marley "[',' ' the' ' reg' 'gae' ' legend' ',' ' who' ' died' ' in' ' 1981' '.'
+ '\n' '\n' 'The' ' song' ' was' ' written' ' by' ' Bob' ' Mar']" ", the reg gae legend , who died in 1981 .
+
+ The song was written by Bob Mar" False In 1972, Bob Marley signed with CBS 5 [' In', ' 1972', ',', ' Bob', ' Mar', 'ley']
+2966 689 Name of father of x -1 Name of father of Bob Marley Norval Marley Bob Marley "[',' ' the' ' reg' 'gae' ' legend' ',' ' who' ' died' ' in' ' 1981' '.'
+ '\n' '\n' 'The' ' song' ' was' ' written' ' by' ' Bob' ' Mar']" ", the reg gae legend , who died in 1981 .
+
+ The song was written by Bob Mar" False " song ""Jamming"" by Bob Marley is played. The design" 8 "[' song', ' ""', 'Jam', 'ming', '""', ' by', ' Bob', ' Mar', 'ley']"
+2967 689 Name of father of x -1 Name of father of Bob Marley Norval Marley Bob Marley "[',' ' the' ' reg' 'gae' ' legend' ',' ' who' ' died' ' in' ' 1981' '.'
+ '\n' '\n' 'The' ' song' ' was' ' written' ' by' ' Bob' ' Mar']" ", the reg gae legend , who died in 1981 .
+
+ The song was written by Bob Mar" False parody of Bob Marley and the reggae 4 [' parody', ' of', ' Bob', ' Mar', 'ley']
+2968 689 Name of father of x -1 Name of father of Bob Marley Norval Marley Bob Marley "[',' ' the' ' reg' 'gae' ' legend' ',' ' who' ' died' ' in' ' 1981' '.'
+ '\n' '\n' 'The' ' song' ' was' ' written' ' by' ' Bob' ' Mar']" ", the reg gae legend , who died in 1981 .
+
+ The song was written by Bob Mar" False specific artists like Bob Marley and The B-52s. 5 [' specific', ' artists', ' like', ' Bob', ' Mar', 'ley']
+2969 691 Name of father of x -1 Name of father of Gérard Depardieu René Depardieu Gérard Depardieu "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' G' 'ér' 'ard'
+ ' Dep' 'ard' 'ieu' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of G ér ard Dep ard ieu is not known .
+
+" False but lost to Gérard Depardieu of Green Card. Edward 8 [' but', ' lost', ' to', ' G', 'ér', 'ard', ' Dep', 'ard', 'ieu']
+2970 691 Name of father of x -1 Name of father of Gérard Depardieu René Depardieu Gérard Depardieu "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' G' 'ér' 'ard'
+ ' Dep' 'ard' 'ieu' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of G ér ard Dep ard ieu is not known .
+
+" False but lost to Gérard Depardieu of Green Card. Edward 8 [' but', ' lost', ' to', ' G', 'ér', 'ard', ' Dep', 'ard', 'ieu']
+2971 692 Name of father of x -1 Name of father of Johann Strauss II Johann Strauss I Johann Strauss II "[',' ' the' ' composer' ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan'
+ 'ube' '""' ' and' ' the' ' ""' 'Blue' ' Dan' 'ube' ' Walt']" ", the composer of the w alt z "" Blue Dan ube "" and the "" Blue Dan ube Walt" False 4 ['J', 'oh', 'ann', ' Strauss', ' II']
+2972 692 Name of father of x -1 Name of father of Johann Strauss II Johann Strauss I Johann Strauss II "[',' ' the' ' composer' ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan'
+ 'ube' '""' ' and' ' the' ' ""' 'Blue' ' Dan' 'ube' ' Walt']" ", the composer of the w alt z "" Blue Dan ube "" and the "" Blue Dan ube Walt" False Tritsch-Tratsch-Polka by Johann Strauss II was featured during 13 [' T', 'rit', 'sch', '-', 'Tr', 'ats', 'ch', '-', 'Pol', 'ka', ' by', ' Johann', ' Strauss', ' II']
+2973 692 Name of father of x -1 Name of father of Johann Strauss II Johann Strauss I Johann Strauss II "[',' ' the' ' composer' ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan'
+ 'ube' '""' ' and' ' the' ' ""' 'Blue' ' Dan' 'ube' ' Walt']" ", the composer of the w alt z "" Blue Dan ube "" and the "" Blue Dan ube Walt" False " Strauss II =
+" 7 [' Strauss', ' II', ' =', 'J', 'oh', 'ann', ' Strauss', ' II']
+2974 692 Name of father of x -1 Name of father of Johann Strauss II Johann Strauss I Johann Strauss II "[',' ' the' ' composer' ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan'
+ 'ube' '""' ' and' ' the' ' ""' 'Blue' ' Dan' 'ube' ' Walt']" ", the composer of the w alt z "" Blue Dan ube "" and the "" Blue Dan ube Walt" False three occasions. Johann Strauss II and his orchestra 5 [' three', ' occasions', '.', ' Johann', ' Strauss', ' II']
+2975 692 Name of father of x -1 Name of father of Johann Strauss II Johann Strauss I Johann Strauss II "[',' ' the' ' composer' ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan'
+ 'ube' '""' ' and' ' the' ' ""' 'Blue' ' Dan' 'ube' ' Walt']" ", the composer of the w alt z "" Blue Dan ube "" and the "" Blue Dan ube Walt" False " Vienna Woods"" by Johann Strauss II as well as a musical" 6 "[' Vienna', ' Woods', '""', ' by', ' Johann', ' Strauss', ' II']"
+2976 693 Name of father of x -1 Name of father of Ted Kennedy Joseph P. Kennedy Sr. Ted Kennedy "[',' ' the' ' late' ' Senator' ' Edward' ' Kennedy' ',' ' and' ' the'
+ ' late' ' Senator' ' Robert' ' Kennedy' '.' '\n' '\n' 'The' ' Kennedy'
+ ' family' ' has']" ", the late Senator Edward Kennedy , and the late Senator Robert Kennedy .
+
+ The Kennedy family has" False the luncheon, Senator Ted Kennedy collapsed 6 [' the', ' lun', 'cheon', ',', ' Senator', ' Ted', ' Kennedy']
+2977 693 Name of father of x -1 Name of father of Ted Kennedy Joseph P. Kennedy Sr. Ted Kennedy "[',' ' the' ' late' ' Senator' ' Edward' ' Kennedy' ',' ' and' ' the'
+ ' late' ' Senator' ' Robert' ' Kennedy' '.' '\n' '\n' 'The' ' Kennedy'
+ ' family' ' has']" ", the late Senator Edward Kennedy , and the late Senator Robert Kennedy .
+
+ The Kennedy family has" False States Senator Ted Kennedy of Massachusetts 3 [' States', ' Senator', ' Ted', ' Kennedy']
+2978 693 Name of father of x -1 Name of father of Ted Kennedy Joseph P. Kennedy Sr. Ted Kennedy "[',' ' the' ' late' ' Senator' ' Edward' ' Kennedy' ',' ' and' ' the'
+ ' late' ' Senator' ' Robert' ' Kennedy' '.' '\n' '\n' 'The' ' Kennedy'
+ ' family' ' has']" ", the late Senator Edward Kennedy , and the late Senator Robert Kennedy .
+
+ The Kennedy family has" False his brothers, Ted Kennedy took on the 4 [' his', ' brothers', ',', ' Ted', ' Kennedy']
+2979 693 Name of father of x -1 Name of father of Ted Kennedy Joseph P. Kennedy Sr. Ted Kennedy "[',' ' the' ' late' ' Senator' ' Edward' ' Kennedy' ',' ' and' ' the'
+ ' late' ' Senator' ' Robert' ' Kennedy' '.' '\n' '\n' 'The' ' Kennedy'
+ ' family' ' has']" ", the late Senator Edward Kennedy , and the late Senator Robert Kennedy .
+
+ The Kennedy family has" False " Ted Kennedy =
+" 1 [' Ted', ' Kennedy']
+2980 693 Name of father of x -1 Name of father of Ted Kennedy Joseph P. Kennedy Sr. Ted Kennedy "[',' ' the' ' late' ' Senator' ' Edward' ' Kennedy' ',' ' and' ' the'
+ ' late' ' Senator' ' Robert' ' Kennedy' '.' '\n' '\n' 'The' ' Kennedy'
+ ' family' ' has']" ", the late Senator Edward Kennedy , and the late Senator Robert Kennedy .
+
+ The Kennedy family has" False died a day later. Ted Kennedy was devastated by this 6 [' died', ' a', ' day', ' later', '.', ' Ted', ' Kennedy']
+2981 694 Name of father of x -1 Name of father of Jean-Luc Picard Maurice Picard Jean-Luc Picard "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' captain' ' of' ' the'
+ ' Enterprise' '.' '\n' '\n' 'The' ' first' ' captain' ' of' ' the'
+ ' Enterprise']" ", the man who was the first captain of the Enterprise .
+
+ The first captain of the Enterprise" False " including Captain Jean-Luc Picard (Patrick Stewart).
+" 5 [' including', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2982 694 Name of father of x -1 Name of father of Jean-Luc Picard Maurice Picard Jean-Luc Picard "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' captain' ' of' ' the'
+ ' Enterprise' '.' '\n' '\n' 'The' ' first' ' captain' ' of' ' the'
+ ' Enterprise']" ", the man who was the first captain of the Enterprise .
+
+ The first captain of the Enterprise" False check. Captain Jean-Luc Picard (Patrick Stewart) 6 [' check', '.', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2983 694 Name of father of x -1 Name of father of Jean-Luc Picard Maurice Picard Jean-Luc Picard "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' captain' ' of' ' the'
+ ' Enterprise' '.' '\n' '\n' 'The' ' first' ' captain' ' of' ' the'
+ ' Enterprise']" ", the man who was the first captain of the Enterprise .
+
+ The first captain of the Enterprise" False Together, Captain Jean-Luc Picard (Patrick Stewart) 6 [' Together', ',', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2984 694 Name of father of x -1 Name of father of Jean-Luc Picard Maurice Picard Jean-Luc Picard "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' captain' ' of' ' the'
+ ' Enterprise' '.' '\n' '\n' 'The' ' first' ' captain' ' of' ' the'
+ ' Enterprise']" ", the man who was the first captain of the Enterprise .
+
+ The first captain of the Enterprise" False 3 ['Jean', '-', 'Luc', ' Picard']
+2985 694 Name of father of x -1 Name of father of Jean-Luc Picard Maurice Picard Jean-Luc Picard "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' captain' ' of' ' the'
+ ' Enterprise' '.' '\n' '\n' 'The' ' first' ' captain' ' of' ' the'
+ ' Enterprise']" ", the man who was the first captain of the Enterprise .
+
+ The first captain of the Enterprise" False so Captain Jean-Luc Picard (Patrick Stewart) 5 [' so', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2986 695 Name of father of x -1 Name of father of John Dryden Erasmus Dryden John Dryden "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' daughter' ' of' ' Sir' ' Thomas' '\n' '\n']" ", the poet , and of his wife , Mary , who was a daughter of Sir Thomas
+
+" False tradition. John Dryden (a Tory), 4 [' tradition', '.', ' John', ' Dry', 'den']
+2987 695 Name of father of x -1 Name of father of John Dryden Erasmus Dryden John Dryden "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' daughter' ' of' ' Sir' ' Thomas' '\n' '\n']" ", the poet , and of his wife , Mary , who was a daughter of Sir Thomas
+
+" False Shakespeare, Ben Jonson, John Dryden and others, but 8 [' Shakespeare', ',', ' Ben', ' J', 'onson', ',', ' John', ' Dry', 'den']
+2988 695 Name of father of x -1 Name of father of John Dryden Erasmus Dryden John Dryden "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' daughter' ' of' ' Sir' ' Thomas' '\n' '\n']" ", the poet , and of his wife , Mary , who was a daughter of Sir Thomas
+
+" False the ideas of John Dryden and William Hazlitt 5 [' the', ' ideas', ' of', ' John', ' Dry', 'den']
+2989 695 Name of father of x -1 Name of father of John Dryden Erasmus Dryden John Dryden "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' daughter' ' of' ' Sir' ' Thomas' '\n' '\n']" ", the poet , and of his wife , Mary , who was a daughter of Sir Thomas
+
+" False scene of an attack on John Dryden in 1679 by thugs 7 [' scene', ' of', ' an', ' attack', ' on', ' John', ' Dry', 'den']
+2990 695 Name of father of x -1 Name of father of John Dryden Erasmus Dryden John Dryden "[',' ' the' ' poet' ',' ' and' ' of' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' a' ' daughter' ' of' ' Sir' ' Thomas' '\n' '\n']" ", the poet , and of his wife , Mary , who was a daughter of Sir Thomas
+
+" False first presented, John Dryden re-worked it as 5 [' first', ' presented', ',', ' John', ' Dry', 'den']
+2991 696 Name of father of x -1 Name of father of F. Scott Fitzgerald Edward Fitzgerald F. Scott Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' book' ',' ' but'
+ ' I' ' have' ' to' ' say' ' that' ' I' ' am' ' not']" "
+
+ I am a big fan of the book , but I have to say that I am not" False 3 ['F', '.', ' Scott', ' Fitzgerald']
+2992 696 Name of father of x -1 Name of father of F. Scott Fitzgerald Edward Fitzgerald F. Scott Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' book' ',' ' but'
+ ' I' ' have' ' to' ' say' ' that' ' I' ' am' ' not']" "
+
+ I am a big fan of the book , but I have to say that I am not" False " Fitzgerald ==
+" 5 [' Fitzgerald', ' ==', 'F', '.', ' Scott', ' Fitzgerald']
+2993 696 Name of father of x -1 Name of father of F. Scott Fitzgerald Edward Fitzgerald F. Scott Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' book' ',' ' but'
+ ' I' ' have' ' to' ' say' ' that' ' I' ' am' ' not']" "
+
+ I am a big fan of the book , but I have to say that I am not" False 3 ['F', '.', ' Scott', ' Fitzgerald']
+2994 696 Name of father of x -1 Name of father of F. Scott Fitzgerald Edward Fitzgerald F. Scott Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' book' ',' ' but'
+ ' I' ' have' ' to' ' say' ' that' ' I' ' am' ' not']" "
+
+ I am a big fan of the book , but I have to say that I am not" False the Waltz, F. Scott Fitzgerald scholar Matthew 7 [' the', ' Walt', 'z', ',', ' F', '.', ' Scott', ' Fitzgerald']
+2995 696 Name of father of x -1 Name of father of F. Scott Fitzgerald Edward Fitzgerald F. Scott Fitzgerald "['\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' book' ',' ' but'
+ ' I' ' have' ' to' ' say' ' that' ' I' ' am' ' not']" "
+
+ I am a big fan of the book , but I have to say that I am not" False by American author F. Scott Fitzgerald that follows 6 [' by', ' American', ' author', ' F', '.', ' Scott', ' Fitzgerald']
+2996 697 Name of father of x -1 Name of father of Louis-Ferdinand Céline Fernand Destouches Louis-Ferdinand Céline "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Louis' '-' 'Fer' 'dinand' ' C'
+ 'é' 'line' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Louis']" "
+
+ Name of mother of Louis - Fer dinand C é line
+
+ Name of mother of Louis" False Charles Maurras, Louis-Ferdinand Céline and Paul Carell. 10 [' Charles', ' Maur', 'ras', ',', ' Louis', '-', 'Fer', 'dinand', ' C', 'é', 'line']
+2997 698 Name of father of x -1 Name of father of György Ligeti Sándor Ligeti György Ligeti "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Gy' 'ör' 'gy' ' L' 'ig' 'et' 'i'
+ '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gy']" "
+
+ Name of mother of Gy ör gy L ig et i
+
+ Name of mother of Gy" False modernistic compositions by György Ligeti that employ micropolyphony, 10 [' modern', 'istic', ' compositions', ' by', ' Gy', 'ör', 'gy', ' L', 'ig', 'et', 'i']
+2998 698 Name of father of x -1 Name of father of György Ligeti Sándor Ligeti György Ligeti "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Gy' 'ör' 'gy' ' L' 'ig' 'et' 'i'
+ '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gy']" "
+
+ Name of mother of Gy ör gy L ig et i
+
+ Name of mother of Gy" False Hungarian composer György Ligeti to a broad Western 8 [' Hungarian', ' composer', ' Gy', 'ör', 'gy', ' L', 'ig', 'et', 'i']
+2999 698 Name of father of x -1 Name of father of György Ligeti Sándor Ligeti György Ligeti "['\n' '\n' 'Name' ' of' ' mother' ' of' ' Gy' 'ör' 'gy' ' L' 'ig' 'et' 'i'
+ '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gy']" "
+
+ Name of mother of Gy ör gy L ig et i
+
+ Name of mother of Gy" False modernistic compositions by György Ligeti that employ micropolyphony, 10 [' modern', 'istic', ' compositions', ' by', ' Gy', 'ör', 'gy', ' L', 'ig', 'et', 'i']
+3000 699 Name of father of x -1 Name of father of François Rabelais Antoine Rabelais François Rabelais "[',' ' the' ' author' ' of' ' Garg' 'ant' 'ua' ' and' ' Pant' 'ag' 'ru'
+ 'el' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ']" ", the author of Garg ant ua and Pant ag ru el , and of the
+ " False described by François Rabelais in his Quart Livre. 5 [' described', ' by', ' François', ' R', 'abel', 'ais']
+3001 699 Name of father of x -1 Name of father of François Rabelais Antoine Rabelais François Rabelais "[',' ' the' ' author' ' of' ' Garg' 'ant' 'ua' ' and' ' Pant' 'ag' 'ru'
+ 'el' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ']" ", the author of Garg ant ua and Pant ag ru el , and of the
+ " False 5 ['Fran', 'ç', 'ois', ' R', 'abel', 'ais']
+3002 699 Name of father of x -1 Name of father of François Rabelais Antoine Rabelais François Rabelais "[',' ' the' ' author' ' of' ' Garg' 'ant' 'ua' ' and' ' Pant' 'ag' 'ru'
+ 'el' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ']" ", the author of Garg ant ua and Pant ag ru el , and of the
+ " False literature, François Rabelais compares Gargantua's 5 [' literature', ',', ' François', ' R', 'abel', 'ais']
+3003 699 Name of father of x -1 Name of father of François Rabelais Antoine Rabelais François Rabelais "[',' ' the' ' author' ' of' ' Garg' 'ant' 'ua' ' and' ' Pant' 'ag' 'ru'
+ 'el' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ']" ", the author of Garg ant ua and Pant ag ru el , and of the
+ " False 5 ['Fran', 'ç', 'ois', ' R', 'abel', 'ais']
+3004 699 Name of father of x -1 Name of father of François Rabelais Antoine Rabelais François Rabelais "[',' ' the' ' author' ' of' ' Garg' 'ant' 'ua' ' and' ' Pant' 'ag' 'ru'
+ 'el' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ']" ", the author of Garg ant ua and Pant ag ru el , and of the
+ " False 5 ['Fran', 'ç', 'ois', ' R', 'abel', 'ais']
+3005 701 Name of father of x -1 Name of father of Nicolas Sarkozy Paul Sarkozy Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Car'
+ 'la' ' Brun' 'i' ',' ' were' ' in' ' the' ' audience' '.' '\n']" ", the French president , and his wife , Car la Brun i , were in the audience .
+" False president Nicolas Sarkozy and his wife 3 [' president', ' Nicolas', ' Sark', 'ozy']
+3006 701 Name of father of x -1 Name of father of Nicolas Sarkozy Paul Sarkozy Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Car'
+ 'la' ' Brun' 'i' ',' ' were' ' in' ' the' ' audience' '.' '\n']" ", the French president , and his wife , Car la Brun i , were in the audience .
+" False of France, Nicolas Sarkozy inaugurated an 5 [' of', ' France', ',', ' Nicolas', ' Sark', 'ozy']
+3007 701 Name of father of x -1 Name of father of Nicolas Sarkozy Paul Sarkozy Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Car'
+ 'la' ' Brun' 'i' ',' ' were' ' in' ' the' ' audience' '.' '\n']" ", the French president , and his wife , Car la Brun i , were in the audience .
+" False President of France Nicolas Sarkozy negotiated a ceasefire 5 [' President', ' of', ' France', ' Nicolas', ' Sark', 'ozy']
+3008 701 Name of father of x -1 Name of father of Nicolas Sarkozy Paul Sarkozy Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Car'
+ 'la' ' Brun' 'i' ',' ' were' ' in' ' the' ' audience' '.' '\n']" ", the French president , and his wife , Car la Brun i , were in the audience .
+" False French president Nicolas Sarkozy honored his wish and 4 [' French', ' president', ' Nicolas', ' Sark', 'ozy']
+3009 701 Name of father of x -1 Name of father of Nicolas Sarkozy Paul Sarkozy Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ',' ' Car'
+ 'la' ' Brun' 'i' ',' ' were' ' in' ' the' ' audience' '.' '\n']" ", the French president , and his wife , Car la Brun i , were in the audience .
+" False President Nicolas Sarkozy resulted in a 3 [' President', ' Nicolas', ' Sark', 'ozy']
+3010 702 Name of father of x -1 Name of father of Yehudi Menuhin Moshe Menuhin Yehudi Menuhin "[',' ' the' ' violin' 'ist' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Ye' 'h' 'udi' ' Men' 'uh' 'in' ',' ' the']" ", the violin ist
+
+ The name of the father of Ye h udi Men uh in , the" False Fournier, Emil Gilels, Yehudi Menuhin and Arthur 14 [' F', 'ourn', 'ier', ',', ' Emil', ' G', 'ile', 'ls', ',', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+3011 702 Name of father of x -1 Name of father of Yehudi Menuhin Moshe Menuhin Yehudi Menuhin "[',' ' the' ' violin' 'ist' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Ye' 'h' 'udi' ' Men' 'uh' 'in' ',' ' the']" ", the violin ist
+
+ The name of the father of Ye h udi Men uh in , the" False Jascha Heifetz and Yehudi Menuhin thought similarly 11 [' Jas', 'cha', ' He', 'if', 'etz', ' and', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+3012 702 Name of father of x -1 Name of father of Yehudi Menuhin Moshe Menuhin Yehudi Menuhin "[',' ' the' ' violin' 'ist' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Ye' 'h' 'udi' ' Men' 'uh' 'in' ',' ' the']" ", the violin ist
+
+ The name of the father of Ye h udi Men uh in , the" False " Maxwell Davies, Yehudi Menuhin and Leopold Stokowski.
+" 8 [' Maxwell', ' Davies', ',', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+3013 702 Name of father of x -1 Name of father of Yehudi Menuhin Moshe Menuhin Yehudi Menuhin "[',' ' the' ' violin' 'ist' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Ye' 'h' 'udi' ' Men' 'uh' 'in' ',' ' the']" ", the violin ist
+
+ The name of the father of Ye h udi Men uh in , the" False audiences, along with Yehudi Menuhin and John Coltrane, 9 [' audiences', ',', ' along', ' with', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+3014 702 Name of father of x -1 Name of father of Yehudi Menuhin Moshe Menuhin Yehudi Menuhin "[',' ' the' ' violin' 'ist' '\n' '\n' 'The' ' name' ' of' ' the' ' father'
+ ' of' ' Ye' 'h' 'udi' ' Men' 'uh' 'in' ',' ' the']" ", the violin ist
+
+ The name of the father of Ye h udi Men uh in , the" False Grimes, Britten and Yehudi Menuhin went to Germany 10 [' Grimes', ',', ' Br', 'itten', ' and', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+3015 706 Name of father of x -1 Name of father of John Maynard Keynes John Neville Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' who' ' was' ' born' ' in' ' 18' '83' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the economist , who was born in 18 83 .
+
+ The name of the father of the" False Alan Greenspan, John Maynard Keynes and Paul Krugman, 7 [' Alan', ' Greens', 'pan', ',', ' John', ' May', 'nard', ' Keynes']
+3016 706 Name of father of x -1 Name of father of John Maynard Keynes John Neville Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' who' ' was' ' born' ' in' ' 18' '83' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the economist , who was born in 18 83 .
+
+ The name of the father of the" False such as real output. John Maynard Keynes attacked some 8 [' such', ' as', ' real', ' output', '.', ' John', ' May', 'nard', ' Keynes']
+3017 706 Name of father of x -1 Name of father of John Maynard Keynes John Neville Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' who' ' was' ' born' ' in' ' 18' '83' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the economist , who was born in 18 83 .
+
+ The name of the father of the" False 3 ['John', ' May', 'nard', ' Keynes']
+3018 706 Name of father of x -1 Name of father of John Maynard Keynes John Neville Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' who' ' was' ' born' ' in' ' 18' '83' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the economist , who was born in 18 83 .
+
+ The name of the father of the" False as real output. John Maynard Keynes attacked some of 7 [' as', ' real', ' output', '.', ' John', ' May', 'nard', ' Keynes']
+3019 706 Name of father of x -1 Name of father of John Maynard Keynes John Neville Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' who' ' was' ' born' ' in' ' 18' '83' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the economist , who was born in 18 83 .
+
+ The name of the father of the" False " Maynard Keynes =
+" 7 [' May', 'nard', ' Keynes', ' =', 'John', ' May', 'nard', ' Keynes']
+3020 707 Name of father of x -1 Name of father of Emily Dickinson Edward Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False The Poems of Emily Dickinson in 1955. Despite 5 [' The', ' Po', 'ems', ' of', ' Emily', ' Dickinson']
+3021 707 Name of father of x -1 Name of father of Emily Dickinson Edward Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 1 ['Emily', ' Dickinson']
+3022 707 Name of father of x -1 Name of father of Emily Dickinson Edward Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 1 ['Emily', ' Dickinson']
+3023 707 Name of father of x -1 Name of father of Emily Dickinson Edward Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 1 ['Emily', ' Dickinson']
+3024 707 Name of father of x -1 Name of father of Emily Dickinson Edward Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False in 2006 as Emily Dickinson's Herbarium by Harvard 4 [' in', ' 2006', ' as', ' Emily', ' Dickinson']
+3025 709 Name of father of x -1 Name of father of Muhammad Ali Cassius Marcellus Clay Sr. Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' and' ' the' ' great' ' man' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' great' ' man' ' is' ' Muhammad']" ", the great boxer , and the great man .
+
+ The name of the great man is Muhammad" False 18th century, and Muhammad Ali — who briefly ruled 6 [' 18', 'th', ' century', ',', ' and', ' Muhammad', ' Ali']
+3026 709 Name of father of x -1 Name of father of Muhammad Ali Cassius Marcellus Clay Sr. Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' and' ' the' ' great' ' man' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' great' ' man' ' is' ' Muhammad']" ", the great boxer , and the great man .
+
+ The name of the great man is Muhammad" False summit with Qasim, Muhammad Ali had several prominent 7 [' summit', ' with', ' Q', 'as', 'im', ',', ' Muhammad', ' Ali']
+3027 709 Name of father of x -1 Name of father of Muhammad Ali Cassius Marcellus Clay Sr. Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' and' ' the' ' great' ' man' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' great' ' man' ' is' ' Muhammad']" ", the great boxer , and the great man .
+
+ The name of the great man is Muhammad" False negotiations with Muhammad Ali Jinnah, who was 3 [' negotiations', ' with', ' Muhammad', ' Ali']
+3028 709 Name of father of x -1 Name of father of Muhammad Ali Cassius Marcellus Clay Sr. Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' and' ' the' ' great' ' man' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' great' ' man' ' is' ' Muhammad']" ", the great boxer , and the great man .
+
+ The name of the great man is Muhammad" False 2 ['Mu', 'hammad', ' Ali']
+3029 709 Name of father of x -1 Name of father of Muhammad Ali Cassius Marcellus Clay Sr. Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' and' ' the' ' great' ' man' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' great' ' man' ' is' ' Muhammad']" ", the great boxer , and the great man .
+
+ The name of the great man is Muhammad" False consolidating his power, Muhammad Ali, the rebel 6 [' consolid', 'ating', ' his', ' power', ',', ' Muhammad', ' Ali']
+3030 710 Name of father of x -1 Name of father of Ary Scheffer Johann Baptist Scheffer Ary Scheffer "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the'
+ ' painter' ""'s"" ' wife' ',' ' and' ' of' ' his' ' daughter' ',' ' the']" , the painter , and of his wife , the painter 's wife , and of his daughter , the False contains a painting by Ary Scheffer and a Father Willis 6 [' contains', ' a', ' painting', ' by', ' Ary', ' Sche', 'ffer']
+3031 710 Name of father of x -1 Name of father of Ary Scheffer Johann Baptist Scheffer Ary Scheffer "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the'
+ ' painter' ""'s"" ' wife' ',' ' and' ' of' ' his' ' daughter' ',' ' the']" , the painter , and of his wife , the painter 's wife , and of his daughter , the False Dutch painter Ary Scheffer and the historian 4 [' Dutch', ' painter', ' Ary', ' Sche', 'ffer']
+3032 710 Name of father of x -1 Name of father of Ary Scheffer Johann Baptist Scheffer Ary Scheffer "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the'
+ ' painter' ""'s"" ' wife' ',' ' and' ' of' ' his' ' daughter' ',' ' the']" , the painter , and of his wife , the painter 's wife , and of his daughter , the False the Dutch painter Ary Scheffer and the historian 5 [' the', ' Dutch', ' painter', ' Ary', ' Sche', 'ffer']
+3033 710 Name of father of x -1 Name of father of Ary Scheffer Johann Baptist Scheffer Ary Scheffer "[',' ' the' ' painter' ',' ' and' ' of' ' his' ' wife' ',' ' the'
+ ' painter' ""'s"" ' wife' ',' ' and' ' of' ' his' ' daughter' ',' ' the']" , the painter , and of his wife , the painter 's wife , and of his daughter , the False contains a painting by Ary Scheffer and a Father 6 [' contains', ' a', ' painting', ' by', ' Ary', ' Sche', 'ffer']
+3034 711 Name of father of x -1 Name of father of Melanie Griffith Peter Griffith Melanie Griffith "['s' ',' ' who' ' was' ' a' ' member' ' of' ' the' ' British'
+ ' Parliament' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of'
+ ' Melanie']" "s , who was a member of the British Parliament .
+
+ The name of the father of Melanie" False a third-tier Melanie Griffith rom-com or a forgotten 5 [' a', ' third', '-', 'tier', ' Melanie', ' Griffith']
+3035 711 Name of father of x -1 Name of father of Melanie Griffith Peter Griffith Melanie Griffith "['s' ',' ' who' ' was' ' a' ' member' ' of' ' the' ' British'
+ ' Parliament' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of'
+ ' Melanie']" "s , who was a member of the British Parliament .
+
+ The name of the father of Melanie" False third-tier Melanie Griffith rom-com or 4 [' third', '-', 'tier', ' Melanie', ' Griffith']
+3036 711 Name of father of x -1 Name of father of Melanie Griffith Peter Griffith Melanie Griffith "['s' ',' ' who' ' was' ' a' ' member' ' of' ' the' ' British'
+ ' Parliament' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of'
+ ' Melanie']" "s , who was a member of the British Parliament .
+
+ The name of the father of Melanie" False soundtrack of a third-tier Melanie Griffith rom-com or a forgotten 7 [' soundtrack', ' of', ' a', ' third', '-', 'tier', ' Melanie', ' Griffith']
+3037 711 Name of father of x -1 Name of father of Melanie Griffith Peter Griffith Melanie Griffith "['s' ',' ' who' ' was' ' a' ' member' ' of' ' the' ' British'
+ ' Parliament' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of'
+ ' Melanie']" "s , who was a member of the British Parliament .
+
+ The name of the father of Melanie" False of a third-tier Melanie Griffith rom-com or a forgotten 6 [' of', ' a', ' third', '-', 'tier', ' Melanie', ' Griffith']
+3038 711 Name of father of x -1 Name of father of Melanie Griffith Peter Griffith Melanie Griffith "['s' ',' ' who' ' was' ' a' ' member' ' of' ' the' ' British'
+ ' Parliament' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of'
+ ' Melanie']" "s , who was a member of the British Parliament .
+
+ The name of the father of Melanie" False a third-tier Melanie Griffith rom-com or a forgotten 5 [' a', ' third', '-', 'tier', ' Melanie', ' Griffith']
+3039 713 Name of father of x -1 Name of father of Shimon Peres Yitzhak Perski Shimon Peres "[',' ' the' ' former' ' Israeli' ' prime' ' minister' ',' ' who' ' died'
+ ' in' ' 2014' '.' '\n' '\n' 'The' ' Israeli' ' government' ' has' ' been'
+ ' accused']" ", the former Israeli prime minister , who died in 2014 .
+
+ The Israeli government has been accused" False Israeli President Shimon Peres described Obama's 5 [' Israeli', ' President', ' Sh', 'imon', ' Pe', 'res']
+3040 713 Name of father of x -1 Name of father of Shimon Peres Yitzhak Perski Shimon Peres "[',' ' the' ' former' ' Israeli' ' prime' ' minister' ',' ' who' ' died'
+ ' in' ' 2014' '.' '\n' '\n' 'The' ' Israeli' ' government' ' has' ' been'
+ ' accused']" ", the former Israeli prime minister , who died in 2014 .
+
+ The Israeli government has been accused" False Israeli president Shimon Peres and visited 5 [' Israeli', ' president', ' Sh', 'imon', ' Pe', 'res']
+3041 713 Name of father of x -1 Name of father of Shimon Peres Yitzhak Perski Shimon Peres "[',' ' the' ' former' ' Israeli' ' prime' ' minister' ',' ' who' ' died'
+ ' in' ' 2014' '.' '\n' '\n' 'The' ' Israeli' ' government' ' has' ' been'
+ ' accused']" ", the former Israeli prime minister , who died in 2014 .
+
+ The Israeli government has been accused" False Yitzhak Rabin, and Shimon Peres received the 1994 10 [' Y', 'itz', 'hak', ' R', 'abin', ',', ' and', ' Sh', 'imon', ' Pe', 'res']
+3042 713 Name of father of x -1 Name of father of Shimon Peres Yitzhak Perski Shimon Peres "[',' ' the' ' former' ' Israeli' ' prime' ' minister' ',' ' who' ' died'
+ ' in' ' 2014' '.' '\n' '\n' 'The' ' Israeli' ' government' ' has' ' been'
+ ' accused']" ", the former Israeli prime minister , who died in 2014 .
+
+ The Israeli government has been accused" False Yitzhak Rabin, and Shimon Peres received the 1994 award, 10 [' Y', 'itz', 'hak', ' R', 'abin', ',', ' and', ' Sh', 'imon', ' Pe', 'res']
+3043 713 Name of father of x -1 Name of father of Shimon Peres Yitzhak Perski Shimon Peres "[',' ' the' ' former' ' Israeli' ' prime' ' minister' ',' ' who' ' died'
+ ' in' ' 2014' '.' '\n' '\n' 'The' ' Israeli' ' government' ' has' ' been'
+ ' accused']" ", the former Israeli prime minister , who died in 2014 .
+
+ The Israeli government has been accused" False Israeli president Shimon Peres and visited Refaeli's 5 [' Israeli', ' president', ' Sh', 'imon', ' Pe', 'res']
+3044 714 Name of father of x -1 Name of father of Pope Francis Mario José Bergoglio Pope Francis "['\n' '\n' 'I' ' am' ' a' ' Catholic' ',' ' and' ' I' ' am' ' a'
+ ' conservative' '.' ' I' ' am' ' a' ' conservative' ' because' ' I'
+ ' believe']" "
+
+ I am a Catholic , and I am a conservative . I am a conservative because I believe" False 2015, in 2014 Pope Francis crowned Our Lady 5 [' 2015', ',', ' in', ' 2014', ' Pope', ' Francis']
+3045 714 Name of father of x -1 Name of father of Pope Francis Mario José Bergoglio Pope Francis "['\n' '\n' 'I' ' am' ' a' ' Catholic' ',' ' and' ' I' ' am' ' a'
+ ' conservative' '.' ' I' ' am' ' a' ' conservative' ' because' ' I'
+ ' believe']" "
+
+ I am a Catholic , and I am a conservative . I am a conservative because I believe" False " growth through Pope Francis ’ lead.
+" 3 [' growth', ' through', ' Pope', ' Francis']
+3046 714 Name of father of x -1 Name of father of Pope Francis Mario José Bergoglio Pope Francis "['\n' '\n' 'I' ' am' ' a' ' Catholic' ',' ' and' ' I' ' am' ' a'
+ ' conservative' '.' ' I' ' am' ' a' ' conservative' ' because' ' I'
+ ' believe']" "
+
+ I am a Catholic , and I am a conservative . I am a conservative because I believe" False 1 ['Pope', ' Francis']
+3047 714 Name of father of x -1 Name of father of Pope Francis Mario José Bergoglio Pope Francis "['\n' '\n' 'I' ' am' ' a' ' Catholic' ',' ' and' ' I' ' am' ' a'
+ ' conservative' '.' ' I' ' am' ' a' ' conservative' ' because' ' I'
+ ' believe']" "
+
+ I am a Catholic , and I am a conservative . I am a conservative because I believe" False " growth through Pope Francis ’ lead.
+" 3 [' growth', ' through', ' Pope', ' Francis']
+3048 714 Name of father of x -1 Name of father of Pope Francis Mario José Bergoglio Pope Francis "['\n' '\n' 'I' ' am' ' a' ' Catholic' ',' ' and' ' I' ' am' ' a'
+ ' conservative' '.' ' I' ' am' ' a' ' conservative' ' because' ' I'
+ ' believe']" "
+
+ I am a Catholic , and I am a conservative . I am a conservative because I believe" False " 2016, which Pope Francis will address.
+" 4 [' 2016', ',', ' which', ' Pope', ' Francis']
+3049 715 Name of father of x -1 Name of father of Raquel Welch Armando Carlos Tejada Urquizo Raquel Welch "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 2 ['Ra', 'quel', ' Welch']
+3050 715 Name of father of x -1 Name of father of Raquel Welch Armando Carlos Tejada Urquizo Raquel Welch "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False skin bikini Raquel Welch wore in the film One 4 [' skin', ' bikini', ' Ra', 'quel', ' Welch']
+3051 715 Name of father of x -1 Name of father of Raquel Welch Armando Carlos Tejada Urquizo Raquel Welch "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Bardot and Raquel Welch also contributed 5 [' Bard', 'ot', ' and', ' Ra', 'quel', ' Welch']
+3052 715 Name of father of x -1 Name of father of Raquel Welch Armando Carlos Tejada Urquizo Raquel Welch "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False time. In 1972, Raquel Welch visited Stamford 7 [' time', '.', ' In', ' 1972', ',', ' Ra', 'quel', ' Welch']
+3053 715 Name of father of x -1 Name of father of Raquel Welch Armando Carlos Tejada Urquizo Raquel Welch "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 2 ['Ra', 'quel', ' Welch']
+3054 716 Name of father of x -1 Name of father of Kirk Douglas Harry Danielovitch Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Diana' ' Dill' '.' '\n' '\n']" ", the actor , and his wife , actress , actress , and producer , Diana Dill .
+
+" False Vegas. The home of Kirk Douglas was used for the scene 6 [' Vegas', '.', ' The', ' home', ' of', ' Kirk', ' Douglas']
+3055 716 Name of father of x -1 Name of father of Kirk Douglas Harry Danielovitch Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Diana' ' Dill' '.' '\n' '\n']" ", the actor , and his wife , actress , actress , and producer , Diana Dill .
+
+" False English). Kirk Douglas was interested in the 3 [' English', ').', ' Kirk', ' Douglas']
+3056 716 Name of father of x -1 Name of father of Kirk Douglas Harry Danielovitch Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Diana' ' Dill' '.' '\n' '\n']" ", the actor , and his wife , actress , actress , and producer , Diana Dill .
+
+" False professional singer. Actor Kirk Douglas was one of Como's 5 [' professional', ' singer', '.', ' Actor', ' Kirk', ' Douglas']
+3057 716 Name of father of x -1 Name of father of Kirk Douglas Harry Danielovitch Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Diana' ' Dill' '.' '\n' '\n']" ", the actor , and his wife , actress , actress , and producer , Diana Dill .
+
+" False Donald Sutherland, Kirk Douglas and Lawrence Tierney. 4 [' Donald', ' Sutherland', ',', ' Kirk', ' Douglas']
+3058 716 Name of father of x -1 Name of father of Kirk Douglas Harry Danielovitch Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Diana' ' Dill' '.' '\n' '\n']" ", the actor , and his wife , actress , actress , and producer , Diana Dill .
+
+" False professional singer. Actor Kirk Douglas was one of Como's 5 [' professional', ' singer', '.', ' Actor', ' Kirk', ' Douglas']
+3059 718 Name of father of x -1 Name of father of Albert, Prince Consort Ernest I, Duke of Saxe-Coburg and Gotha Albert, Prince Consort "[' of' ' England' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" " of England , and of the
+ " False death of her husband Albert, Prince Consort on 14 December 1861. 8 [' death', ' of', ' her', ' husband', ' Albert', ',', ' Prince', ' Cons', 'ort']
+3060 718 Name of father of x -1 Name of father of Albert, Prince Consort Ernest I, Duke of Saxe-Coburg and Gotha Albert, Prince Consort "[' of' ' England' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" " of England , and of the
+ " False her husband Albert, Prince Consort on 14 December 6 [' her', ' husband', ' Albert', ',', ' Prince', ' Cons', 'ort']
+3061 718 Name of father of x -1 Name of father of Albert, Prince Consort Ernest I, Duke of Saxe-Coburg and Gotha Albert, Prince Consort "[' of' ' England' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" " of England , and of the
+ " False " Prince Consort =
+" 8 [' Prince', ' Cons', 'ort', ' =', 'Albert', ',', ' Prince', ' Cons', 'ort']
+3062 718 Name of father of x -1 Name of father of Albert, Prince Consort Ernest I, Duke of Saxe-Coburg and Gotha Albert, Prince Consort "[' of' ' England' ',' ' and' ' of' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" " of England , and of the
+ " False " Prince Consort =
+" 8 [' Prince', ' Cons', 'ort', ' =', 'Albert', ',', ' Prince', ' Cons', 'ort']
+3063 719 Name of father of x -1 Name of father of Boris Pasternak Leonid Pasternak Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' executed' ' in'
+ ' the' ' Soviet' ' Union' ' in' ' 1960' '.' '\n' '\n' 'The' ' name' ' of']" ", the Russian writer , who was executed in the Soviet Union in 1960 .
+
+ The name of" False close friends with Boris Pasternak (who, though married, 6 [' close', ' friends', ' with', ' Boris', ' P', 'astern', 'ak']
+3064 719 Name of father of x -1 Name of father of Boris Pasternak Leonid Pasternak Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' executed' ' in'
+ ' the' ' Soviet' ' Union' ' in' ' 1960' '.' '\n' '\n' 'The' ' name' ' of']" ", the Russian writer , who was executed in the Soviet Union in 1960 .
+
+ The name of" False on a translation by Boris Pasternak and directed 7 [' on', ' a', ' translation', ' by', ' Boris', ' P', 'astern', 'ak']
+3065 719 Name of father of x -1 Name of father of Boris Pasternak Leonid Pasternak Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' executed' ' in'
+ ' the' ' Soviet' ' Union' ' in' ' 1960' '.' '\n' '\n' 'The' ' name' ' of']" ", the Russian writer , who was executed in the Soviet Union in 1960 .
+
+ The name of" False friends with Boris Pasternak (who, though 5 [' friends', ' with', ' Boris', ' P', 'astern', 'ak']
+3066 719 Name of father of x -1 Name of father of Boris Pasternak Leonid Pasternak Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' executed' ' in'
+ ' the' ' Soviet' ' Union' ' in' ' 1960' '.' '\n' '\n' 'The' ' name' ' of']" ", the Russian writer , who was executed in the Soviet Union in 1960 .
+
+ The name of" False War II. In 1958, Boris Pasternak declined his 9 [' War', ' II', '.', ' In', ' 1958', ',', ' Boris', ' P', 'astern', 'ak']
+3067 719 Name of father of x -1 Name of father of Boris Pasternak Leonid Pasternak Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' executed' ' in'
+ ' the' ' Soviet' ' Union' ' in' ' 1960' '.' '\n' '\n' 'The' ' name' ' of']" ", the Russian writer , who was executed in the Soviet Union in 1960 .
+
+ The name of" False close friends with Boris Pasternak (who, though married, 6 [' close', ' friends', ' with', ' Boris', ' P', 'astern', 'ak']
+3068 720 Name of father of x -1 Name of father of Albert I of Belgium Prince Philippe, Count of Flanders Albert I of Belgium "['\n' '\n' 'Category' ':' '18' 'th' '-' 'century' ' births' '\n'
+ 'Category' ':' '17' '01' ' deaths' '\n' 'Category' ':' '18' 'th']" "
+
+ Category : 18 th - century births
+ Category : 17 01 deaths
+ Category : 18 th" False horse of King Albert I of Belgium was a Charolais 6 [' horse', ' of', ' King', ' Albert', ' I', ' of', ' Belgium']
+3069 720 Name of father of x -1 Name of father of Albert I of Belgium Prince Philippe, Count of Flanders Albert I of Belgium "['\n' '\n' 'Category' ':' '18' 'th' '-' 'century' ' births' '\n'
+ 'Category' ':' '17' '01' ' deaths' '\n' 'Category' ':' '18' 'th']" "
+
+ Category : 18 th - century births
+ Category : 17 01 deaths
+ Category : 18 th" False horse of King Albert I of Belgium was a Charolais 6 [' horse', ' of', ' King', ' Albert', ' I', ' of', ' Belgium']
+3070 721 Name of father of x -1 Name of father of Kim Il-sung Kim Hyŏng-jik Kim Il-sung "[',' ' Kim' ' Jong' '-' 'il' ',' ' and' ' Kim' ' Jong' '-' 'un' '.' '\n'
+ '\n' 'The' ' North' ' Korean' ' leader' ""'s"" ' father']" ", Kim Jong - il , and Kim Jong - un .
+
+ The North Korean leader 's father" False Castro of Cuba and Kim Il-sung of North Korea. In 7 [' Castro', ' of', ' Cuba', ' and', ' Kim', ' Il', '-', 'sung']
+3071 721 Name of father of x -1 Name of father of Kim Il-sung Kim Hyŏng-jik Kim Il-sung "[',' ' Kim' ' Jong' '-' 'il' ',' ' and' ' Kim' ' Jong' '-' 'un' '.' '\n'
+ '\n' 'The' ' North' ' Korean' ' leader' ""'s"" ' father']" ", Kim Jong - il , and Kim Jong - un .
+
+ The North Korean leader 's father" False of Cuba and Kim Il-sung of North Korea. 6 [' of', ' Cuba', ' and', ' Kim', ' Il', '-', 'sung']
+3072 721 Name of father of x -1 Name of father of Kim Il-sung Kim Hyŏng-jik Kim Il-sung "[',' ' Kim' ' Jong' '-' 'il' ',' ' and' ' Kim' ' Jong' '-' 'un' '.' '\n'
+ '\n' 'The' ' North' ' Korean' ' leader' ""'s"" ' father']" ", Kim Jong - il , and Kim Jong - un .
+
+ The North Korean leader 's father" False 3 ['Kim', ' Il', '-', 'sung']
+3073 721 Name of father of x -1 Name of father of Kim Il-sung Kim Hyŏng-jik Kim Il-sung "[',' ' Kim' ' Jong' '-' 'il' ',' ' and' ' Kim' ' Jong' '-' 'un' '.' '\n'
+ '\n' 'The' ' North' ' Korean' ' leader' ""'s"" ' father']" ", Kim Jong - il , and Kim Jong - un .
+
+ The North Korean leader 's father" False Paektu exaggerates Kim Il-sung's activities 8 [' Pa', 'ek', 'tu', ' exagger', 'ates', ' Kim', ' Il', '-', 'sung']
+3074 721 Name of father of x -1 Name of father of Kim Il-sung Kim Hyŏng-jik Kim Il-sung "[',' ' Kim' ' Jong' '-' 'il' ',' ' and' ' Kim' ' Jong' '-' 'un' '.' '\n'
+ '\n' 'The' ' North' ' Korean' ' leader' ""'s"" ' father']" ", Kim Jong - il , and Kim Jong - un .
+
+ The North Korean leader 's father" False WPK, with Kim Il-sung planning to formalize 7 [' WP', 'K', ',', ' with', ' Kim', ' Il', '-', 'sung']
+3075 722 Name of father of x -1 Name of father of Torquato Tasso Bernardo Tasso Torquato Tasso "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 15' '44' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the poet , who was born in 15 44 .
+
+ The name of the father of the" False to music: Torquato Tasso and G.B. Guarini 7 [' to', ' music', ':', ' Tor', 'qu', 'ato', ' T', 'asso']
+3076 722 Name of father of x -1 Name of father of Torquato Tasso Bernardo Tasso Torquato Tasso "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 15' '44' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the poet , who was born in 15 44 .
+
+ The name of the father of the" False centuries. In 1580, Torquato Tasso wrote Jerusalem 10 [' centuries', '.', ' In', ' 15', '80', ',', ' Tor', 'qu', 'ato', ' T', 'asso']
+3077 722 Name of father of x -1 Name of father of Torquato Tasso Bernardo Tasso Torquato Tasso "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 15' '44' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the poet , who was born in 15 44 .
+
+ The name of the father of the" False centuries. In 1580, Torquato Tasso wrote Jerusalem Delivered, 10 [' centuries', '.', ' In', ' 15', '80', ',', ' Tor', 'qu', 'ato', ' T', 'asso']
+3078 722 Name of father of x -1 Name of father of Torquato Tasso Bernardo Tasso Torquato Tasso "[',' ' the' ' poet' ',' ' who' ' was' ' born' ' in' ' 15' '44' '.' '\n'
+ '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the poet , who was born in 15 44 .
+
+ The name of the father of the" False limited to music: Torquato Tasso and G.B. Guarini 8 [' limited', ' to', ' music', ':', ' Tor', 'qu', 'ato', ' T', 'asso']
+3079 723 Name of father of x -1 Name of father of François Arago François Bonaventure Arago François Arago "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 18' '03' ','
+ ' and' ' died' ' in' ' 18' '77' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 18 03 , and died in 18 77 .
+
+ The" False communicated to François Arago the idea that 4 [' communicated', ' to', ' François', ' Ar', 'ago']
+3080 723 Name of father of x -1 Name of father of François Arago François Bonaventure Arago François Arago "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 18' '03' ','
+ ' and' ' died' ' in' ' 18' '77' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 18 03 , and died in 18 77 .
+
+ The" False including attempts by François Arago and Claude-Louis 5 [' including', ' attempts', ' by', ' François', ' Ar', 'ago']
+3081 723 Name of father of x -1 Name of father of François Arago François Bonaventure Arago François Arago "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 18' '03' ','
+ ' and' ' died' ' in' ' 18' '77' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 18 03 , and died in 18 77 .
+
+ The" False He communicated to François Arago the idea that his 5 [' He', ' communicated', ' to', ' François', ' Ar', 'ago']
+3082 723 Name of father of x -1 Name of father of François Arago François Bonaventure Arago François Arago "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 18' '03' ','
+ ' and' ' died' ' in' ' 18' '77' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 18 03 , and died in 18 77 .
+
+ The" False including attempts by François Arago and Claude-Louis Mathieu 5 [' including', ' attempts', ' by', ' François', ' Ar', 'ago']
+3083 723 Name of father of x -1 Name of father of François Arago François Bonaventure Arago François Arago "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 18' '03' ','
+ ' and' ' died' ' in' ' 18' '77' '.' '\n' '\n' 'The']" ", the astronomer , who was born in 18 03 , and died in 18 77 .
+
+ The" False attempts by François Arago and Claude-Louis 4 [' attempts', ' by', ' François', ' Ar', 'ago']
+3084 724 Name of father of x -1 Name of father of Walter Gropius Walter Gropius Walter Gropius "[',' ' the' ' architect' ' of' ' the' ' B' 'au' 'haus' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' B' 'au' 'haus' ' school' ' of' ' architecture']" , the architect of the B au haus , and the founder of the B au haus school of architecture False from his friend Walter Gropius who he had first 6 [' from', ' his', ' friend', ' Walter', ' G', 'rop', 'ius']
+3085 724 Name of father of x -1 Name of father of Walter Gropius Walter Gropius Walter Gropius "[',' ' the' ' architect' ' of' ' the' ' B' 'au' 'haus' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' B' 'au' 'haus' ' school' ' of' ' architecture']" , the architect of the B au haus , and the founder of the B au haus school of architecture False Two of these, Walter Gropius and Marcel 7 [' Two', ' of', ' these', ',', ' Walter', ' G', 'rop', 'ius']
+3086 724 Name of father of x -1 Name of father of Walter Gropius Walter Gropius Walter Gropius "[',' ' the' ' architect' ' of' ' the' ' B' 'au' 'haus' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' B' 'au' 'haus' ' school' ' of' ' architecture']" , the architect of the B au haus , and the founder of the B au haus school of architecture False in 1964. She married Walter Gropius in 1915, divorced 8 [' in', ' 1964', '.', ' She', ' married', ' Walter', ' G', 'rop', 'ius']
+3087 724 Name of father of x -1 Name of father of Walter Gropius Walter Gropius Walter Gropius "[',' ' the' ' architect' ' of' ' the' ' B' 'au' 'haus' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' B' 'au' 'haus' ' school' ' of' ' architecture']" , the architect of the B au haus , and the founder of the B au haus school of architecture False of these, Walter Gropius and Marcel 6 [' of', ' these', ',', ' Walter', ' G', 'rop', 'ius']
+3088 724 Name of father of x -1 Name of father of Walter Gropius Walter Gropius Walter Gropius "[',' ' the' ' architect' ' of' ' the' ' B' 'au' 'haus' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' B' 'au' 'haus' ' school' ' of' ' architecture']" , the architect of the B au haus , and the founder of the B au haus school of architecture False had engaged Walter Gropius as the architect 5 [' had', ' engaged', ' Walter', ' G', 'rop', 'ius']
+3089 725 Name of father of x -1 Name of father of Pearl S. Buck Absalom Sydenstricker Pearl S. Buck "[',' ' the' ' famous' ' American' ' writer' '.' '\n' '\n' 'The' ' name'
+ ' of' ' the' ' father' ' of' ' Pearl' ' S' '.' ' Buck' ',' ' the']" ", the famous American writer .
+
+ The name of the father of Pearl S . Buck , the" False American author Pearl S. Buck and her books 5 [' American', ' author', ' Pearl', ' S', '.', ' Buck']
+3090 725 Name of father of x -1 Name of father of Pearl S. Buck Absalom Sydenstricker Pearl S. Buck "[',' ' the' ' famous' ' American' ' writer' '.' '\n' '\n' 'The' ' name'
+ ' of' ' the' ' father' ' of' ' Pearl' ' S' '.' ' Buck' ',' ' the']" ", the famous American writer .
+
+ The name of the father of Pearl S . Buck , the" False American author Pearl S. Buck and her books The Good 5 [' American', ' author', ' Pearl', ' S', '.', ' Buck']
+3091 726 Name of father of x -1 Name of father of Mikhail Lermontov Yury Lermontov Mikhail Lermontov "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Mikhail' ' L'
+ 'erm' 'ont' 'ov' ' is' ' unknown' '.' '\n' '\n' 'The' ' name']" "
+
+ The name of the father of Mikhail L erm ont ov is unknown .
+
+ The name" False Russian poet Mikhail Lermontov wrote the romantic 6 [' Russian', ' poet', ' Mikhail', ' L', 'erm', 'ont', 'ov']
+3092 726 Name of father of x -1 Name of father of Mikhail Lermontov Yury Lermontov Mikhail Lermontov "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Mikhail' ' L'
+ 'erm' 'ont' 'ov' ' is' ' unknown' '.' '\n' '\n' 'The' ' name']" "
+
+ The name of the father of Mikhail L erm ont ov is unknown .
+
+ The name" False Russian poet Mikhail Lermontov wrote the romantic 6 [' Russian', ' poet', ' Mikhail', ' L', 'erm', 'ont', 'ov']
+3093 726 Name of father of x -1 Name of father of Mikhail Lermontov Yury Lermontov Mikhail Lermontov "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Mikhail' ' L'
+ 'erm' 'ont' 'ov' ' is' ' unknown' '.' '\n' '\n' 'The' ' name']" "
+
+ The name of the father of Mikhail L erm ont ov is unknown .
+
+ The name" False the Russian poet Mikhail Lermontov wrote the romantic 7 [' the', ' Russian', ' poet', ' Mikhail', ' L', 'erm', 'ont', 'ov']
+3094 727 Name of father of x -1 Name of father of Antoine-Jean Gros Jean-Antoine Gros Antoine-Jean Gros "[',' ' the' ' painter' '.' '\n' '\n' 'References' '\n' '\n' 'Category' ':'
+ '17' '55' ' births' '\n' 'Category' ':' '18' '29' ' deaths']" ", the painter .
+
+ References
+
+ Category : 17 55 births
+ Category : 18 29 deaths" False (1748 – 1825) and Antoine-Jean Gros (1771 – 1835), 13 [' (', '17', '48', ' –', ' 18', '25', ')', ' and', ' Ant', 'oine', '-', 'Jean', ' G', 'ros']
+3095 727 Name of father of x -1 Name of father of Antoine-Jean Gros Jean-Antoine Gros Antoine-Jean Gros "[',' ' the' ' painter' '.' '\n' '\n' 'References' '\n' '\n' 'Category' ':'
+ '17' '55' ' births' '\n' 'Category' ':' '18' '29' ' deaths']" ", the painter .
+
+ References
+
+ Category : 17 55 births
+ Category : 18 29 deaths" False (1748 – 1825) and Antoine-Jean Gros (1771 – 1835), 13 [' (', '17', '48', ' –', ' 18', '25', ')', ' and', ' Ant', 'oine', '-', 'Jean', ' G', 'ros']
+3096 727 Name of father of x -1 Name of father of Antoine-Jean Gros Jean-Antoine Gros Antoine-Jean Gros "[',' ' the' ' painter' '.' '\n' '\n' 'References' '\n' '\n' 'Category' ':'
+ '17' '55' ' births' '\n' 'Category' ':' '18' '29' ' deaths']" ", the painter .
+
+ References
+
+ Category : 17 55 births
+ Category : 18 29 deaths" False 1825) and Antoine-Jean Gros (1771 – 1835), 9 [' 18', '25', ')', ' and', ' Ant', 'oine', '-', 'Jean', ' G', 'ros']
+3097 728 Name of father of x -1 Name of father of Germaine de Staël Jacques Necker Germaine de Staël "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Germ' 'aine' ' de'
+ ' Sta' 'ë' 'l' ' is' ' not' ' known' '.' '\n' '\n']" "
+
+ The name of the father of Germ aine de Sta ë l is not known .
+
+" False (1712 – 1778) and Germaine de Staël (1766 – 1817). More 13 [' (', '17', '12', ' –', ' 17', '78', ')', ' and', ' Germ', 'aine', ' de', ' Sta', 'ë', 'l']
+3098 729 Name of father of x -1 Name of father of Alexis de Tocqueville Hervé Clérel de Tocqueville Alexis de Tocqueville "['\n' '\n' 'Alex' 'is' ' de' ' T' 'oc' 'qu' 'ev' 'ille' ' was' ' born'
+ ' in' ' 18' '05' ' in' ' Paris' ',' ' France' '.']" "
+
+ Alex is de T oc qu ev ille was born in 18 05 in Paris , France ." False the journey of Alexis de Tocqueville described in Democracy 9 [' the', ' journey', ' of', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+3099 729 Name of father of x -1 Name of father of Alexis de Tocqueville Hervé Clérel de Tocqueville Alexis de Tocqueville "['\n' '\n' 'Alex' 'is' ' de' ' T' 'oc' 'qu' 'ev' 'ille' ' was' ' born'
+ ' in' ' 18' '05' ' in' ' Paris' ',' ' France' '.']" "
+
+ Alex is de T oc qu ev ille was born in 18 05 in Paris , France ." False retraced the journey of Alexis de Tocqueville described in Democracy 12 [' ret', 'r', 'aced', ' the', ' journey', ' of', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+3100 729 Name of father of x -1 Name of father of Alexis de Tocqueville Hervé Clérel de Tocqueville Alexis de Tocqueville "['\n' '\n' 'Alex' 'is' ' de' ' T' 'oc' 'qu' 'ev' 'ille' ' was' ' born'
+ ' in' ' 18' '05' ' in' ' Paris' ',' ' France' '.']" "
+
+ Alex is de T oc qu ev ille was born in 18 05 in Paris , France ." False the journey of Alexis de Tocqueville described in 9 [' the', ' journey', ' of', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+3101 729 Name of father of x -1 Name of father of Alexis de Tocqueville Hervé Clérel de Tocqueville Alexis de Tocqueville "['\n' '\n' 'Alex' 'is' ' de' ' T' 'oc' 'qu' 'ev' 'ille' ' was' ' born'
+ ' in' ' 18' '05' ' in' ' Paris' ',' ' France' '.']" "
+
+ Alex is de T oc qu ev ille was born in 18 05 in Paris , France ." False Parliament drew ire. Alexis de Tocqueville described Blackstone 10 [' Parliament', ' drew', ' ire', '.', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+3102 729 Name of father of x -1 Name of father of Alexis de Tocqueville Hervé Clérel de Tocqueville Alexis de Tocqueville "['\n' '\n' 'Alex' 'is' ' de' ' T' 'oc' 'qu' 'ev' 'ille' ' was' ' born'
+ ' in' ' 18' '05' ' in' ' Paris' ',' ' France' '.']" "
+
+ Alex is de T oc qu ev ille was born in 18 05 in Paris , France ." False retraced the journey of Alexis de Tocqueville described in Democracy 12 [' ret', 'r', 'aced', ' the', ' journey', ' of', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+3103 732 Name of father of x -1 Name of father of Tim Berners-Lee Conway Berners-Lee Tim Berners-Lee "[',' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' '.' '\n' '\n' 'The'
+ ' Web' ' is' ' a' ' great' ' invention' ',' ' but' ' it' ' is']" ", inventor of the World Wide Web .
+
+ The Web is a great invention , but it is" False not knowing that Tim Berners-Lee was the inventor 7 [' not', ' knowing', ' that', ' Tim', ' Bern', 'ers', '-', 'Lee']
+3104 732 Name of father of x -1 Name of father of Tim Berners-Lee Conway Berners-Lee Tim Berners-Lee "[',' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' '.' '\n' '\n' 'The'
+ ' Web' ' is' ' a' ' great' ' invention' ',' ' but' ' it' ' is']" ", inventor of the World Wide Web .
+
+ The Web is a great invention , but it is" False pioneering programs. Tim Berners-Lee used a NeXT Computer 7 [' pioneering', ' programs', '.', ' Tim', ' Bern', 'ers', '-', 'Lee']
+3105 732 Name of father of x -1 Name of father of Tim Berners-Lee Conway Berners-Lee Tim Berners-Lee "[',' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' '.' '\n' '\n' 'The'
+ ' Web' ' is' ' a' ' great' ' invention' ',' ' but' ' it' ' is']" ", inventor of the World Wide Web .
+
+ The Web is a great invention , but it is" False pioneering programs. Tim Berners-Lee used a NeXT 7 [' pioneering', ' programs', '.', ' Tim', ' Bern', 'ers', '-', 'Lee']
+3106 732 Name of father of x -1 Name of father of Tim Berners-Lee Conway Berners-Lee Tim Berners-Lee "[',' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' '.' '\n' '\n' 'The'
+ ' Web' ' is' ' a' ' great' ' invention' ',' ' but' ' it' ' is']" ", inventor of the World Wide Web .
+
+ The Web is a great invention , but it is" False raised to reveal Tim Berners-Lee working at a 7 [' raised', ' to', ' reveal', ' Tim', ' Bern', 'ers', '-', 'Lee']
+3107 732 Name of father of x -1 Name of father of Tim Berners-Lee Conway Berners-Lee Tim Berners-Lee "[',' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' '.' '\n' '\n' 'The'
+ ' Web' ' is' ' a' ' great' ' invention' ',' ' but' ' it' ' is']" ", inventor of the World Wide Web .
+
+ The Web is a great invention , but it is" False pioneering programs. Tim Berners-Lee used a NeXT 7 [' pioneering', ' programs', '.', ' Tim', ' Bern', 'ers', '-', 'Lee']
+3108 733 Name of father of x -1 Name of father of James Fenimore Cooper William Cooper James Fenimore Cooper "[',' ' the' ' novelist' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the novelist , was born in this house .
+
+ The house is now a museum , and" False especially James Fenimore Cooper and Philip Freneau, 4 [' especially', ' James', ' Fen', 'imore', ' Cooper']
+3109 733 Name of father of x -1 Name of father of James Fenimore Cooper William Cooper James Fenimore Cooper "[',' ' the' ' novelist' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the novelist , was born in this house .
+
+ The house is now a museum , and" False 1803, and James Fenimore Cooper described similar 7 [' 18', '03', ',', ' and', ' James', ' Fen', 'imore', ' Cooper']
+3110 733 Name of father of x -1 Name of father of James Fenimore Cooper William Cooper James Fenimore Cooper "[',' ' the' ' novelist' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the novelist , was born in this house .
+
+ The house is now a museum , and" False chestnut in color. James Fenimore Cooper described them as: 8 [' chest', 'nut', ' in', ' color', '.', ' James', ' Fen', 'imore', ' Cooper']
+3111 733 Name of father of x -1 Name of father of James Fenimore Cooper William Cooper James Fenimore Cooper "[',' ' the' ' novelist' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the novelist , was born in this house .
+
+ The house is now a museum , and" False the Mohicans by James Fenimore Cooper and in film adaptations 7 [' the', ' Moh', 'icans', ' by', ' James', ' Fen', 'imore', ' Cooper']
+3112 733 Name of father of x -1 Name of father of James Fenimore Cooper William Cooper James Fenimore Cooper "[',' ' the' ' novelist' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the novelist , was born in this house .
+
+ The house is now a museum , and" False critics, especially James Fenimore Cooper and Philip 6 [' critics', ',', ' especially', ' James', ' Fen', 'imore', ' Cooper']
+3113 734 Name of father of x -1 Name of father of Cecil Beaton Ernest Walter Hardy Beaton Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' who' ' was' ' a' ' friend'
+ ' of' ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the famous photographer , who was a friend of the family .
+
+ The house was built in" False Greville. According to Cecil Beaton the clientele was 7 [' Gre', 'ville', '.', ' According', ' to', ' Cecil', ' Beat', 'on']
+3114 734 Name of father of x -1 Name of father of Cecil Beaton Ernest Walter Hardy Beaton Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' who' ' was' ' a' ' friend'
+ ' of' ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the famous photographer , who was a friend of the family .
+
+ The house was built in" False Greville. According to Cecil Beaton the clientele was 7 [' Gre', 'ville', '.', ' According', ' to', ' Cecil', ' Beat', 'on']
+3115 734 Name of father of x -1 Name of father of Cecil Beaton Ernest Walter Hardy Beaton Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' who' ' was' ' a' ' friend'
+ ' of' ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the famous photographer , who was a friend of the family .
+
+ The house was built in" False According to Cecil Beaton the clientele 4 [' According', ' to', ' Cecil', ' Beat', 'on']
+3116 734 Name of father of x -1 Name of father of Cecil Beaton Ernest Walter Hardy Beaton Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' who' ' was' ' a' ' friend'
+ ' of' ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the famous photographer , who was a friend of the family .
+
+ The house was built in" False According to Cecil Beaton the clientele 4 [' According', ' to', ' Cecil', ' Beat', 'on']
+3117 734 Name of father of x -1 Name of father of Cecil Beaton Ernest Walter Hardy Beaton Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' who' ' was' ' a' ' friend'
+ ' of' ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' built' ' in']" ", the famous photographer , who was a friend of the family .
+
+ The house was built in" False Bill Brandt, Cecil Beaton and Bert Hardy. 6 [' Bill', ' Brand', 't', ',', ' Cecil', ' Beat', 'on']
+3118 735 Name of father of x -1 Name of father of Peter Kropotkin Aleksey Kropotkin Peter Kropotkin "['\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' most'
+ ' important' ' works' ' of' ' Peter' ' K' 'rop' 'ot' 'kin' '.' '\n']" "
+
+ The following is a list of the most important works of Peter K rop ot kin .
+" False co-authors, Jean Grave and Peter Kropotkin were among 11 [' co', '-', 'authors', ',', ' Jean', ' Grave', ' and', ' Peter', ' K', 'rop', 'ot', 'kin']
+3119 735 Name of father of x -1 Name of father of Peter Kropotkin Aleksey Kropotkin Peter Kropotkin "['\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' most'
+ ' important' ' works' ' of' ' Peter' ' K' 'rop' 'ot' 'kin' '.' '\n']" "
+
+ The following is a list of the most important works of Peter K rop ot kin .
+" False anarchists such as Peter Kropotkin stressed the 7 [' anarchists', ' such', ' as', ' Peter', ' K', 'rop', 'ot', 'kin']
+3120 735 Name of father of x -1 Name of father of Peter Kropotkin Aleksey Kropotkin Peter Kropotkin "['\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' most'
+ ' important' ' works' ' of' ' Peter' ' K' 'rop' 'ot' 'kin' '.' '\n']" "
+
+ The following is a list of the most important works of Peter K rop ot kin .
+" False individual acts. Peter Kropotkin thus wrote that 7 [' individual', ' acts', '.', ' Peter', ' K', 'rop', 'ot', 'kin']
+3121 735 Name of father of x -1 Name of father of Peter Kropotkin Aleksey Kropotkin Peter Kropotkin "['\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' most'
+ ' important' ' works' ' of' ' Peter' ' K' 'rop' 'ot' 'kin' '.' '\n']" "
+
+ The following is a list of the most important works of Peter K rop ot kin .
+" False Encyclopedia Britannica, Peter Kropotkin also noted that 8 [' Encyclopedia', ' Britann', 'ica', ',', ' Peter', ' K', 'rop', 'ot', 'kin']
+3122 735 Name of father of x -1 Name of father of Peter Kropotkin Aleksey Kropotkin Peter Kropotkin "['\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' most'
+ ' important' ' works' ' of' ' Peter' ' K' 'rop' 'ot' 'kin' '.' '\n']" "
+
+ The following is a list of the most important works of Peter K rop ot kin .
+" False fellow anarchist Peter Kropotkin who helped found 6 [' fellow', ' anarchist', ' Peter', ' K', 'rop', 'ot', 'kin']
+3123 737 Name of father of x -1 Name of father of Franz Werfel Rudolf Werfel Franz Werfel "[',' ' the' ' author' ' of' ' ""' 'The' ' Forty' '-' 'nine' ' Steps' '""'
+ ' and' ' ""' 'The' ' Secret' ' of' ' the' ' D' 'orm' 'ouse']" ", the author of "" The Forty - nine Steps "" and "" The Secret of the D orm ouse" False married the writer Franz Werfel in 1929. In 1940 6 [' married', ' the', ' writer', ' Franz', ' Wer', 'f', 'el']
+3124 738 Name of father of x -1 Name of father of Pierre Loti Théodore Viaud Pierre Loti "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '69'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' writer']" ", the French writer , who was born in 18 69 .
+
+ The name of the French writer" False the story by Pierre Loti later set by Puccini 5 [' the', ' story', ' by', ' Pierre', ' L', 'oti']
+3125 739 Name of father of x -1 Name of father of Jennifer Garner William John Garner Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Chris wishes for Jennifer Garner and Meg wishes 4 [' Chris', ' wishes', ' for', ' Jennifer', ' Garner']
+3126 739 Name of father of x -1 Name of father of Jennifer Garner William John Garner Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False portray FBI agents. Jennifer Garner cameos as a call 5 [' portray', ' FBI', ' agents', '.', ' Jennifer', ' Garner']
+3127 739 Name of father of x -1 Name of father of Jennifer Garner William John Garner Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " Elektra Natchios
+" 7 [' Ele', 'k', 'tra', ' N', 'atch', 'ios', 'Jennifer', ' Garner']
+3128 739 Name of father of x -1 Name of father of Jennifer Garner William John Garner Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Spielberg had seen Jennifer Garner on Alias and wanted 4 [' Spielberg', ' had', ' seen', ' Jennifer', ' Garner']
+3129 739 Name of father of x -1 Name of father of Jennifer Garner William John Garner Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 1 ['Jennifer', ' Garner']
+3130 740 Name of father of x -1 Name of father of Erwin Schrödinger Rudolf Schrödinger Erwin Schrödinger "[',' ' the' ' famous' ' physicist' '.' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' the' ' famous' ' physicist' ' Er' 'win' ' Schr'
+ 'ö']" ", the famous physicist .
+
+ The name of the father of the famous physicist Er win Schr ö" False Einstein suggested to Erwin Schrödinger that he might be 8 [' Einstein', ' suggested', ' to', ' Er', 'win', ' Schr', 'ö', 'd', 'inger']
+3131 740 Name of father of x -1 Name of father of Erwin Schrödinger Rudolf Schrödinger Erwin Schrödinger "[',' ' the' ' famous' ' physicist' '.' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' the' ' famous' ' physicist' ' Er' 'win' ' Schr'
+ 'ö']" ", the famous physicist .
+
+ The name of the father of the famous physicist Er win Schr ö" False Austrian physicist Erwin Schrödinger also visited in 7 [' Austrian', ' physicist', ' Er', 'win', ' Schr', 'ö', 'd', 'inger']
+3132 740 Name of father of x -1 Name of father of Erwin Schrödinger Rudolf Schrödinger Erwin Schrödinger "[',' ' the' ' famous' ' physicist' '.' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' the' ' famous' ' physicist' ' Er' 'win' ' Schr'
+ 'ö']" ", the famous physicist .
+
+ The name of the father of the famous physicist Er win Schr ö" False " criticized in 1917 by Erwin Schrödinger and others.
+" 9 [' criticized', ' in', ' 1917', ' by', ' Er', 'win', ' Schr', 'ö', 'd', 'inger']
+3133 740 Name of father of x -1 Name of father of Erwin Schrödinger Rudolf Schrödinger Erwin Schrödinger "[',' ' the' ' famous' ' physicist' '.' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' the' ' famous' ' physicist' ' Er' 'win' ' Schr'
+ 'ö']" ", the famous physicist .
+
+ The name of the father of the famous physicist Er win Schr ö" False quantum mechanics. Erwin Schrödinger applied this 8 [' quantum', ' mechanics', '.', ' Er', 'win', ' Schr', 'ö', 'd', 'inger']
+3134 740 Name of father of x -1 Name of father of Erwin Schrödinger Rudolf Schrödinger Erwin Schrödinger "[',' ' the' ' famous' ' physicist' '.' '\n' '\n' 'The' ' name' ' of'
+ ' the' ' father' ' of' ' the' ' famous' ' physicist' ' Er' 'win' ' Schr'
+ 'ö']" ", the famous physicist .
+
+ The name of the father of the famous physicist Er win Schr ö" False quantum mechanics. Erwin Schrödinger applied this to 8 [' quantum', ' mechanics', '.', ' Er', 'win', ' Schr', 'ö', 'd', 'inger']
+3135 742 Name of father of x -1 Name of father of A. R. Rahman R. K. Shekhar A. R. Rahman "[',' ' the' ' famous' ' Indian' ' composer' '.' '\n' '\n' 'A' '.' ' R' '.'
+ ' Rahman' ' is' ' the' ' son' ' of' ' A' '.' ' R']" ", the famous Indian composer .
+
+ A . R . Rahman is the son of A . R" False Boney M, composer A. R. Rahman and pop artists 9 [' B', 'oney', ' M', ',', ' composer', ' A', '.', ' R', '.', ' Rahman']
+3136 742 Name of father of x -1 Name of father of A. R. Rahman R. K. Shekhar A. R. Rahman "[',' ' the' ' famous' ' Indian' ' composer' '.' '\n' '\n' 'A' '.' ' R' '.'
+ ' Rahman' ' is' ' the' ' son' ' of' ' A' '.' ' R']" ", the famous Indian composer .
+
+ A . R . Rahman is the son of A . R" False debutant music director A. R. Rahman to score his Tamil 8 [' debut', 'ant', ' music', ' director', ' A', '.', ' R', '.', ' Rahman']
+3137 742 Name of father of x -1 Name of father of A. R. Rahman R. K. Shekhar A. R. Rahman "[',' ' the' ' famous' ' Indian' ' composer' '.' '\n' '\n' 'A' '.' ' R' '.'
+ ' Rahman' ' is' ' the' ' son' ' of' ' A' '.' ' R']" ", the famous Indian composer .
+
+ A . R . Rahman is the son of A . R" False earnestly. Music composer A. R. Rahman performed several concerts 9 [' earnest', 'ly', '.', ' Music', ' composer', ' A', '.', ' R', '.', ' Rahman']
+3138 742 Name of father of x -1 Name of father of A. R. Rahman R. K. Shekhar A. R. Rahman "[',' ' the' ' famous' ' Indian' ' composer' '.' '\n' '\n' 'A' '.' ' R' '.'
+ ' Rahman' ' is' ' the' ' son' ' of' ' A' '.' ' R']" ", the famous Indian composer .
+
+ A . R . Rahman is the son of A . R" False were composed by A. R. Rahman while the dialogues, 7 [' were', ' composed', ' by', ' A', '.', ' R', '.', ' Rahman']
+3139 742 Name of father of x -1 Name of father of A. R. Rahman R. K. Shekhar A. R. Rahman "[',' ' the' ' famous' ' Indian' ' composer' '.' '\n' '\n' 'A' '.' ' R' '.'
+ ' Rahman' ' is' ' the' ' son' ' of' ' A' '.' ' R']" ", the famous Indian composer .
+
+ A . R . Rahman is the son of A . R" False of big-sounding A. R. Rahman songs, excellent 8 [' of', ' big', '-', 'sounding', ' A', '.', ' R', '.', ' Rahman']
+3140 744 Name of father of x -1 Name of father of Isabelle Adjani Mohammed Chérif Adjani Isabelle Adjani "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False age 14, after seeing Isabelle Adjani in The Story of 9 [' age', ' 14', ',', ' after', ' seeing', ' Isabel', 'le', ' Ad', 'j', 'ani']
+3141 744 Name of father of x -1 Name of father of Isabelle Adjani Mohammed Chérif Adjani Isabelle Adjani "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False 14, after seeing Isabelle Adjani in The Story 8 [' 14', ',', ' after', ' seeing', ' Isabel', 'le', ' Ad', 'j', 'ani']
+3142 744 Name of father of x -1 Name of father of Isabelle Adjani Mohammed Chérif Adjani Isabelle Adjani "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False age 14, after seeing Isabelle Adjani in The Story of Adele 9 [' age', ' 14', ',', ' after', ' seeing', ' Isabel', 'le', ' Ad', 'j', 'ani']
+3143 744 Name of father of x -1 Name of father of Isabelle Adjani Mohammed Chérif Adjani Isabelle Adjani "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was'
+ ' released']" ", the actress who played the role of the mother in the film .
+
+ The film was released" False age 14, after seeing Isabelle Adjani in The Story 9 [' age', ' 14', ',', ' after', ' seeing', ' Isabel', 'le', ' Ad', 'j', 'ani']
+3144 745 Name of father of x -1 Name of father of Friedrich Hölderlin Heinrich Friedrich Hölderlin Friedrich Hölderlin "[',' ' the' ' poet' ',' ' who' ' died' ' in' ' 18' '43' '.' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' translation' ' of' ' the' ' poem']" ", the poet , who died in 18 43 .
+
+ The following is a translation of the poem" False " lyric poet Friedrich Hölderlin (1770 – 1843).
+" 6 [' lyric', ' poet', ' Friedrich', ' H', 'ö', 'lder', 'lin']
+3145 745 Name of father of x -1 Name of father of Friedrich Hölderlin Heinrich Friedrich Hölderlin Friedrich Hölderlin "[',' ' the' ' poet' ',' ' who' ' died' ' in' ' 18' '43' '.' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' translation' ' of' ' the' ' poem']" ", the poet , who died in 18 43 .
+
+ The following is a translation of the poem" False " German lyric poet Friedrich Hölderlin (1770 – 1843).
+" 7 [' German', ' lyric', ' poet', ' Friedrich', ' H', 'ö', 'lder', 'lin']
+3146 745 Name of father of x -1 Name of father of Friedrich Hölderlin Heinrich Friedrich Hölderlin Friedrich Hölderlin "[',' ' the' ' poet' ',' ' who' ' died' ' in' ' 18' '43' '.' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' translation' ' of' ' the' ' poem']" ", the poet , who died in 18 43 .
+
+ The following is a translation of the poem" False " German lyric poet Friedrich Hölderlin (1770 – 1843).
+" 7 [' German', ' lyric', ' poet', ' Friedrich', ' H', 'ö', 'lder', 'lin']
+3147 746 Name of father of x -1 Name of father of Helmut Schmidt Gustav Ludwig Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Nazi' ' party' '.' '\n' '\n' 'The' ' German' ' chancellor' ',']" ", the German chancellor , who was a member of the Nazi party .
+
+ The German chancellor ," False representative for chancellor Helmut Schmidt during a debt 5 [' representative', ' for', ' chancellor', ' Hel', 'mut', ' Schmidt']
+3148 746 Name of father of x -1 Name of father of Helmut Schmidt Gustav Ludwig Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Nazi' ' party' '.' '\n' '\n' 'The' ' German' ' chancellor' ',']" ", the German chancellor , who was a member of the Nazi party .
+
+ The German chancellor ," False for former chancellor Helmut Schmidt (born in Hamburg), 5 [' for', ' former', ' chancellor', ' Hel', 'mut', ' Schmidt']
+3149 746 Name of father of x -1 Name of father of Helmut Schmidt Gustav Ludwig Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Nazi' ' party' '.' '\n' '\n' 'The' ' German' ' chancellor' ',']" ", the German chancellor , who was a member of the Nazi party .
+
+ The German chancellor ," False former chancellors Helmut Schmidt and Gerhard Schröder. 6 [' former', ' chance', 'll', 'ors', ' Hel', 'mut', ' Schmidt']
+3150 746 Name of father of x -1 Name of father of Helmut Schmidt Gustav Ludwig Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Nazi' ' party' '.' '\n' '\n' 'The' ' German' ' chancellor' ',']" ", the German chancellor , who was a member of the Nazi party .
+
+ The German chancellor ," False former chancellor Helmut Schmidt (born in Hamburg), 4 [' former', ' chancellor', ' Hel', 'mut', ' Schmidt']
+3151 746 Name of father of x -1 Name of father of Helmut Schmidt Gustav Ludwig Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Nazi' ' party' '.' '\n' '\n' 'The' ' German' ' chancellor' ',']" ", the German chancellor , who was a member of the Nazi party .
+
+ The German chancellor ," False former chancellors Helmut Schmidt and Gerhard 6 [' former', ' chance', 'll', 'ors', ' Hel', 'mut', ' Schmidt']
+3152 747 Name of father of x -1 Name of father of Ernst Haeckel Carl Haeckel Ernst Haeckel "[',' ' the' ' German' ' biologist' ' and' ' philosopher' ',' ' who' ' was'
+ ' a' ' leading' ' figure' ' in' ' the' ' development' ' of' ' the'
+ ' theory' ' of' ' evolution']" , the German biologist and philosopher , who was a leading figure in the development of the theory of evolution False political radicals. Ernst Haeckel was particularly 6 [' political', ' radicals', '.', ' Ernst', ' Ha', 'ec', 'kel']
+3153 747 Name of father of x -1 Name of father of Ernst Haeckel Carl Haeckel Ernst Haeckel "[',' ' the' ' German' ' biologist' ' and' ' philosopher' ',' ' who' ' was'
+ ' a' ' leading' ' figure' ' in' ' the' ' development' ' of' ' the'
+ ' theory' ' of' ' evolution']" , the German biologist and philosopher , who was a leading figure in the development of the theory of evolution False zoologist and artist Ernst Haeckel painted teleosts 7 [' zo', 'ologist', ' and', ' artist', ' Ernst', ' Ha', 'ec', 'kel']
+3154 747 Name of father of x -1 Name of father of Ernst Haeckel Carl Haeckel Ernst Haeckel "[',' ' the' ' German' ' biologist' ' and' ' philosopher' ',' ' who' ' was'
+ ' a' ' leading' ' figure' ' in' ' the' ' development' ' of' ' the'
+ ' theory' ' of' ' evolution']" , the German biologist and philosopher , who was a leading figure in the development of the theory of evolution False 5 ['Er', 'n', 'st', ' Ha', 'ec', 'kel']
+3155 747 Name of father of x -1 Name of father of Ernst Haeckel Carl Haeckel Ernst Haeckel "[',' ' the' ' German' ' biologist' ' and' ' philosopher' ',' ' who' ' was'
+ ' a' ' leading' ' figure' ' in' ' the' ' development' ' of' ' the'
+ ' theory' ' of' ' evolution']" , the German biologist and philosopher , who was a leading figure in the development of the theory of evolution False Saint-Hilaire, and Ernst Haeckel suggested that Madagascar 10 [' Saint', '-', 'H', 'il', 'aire', ',', ' and', ' Ernst', ' Ha', 'ec', 'kel']
+3156 747 Name of father of x -1 Name of father of Ernst Haeckel Carl Haeckel Ernst Haeckel "[',' ' the' ' German' ' biologist' ' and' ' philosopher' ',' ' who' ' was'
+ ' a' ' leading' ' figure' ' in' ' the' ' development' ' of' ' the'
+ ' theory' ' of' ' evolution']" , the German biologist and philosopher , who was a leading figure in the development of the theory of evolution False Saint-Hilaire, and Ernst Haeckel suggested that 10 [' Saint', '-', 'H', 'il', 'aire', ',', ' and', ' Ernst', ' Ha', 'ec', 'kel']
+3157 749 Name of father of x -1 Name of father of Narendra Modi Damodardas Mulchand Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' and' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ',']" , the BJP � � s prime ministerial candidate , and the BJP � � s prime ministerial candidate , False Gujarat Chief Minister Narendra Modi led it to a landslide 4 [' Gujarat', ' Chief', ' Minister', ' Narendra', ' Modi']
+3158 749 Name of father of x -1 Name of father of Narendra Modi Damodardas Mulchand Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' and' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ',']" , the BJP � � s prime ministerial candidate , and the BJP � � s prime ministerial candidate , False " government and Narendra Modi ==
+" 3 [' government', ' and', ' Narendra', ' Modi']
+3159 749 Name of father of x -1 Name of father of Narendra Modi Damodardas Mulchand Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' and' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ',']" , the BJP � � s prime ministerial candidate , and the BJP � � s prime ministerial candidate , False Prime Minister of India Narendra Modi who won the Lok 5 [' Prime', ' Minister', ' of', ' India', ' Narendra', ' Modi']
+3160 749 Name of father of x -1 Name of father of Narendra Modi Damodardas Mulchand Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' and' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ',']" , the BJP � � s prime ministerial candidate , and the BJP � � s prime ministerial candidate , False Chief Minister Narendra Modi led it to a landslide 3 [' Chief', ' Minister', ' Narendra', ' Modi']
+3161 749 Name of father of x -1 Name of father of Narendra Modi Damodardas Mulchand Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' and' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ',']" , the BJP � � s prime ministerial candidate , and the BJP � � s prime ministerial candidate , False Prime Minister Narendra Modi and a large number 3 [' Prime', ' Minister', ' Narendra', ' Modi']
+3162 750 Name of father of x -1 Name of father of Mary Wollstonecraft Edward John Wollstonecraft Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '.' '\n' '\n' 'W' 'oll' 'stone' 'craft']" ", the author of the V ind ication of the Rights of Woman .
+
+ W oll stone craft" False between Johnson and Mary Wollstonecraft was pivotal 7 [' between', ' Johnson', ' and', ' Mary', ' W', 'oll', 'stone', 'craft']
+3163 750 Name of father of x -1 Name of father of Mary Wollstonecraft Edward John Wollstonecraft Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '.' '\n' '\n' 'W' 'oll' 'stone' 'craft']" ", the author of the V ind ication of the Rights of Woman .
+
+ W oll stone craft" False extraordinary pair: Mary Wollstonecraft and William Godwin, 7 [' extraordinary', ' pair', ':', ' Mary', ' W', 'oll', 'stone', 'craft']
+3164 750 Name of father of x -1 Name of father of Mary Wollstonecraft Edward John Wollstonecraft Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '.' '\n' '\n' 'W' 'oll' 'stone' 'craft']" ", the author of the V ind ication of the Rights of Woman .
+
+ W oll stone craft" False Moral Mothers: Mary Wollstonecraft and the Female 7 [' Moral', ' Mothers', ':', ' Mary', ' W', 'oll', 'stone', 'craft']
+3165 750 Name of father of x -1 Name of father of Mary Wollstonecraft Edward John Wollstonecraft Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '.' '\n' '\n' 'W' 'oll' 'stone' 'craft']" ", the author of the V ind ication of the Rights of Woman .
+
+ W oll stone craft" False the period such as Mary Wollstonecraft argued for co-educational 8 [' the', ' period', ' such', ' as', ' Mary', ' W', 'oll', 'stone', 'craft']
+3166 750 Name of father of x -1 Name of father of Mary Wollstonecraft Edward John Wollstonecraft Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '.' '\n' '\n' 'W' 'oll' 'stone' 'craft']" ", the author of the V ind ication of the Rights of Woman .
+
+ W oll stone craft" False Godwin could replace Mary Wollstonecraft with her. Fanny 8 [' God', 'win', ' could', ' replace', ' Mary', ' W', 'oll', 'stone', 'craft']
+3167 751 Name of father of x -1 Name of father of Norman Foster Robert Foster Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' '.' '\n' '\n'
+ 'The' ' stadium' ' is' ' a' ' great' ' idea' ',' ' but' ' it' ' is']" ", the architect of the new stadium .
+
+ The stadium is a great idea , but it is" False designed by Sir Norman Foster and Chris Wise at 4 [' designed', ' by', ' Sir', ' Norman', ' Foster']
+3168 751 Name of father of x -1 Name of father of Norman Foster Robert Foster Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' '.' '\n' '\n'
+ 'The' ' stadium' ' is' ' a' ' great' ' idea' ',' ' but' ' it' ' is']" ", the architect of the new stadium .
+
+ The stadium is a great idea , but it is" False architect Sir Norman Foster for a site on Hartlepool's 3 [' architect', ' Sir', ' Norman', ' Foster']
+3169 751 Name of father of x -1 Name of father of Norman Foster Robert Foster Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' '.' '\n' '\n'
+ 'The' ' stadium' ' is' ' a' ' great' ' idea' ',' ' but' ' it' ' is']" ", the architect of the new stadium .
+
+ The stadium is a great idea , but it is" False the architect Sir Norman Foster was commissioned 4 [' the', ' architect', ' Sir', ' Norman', ' Foster']
+3170 751 Name of father of x -1 Name of father of Norman Foster Robert Foster Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' '.' '\n' '\n'
+ 'The' ' stadium' ' is' ' a' ' great' ' idea' ',' ' but' ' it' ' is']" ", the architect of the new stadium .
+
+ The stadium is a great idea , but it is" False (1997), designed by Norman Foster (1935 –) and known 6 [' (', '1997', '),', ' designed', ' by', ' Norman', ' Foster']
+3171 751 Name of father of x -1 Name of father of Norman Foster Robert Foster Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' '.' '\n' '\n'
+ 'The' ' stadium' ' is' ' a' ' great' ' idea' ',' ' but' ' it' ' is']" ", the architect of the new stadium .
+
+ The stadium is a great idea , but it is" False (1997), designed by Norman Foster (1935 –) and known 6 [' (', '1997', '),', ' designed', ' by', ' Norman', ' Foster']
+3172 752 Name of father of x -1 Name of father of Walter Benjamin Emil Benjamin Walter Benjamin "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' two']" ", the great German philosopher , who was born in 18 92 .
+
+ The first of the two" False 2 ['Wal', 'ter', ' Benjamin']
+3173 752 Name of father of x -1 Name of father of Walter Benjamin Emil Benjamin Walter Benjamin "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' two']" ", the great German philosopher , who was born in 18 92 .
+
+ The first of the two" False Wagner critics. Walter Benjamin gave Wagner 4 [' Wagner', ' critics', '.', ' Walter', ' Benjamin']
+3174 752 Name of father of x -1 Name of father of Walter Benjamin Emil Benjamin Walter Benjamin "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' two']" ", the great German philosopher , who was born in 18 92 .
+
+ The first of the two" False 2 ['Wal', 'ter', ' Benjamin']
+3175 752 Name of father of x -1 Name of father of Walter Benjamin Emil Benjamin Walter Benjamin "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' born'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' two']" ", the great German philosopher , who was born in 18 92 .
+
+ The first of the two" False German philosopher Walter Benjamin published in 3 [' German', ' philosopher', ' Walter', ' Benjamin']
+3176 753 Name of father of x -1 Name of father of Harold Pinter Jack Haim Pinter Harold Pinter "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Viv' 'ien' ' Leigh' ',' ' who' ' was' ' the' ' star' ' of']" , the play wright , and his wife , the actress Viv ien Leigh , who was the star of False 3 ['Har', 'old', ' P', 'inter']
+3177 753 Name of father of x -1 Name of father of Harold Pinter Jack Haim Pinter Harold Pinter "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Viv' 'ien' ' Leigh' ',' ' who' ' was' ' the' ' star' ' of']" , the play wright , and his wife , the actress Viv ien Leigh , who was the star of False London 5 [' L', 'ondo', 'Har', 'old', ' P', 'inter']
+3178 753 Name of father of x -1 Name of father of Harold Pinter Jack Haim Pinter Harold Pinter "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Viv' 'ien' ' Leigh' ',' ' who' ' was' ' the' ' star' ' of']" , the play wright , and his wife , the actress Viv ien Leigh , who was the star of False him are held in the Harold Pinter Archive in the 7 [' him', ' are', ' held', ' in', ' the', ' Harold', ' P', 'inter']
+3179 753 Name of father of x -1 Name of father of Harold Pinter Jack Haim Pinter Harold Pinter "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Viv' 'ien' ' Leigh' ',' ' who' ' was' ' the' ' star' ' of']" , the play wright , and his wife , the actress Viv ien Leigh , who was the star of False of the plays of Harold Pinter and backed Pinter's 6 [' of', ' the', ' plays', ' of', ' Harold', ' P', 'inter']
+3180 753 Name of father of x -1 Name of father of Harold Pinter Jack Haim Pinter Harold Pinter "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Viv' 'ien' ' Leigh' ',' ' who' ' was' ' the' ' star' ' of']" , the play wright , and his wife , the actress Viv ien Leigh , who was the star of False Betrayal as part of a Harold Pinter tribute evening at 9 [' Bet', 'ray', 'al', ' as', ' part', ' of', ' a', ' Harold', ' P', 'inter']
+3181 754 Name of father of x -1 Name of father of Mila Kunis Mark Kunis Mila Kunis "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False played by Mila Kunis in the film 5 [' played', ' by', ' Mil', 'a', ' Kun', 'is']
+3182 754 Name of father of x -1 Name of father of Mila Kunis Mark Kunis Mila Kunis "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Portman, Vincent Cassel, Mila Kunis and Winona Ryder. 10 [' Port', 'man', ',', ' Vincent', ' Cas', 'sel', ',', ' Mil', 'a', ' Kun', 'is']
+3183 754 Name of father of x -1 Name of father of Mila Kunis Mark Kunis Mila Kunis "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False " was played by Mila Kunis in the film adaptation.
+" 6 [' was', ' played', ' by', ' Mil', 'a', ' Kun', 'is']
+3184 754 Name of father of x -1 Name of father of Mila Kunis Mark Kunis Mila Kunis "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False that aired to feature Mila Kunis as the voice of Meg 7 [' that', ' aired', ' to', ' feature', ' Mil', 'a', ' Kun', 'is']
+3185 754 Name of father of x -1 Name of father of Mila Kunis Mark Kunis Mila Kunis "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Timberlake and Mila Kunis is almost enough to 6 [' Timber', 'lake', ' and', ' Mil', 'a', ' Kun', 'is']
+3186 755 Name of father of x -1 Name of father of John Quincy Adams John Adams John Quincy Adams "[',' ' the' ' great' ' American' ' president' '.' '\n' '\n' 'The' ' first'
+ ' time' ' I' ' saw' ' the' ' name' ' of' ' John' ' Quincy' ' Adams' ' on']" ", the great American president .
+
+ The first time I saw the name of John Quincy Adams on" False Judah Lord wrote to John Quincy Adams (then United States 6 [' Judah', ' Lord', ' wrote', ' to', ' John', ' Quincy', ' Adams']
+3187 755 Name of father of x -1 Name of father of John Quincy Adams John Adams John Quincy Adams "[',' ' the' ' great' ' American' ' president' '.' '\n' '\n' 'The' ' first'
+ ' time' ' I' ' saw' ' the' ' name' ' of' ' John' ' Quincy' ' Adams' ' on']" ", the great American president .
+
+ The first time I saw the name of John Quincy Adams on" False President Monroe and John Quincy Adams ... On the reverse 5 [' President', ' Monroe', ' and', ' John', ' Quincy', ' Adams']
+3188 755 Name of father of x -1 Name of father of John Quincy Adams John Adams John Quincy Adams "[',' ' the' ' great' ' American' ' president' '.' '\n' '\n' 'The' ' first'
+ ' time' ' I' ' saw' ' the' ' name' ' of' ' John' ' Quincy' ' Adams' ' on']" ", the great American president .
+
+ The first time I saw the name of John Quincy Adams on" False to President John Quincy Adams was General Andrew 4 [' to', ' President', ' John', ' Quincy', ' Adams']
+3189 755 Name of father of x -1 Name of father of John Quincy Adams John Adams John Quincy Adams "[',' ' the' ' great' ' American' ' president' '.' '\n' '\n' 'The' ' first'
+ ' time' ' I' ' saw' ' the' ' name' ' of' ' John' ' Quincy' ' Adams' ' on']" ", the great American president .
+
+ The first time I saw the name of John Quincy Adams on" False 2 ['John', ' Quincy', ' Adams']
+3190 755 Name of father of x -1 Name of father of John Quincy Adams John Adams John Quincy Adams "[',' ' the' ' great' ' American' ' president' '.' '\n' '\n' 'The' ' first'
+ ' time' ' I' ' saw' ' the' ' name' ' of' ' John' ' Quincy' ' Adams' ' on']" ", the great American president .
+
+ The first time I saw the name of John Quincy Adams on" False future president John Quincy Adams in 1767, Susanna 4 [' future', ' president', ' John', ' Quincy', ' Adams']
+3191 756 Name of father of x -1 Name of father of Gerald Ford Leslie Lynch King, Sr. Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Betty' ',' ' who' ' was' ' the' ' mother'
+ ' of']" , the former president of the United States , and his wife , Betty , who was the mother of False 2 ['G', 'erald', ' Ford']
+3192 756 Name of father of x -1 Name of father of Gerald Ford Leslie Lynch King, Sr. Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Betty' ',' ' who' ' was' ' the' ' mother'
+ ' of']" , the former president of the United States , and his wife , Betty , who was the mother of False hosted by President Gerald Ford at the Dorado 4 [' hosted', ' by', ' President', ' Gerald', ' Ford']
+3193 756 Name of father of x -1 Name of father of Gerald Ford Leslie Lynch King, Sr. Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Betty' ',' ' who' ' was' ' the' ' mother'
+ ' of']" , the former president of the United States , and his wife , Betty , who was the mother of False the episode, Gerald Ford moves into 4 [' the', ' episode', ',', ' Gerald', ' Ford']
+3194 756 Name of father of x -1 Name of father of Gerald Ford Leslie Lynch King, Sr. Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Betty' ',' ' who' ' was' ' the' ' mother'
+ ' of']" , the former president of the United States , and his wife , Betty , who was the mother of False 2 ['G', 'erald', ' Ford']
+3195 756 Name of father of x -1 Name of father of Gerald Ford Leslie Lynch King, Sr. Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Betty' ',' ' who' ' was' ' the' ' mother'
+ ' of']" , the former president of the United States , and his wife , Betty , who was the mother of False Appointment of Gerald Ford as Vice President 4 [' App', 'ointment', ' of', ' Gerald', ' Ford']
+3196 759 Name of father of x -1 Name of father of Luis Buñuel Leonardo Buñuel González Luis Buñuel "[',' ' the' ' famous' ' Spanish' ' director' ',' ' who' ' was' ' born'
+ ' in' ' Mexico' ' City' ' in' ' 1899' '.' '\n' '\n' 'The' ' name' ' of']" ", the famous Spanish director , who was born in Mexico City in 1899 .
+
+ The name of" False (1818 – 19). Luis Buñuel identified 9 [' (', '18', '18', ' –', ' 19', ').', ' Luis', ' Bu', 'ñ', 'uel']
+3197 759 Name of father of x -1 Name of father of Luis Buñuel Leonardo Buñuel González Luis Buñuel "[',' ' the' ' famous' ' Spanish' ' director' ',' ' who' ' was' ' born'
+ ' in' ' Mexico' ' City' ' in' ' 1899' '.' '\n' '\n' 'The' ' name' ' of']" ", the famous Spanish director , who was born in Mexico City in 1899 .
+
+ The name of" False Medusa (1818 – 19). Luis Buñuel identified 11 [' Med', 'usa', ' (', '18', '18', ' –', ' 19', ').', ' Luis', ' Bu', 'ñ', 'uel']
+3198 759 Name of father of x -1 Name of father of Luis Buñuel Leonardo Buñuel González Luis Buñuel "[',' ' the' ' famous' ' Spanish' ' director' ',' ' who' ' was' ' born'
+ ' in' ' Mexico' ' City' ' in' ' 1899' '.' '\n' '\n' 'The' ' name' ' of']" ", the famous Spanish director , who was born in Mexico City in 1899 .
+
+ The name of" False both Fellini and Luis Buñuel declined the cameo. 7 [' both', ' Fell', 'ini', ' and', ' Luis', ' Bu', 'ñ', 'uel']
+3199 759 Name of father of x -1 Name of father of Luis Buñuel Leonardo Buñuel González Luis Buñuel "[',' ' the' ' famous' ' Spanish' ' director' ',' ' who' ' was' ' born'
+ ' in' ' Mexico' ' City' ' in' ' 1899' '.' '\n' '\n' 'The' ' name' ' of']" ", the famous Spanish director , who was born in Mexico City in 1899 .
+
+ The name of" False Medusa (1818 – 19). Luis Buñuel identified with Goya's 11 [' Med', 'usa', ' (', '18', '18', ' –', ' 19', ').', ' Luis', ' Bu', 'ñ', 'uel']
+3200 759 Name of father of x -1 Name of father of Luis Buñuel Leonardo Buñuel González Luis Buñuel "[',' ' the' ' famous' ' Spanish' ' director' ',' ' who' ' was' ' born'
+ ' in' ' Mexico' ' City' ' in' ' 1899' '.' '\n' '\n' 'The' ' name' ' of']" ", the famous Spanish director , who was born in Mexico City in 1899 .
+
+ The name of" False " ""Debaser"" alludes to Luis Buñuel and Salvador" 10 "[' ""', 'Deb', 'aser', '""', ' all', 'udes', ' to', ' Luis', ' Bu', 'ñ', 'uel']"
+3201 760 Name of father of x -1 Name of father of Neil Armstrong Stephen Koenig Armstrong Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False rather than Neil Armstrong as the first 3 [' rather', ' than', ' Neil', ' Armstrong']
+3202 760 Name of father of x -1 Name of father of Neil Armstrong Stephen Koenig Armstrong Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False " ""That happened because Neil Armstrong was a team player," 5 "[' ""', 'That', ' happened', ' because', ' Neil', ' Armstrong']"
+3203 760 Name of father of x -1 Name of father of Neil Armstrong Stephen Koenig Armstrong Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False as backup LMP. Neil Armstrong went on to 6 [' as', ' backup', ' L', 'MP', '.', ' Neil', ' Armstrong']
+3204 760 Name of father of x -1 Name of father of Neil Armstrong Stephen Koenig Armstrong Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False " romantic image of Neil Armstrong taking"" one giant" 4 [' romantic', ' image', ' of', ' Neil', ' Armstrong']
+3205 760 Name of father of x -1 Name of father of Neil Armstrong Stephen Koenig Armstrong Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False Moon. Americans Neil Armstrong and Buzz Aldrin landed 4 [' Moon', '.', ' Americans', ' Neil', ' Armstrong']
+3206 762 Name of father of x -1 Name of father of Richard von Weizsäcker Ernst von Weizsäcker Richard von Weizsäcker "['\n' '\n' 'Category' ':' '18' '86' ' births' '\n' 'Category' ':' '19'
+ '57' ' deaths' '\n' 'Category' ':' 'People' ' from' ' the' ' Province']" "
+
+ Category : 18 86 births
+ Category : 19 57 deaths
+ Category : People from the Province" False occupied France, Richard von Weizsäcker served as his 9 [' occupied', ' France', ',', ' Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+3207 762 Name of father of x -1 Name of father of Richard von Weizsäcker Ernst von Weizsäcker Richard von Weizsäcker "['\n' '\n' 'Category' ':' '18' '86' ' births' '\n' 'Category' ':' '19'
+ '57' ' deaths' '\n' 'Category' ':' 'People' ' from' ' the' ' Province']" "
+
+ Category : 18 86 births
+ Category : 19 57 deaths
+ Category : People from the Province" False " Weizsäcker =
+" 12 [' We', 'iz', 's', 'ä', 'cker', ' =', 'Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+3208 762 Name of father of x -1 Name of father of Richard von Weizsäcker Ernst von Weizsäcker Richard von Weizsäcker "['\n' '\n' 'Category' ':' '18' '86' ' births' '\n' 'Category' ':' '19'
+ '57' ' deaths' '\n' 'Category' ':' 'People' ' from' ' the' ' Province']" "
+
+ Category : 18 86 births
+ Category : 19 57 deaths
+ Category : People from the Province" False 6 ['Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+3209 762 Name of father of x -1 Name of father of Richard von Weizsäcker Ernst von Weizsäcker Richard von Weizsäcker "['\n' '\n' 'Category' ':' '18' '86' ' births' '\n' 'Category' ':' '19'
+ '57' ' deaths' '\n' 'Category' ':' 'People' ' from' ' the' ' Province']" "
+
+ Category : 18 86 births
+ Category : 19 57 deaths
+ Category : People from the Province" False 6 ['Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+3210 762 Name of father of x -1 Name of father of Richard von Weizsäcker Ernst von Weizsäcker Richard von Weizsäcker "['\n' '\n' 'Category' ':' '18' '86' ' births' '\n' 'Category' ':' '19'
+ '57' ' deaths' '\n' 'Category' ':' 'People' ' from' ' the' ' Province']" "
+
+ Category : 18 86 births
+ Category : 19 57 deaths
+ Category : People from the Province" False " von Weizsäcker =
+" 13 [' von', ' We', 'iz', 's', 'ä', 'cker', ' =', 'Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+3211 763 Name of father of x -1 Name of father of Juan Carlos I of Spain Infante Juan, Count of Barcelona Juan Carlos I of Spain "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Juan' ' Carlos'
+ ' I' ' of' ' Spain' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Juan Carlos I of Spain is not known .
+
+ The" False House, the King Juan Carlos I of Spain Center, the Hagop 8 [' House', ',', ' the', ' King', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+3212 763 Name of father of x -1 Name of father of Juan Carlos I of Spain Infante Juan, Count of Barcelona Juan Carlos I of Spain "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Juan' ' Carlos'
+ ' I' ' of' ' Spain' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Juan Carlos I of Spain is not known .
+
+ The" False " Greece and Denmark and Juan Carlos I of Spain (1962)
+" 8 [' Greece', ' and', ' Denmark', ' and', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+3213 763 Name of father of x -1 Name of father of Juan Carlos I of Spain Infante Juan, Count of Barcelona Juan Carlos I of Spain "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Juan' ' Carlos'
+ ' I' ' of' ' Spain' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Juan Carlos I of Spain is not known .
+
+ The" False Ireland House, the King Juan Carlos I of Spain Center, the Hagop 9 [' Ireland', ' House', ',', ' the', ' King', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+3214 763 Name of father of x -1 Name of father of Juan Carlos I of Spain Infante Juan, Count of Barcelona Juan Carlos I of Spain "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Juan' ' Carlos'
+ ' I' ' of' ' Spain' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Juan Carlos I of Spain is not known .
+
+ The" False " Greece and Denmark and Juan Carlos I of Spain (1962)
+" 8 [' Greece', ' and', ' Denmark', ' and', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+3215 763 Name of father of x -1 Name of father of Juan Carlos I of Spain Infante Juan, Count of Barcelona Juan Carlos I of Spain "['\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' Juan' ' Carlos'
+ ' I' ' of' ' Spain' ' is' ' not' ' known' '.' '\n' '\n' 'The']" "
+
+ The name of the father of Juan Carlos I of Spain is not known .
+
+ The" False " Denmark and Juan Carlos I of Spain (1962)
+" 6 [' Denmark', ' and', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+3216 764 Name of father of x -1 Name of father of Joan Baez Albert Baez Joan Baez "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Joan' ' B' 'aez' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the singer , and the father of the singer Joan B aez .
+
+ The name of the" False consequently it did not chart. Joan Baez included a gender-switched 8 [' consequently', ' it', ' did', ' not', ' chart', '.', ' Joan', ' B', 'aez']
+3217 764 Name of father of x -1 Name of father of Joan Baez Albert Baez Joan Baez "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Joan' ' B' 'aez' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the singer , and the father of the singer Joan B aez .
+
+ The name of the" False Dylan performed with Joan Baez at the Monterey Folk 5 [' Dylan', ' performed', ' with', ' Joan', ' B', 'aez']
+3218 764 Name of father of x -1 Name of father of Joan Baez Albert Baez Joan Baez "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Joan' ' B' 'aez' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the singer , and the father of the singer Joan B aez .
+
+ The name of the" False others. Ochs and Joan Baez sang a duet 8 [' others', '.', ' O', 'ch', 's', ' and', ' Joan', ' B', 'aez']
+3219 764 Name of father of x -1 Name of father of Joan Baez Albert Baez Joan Baez "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Joan' ' B' 'aez' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the singer , and the father of the singer Joan B aez .
+
+ The name of the" False figures such as Joan Baez and Bob Dylan had 5 [' figures', ' such', ' as', ' Joan', ' B', 'aez']
+3220 764 Name of father of x -1 Name of father of Joan Baez Albert Baez Joan Baez "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ' Joan' ' B' 'aez' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", the singer , and the father of the singer Joan B aez .
+
+ The name of the" False Hey Joe and Joan Baez included it on 5 [' Hey', ' Joe', ' and', ' Joan', ' B', 'aez']
+3221 765 Name of father of x -1 Name of father of Bernardo Bertolucci Attilio Bertolucci Bernardo Bertolucci "[',' ' the' ' director' ' of' ' The' ' Last' ' Emperor' ',' ' and' ' the'
+ ' father' ' of' ' the' ' actor' ' Mar' 'cell' 'o' ' Mast' 'ro' 'ian']" , the director of The Last Emperor , and the father of the actor Mar cell o Mast ro ian False debut, when director Bernardo Bertolucci cast her for the role 8 [' debut', ',', ' when', ' director', ' Bern', 'ardo', ' Bert', 'ol', 'ucci']
+3222 765 Name of father of x -1 Name of father of Bernardo Bertolucci Attilio Bertolucci Bernardo Bertolucci "[',' ' the' ' director' ' of' ' The' ' Last' ' Emperor' ',' ' and' ' the'
+ ' father' ' of' ' the' ' actor' ' Mar' 'cell' 'o' ' Mast' 'ro' 'ian']" , the director of The Last Emperor , and the father of the actor Mar cell o Mast ro ian False for his role in the Bernardo Bertolucci film The Dreamers. 9 [' for', ' his', ' role', ' in', ' the', ' Bern', 'ardo', ' Bert', 'ol', 'ucci']
+3223 765 Name of father of x -1 Name of father of Bernardo Bertolucci Attilio Bertolucci Bernardo Bertolucci "[',' ' the' ' director' ' of' ' The' ' Last' ' Emperor' ',' ' and' ' the'
+ ' father' ' of' ' the' ' actor' ' Mar' 'cell' 'o' ' Mast' 'ro' 'ian']" , the director of The Last Emperor , and the father of the actor Mar cell o Mast ro ian False his role in the Bernardo Bertolucci film The Dreamers. 8 [' his', ' role', ' in', ' the', ' Bern', 'ardo', ' Bert', 'ol', 'ucci']
+3224 765 Name of father of x -1 Name of father of Bernardo Bertolucci Attilio Bertolucci Bernardo Bertolucci "[',' ' the' ' director' ' of' ' The' ' Last' ' Emperor' ',' ' and' ' the'
+ ' father' ' of' ' the' ' actor' ' Mar' 'cell' 'o' ' Mast' 'ro' 'ian']" , the director of The Last Emperor , and the father of the actor Mar cell o Mast ro ian False debut, when director Bernardo Bertolucci cast her for the 8 [' debut', ',', ' when', ' director', ' Bern', 'ardo', ' Bert', 'ol', 'ucci']
+3225 765 Name of father of x -1 Name of father of Bernardo Bertolucci Attilio Bertolucci Bernardo Bertolucci "[',' ' the' ' director' ' of' ' The' ' Last' ' Emperor' ',' ' and' ' the'
+ ' father' ' of' ' the' ' actor' ' Mar' 'cell' 'o' ' Mast' 'ro' 'ian']" , the director of The Last Emperor , and the father of the actor Mar cell o Mast ro ian False written for director Bernardo Bertolucci for his film Stealing 7 [' written', ' for', ' director', ' Bern', 'ardo', ' Bert', 'ol', 'ucci']
+3226 766 Name of father of x -1 Name of father of Tycho Brahe Otte Brahe Tycho Brahe "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '46' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the astronomer , who was born in 15 46 .
+
+ The name of the father of the" False seventeenth century, Tycho Brahe measured the diurnal 7 [' sevent', 'eenth', ' century', ',', ' Ty', 'cho', ' Bra', 'he']
+3227 766 Name of father of x -1 Name of father of Tycho Brahe Otte Brahe Tycho Brahe "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '46' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the astronomer , who was born in 15 46 .
+
+ The name of the father of the" False astronomers such as Tycho Brahe identified 6 [' astronomers', ' such', ' as', ' Ty', 'cho', ' Bra', 'he']
+3228 766 Name of father of x -1 Name of father of Tycho Brahe Otte Brahe Tycho Brahe "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '46' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the astronomer , who was born in 15 46 .
+
+ The name of the father of the" False system later proposed by Tycho Brahe in the late 16th 7 [' system', ' later', ' proposed', ' by', ' Ty', 'cho', ' Bra', 'he']
+3229 766 Name of father of x -1 Name of father of Tycho Brahe Otte Brahe Tycho Brahe "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '46' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the astronomer , who was born in 15 46 .
+
+ The name of the father of the" False 16th-century astronomer Tycho Brahe used clocks with 8 [' 16', 'th', '-', 'century', ' astronomer', ' Ty', 'cho', ' Bra', 'he']
+3230 766 Name of father of x -1 Name of father of Tycho Brahe Otte Brahe Tycho Brahe "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' 15' '46' '.'
+ '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the']" ", the astronomer , who was born in 15 46 .
+
+ The name of the father of the" False 3 ['Ty', 'cho', ' Bra', 'he']
+3231 767 Name of father of x -1 Name of father of Edmond Rostand Eugène Rostand Edmond Rostand "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' 18' '94'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' poet']" ", the French poet , who was born in 18 94 .
+
+ The name of the French poet" False articles denounced Edmond Rostand as a kitsch author, 5 [' articles', ' denounced', ' Ed', 'mond', ' Rost', 'and']
+3232 767 Name of father of x -1 Name of father of Edmond Rostand Eugène Rostand Edmond Rostand "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' 18' '94'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' French' ' poet']" ", the French poet , who was born in 18 94 .
+
+ The name of the French poet" False denounced Edmond Rostand as a kitsch author, 4 [' denounced', ' Ed', 'mond', ' Rost', 'and']
+3233 769 Name of father of x -1 Name of father of Nicolas Cage August Coppola Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' big' ' fan' ' of' ' Nicolas']" ".
+
+ I am a big fan of the show , and I am a big fan of Nicolas" False 2 ['Nic', 'olas', ' Cage']
+3234 769 Name of father of x -1 Name of father of Nicolas Cage August Coppola Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' big' ' fan' ' of' ' Nicolas']" ".
+
+ I am a big fan of the show , and I am a big fan of Nicolas" False also approached Nicolas Cage for the role of 3 [' also', ' approached', ' Nicolas', ' Cage']
+3235 769 Name of father of x -1 Name of father of Nicolas Cage August Coppola Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' big' ' fan' ' of' ' Nicolas']" ".
+
+ I am a big fan of the show , and I am a big fan of Nicolas" False approached Nicolas Cage for the role of 2 [' approached', ' Nicolas', ' Cage']
+3236 769 Name of father of x -1 Name of father of Nicolas Cage August Coppola Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' big' ' fan' ' of' ' Nicolas']" ".
+
+ I am a big fan of the show , and I am a big fan of Nicolas" False Supporting Actor, while Nicolas Cage (Actor in a Leading 5 [' Supporting', ' Actor', ',', ' while', ' Nicolas', ' Cage']
+3237 769 Name of father of x -1 Name of father of Nicolas Cage August Coppola Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ','
+ ' and' ' I' ' am' ' a' ' big' ' fan' ' of' ' Nicolas']" ".
+
+ I am a big fan of the show , and I am a big fan of Nicolas" False film had been. Nicolas Cage was under discussions 5 [' film', ' had', ' been', '.', ' Nicolas', ' Cage']
+3238 770 Name of father of x -1 Name of father of David Attenborough Frederick Attenborough David Attenborough "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' BBC' ""'s""
+ ' Planet' ' Earth' ' series' ',' ' and' ' I' ' have' ' watched']" ".
+
+ I am a big fan of the BBC 's Planet Earth series , and I have watched" False " Joyce Grenfell, David Attenborough and Richard Baker.
+" 7 [' Joyce', ' Gren', 'fell', ',', ' David', ' Att', 'en', 'borough']
+3239 770 Name of father of x -1 Name of father of David Attenborough Frederick Attenborough David Attenborough "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' BBC' ""'s""
+ ' Planet' ' Earth' ' series' ',' ' and' ' I' ' have' ' watched']" ".
+
+ I am a big fan of the BBC 's Planet Earth series , and I have watched" False and naturalist Sir David Attenborough in July 2014. It was 7 [' and', ' natural', 'ist', ' Sir', ' David', ' Att', 'en', 'borough']
+3240 770 Name of father of x -1 Name of father of David Attenborough Frederick Attenborough David Attenborough "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' BBC' ""'s""
+ ' Planet' ' Earth' ' series' ',' ' and' ' I' ' have' ' watched']" ".
+
+ I am a big fan of the BBC 's Planet Earth series , and I have watched" False independence, David Attenborough introduced 5 [' independence', ',', ' David', ' Att', 'en', 'borough']
+3241 770 Name of father of x -1 Name of father of David Attenborough Frederick Attenborough David Attenborough "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' BBC' ""'s""
+ ' Planet' ' Earth' ' series' ',' ' and' ' I' ' have' ' watched']" ".
+
+ I am a big fan of the BBC 's Planet Earth series , and I have watched" False known worldwide. Sir David Attenborough praised Irwin 7 [' known', ' worldwide', '.', ' Sir', ' David', ' Att', 'en', 'borough']
+3242 770 Name of father of x -1 Name of father of David Attenborough Frederick Attenborough David Attenborough "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' BBC' ""'s""
+ ' Planet' ' Earth' ' series' ',' ' and' ' I' ' have' ' watched']" ".
+
+ I am a big fan of the BBC 's Planet Earth series , and I have watched" False sequence with Sir David Attenborough in The Life of 6 [' sequence', ' with', ' Sir', ' David', ' Att', 'en', 'borough']
+3243 771 Name of father of x -1 Name of father of Marlon Brando Marlon Brando, Sr. Marlon Brando "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Maria' ' Schneider'
+ ',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' famous' ' director'
+ ',']" , the actor , and his wife , Maria Schneider , who is the daughter of the famous director , False Julius Caesar. Marlon Brando (Mark Antony) was 6 [' Julius', ' Caesar', '.', ' Mar', 'lon', ' Brand', 'o']
+3244 771 Name of father of x -1 Name of father of Marlon Brando Marlon Brando, Sr. Marlon Brando "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Maria' ' Schneider'
+ ',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' famous' ' director'
+ ',']" , the actor , and his wife , Maria Schneider , who is the daughter of the famous director , False reference point, Marlon Brando was re-created 6 [' reference', ' point', ',', ' Mar', 'lon', ' Brand', 'o']
+3245 771 Name of father of x -1 Name of father of Marlon Brando Marlon Brando, Sr. Marlon Brando "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Maria' ' Schneider'
+ ',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' famous' ' director'
+ ',']" , the actor , and his wife , Maria Schneider , who is the daughter of the famous director , False which went to Marlon Brando for The Island 6 [' which', ' went', ' to', ' Mar', 'lon', ' Brand', 'o']
+3246 771 Name of father of x -1 Name of father of Marlon Brando Marlon Brando, Sr. Marlon Brando "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Maria' ' Schneider'
+ ',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' famous' ' director'
+ ',']" , the actor , and his wife , Maria Schneider , who is the daughter of the famous director , False originally with Marlon Brando in the lead role. Starting 5 [' originally', ' with', ' Mar', 'lon', ' Brand', 'o']
+3247 771 Name of father of x -1 Name of father of Marlon Brando Marlon Brando, Sr. Marlon Brando "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Maria' ' Schneider'
+ ',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' famous' ' director'
+ ',']" , the actor , and his wife , Maria Schneider , who is the daughter of the famous director , False " avoids pressure to cast Marlon Brando in the lead role.
+" 7 [' avoids', ' pressure', ' to', ' cast', ' Mar', 'lon', ' Brand', 'o']
+3248 772 Name of father of x -1 Name of father of George Berkeley William Berkeley George Berkeley "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' peace' ',' ' and' ' I'
+ ' am' ' not' ' a' ' man' ' of' ' war' '.']" ", the
+
+ I am a man of peace , and I am not a man of war ." False university; only George Berkeley bowled more deliveries. 4 [' university', ';', ' only', ' George', ' Berkeley']
+3249 772 Name of father of x -1 Name of father of George Berkeley William Berkeley George Berkeley "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' peace' ',' ' and' ' I'
+ ' am' ' not' ' a' ' man' ' of' ' war' '.']" ", the
+
+ I am a man of peace , and I am not a man of war ." False Anglican Irish Bishop George Berkeley advanced a 5 [' Anglic', 'an', ' Irish', ' Bishop', ' George', ' Berkeley']
+3250 772 Name of father of x -1 Name of father of George Berkeley William Berkeley George Berkeley "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' peace' ',' ' and' ' I'
+ ' am' ' not' ' a' ' man' ' of' ' war' '.']" ", the
+
+ I am a man of peace , and I am not a man of war ." False by Vice-Admiral George Berkeley of the Royal Navy, 7 [' by', ' Vice', '-', 'Ad', 'mir', 'al', ' George', ' Berkeley']
+3251 772 Name of father of x -1 Name of father of George Berkeley William Berkeley George Berkeley "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' peace' ',' ' and' ' I'
+ ' am' ' not' ' a' ' man' ' of' ' war' '.']" ", the
+
+ I am a man of peace , and I am not a man of war ." False Anglican Irish Bishop George Berkeley advanced a form of 5 [' Anglic', 'an', ' Irish', ' Bishop', ' George', ' Berkeley']
+3252 772 Name of father of x -1 Name of father of George Berkeley William Berkeley George Berkeley "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' man' ' of' ' peace' ',' ' and' ' I'
+ ' am' ' not' ' a' ' man' ' of' ' war' '.']" ", the
+
+ I am a man of peace , and I am not a man of war ." False by Vice-Admiral George Berkeley of the Royal 7 [' by', ' Vice', '-', 'Ad', 'mir', 'al', ' George', ' Berkeley']
+3253 774 Name of father of x -1 Name of father of Diana, Princess of Wales John Spencer, 8th Earl Spencer Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have' ' been'
+ ' married']" ", and the mother of Prince William and Prince Harry .
+
+ The couple , who have been married" False " Princess of Wales ===
+" 9 [' Princess', ' of', ' Wales', ' ===', 'D', 'iana', ',', ' Princess', ' of', ' Wales']
+3254 774 Name of father of x -1 Name of father of Diana, Princess of Wales John Spencer, 8th Earl Spencer Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have' ' been'
+ ' married']" ", and the mother of Prince William and Prince Harry .
+
+ The couple , who have been married" False 31 August 1997, Diana, Princess of Wales and Al-Fayed's 8 [' 31', ' August', ' 1997', ',', ' Diana', ',', ' Princess', ' of', ' Wales']
+3255 774 Name of father of x -1 Name of father of Diana, Princess of Wales John Spencer, 8th Earl Spencer Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have' ' been'
+ ' married']" ", and the mother of Prince William and Prince Harry .
+
+ The couple , who have been married" False of a visit by Diana, Princess of Wales to the gardens 8 [' of', ' a', ' visit', ' by', ' Diana', ',', ' Princess', ' of', ' Wales']
+3256 774 Name of father of x -1 Name of father of Diana, Princess of Wales John Spencer, 8th Earl Spencer Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have' ' been'
+ ' married']" ", and the mother of Prince William and Prince Harry .
+
+ The couple , who have been married" False " Walk, and the Diana, Princess of Wales Memorial Walk.
+" 8 [' Walk', ',', ' and', ' the', ' Diana', ',', ' Princess', ' of', ' Wales']
+3257 774 Name of father of x -1 Name of father of Diana, Princess of Wales John Spencer, 8th Earl Spencer Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' couple' ',' ' who' ' have' ' been'
+ ' married']" ", and the mother of Prince William and Prince Harry .
+
+ The couple , who have been married" False In August 1997, Diana, Princess of Wales and Al-Fayed's 8 [' In', ' August', ' 1997', ',', ' Diana', ',', ' Princess', ' of', ' Wales']
+3258 775 Name of father of x -1 Name of father of Lyndon B. Johnson Samuel Ealy Johnson Lyndon B. Johnson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' president' ' of' ' the' ' United' ' States' ' is' ' a' ' man'
+ ' who' ' has']" , the president of the United States , and the president of the United States is a man who has False 1965, President Lyndon B. Johnson signed the Medicare 6 [' 1965', ',', ' President', ' Lyndon', ' B', '.', ' Johnson']
+3259 775 Name of father of x -1 Name of father of Lyndon B. Johnson Samuel Ealy Johnson Lyndon B. Johnson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' president' ' of' ' the' ' United' ' States' ' is' ' a' ' man'
+ ' who' ' has']" , the president of the United States , and the president of the United States is a man who has False southwestern portion of Lyndon B. Johnson National Historic 6 [' southwestern', ' portion', ' of', ' Lyndon', ' B', '.', ' Johnson']
+3260 775 Name of father of x -1 Name of father of Lyndon B. Johnson Samuel Ealy Johnson Lyndon B. Johnson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' president' ' of' ' the' ' United' ' States' ' is' ' a' ' man'
+ ' who' ' has']" , the president of the United States , and the president of the United States is a man who has False American Vice President Lyndon B. Johnson in the film 6 [' American', ' Vice', ' President', ' Lyndon', ' B', '.', ' Johnson']
+3261 775 Name of father of x -1 Name of father of Lyndon B. Johnson Samuel Ealy Johnson Lyndon B. Johnson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' president' ' of' ' the' ' United' ' States' ' is' ' a' ' man'
+ ' who' ' has']" , the president of the United States , and the president of the United States is a man who has False to President Lyndon B. Johnson on October 25, 5 [' to', ' President', ' Lyndon', ' B', '.', ' Johnson']
+3262 775 Name of father of x -1 Name of father of Lyndon B. Johnson Samuel Ealy Johnson Lyndon B. Johnson "[',' ' the' ' president' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' president' ' of' ' the' ' United' ' States' ' is' ' a' ' man'
+ ' who' ' has']" , the president of the United States , and the president of the United States is a man who has False location of NASA's Lyndon B. Johnson Space Center. Other 7 "[' location', ' of', ' NASA', ""'s"", ' Lyndon', ' B', '.', ' Johnson']"
+3263 776 Name of father of x -1 Name of father of Joseph Addison Lancelot Addison Joseph Addison "[',' ' the' ' author' ' of' ' the' ' _' 'Spect' 'ator' '_' ',' ' and'
+ ' the' ' _' 'Guard' 'ian' '_' ',' ' and' ' the' '\n']" ", the author of the _ Spect ator _ , and the _ Guard ian _ , and the
+" False by the writers Joseph Addison and Richard 5 [' by', ' the', ' writers', ' Joseph', ' Add', 'ison']
+3264 776 Name of father of x -1 Name of father of Joseph Addison Lancelot Addison Joseph Addison "[',' ' the' ' author' ' of' ' the' ' _' 'Spect' 'ator' '_' ',' ' and'
+ ' the' ' _' 'Guard' 'ian' '_' ',' ' and' ' the' '\n']" ", the author of the _ Spect ator _ , and the _ Guard ian _ , and the
+" False in historian Joseph Addison Waddell's Annals of 4 [' in', ' historian', ' Joseph', ' Add', 'ison']
+3265 776 Name of father of x -1 Name of father of Joseph Addison Lancelot Addison Joseph Addison "[',' ' the' ' author' ' of' ' the' ' _' 'Spect' 'ator' '_' ',' ' and'
+ ' the' ' _' 'Guard' 'ian' '_' ',' ' and' ' the' '\n']" ", the author of the _ Spect ator _ , and the _ Guard ian _ , and the
+" False shared by the writers Joseph Addison and Richard 6 [' shared', ' by', ' the', ' writers', ' Joseph', ' Add', 'ison']
+3266 776 Name of father of x -1 Name of father of Joseph Addison Lancelot Addison Joseph Addison "[',' ' the' ' author' ' of' ' the' ' _' 'Spect' 'ator' '_' ',' ' and'
+ ' the' ' _' 'Guard' 'ian' '_' ',' ' and' ' the' '\n']" ", the author of the _ Spect ator _ , and the _ Guard ian _ , and the
+" False memorialized in historian Joseph Addison Waddell's Annals 6 [' memorial', 'ized', ' in', ' historian', ' Joseph', ' Add', 'ison']
+3267 776 Name of father of x -1 Name of father of Joseph Addison Lancelot Addison Joseph Addison "[',' ' the' ' author' ' of' ' the' ' _' 'Spect' 'ator' '_' ',' ' and'
+ ' the' ' _' 'Guard' 'ian' '_' ',' ' and' ' the' '\n']" ", the author of the _ Spect ator _ , and the _ Guard ian _ , and the
+" False memorialized in historian Joseph Addison Waddell's Annals 6 [' memorial', 'ized', ' in', ' historian', ' Joseph', ' Add', 'ison']
+3268 777 Name of father of x -1 Name of father of Elie Wiesel Shlomo Wiesel Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False Finkelstein argues that Elie Wiesel and others exploit 8 [' Fin', 'kel', 'stein', ' argues', ' that', ' El', 'ie', ' W', 'iesel']
+3269 777 Name of father of x -1 Name of father of Elie Wiesel Shlomo Wiesel Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False Finkelstein argues that Elie Wiesel and others 8 [' Fin', 'kel', 'stein', ' argues', ' that', ' El', 'ie', ' W', 'iesel']
+3270 777 Name of father of x -1 Name of father of Elie Wiesel Shlomo Wiesel Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False In October 2002, the Elie Wiesel Foundation 8 [' In', ' October', ' 2002', ',', ' the', ' El', 'ie', ' W', 'iesel']
+3271 777 Name of father of x -1 Name of father of Elie Wiesel Shlomo Wiesel Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False Nobel Prize winner Elie Wiesel calls the Holocaust 6 [' Nobel', ' Prize', ' winner', ' El', 'ie', ' W', 'iesel']
+3272 777 Name of father of x -1 Name of father of Elie Wiesel Shlomo Wiesel Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False argues that Elie Wiesel and others exploit 5 [' argues', ' that', ' El', 'ie', ' W', 'iesel']
+3273 778 Name of father of x -1 Name of father of Niki de Saint Phalle André, Comte de Saint Phalle Niki de Saint Phalle "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' daughter' ','
+ ' the' ' artist' ""'s"" ' daughter' ""'s"" ' husband' ',' ' and' ' the'
+ ' artist']" , the artist , and the artist 's daughter , the artist 's daughter 's husband , and the artist False the sculptor Niki de Saint Phalle built her multiple-piece 8 [' the', ' sculpt', 'or', ' N', 'iki', ' de', ' Saint', ' Ph', 'alle']
+3274 778 Name of father of x -1 Name of father of Niki de Saint Phalle André, Comte de Saint Phalle Niki de Saint Phalle "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' daughter' ','
+ ' the' ' artist' ""'s"" ' daughter' ""'s"" ' husband' ',' ' and' ' the'
+ ' artist']" , the artist , and the artist 's daughter , the artist 's daughter 's husband , and the artist False California, the sculptor Niki de Saint Phalle built her multiple-piece 10 [' California', ',', ' the', ' sculpt', 'or', ' N', 'iki', ' de', ' Saint', ' Ph', 'alle']
+3275 778 Name of father of x -1 Name of father of Niki de Saint Phalle André, Comte de Saint Phalle Niki de Saint Phalle "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' daughter' ','
+ ' the' ' artist' ""'s"" ' daughter' ""'s"" ' husband' ',' ' and' ' the'
+ ' artist']" , the artist , and the artist 's daughter , the artist 's daughter 's husband , and the artist False California, the sculptor Niki de Saint Phalle built her 10 [' California', ',', ' the', ' sculpt', 'or', ' N', 'iki', ' de', ' Saint', ' Ph', 'alle']
+3276 779 Name of father of x -1 Name of father of Kate Beckinsale Richard Beckinsale Kate Beckinsale "[',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' late' ' actor'
+ ' Christopher' ' Beck' 'ins' 'ale' ',' ' who' ' died' ' in' ' a' ' car'
+ ' crash']" , who is the daughter of the late actor Christopher Beck ins ale , who died in a car crash False Festival starring Kate Beckinsale and Chloe Sevigny 5 [' Festival', ' starring', ' Kate', ' Beck', 'ins', 'ale']
+3277 779 Name of father of x -1 Name of father of Kate Beckinsale Richard Beckinsale Kate Beckinsale "[',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' late' ' actor'
+ ' Christopher' ' Beck' 'ins' 'ale' ',' ' who' ' died' ' in' ' a' ' car'
+ ' crash']" , who is the daughter of the late actor Christopher Beck ins ale , who died in a car crash False Stillman and starring Kate Beckinsale and Chloe Sevigny, 7 [' Still', 'man', ' and', ' starring', ' Kate', ' Beck', 'ins', 'ale']
+3278 779 Name of father of x -1 Name of father of Kate Beckinsale Richard Beckinsale Kate Beckinsale "[',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' late' ' actor'
+ ' Christopher' ' Beck' 'ins' 'ale' ',' ' who' ' died' ' in' ' a' ' car'
+ ' crash']" , who is the daughter of the late actor Christopher Beck ins ale , who died in a car crash False appeared opposite Kate Beckinsale in a production of 5 [' appeared', ' opposite', ' Kate', ' Beck', 'ins', 'ale']
+3279 779 Name of father of x -1 Name of father of Kate Beckinsale Richard Beckinsale Kate Beckinsale "[',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' late' ' actor'
+ ' Christopher' ' Beck' 'ins' 'ale' ',' ' who' ' died' ' in' ' a' ' car'
+ ' crash']" , who is the daughter of the late actor Christopher Beck ins ale , who died in a car crash False and starring Kate Beckinsale and Chloe Sevigny, 5 [' and', ' starring', ' Kate', ' Beck', 'ins', 'ale']
+3280 779 Name of father of x -1 Name of father of Kate Beckinsale Richard Beckinsale Kate Beckinsale "[',' ' who' ' is' ' the' ' daughter' ' of' ' the' ' late' ' actor'
+ ' Christopher' ' Beck' 'ins' 'ale' ',' ' who' ' died' ' in' ' a' ' car'
+ ' crash']" , who is the daughter of the late actor Christopher Beck ins ale , who died in a car crash False appeared opposite Kate Beckinsale in a production 5 [' appeared', ' opposite', ' Kate', ' Beck', 'ins', 'ale']
+3281 784 Name of father of x -1 Name of father of Beatrix Potter Rupert William Potter Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 2 ['Beat', 'rix', ' Potter']
+3282 784 Name of father of x -1 Name of father of Beatrix Potter Rupert William Potter Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " this period, and Beatrix Potter lived there as well.
+" 6 [' this', ' period', ',', ' and', ' Beat', 'rix', ' Potter']
+3283 784 Name of father of x -1 Name of father of Beatrix Potter Rupert William Potter Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 2 ['Beat', 'rix', ' Potter']
+3284 784 Name of father of x -1 Name of father of Beatrix Potter Rupert William Potter Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False that makes Beatrix Potter so much fun for 4 [' that', ' makes', ' Beat', 'rix', ' Potter']
+3285 784 Name of father of x -1 Name of father of Beatrix Potter Rupert William Potter Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Rabbit series by Beatrix Potter to commemorate 5 [' Rabbit', ' series', ' by', ' Beat', 'rix', ' Potter']
+3286 785 Name of father of x -1 Name of father of Eric Hobsbawm Leopold Percy Hobsbaum Eric Hobsbawm "[',' ' the' ' author' ' of' ' the' ' book' ',' ' The' ' Age' ' of'
+ ' Extrem' 'es' ':' ' The' ' Short' ' Tw' 'ent' 'ieth' ' Century' ',']" , the author of the book , The Age of Extrem es : The Short Tw ent ieth Century , False " late Marxist historian Eric Hobsbawm remarked that ""One" 8 [' late', ' Marxist', ' historian', ' Eric', ' H', 'obs', 'b', 'aw', 'm']
+3287 785 Name of father of x -1 Name of father of Eric Hobsbawm Leopold Percy Hobsbaum Eric Hobsbawm "[',' ' the' ' author' ' of' ' the' ' book' ',' ' The' ' Age' ' of'
+ ' Extrem' 'es' ':' ' The' ' Short' ' Tw' 'ent' 'ieth' ' Century' ',']" , the author of the book , The Age of Extrem es : The Short Tw ent ieth Century , False " communist historian Eric Hobsbawm put it, ""Suez" 7 [' communist', ' historian', ' Eric', ' H', 'obs', 'b', 'aw', 'm']
+3288 785 Name of father of x -1 Name of father of Eric Hobsbawm Leopold Percy Hobsbaum Eric Hobsbawm "[',' ' the' ' author' ' of' ' the' ' book' ',' ' The' ' Age' ' of'
+ ' Extrem' 'es' ':' ' The' ' Short' ' Tw' 'ent' 'ieth' ' Century' ',']" , the author of the book , The Age of Extrem es : The Short Tw ent ieth Century , False " Marxist historian Eric Hobsbawm remarked that ""One" 7 [' Marxist', ' historian', ' Eric', ' H', 'obs', 'b', 'aw', 'm']
+3289 785 Name of father of x -1 Name of father of Eric Hobsbawm Leopold Percy Hobsbaum Eric Hobsbawm "[',' ' the' ' author' ' of' ' the' ' book' ',' ' The' ' Age' ' of'
+ ' Extrem' 'es' ':' ' The' ' Short' ' Tw' 'ent' 'ieth' ' Century' ',']" , the author of the book , The Age of Extrem es : The Short Tw ent ieth Century , False " historian Eric Hobsbawm remarked that ""One" 6 [' historian', ' Eric', ' H', 'obs', 'b', 'aw', 'm']
+3290 785 Name of father of x -1 Name of father of Eric Hobsbawm Leopold Percy Hobsbaum Eric Hobsbawm "[',' ' the' ' author' ' of' ' the' ' book' ',' ' The' ' Age' ' of'
+ ' Extrem' 'es' ':' ' The' ' Short' ' Tw' 'ent' 'ieth' ' Century' ',']" , the author of the book , The Age of Extrem es : The Short Tw ent ieth Century , False " Marxist historian Eric Hobsbawm remarked that ""One" 7 [' Marxist', ' historian', ' Eric', ' H', 'obs', 'b', 'aw', 'm']
+3291 786 Name of father of x -1 Name of father of Mikhail Bakunin Aleksandr Bakunin Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' member' ' of' ' the' ' Russian' ' Imperial' ' family' '.' '\n' '\n']" ", the son of a Russian noble man , and a member of the Russian Imperial family .
+
+" False starred as Mikhail Bakunin in Tom Stoppard's 4 [' starred', ' as', ' Mikhail', ' Bak', 'unin']
+3292 786 Name of father of x -1 Name of father of Mikhail Bakunin Aleksandr Bakunin Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' member' ' of' ' the' ' Russian' ' Imperial' ' family' '.' '\n' '\n']" ", the son of a Russian noble man , and a member of the Russian Imperial family .
+
+" False associated with Mikhail Bakunin and Johann Most. 4 [' associated', ' with', ' Mikhail', ' Bak', 'unin']
+3293 786 Name of father of x -1 Name of father of Mikhail Bakunin Aleksandr Bakunin Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' member' ' of' ' the' ' Russian' ' Imperial' ' family' '.' '\n' '\n']" ", the son of a Russian noble man , and a member of the Russian Imperial family .
+
+" False associated with Mikhail Bakunin and Johann Most. 4 [' associated', ' with', ' Mikhail', ' Bak', 'unin']
+3294 786 Name of father of x -1 Name of father of Mikhail Bakunin Aleksandr Bakunin Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' member' ' of' ' the' ' Russian' ' Imperial' ' family' '.' '\n' '\n']" ", the son of a Russian noble man , and a member of the Russian Imperial family .
+
+" False anarchist theorist Mikhail Bakunin from its bookstore 4 [' anarchist', ' theorist', ' Mikhail', ' Bak', 'unin']
+3295 786 Name of father of x -1 Name of father of Mikhail Bakunin Aleksandr Bakunin Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' member' ' of' ' the' ' Russian' ' Imperial' ' family' '.' '\n' '\n']" ", the son of a Russian noble man , and a member of the Russian Imperial family .
+
+" False Hawke starred as Mikhail Bakunin in Tom Stoppard's 6 [' Haw', 'ke', ' starred', ' as', ' Mikhail', ' Bak', 'unin']
+3296 787 Name of father of x -1 Name of father of Charles I of England James VI and I Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False of the deposed Charles I of England to be brought 7 [' of', ' the', ' dep', 'osed', ' Charles', ' I', ' of', ' England']
+3297 787 Name of father of x -1 Name of father of Charles I of England James VI and I Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False was referred by Charles I of England to the House of 6 [' was', ' referred', ' by', ' Charles', ' I', ' of', ' England']
+3298 787 Name of father of x -1 Name of father of Charles I of England James VI and I Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False 3 ['Charles', ' I', ' of', ' England']
+3299 787 Name of father of x -1 Name of father of Charles I of England James VI and I Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " England =
+" 5 [' England', ' =', 'Charles', ' I', ' of', ' England']
+3300 787 Name of father of x -1 Name of father of Charles I of England James VI and I Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False equestrian statue of Charles I of England and the two 8 [' equ', 'est', 'rian', ' statue', ' of', ' Charles', ' I', ' of', ' England']
+3301 788 Name of father of x -1 Name of father of Frederick II, Holy Roman Emperor Henry VI Frederick II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False Tunisians in 1231 Frederick II, Holy Roman Emperor minted the augustalis. 10 [' Tunis', 'ians', ' in', ' 12', '31', ' Frederick', ' II', ',', ' Holy', ' Roman', ' Emperor']
+3302 788 Name of father of x -1 Name of father of Frederick II, Holy Roman Emperor Henry VI Frederick II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False sworn fidelity to Frederick II, Holy Roman Emperor during the Mongol 8 [' sworn', ' fidelity', ' to', ' Frederick', ' II', ',', ' Holy', ' Roman', ' Emperor']
+3303 788 Name of father of x -1 Name of father of Frederick II, Holy Roman Emperor Henry VI Frederick II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False Tunisians in 1231 Frederick II, Holy Roman Emperor minted the augustalis. 10 [' Tunis', 'ians', ' in', ' 12', '31', ' Frederick', ' II', ',', ' Holy', ' Roman', ' Emperor']
+3304 788 Name of father of x -1 Name of father of Frederick II, Holy Roman Emperor Henry VI Frederick II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False Tunisians in 1231 Frederick II, Holy Roman Emperor minted the augustalis. 10 [' Tunis', 'ians', ' in', ' 12', '31', ' Frederick', ' II', ',', ' Holy', ' Roman', ' Emperor']
+3305 788 Name of father of x -1 Name of father of Frederick II, Holy Roman Emperor Henry VI Frederick II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False Tunisians in 1231 Frederick II, Holy Roman Emperor minted the 10 [' Tunis', 'ians', ' in', ' 12', '31', ' Frederick', ' II', ',', ' Holy', ' Roman', ' Emperor']
+3306 789 Name of father of x -1 Name of father of Julian Julius Constantius Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' is' ' a' ' man'
+ ' who' ' has' ' been' ' charged' ' with' ' rape' ',' ' sexual' ' assault'
+ ',']" Assange , the founder of Wikileaks , is a man who has been charged with rape , sexual assault , False years despite Julian's innovative work. 2 [' years', ' despite', ' Julian']
+3307 789 Name of father of x -1 Name of father of Julian Julius Constantius Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' is' ' a' ' man'
+ ' who' ' has' ' been' ' charged' ' with' ' rape' ',' ' sexual' ' assault'
+ ',']" Assange , the founder of Wikileaks , is a man who has been charged with rape , sexual assault , False Around the time of Julian's birth, the Beatles 4 [' Around', ' the', ' time', ' of', ' Julian']
+3308 789 Name of father of x -1 Name of father of Julian Julius Constantius Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' is' ' a' ' man'
+ ' who' ' has' ' been' ' charged' ' with' ' rape' ',' ' sexual' ' assault'
+ ',']" Assange , the founder of Wikileaks , is a man who has been charged with rape , sexual assault , False 1 ['Jul', 'ian']
+3309 789 Name of father of x -1 Name of father of Julian Julius Constantius Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' is' ' a' ' man'
+ ' who' ' has' ' been' ' charged' ' with' ' rape' ',' ' sexual' ' assault'
+ ',']" Assange , the founder of Wikileaks , is a man who has been charged with rape , sexual assault , False paintings, and specimens, Julian Hume has inferred 5 [' paintings', ',', ' and', ' specimens', ',', ' Julian']
+3310 789 Name of father of x -1 Name of father of Julian Julius Constantius Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' is' ' a' ' man'
+ ' who' ' has' ' been' ' charged' ' with' ' rape' ',' ' sexual' ' assault'
+ ',']" Assange , the founder of Wikileaks , is a man who has been charged with rape , sexual assault , False skeleton in 1914, but Julian P. Hume noted several 5 [' skeleton', ' in', ' 1914', ',', ' but', ' Julian']
+3311 790 Name of father of x -1 Name of father of Adolf Erik Nordenskiöld Nils Gustaf Nordenskiöld Adolf Erik Nordenskiöld "[',' ' the' ' Swedish' ' explorer' ' and' ' Arctic' ' explorer' ',' ' who'
+ ' was' ' born' ' in' ' Stockholm' ',' ' Sweden' ',' ' on' ' this' ' day'
+ ' in']" , the Swedish explorer and Arctic explorer , who was born in Stockholm , Sweden , on this day in False Siberian coast for news of Adolf Erik Nordenskiöld and his ship Vega. 11 [' Siberian', ' coast', ' for', ' news', ' of', ' Adolf', ' Erik', ' Nord', 'ens', 'ki', 'ö', 'ld']
+3312 790 Name of father of x -1 Name of father of Adolf Erik Nordenskiöld Nils Gustaf Nordenskiöld Adolf Erik Nordenskiöld "[',' ' the' ' Swedish' ' explorer' ' and' ' Arctic' ' explorer' ',' ' who'
+ ' was' ' born' ' in' ' Stockholm' ',' ' Sweden' ',' ' on' ' this' ' day'
+ ' in']" , the Swedish explorer and Arctic explorer , who was born in Stockholm , Sweden , on this day in False Swedish explorer Adolf Erik Nordenskiöld after the latter's 8 [' Swedish', ' explorer', ' Adolf', ' Erik', ' Nord', 'ens', 'ki', 'ö', 'ld']
+3313 790 Name of father of x -1 Name of father of Adolf Erik Nordenskiöld Nils Gustaf Nordenskiöld Adolf Erik Nordenskiöld "[',' ' the' ' Swedish' ' explorer' ' and' ' Arctic' ' explorer' ',' ' who'
+ ' was' ' born' ' in' ' Stockholm' ',' ' Sweden' ',' ' on' ' this' ' day'
+ ' in']" , the Swedish explorer and Arctic explorer , who was born in Stockholm , Sweden , on this day in False coast for news of Adolf Erik Nordenskiöld and his ship 10 [' coast', ' for', ' news', ' of', ' Adolf', ' Erik', ' Nord', 'ens', 'ki', 'ö', 'ld']
+3314 790 Name of father of x -1 Name of father of Adolf Erik Nordenskiöld Nils Gustaf Nordenskiöld Adolf Erik Nordenskiöld "[',' ' the' ' Swedish' ' explorer' ' and' ' Arctic' ' explorer' ',' ' who'
+ ' was' ' born' ' in' ' Stockholm' ',' ' Sweden' ',' ' on' ' this' ' day'
+ ' in']" , the Swedish explorer and Arctic explorer , who was born in Stockholm , Sweden , on this day in False for news of Adolf Erik Nordenskiöld and his ship Vega. 9 [' for', ' news', ' of', ' Adolf', ' Erik', ' Nord', 'ens', 'ki', 'ö', 'ld']
+3315 790 Name of father of x -1 Name of father of Adolf Erik Nordenskiöld Nils Gustaf Nordenskiöld Adolf Erik Nordenskiöld "[',' ' the' ' Swedish' ' explorer' ' and' ' Arctic' ' explorer' ',' ' who'
+ ' was' ' born' ' in' ' Stockholm' ',' ' Sweden' ',' ' on' ' this' ' day'
+ ' in']" , the Swedish explorer and Arctic explorer , who was born in Stockholm , Sweden , on this day in False been those of Adolf Erik Nordenskiöld in 1883, and Robert 9 [' been', ' those', ' of', ' Adolf', ' Erik', ' Nord', 'ens', 'ki', 'ö', 'ld']
+3316 791 Name of father of x -1 Name of father of Paris Hilton Richard Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' powerful' ' man' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the daughter of a wealthy , famous , and powerful man .
+
+ I 'm not sure if" False " Clearasil ads and Paris Hilton in House of Wax""." 6 [' Clear', 'as', 'il', ' ads', ' and', ' Paris', ' Hilton']
+3317 791 Name of father of x -1 Name of father of Paris Hilton Richard Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' powerful' ' man' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the daughter of a wealthy , famous , and powerful man .
+
+ I 'm not sure if" False Lewinsky, Paris Hilton and Britney Spears 4 [' Lew', 'insky', ',', ' Paris', ' Hilton']
+3318 791 Name of father of x -1 Name of father of Paris Hilton Richard Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' powerful' ' man' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the daughter of a wealthy , famous , and powerful man .
+
+ I 'm not sure if" False script and alludes to Paris Hilton (London spoofing Paris, 6 [' script', ' and', ' all', 'udes', ' to', ' Paris', ' Hilton']
+3319 791 Name of father of x -1 Name of father of Paris Hilton Richard Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' powerful' ' man' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the daughter of a wealthy , famous , and powerful man .
+
+ I 'm not sure if" False Nicole Kidman, Paris Hilton and Stefani herself. 5 [' Nicole', ' Kid', 'man', ',', ' Paris', ' Hilton']
+3320 791 Name of father of x -1 Name of father of Paris Hilton Richard Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' powerful' ' man' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", the daughter of a wealthy , famous , and powerful man .
+
+ I 'm not sure if" False Celebrities like Paris Hilton and Christina 4 [' Celeb', 'rities', ' like', ' Paris', ' Hilton']
+3321 792 Name of father of x -1 Name of father of James Clerk Maxwell John Clerk-Maxwell of Middlebie James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' James'
+ ' Clerk' ' Maxwell' ' is' ' not' ' known' '.' '\n' '\n' 'The']" ", the
+
+ The name of the father of James Clerk Maxwell is not known .
+
+ The" False and others of James Clerk Maxwell and John Witherspoon. 5 [' and', ' others', ' of', ' James', ' Clerk', ' Maxwell']
+3322 792 Name of father of x -1 Name of father of James Clerk Maxwell John Clerk-Maxwell of Middlebie James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' James'
+ ' Clerk' ' Maxwell' ' is' ' not' ' known' '.' '\n' '\n' 'The']" ", the
+
+ The name of the father of James Clerk Maxwell is not known .
+
+ The" False 2 ['James', ' Clerk', ' Maxwell']
+3323 792 Name of father of x -1 Name of father of James Clerk Maxwell John Clerk-Maxwell of Middlebie James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' James'
+ ' Clerk' ' Maxwell' ' is' ' not' ' known' '.' '\n' '\n' 'The']" ", the
+
+ The name of the father of James Clerk Maxwell is not known .
+
+ The" False Lord Kelvin), James Clerk Maxwell and Joule 5 [' Lord', ' Kelvin', '),', ' James', ' Clerk', ' Maxwell']
+3324 792 Name of father of x -1 Name of father of James Clerk Maxwell John Clerk-Maxwell of Middlebie James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' James'
+ ' Clerk' ' Maxwell' ' is' ' not' ' known' '.' '\n' '\n' 'The']" ", the
+
+ The name of the father of James Clerk Maxwell is not known .
+
+ The" False Scottish scientists James Clerk Maxwell and David Brewster 4 [' Scottish', ' scientists', ' James', ' Clerk', ' Maxwell']
+3325 792 Name of father of x -1 Name of father of James Clerk Maxwell John Clerk-Maxwell of Middlebie James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' James'
+ ' Clerk' ' Maxwell' ' is' ' not' ' known' '.' '\n' '\n' 'The']" ", the
+
+ The name of the father of James Clerk Maxwell is not known .
+
+ The" False explained until 1864 when James Clerk Maxwell unified a number 7 [' explained', ' until', ' 18', '64', ' when', ' James', ' Clerk', ' Maxwell']
+3326 793 Name of father of x -1 Name of father of Nikola Tesla Milutin Tesla Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False Müller, and scientist Nikola Tesla was one of 6 [' Mü', 'ller', ',', ' and', ' scientist', ' Nikola', ' Tesla']
+3327 793 Name of father of x -1 Name of father of Nikola Tesla Milutin Tesla Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False Charles Steinmetz or Nikola Tesla — were popularly conceived 6 [' Charles', ' Stein', 'met', 'z', ' or', ' Nikola', ' Tesla']
+3328 793 Name of father of x -1 Name of father of Nikola Tesla Milutin Tesla Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False portrayed physicist Nikola Tesla in the Christopher 3 [' portrayed', ' physicist', ' Nikola', ' Tesla']
+3329 793 Name of father of x -1 Name of father of Nikola Tesla Milutin Tesla Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False Müller, and scientist Nikola Tesla was one of those 6 [' Mü', 'ller', ',', ' and', ' scientist', ' Nikola', ' Tesla']
+3330 793 Name of father of x -1 Name of father of Nikola Tesla Milutin Tesla Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False meets scientist Nikola Tesla (David Bowie) and 3 [' meets', ' scientist', ' Nikola', ' Tesla']
+3331 794 Name of father of x -1 Name of father of Plácido Domingo Plácido Domingo Ferrer Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False the Three Tenors, Plácido Domingo and José Carreras 11 [' the', ' Three', ' Ten', 'ors', ',', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+3332 794 Name of father of x -1 Name of father of Plácido Domingo Plácido Domingo Ferrer Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False performances. Plácido Domingo first recorded Cavaradossi 8 [' performances', '.', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+3333 794 Name of father of x -1 Name of father of Plácido Domingo Plácido Domingo Ferrer Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False Symphony, with Plácido Domingo as baritone 9 [' Symphony', ',', ' with', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+3334 794 Name of father of x -1 Name of father of Plácido Domingo Plácido Domingo Ferrer Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False live performances. Plácido Domingo first recorded Cavaradossi 9 [' live', ' performances', '.', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+3335 794 Name of father of x -1 Name of father of Plácido Domingo Plácido Domingo Ferrer Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False Rodríguez and Plácido Domingo to record 12 [' Rod', 'r', 'í', 'g', 'uez', ' and', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+3336 796 Name of father of x -1 Name of father of Christopher Lee Geoffrey Trollope Lee Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Christopher' ' Lee' ','
+ ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lady' ' Lee' ',']" , the son of the late Sir Christopher Lee , the actor , and his wife , Lady Lee , False " Burgomaster
+" 4 [' Burg', 'om', 'aster', 'Christopher', ' Lee']
+3337 796 Name of father of x -1 Name of father of Christopher Lee Geoffrey Trollope Lee Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Christopher' ' Lee' ','
+ ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lady' ' Lee' ',']" , the son of the late Sir Christopher Lee , the actor , and his wife , Lady Lee , False disassembled, although Christopher Lee said that the 5 [' dis', 'assembled', ',', ' although', ' Christopher', ' Lee']
+3338 796 Name of father of x -1 Name of father of Christopher Lee Geoffrey Trollope Lee Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Christopher' ' Lee' ','
+ ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lady' ' Lee' ',']" , the son of the late Sir Christopher Lee , the actor , and his wife , Lady Lee , False Tim Curry and Christopher Lee as the voice of Death. 4 [' Tim', ' Curry', ' and', ' Christopher', ' Lee']
+3339 796 Name of father of x -1 Name of father of Christopher Lee Geoffrey Trollope Lee Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Christopher' ' Lee' ','
+ ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lady' ' Lee' ',']" , the son of the late Sir Christopher Lee , the actor , and his wife , Lady Lee , False then 78-year-old Christopher Lee relied on a 7 [' then', ' 78', '-', 'year', '-', 'old', ' Christopher', ' Lee']
+3340 796 Name of father of x -1 Name of father of Christopher Lee Geoffrey Trollope Lee Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Christopher' ' Lee' ','
+ ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lady' ' Lee' ',']" , the son of the late Sir Christopher Lee , the actor , and his wife , Lady Lee , False action of the kind that Christopher Lee or Peter Cushing 6 [' action', ' of', ' the', ' kind', ' that', ' Christopher', ' Lee']
+3341 797 Name of father of x -1 Name of father of Ingrid Bergman Justus Bergman Ingrid Bergman "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Burgess Meredith and Ingrid Bergman was seen by both 6 [' Burgess', ' Meredith', ' and', ' In', 'grid', ' Berg', 'man']
+3342 797 Name of father of x -1 Name of father of Ingrid Bergman Justus Bergman Ingrid Bergman "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Meredith and Ingrid Bergman was seen by 5 [' Meredith', ' and', ' In', 'grid', ' Berg', 'man']
+3343 797 Name of father of x -1 Name of father of Ingrid Bergman Justus Bergman Ingrid Bergman "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False an affair with Ingrid Bergman during the production 6 [' an', ' affair', ' with', ' In', 'grid', ' Berg', 'man']
+3344 797 Name of father of x -1 Name of father of Ingrid Bergman Justus Bergman Ingrid Bergman "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " ""she's got this Ingrid Bergman thing going on," 8 "[' ""', 'she', ""'s"", ' got', ' this', ' In', 'grid', ' Berg', 'man']"
+3345 797 Name of father of x -1 Name of father of Ingrid Bergman Justus Bergman Ingrid Bergman "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " ""she's got this Ingrid Bergman thing going" 8 "[' ""', 'she', ""'s"", ' got', ' this', ' In', 'grid', ' Berg', 'man']"
+3346 798 Name of father of x -1 Name of father of Archimedes Fidias Archimedes "[',' ' the' ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',' ' and' ' the'
+ ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',']" , the son of E ury st he us , and the son of E ury st he us , False this treatise, Archimedes spells out 6 [' this', ' treat', 'ise', ',', ' Arch', 'im', 'edes']
+3347 798 Name of father of x -1 Name of father of Archimedes Fidias Archimedes "[',' ' the' ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',' ' and' ' the'
+ ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',']" , the son of E ury st he us , and the son of E ury st he us , False known anecdote about Archimedes tells of how he invented 5 [' known', ' anecdote', ' about', ' Arch', 'im', 'edes']
+3348 798 Name of father of x -1 Name of father of Archimedes Fidias Archimedes "[',' ' the' ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',' ' and' ' the'
+ ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',']" , the son of E ury st he us , and the son of E ury st he us , False The tomb of Archimedes carried a sculpture 5 [' The', ' tomb', ' of', ' Arch', 'im', 'edes']
+3349 798 Name of father of x -1 Name of father of Archimedes Fidias Archimedes "[',' ' the' ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',' ' and' ' the'
+ ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',']" , the son of E ury st he us , and the son of E ury st he us , False The OSS agent Archimedes Patti, who was 6 [' The', ' O', 'SS', ' agent', ' Arch', 'im', 'edes']
+3350 798 Name of father of x -1 Name of father of Archimedes Fidias Archimedes "[',' ' the' ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',' ' and' ' the'
+ ' son' ' of' ' E' 'ury' 'st' 'he' 'us' ',']" , the son of E ury st he us , and the son of E ury st he us , False on the Moon named Archimedes (29.7°N, 4.0°W) 6 [' on', ' the', ' Moon', ' named', ' Arch', 'im', 'edes']
+3351 799 Name of father of x -1 Name of father of Romain Gary Arieh-Leïb Kacew Romain Gary "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False In the original Romain Gary novel, this was not 5 [' In', ' the', ' original', ' Rom', 'ain', ' Gary']
+3352 799 Name of father of x -1 Name of father of Romain Gary Arieh-Leïb Kacew Romain Gary "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False autobiographical novel written by Romain Gary of the same name. 7 [' autobi', 'ographical', ' novel', ' written', ' by', ' Rom', 'ain', ' Gary']
+3353 799 Name of father of x -1 Name of father of Romain Gary Arieh-Leïb Kacew Romain Gary "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False 2 ['Rom', 'ain', ' Gary']
+3354 799 Name of father of x -1 Name of father of Romain Gary Arieh-Leïb Kacew Romain Gary "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False written by Romain Gary of the same name. 4 [' written', ' by', ' Rom', 'ain', ' Gary']
+3355 799 Name of father of x -1 Name of father of Romain Gary Arieh-Leïb Kacew Romain Gary "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False novel written by Romain Gary of the same 5 [' novel', ' written', ' by', ' Rom', 'ain', ' Gary']
+3356 801 Name of father of x -1 Name of father of John Cox Jeffrey James Cox John Cox "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' and' ' the' ' groom' ""'s"" ' father' ',' ' and' ' the' ' groom']" , the father of the bride , and the groom , and the groom 's father , and the groom False Professor of English John Cox suggests that Shakespeare 4 [' Professor', ' of', ' English', ' John', ' Cox']
+3357 802 Name of father of x -1 Name of father of Hilary Swank Stephen Michael Swank Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actor , and a great person .
+
+ I am a huge fan of" False created for Hilary Swank in the 2003 film 5 [' created', ' for', ' Hil', 'ary', ' Sw', 'ank']
+3358 802 Name of father of x -1 Name of father of Hilary Swank Stephen Michael Swank Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actor , and a great person .
+
+ I am a huge fan of" False co-starred with Hilary Swank in The Core, a film 8 [' co', '-', 'star', 'red', ' with', ' Hil', 'ary', ' Sw', 'ank']
+3359 802 Name of father of x -1 Name of father of Hilary Swank Stephen Michael Swank Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actor , and a great person .
+
+ I am a huge fan of" False unknown actor Hilary Swank sent a videotape 5 [' unknown', ' actor', ' Hil', 'ary', ' Sw', 'ank']
+3360 802 Name of father of x -1 Name of father of Hilary Swank Stephen Michael Swank Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actor , and a great person .
+
+ I am a huge fan of" False May 11, 2011, Hilary Swank was reportedly in 8 [' May', ' 11', ',', ' 2011', ',', ' Hil', 'ary', ' Sw', 'ank']
+3361 802 Name of father of x -1 Name of father of Hilary Swank Stephen Michael Swank Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actor , and a great person .
+
+ I am a huge fan of" False over whether Hilary Swank is hot, but said 5 [' over', ' whether', ' Hil', 'ary', ' Sw', 'ank']
+3362 803 Name of father of x -1 Name of father of Henry IV of France Antoine of Navarre Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False French ownership. Henry IV of France ordered the demolition 6 [' French', ' ownership', '.', ' Henry', ' IV', ' of', ' France']
+3363 803 Name of father of x -1 Name of father of Henry IV of France Antoine of Navarre Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False fighting with King Henry IV of France during the French 6 [' fighting', ' with', ' King', ' Henry', ' IV', ' of', ' France']
+3364 803 Name of father of x -1 Name of father of Henry IV of France Antoine of Navarre Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False French ownership. Henry IV of France ordered the demolition 6 [' French', ' ownership', '.', ' Henry', ' IV', ' of', ' France']
+3365 803 Name of father of x -1 Name of father of Henry IV of France Antoine of Navarre Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False they approached Henry IV of France for funding, but 5 [' they', ' approached', ' Henry', ' IV', ' of', ' France']
+3366 803 Name of father of x -1 Name of father of Henry IV of France Antoine of Navarre Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False French ownership. Henry IV of France ordered the 6 [' French', ' ownership', '.', ' Henry', ' IV', ' of', ' France']
+3367 804 Name of father of x -1 Name of father of Roger Penrose Lionel Penrose Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show'
+ ',' ' and' ' I' ' am' ' a' ' big' ' fan' ' of']" ", the
+
+ I am a big fan of the show , and I am a big fan of" False of causal structure, Roger Penrose and others developed 6 [' of', ' causal', ' structure', ',', ' Roger', ' Pen', 'rose']
+3368 804 Name of father of x -1 Name of father of Roger Penrose Lionel Penrose Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show'
+ ',' ' and' ' I' ' am' ' a' ' big' ' fan' ' of']" ", the
+
+ I am a big fan of the show , and I am a big fan of" False mathematician Roger Penrose and his father 3 [' mathematician', ' Roger', ' Pen', 'rose']
+3369 804 Name of father of x -1 Name of father of Roger Penrose Lionel Penrose Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show'
+ ',' ' and' ' I' ' am' ' a' ' big' ' fan' ' of']" ", the
+
+ I am a big fan of the show , and I am a big fan of" False participants. Both Roger Penrose and H. S. M. Coxeter 5 [' participants', '.', ' Both', ' Roger', ' Pen', 'rose']
+3370 804 Name of father of x -1 Name of father of Roger Penrose Lionel Penrose Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show'
+ ',' ' and' ' I' ' am' ' a' ' big' ' fan' ' of']" ", the
+
+ I am a big fan of the show , and I am a big fan of" False participants. Both Roger Penrose and H. S. M. Coxeter 5 [' participants', '.', ' Both', ' Roger', ' Pen', 'rose']
+3371 804 Name of father of x -1 Name of father of Roger Penrose Lionel Penrose Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show'
+ ',' ' and' ' I' ' am' ' a' ' big' ' fan' ' of']" ", the
+
+ I am a big fan of the show , and I am a big fan of" False introduced by Roger Penrose in a 1974 4 [' introduced', ' by', ' Roger', ' Pen', 'rose']
+3372 805 Name of father of x -1 Name of father of Nina Hagen Hans Oliva-Hagen Nina Hagen "[',' ' the' ' daughter' ' of' ' a' ' German' ' immigrant' ',' ' was'
+ ' born' ' in' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' family'
+ ' moved']" ", the daughter of a German immigrant , was born in the United States .
+
+ The family moved" False Runaways or singers like Nina Hagen and Carole Pope. 7 [' Run', 'aways', ' or', ' singers', ' like', ' Nina', ' H', 'agen']
+3373 805 Name of father of x -1 Name of father of Nina Hagen Hans Oliva-Hagen Nina Hagen "[',' ' the' ' daughter' ' of' ' a' ' German' ' immigrant' ',' ' was'
+ ' born' ' in' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' family'
+ ' moved']" ", the daughter of a German immigrant , was born in the United States .
+
+ The family moved" False singers like Nina Hagen and Carole Pope. 4 [' singers', ' like', ' Nina', ' H', 'agen']
+3374 805 Name of father of x -1 Name of father of Nina Hagen Hans Oliva-Hagen Nina Hagen "[',' ' the' ' daughter' ' of' ' a' ' German' ' immigrant' ',' ' was'
+ ' born' ' in' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' family'
+ ' moved']" ", the daughter of a German immigrant , was born in the United States .
+
+ The family moved" False such as the Nina Hagen Band and S.Y.P.H. 5 [' such', ' as', ' the', ' Nina', ' H', 'agen']
+3375 805 Name of father of x -1 Name of father of Nina Hagen Hans Oliva-Hagen Nina Hagen "[',' ' the' ' daughter' ' of' ' a' ' German' ' immigrant' ',' ' was'
+ ' born' ' in' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' family'
+ ' moved']" ", the daughter of a German immigrant , was born in the United States .
+
+ The family moved" False singers like Nina Hagen and Carole Pope. 4 [' singers', ' like', ' Nina', ' H', 'agen']
+3376 805 Name of father of x -1 Name of father of Nina Hagen Hans Oliva-Hagen Nina Hagen "[',' ' the' ' daughter' ' of' ' a' ' German' ' immigrant' ',' ' was'
+ ' born' ' in' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' family'
+ ' moved']" ", the daughter of a German immigrant , was born in the United States .
+
+ The family moved" False such as the Nina Hagen Band and S.Y.P.H. 5 [' such', ' as', ' the', ' Nina', ' H', 'agen']
+3377 807 Name of father of x -1 Name of father of Gore Vidal Eugene Luther Vidal Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False stepbrother of author Gore Vidal after Olds'father 6 [' step', 'brother', ' of', ' author', ' Gore', ' V', 'idal']
+3378 807 Name of father of x -1 Name of father of Gore Vidal Eugene Luther Vidal Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False stepbrother of author Gore Vidal after Olds'father 6 [' step', 'brother', ' of', ' author', ' Gore', ' V', 'idal']
+3379 807 Name of father of x -1 Name of father of Gore Vidal Eugene Luther Vidal Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Bisexual author Gore Vidal (1925-2012) is a 5 [' B', 'isexual', ' author', ' Gore', ' V', 'idal']
+3380 807 Name of father of x -1 Name of father of Gore Vidal Eugene Luther Vidal Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Joyce, and Gore Vidal lying about, 5 [' Joyce', ',', ' and', ' Gore', ' V', 'idal']
+3381 807 Name of father of x -1 Name of father of Gore Vidal Eugene Luther Vidal Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False of Wills's book. Gore Vidal also draws attention 8 "[' of', ' W', 'ills', ""'s"", ' book', '.', ' Gore', ' V', 'idal']"
+3382 808 Name of father of x -1 Name of father of Alexander VI Jofré Llançol i Escrivà Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the' ' first'
+ ' of' ' the' ' two' ',' '\n' '\n' 'The' ' first' ' of']" ".
+
+ The first of the two , the first of the two ,
+
+ The first of" False Italian War, Pope Alexander VI had, with French 5 [' Italian', ' War', ',', ' Pope', ' Alexander', ' VI']
+3383 808 Name of father of x -1 Name of father of Alexander VI Jofré Llançol i Escrivà Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the' ' first'
+ ' of' ' the' ' two' ',' '\n' '\n' 'The' ' first' ' of']" ".
+
+ The first of the two , the first of the two ,
+
+ The first of" False Nicholas V and Alexander VI respectively. 4 [' Nicholas', ' V', ' and', ' Alexander', ' VI']
+3384 808 Name of father of x -1 Name of father of Alexander VI Jofré Llançol i Escrivà Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the' ' first'
+ ' of' ' the' ' two' ',' '\n' '\n' 'The' ' first' ' of']" ".
+
+ The first of the two , the first of the two ,
+
+ The first of" False Nicholas V and Alexander VI respectively. St Leonard's 4 [' Nicholas', ' V', ' and', ' Alexander', ' VI']
+3385 808 Name of father of x -1 Name of father of Alexander VI Jofré Llançol i Escrivà Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the' ' first'
+ ' of' ' the' ' two' ',' '\n' '\n' 'The' ' first' ' of']" ".
+
+ The first of the two , the first of the two ,
+
+ The first of" False the Pope, as Pope Alexander VI had granted the Indies 6 [' the', ' Pope', ',', ' as', ' Pope', ' Alexander', ' VI']
+3386 808 Name of father of x -1 Name of father of Alexander VI Jofré Llançol i Escrivà Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the' ' first'
+ ' of' ' the' ' two' ',' '\n' '\n' 'The' ' first' ' of']" ".
+
+ The first of the two , the first of the two ,
+
+ The first of" False intervened between Alexander VI and Julius II, 3 [' intervened', ' between', ' Alexander', ' VI']
+3387 809 Name of father of x -1 Name of father of Robert Frost William Prescott Frost Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' son' ',' ' and'
+ ' the' ' poet' ""'s"" ' son' ""'s"" ' son' ',' ' and' ' the']" , the poet , and the poet 's son , and the poet 's son 's son , and the False medallions depicted Robert Frost and Alexander 5 [' med', 'all', 'ions', ' depicted', ' Robert', ' Frost']
+3388 809 Name of father of x -1 Name of father of Robert Frost William Prescott Frost Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' son' ',' ' and'
+ ' the' ' poet' ""'s"" ' son' ""'s"" ' son' ',' ' and' ' the']" , the poet , and the poet 's son , and the poet 's son 's son , and the False recitation since Robert Frost at President John 4 [' rec', 'itation', ' since', ' Robert', ' Frost']
+3389 809 Name of father of x -1 Name of father of Robert Frost William Prescott Frost Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' son' ',' ' and'
+ ' the' ' poet' ""'s"" ' son' ""'s"" ' son' ',' ' and' ' the']" , the poet , and the poet 's son , and the poet 's son 's son , and the False inaugural poem since Robert Frost at the 1961 inauguration 4 [' inaugural', ' poem', ' since', ' Robert', ' Frost']
+3390 809 Name of father of x -1 Name of father of Robert Frost William Prescott Frost Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' son' ',' ' and'
+ ' the' ' poet' ""'s"" ' son' ""'s"" ' son' ',' ' and' ' the']" , the poet , and the poet 's son , and the poet 's son 's son , and the False recitation since Robert Frost at President John 4 [' rec', 'itation', ' since', ' Robert', ' Frost']
+3391 809 Name of father of x -1 Name of father of Robert Frost William Prescott Frost Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' son' ',' ' and'
+ ' the' ' poet' ""'s"" ' son' ""'s"" ' son' ',' ' and' ' the']" , the poet , and the poet 's son , and the poet 's son 's son , and the False is based on the Robert Frost poem of the same 5 [' is', ' based', ' on', ' the', ' Robert', ' Frost']
+3392 811 Name of father of x -1 Name of father of Carl Philipp Emanuel Bach Johann Sebastian Bach Carl Philipp Emanuel Bach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'Brand' 'enburg'
+ ' Concert' 'os' '""' ' and' ' the' ' ""' 'Gold' 'berg' ' Vari' 'ations' '""']" ", the composer of the famous "" Brand enburg Concert os "" and the "" Gold berg Vari ations """ False 3 ['Carl', ' Philipp', ' Emanuel', ' Bach']
+3393 811 Name of father of x -1 Name of father of Carl Philipp Emanuel Bach Johann Sebastian Bach Carl Philipp Emanuel Bach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'Brand' 'enburg'
+ ' Concert' 'os' '""' ' and' ' the' ' ""' 'Gold' 'berg' ' Vari' 'ations' '""']" ", the composer of the famous "" Brand enburg Concert os "" and the "" Gold berg Vari ations """ False 3 ['Carl', ' Philipp', ' Emanuel', ' Bach']
+3394 812 Name of father of x -1 Name of father of Lise Meitner Philipp Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False of the discovery by Lise Meitner and Otto Frisch 8 [' of', ' the', ' discovery', ' by', ' L', 'ise', ' Me', 'it', 'ner']
+3395 812 Name of father of x -1 Name of father of Lise Meitner Philipp Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False (and naming) by Lise Meitner and Otto Frisch soon 9 [' (', 'and', ' naming', ')', ' by', ' L', 'ise', ' Me', 'it', 'ner']
+3396 812 Name of father of x -1 Name of father of Lise Meitner Philipp Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False fission in uranium, by Lise Meitner and Otto Hahn, 10 [' f', 'ission', ' in', ' uranium', ',', ' by', ' L', 'ise', ' Me', 'it', 'ner']
+3397 812 Name of father of x -1 Name of father of Lise Meitner Philipp Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False explanation by Lise Meitner and Otto Frisch, was 6 [' explanation', ' by', ' L', 'ise', ' Me', 'it', 'ner']
+3398 812 Name of father of x -1 Name of father of Lise Meitner Philipp Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False nuclear fission by Lise Meitner in the February 8 [' nuclear', ' f', 'ission', ' by', ' L', 'ise', ' Me', 'it', 'ner']
+3399 814 Name of father of x -1 Name of father of Paul VI Giorgio Montini Paul VI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ' of' ' the' ' Roman'
+ ' Catholic' ' Church' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope of the Roman Catholic Church .
+
+ The Pope is the head" False promulgated by Paul VI in 1969, but 4 [' promulg', 'ated', ' by', ' Paul', ' VI']
+3400 814 Name of father of x -1 Name of father of Paul VI Giorgio Montini Paul VI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ' of' ' the' ' Roman'
+ ' Catholic' ' Church' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope of the Roman Catholic Church .
+
+ The Pope is the head" False during the race to Pope Paul VI while visiting Vatican 6 [' during', ' the', ' race', ' to', ' Pope', ' Paul', ' VI']
+3401 814 Name of father of x -1 Name of father of Paul VI Giorgio Montini Paul VI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ' of' ' the' ' Roman'
+ ' Catholic' ' Church' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope of the Roman Catholic Church .
+
+ The Pope is the head" False that Pope Paul VI was the last 3 [' that', ' Pope', ' Paul', ' VI']
+3402 814 Name of father of x -1 Name of father of Paul VI Giorgio Montini Paul VI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ' of' ' the' ' Roman'
+ ' Catholic' ' Church' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope of the Roman Catholic Church .
+
+ The Pope is the head" False the new method. Paul VI later expanded the 5 [' the', ' new', ' method', '.', ' Paul', ' VI']
+3403 814 Name of father of x -1 Name of father of Paul VI Giorgio Montini Paul VI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ' of' ' the' ' Roman'
+ ' Catholic' ' Church' '.' '\n' '\n' 'The' ' Pope' ' is' ' the' ' head']" ", the Pope , and the Pope of the Roman Catholic Church .
+
+ The Pope is the head" False in 1964, Pope Paul VI authorized Jesuit 5 [' in', ' 1964', ',', ' Pope', ' Paul', ' VI']
+3404 815 Name of father of x -1 Name of father of Marco Polo Niccolò Polo Marco Polo "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' Marco' ' Polo'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", and the
+
+ Name of father of Marco Polo , and the
+
+ Name of father of" False Italian traveller Marco Polo recorded that there 3 [' Italian', ' traveller', ' Marco', ' Polo']
+3405 815 Name of father of x -1 Name of father of Marco Polo Niccolò Polo Marco Polo "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' Marco' ' Polo'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", and the
+
+ Name of father of Marco Polo , and the
+
+ Name of father of" False merchants like Marco Polo made their way 3 [' merchants', ' like', ' Marco', ' Polo']
+3406 815 Name of father of x -1 Name of father of Marco Polo Niccolò Polo Marco Polo "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' Marco' ' Polo'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", and the
+
+ Name of father of Marco Polo , and the
+
+ Name of father of" False " Club ===
+" 3 [' Club', ' ===', 'Marco', ' Polo']
+3407 815 Name of father of x -1 Name of father of Marco Polo Niccolò Polo Marco Polo "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' Marco' ' Polo'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", and the
+
+ Name of father of Marco Polo , and the
+
+ Name of father of" False Indian princes, and Marco Polo reported that 5 [' Indian', ' princes', ',', ' and', ' Marco', ' Polo']
+3408 815 Name of father of x -1 Name of father of Marco Polo Niccolò Polo Marco Polo "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' Marco' ' Polo'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", and the
+
+ Name of father of Marco Polo , and the
+
+ Name of father of" False 1 ['Marco', ' Polo']
+3409 816 Name of father of x -1 Name of father of Mother Teresa Nikollë Bojaxhiu Mother Teresa "[',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the' ' saint' ' of' ' the'
+ ' poor' ',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the']" , the saint of the poor , the saint of the poor , the saint of the poor , the False " Airport and Rinas Mother Teresa Airport (Tirana).
+" 5 [' Airport', ' and', ' R', 'inas', ' Mother', ' Teresa']
+3410 816 Name of father of x -1 Name of father of Mother Teresa Nikollë Bojaxhiu Mother Teresa "[',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the' ' saint' ' of' ' the'
+ ' poor' ',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the']" , the saint of the poor , the saint of the poor , the saint of the poor , the False Indian citizen, Mother Teresa in 1980, and 4 [' Indian', ' citizen', ',', ' Mother', ' Teresa']
+3411 816 Name of father of x -1 Name of father of Mother Teresa Nikollë Bojaxhiu Mother Teresa "[',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the' ' saint' ' of' ' the'
+ ' poor' ',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the']" , the saint of the poor , the saint of the poor , the saint of the poor , the False The award for Mother Teresa was announced in 4 [' The', ' award', ' for', ' Mother', ' Teresa']
+3412 816 Name of father of x -1 Name of father of Mother Teresa Nikollë Bojaxhiu Mother Teresa "[',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the' ' saint' ' of' ' the'
+ ' poor' ',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the']" , the saint of the poor , the saint of the poor , the saint of the poor , the False wealth status. Mother Teresa was awarded the Nobel 4 [' wealth', ' status', '.', ' Mother', ' Teresa']
+3413 816 Name of father of x -1 Name of father of Mother Teresa Nikollë Bojaxhiu Mother Teresa "[',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the' ' saint' ' of' ' the'
+ ' poor' ',' ' the' ' saint' ' of' ' the' ' poor' ',' ' the']" , the saint of the poor , the saint of the poor , the saint of the poor , the False " other two are Mother Teresa and Nelson Mandela.
+" 4 [' other', ' two', ' are', ' Mother', ' Teresa']
+3414 818 Name of father of x -1 Name of father of Wilhelm von Humboldt Alexander Georg von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False Alexander and Wilhelm von Humboldt famously championed 7 [' Alexander', ' and', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+3415 818 Name of father of x -1 Name of father of Wilhelm von Humboldt Alexander Georg von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False linguist and diplomat Wilhelm von Humboldt were close 9 [' lingu', 'ist', ' and', ' diplomat', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+3416 818 Name of father of x -1 Name of father of Wilhelm von Humboldt Alexander Georg von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False Alexander and Wilhelm von Humboldt famously championed 7 [' Alexander', ' and', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+3417 818 Name of father of x -1 Name of father of Wilhelm von Humboldt Alexander Georg von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False brothers Alexander and Wilhelm von Humboldt famously championed 8 [' brothers', ' Alexander', ' and', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+3418 818 Name of father of x -1 Name of father of Wilhelm von Humboldt Alexander Georg von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False Alexander and Wilhelm von Humboldt famously championed 7 [' Alexander', ' and', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+3419 819 Name of father of x -1 Name of father of Kamala Harris Donald J. Harris Kamala Harris "[',' ' the' ' daughter' ' of' ' a' ' former' ' president' ' of' ' the'
+ ' United' ' States' ',' ' and' ' a' ' former' ' attorney' ' general'
+ ' of' ' California' ',']" , the daughter of a former president of the United States , and a former attorney general of California , False District Attorney Kamala Harris issued a warrant 4 [' District', ' Attorney', ' Kam', 'ala', ' Harris']
+3420 820 Name of father of x -1 Name of father of Wilhelm Ostwald Gottfried Ostwald Wilhelm Ostwald "[',' ' the' ' German' ' chemist' ' who' ' discovered' ' the' ' law' ' of'
+ ' the' ' conservation' ' of' ' mass' '.' '\n' '\n' 'The' ' first' ' of'
+ ' the']" ", the German chemist who discovered the law of the conservation of mass .
+
+ The first of the" False chemistry) by Wilhelm Ostwald in 1892 and into French 5 [' chemistry', ')', ' by', ' Wilhelm', ' Ost', 'wald']
+3421 820 Name of father of x -1 Name of father of Wilhelm Ostwald Gottfried Ostwald Wilhelm Ostwald "[',' ' the' ' German' ' chemist' ' who' ' discovered' ' the' ' law' ' of'
+ ' the' ' conservation' ' of' ' mass' '.' '\n' '\n' 'The' ' first' ' of'
+ ' the']" ", the German chemist who discovered the law of the conservation of mass .
+
+ The first of the" False for chemistry) by Wilhelm Ostwald in 1892 and 6 [' for', ' chemistry', ')', ' by', ' Wilhelm', ' Ost', 'wald']
+3422 820 Name of father of x -1 Name of father of Wilhelm Ostwald Gottfried Ostwald Wilhelm Ostwald "[',' ' the' ' German' ' chemist' ' who' ' discovered' ' the' ' law' ' of'
+ ' the' ' conservation' ' of' ' mass' '.' '\n' '\n' 'The' ' first' ' of'
+ ' the']" ", the German chemist who discovered the law of the conservation of mass .
+
+ The first of the" False chemistry) by Wilhelm Ostwald in 1892 and into 5 [' chemistry', ')', ' by', ' Wilhelm', ' Ost', 'wald']
+3423 821 Name of father of x -1 Name of father of Serge Gainsbourg Joseph Ginsburg Serge Gainsbourg "[',' ' the' ' French' ' singer' '-' 'song' 'writer' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Jane' ' Bir' 'kin' '.']" , the French singer - song writer , and his wife , the actress and singer Jane Bir kin . False from 1958, while Serge Gainsbourg used the theme 7 [' from', ' 1958', ',', ' while', ' Serge', ' G', 'ains', 'bourg']
+3424 821 Name of father of x -1 Name of father of Serge Gainsbourg Joseph Ginsburg Serge Gainsbourg "[',' ' the' ' French' ' singer' '-' 'song' 'writer' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Jane' ' Bir' 'kin' '.']" , the French singer - song writer , and his wife , the actress and singer Jane Bir kin . False from 1958, while Serge Gainsbourg used the theme 7 [' from', ' 1958', ',', ' while', ' Serge', ' G', 'ains', 'bourg']
+3425 821 Name of father of x -1 Name of father of Serge Gainsbourg Joseph Ginsburg Serge Gainsbourg "[',' ' the' ' French' ' singer' '-' 'song' 'writer' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Jane' ' Bir' 'kin' '.']" , the French singer - song writer , and his wife , the actress and singer Jane Bir kin . False duet between Serge Gainsbourg and Jane Birkin. 6 [' du', 'et', ' between', ' Serge', ' G', 'ains', 'bourg']
+3426 821 Name of father of x -1 Name of father of Serge Gainsbourg Joseph Ginsburg Serge Gainsbourg "[',' ' the' ' French' ' singer' '-' 'song' 'writer' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Jane' ' Bir' 'kin' '.']" , the French singer - song writer , and his wife , the actress and singer Jane Bir kin . False tune from 1958, while Serge Gainsbourg used the theme 8 [' tune', ' from', ' 1958', ',', ' while', ' Serge', ' G', 'ains', 'bourg']
+3427 821 Name of father of x -1 Name of father of Serge Gainsbourg Joseph Ginsburg Serge Gainsbourg "[',' ' the' ' French' ' singer' '-' 'song' 'writer' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Jane' ' Bir' 'kin' '.']" , the French singer - song writer , and his wife , the actress and singer Jane Bir kin . False from 1958, while Serge Gainsbourg used the theme 7 [' from', ' 1958', ',', ' while', ' Serge', ' G', 'ains', 'bourg']
+3428 822 Name of father of x -1 Name of father of Heinrich von Kleist Joachim Friedrich von Kleist Heinrich von Kleist "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '11' '.' '\n' '\n']" ", the German poet , who was born in 17 77 , and died in 18 11 .
+
+" False " Grillparzer, and Heinrich von Kleist to be his ""true blood" 9 [' Grill', 'par', 'zer', ',', ' and', ' Hein', 'rich', ' von', ' Kle', 'ist']
+3429 822 Name of father of x -1 Name of father of Heinrich von Kleist Joachim Friedrich von Kleist Heinrich von Kleist "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '11' '.' '\n' '\n']" ", the German poet , who was born in 17 77 , and died in 18 11 .
+
+" False " Franz Grillparzer, and Heinrich von Kleist to be his ""true blood" 10 [' Franz', ' Grill', 'par', 'zer', ',', ' and', ' Hein', 'rich', ' von', ' Kle', 'ist']
+3430 822 Name of father of x -1 Name of father of Heinrich von Kleist Joachim Friedrich von Kleist Heinrich von Kleist "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '11' '.' '\n' '\n']" ", the German poet , who was born in 17 77 , and died in 18 11 .
+
+" False " Grillparzer, and Heinrich von Kleist to be his ""true" 9 [' Grill', 'par', 'zer', ',', ' and', ' Hein', 'rich', ' von', ' Kle', 'ist']
+3431 822 Name of father of x -1 Name of father of Heinrich von Kleist Joachim Friedrich von Kleist Heinrich von Kleist "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '11' '.' '\n' '\n']" ", the German poet , who was born in 17 77 , and died in 18 11 .
+
+" False " Grillparzer, and Heinrich von Kleist to be his ""true" 9 [' Grill', 'par', 'zer', ',', ' and', ' Hein', 'rich', ' von', ' Kle', 'ist']
+3432 822 Name of father of x -1 Name of father of Heinrich von Kleist Joachim Friedrich von Kleist Heinrich von Kleist "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '11' '.' '\n' '\n']" ", the German poet , who was born in 17 77 , and died in 18 11 .
+
+" False Grillparzer, and Heinrich von Kleist to be his 9 [' Grill', 'par', 'zer', ',', ' and', ' Hein', 'rich', ' von', ' Kle', 'ist']
+3433 823 Name of father of x -1 Name of father of Leoš Janáček Jiří Janáček Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Marie' ' V' 'ond' 'r' 'á' 'č' 'kov' 'á' '.']" , the Czech composer , and his wife , the actress Marie V ond r á č kov á . False " Janáček =
+" 11 [' Jan', 'á', 'č', 'ek', ' =', 'Le', 'o', 'š', ' Jan', 'á', 'č', 'ek']
+3434 823 Name of father of x -1 Name of father of Leoš Janáček Jiří Janáček Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Marie' ' V' 'ond' 'r' 'á' 'č' 'kov' 'á' '.']" , the Czech composer , and his wife , the actress Marie V ond r á č kov á . False 6 ['Le', 'o', 'š', ' Jan', 'á', 'č', 'ek']
+3435 823 Name of father of x -1 Name of father of Leoš Janáček Jiří Janáček Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Marie' ' V' 'ond' 'r' 'á' 'č' 'kov' 'á' '.']" , the Czech composer , and his wife , the actress Marie V ond r á č kov á . False " Leoš Janáček =
+" 5 [' Leo', 'š', ' Jan', 'á', 'č', 'ek']
+3436 823 Name of father of x -1 Name of father of Leoš Janáček Jiří Janáček Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Marie' ' V' 'ond' 'r' 'á' 'č' 'kov' 'á' '.']" , the Czech composer , and his wife , the actress Marie V ond r á č kov á . False The Operas of Leoš Janáček shortly before 9 [' The', ' Oper', 'as', ' of', ' Leo', 'š', ' Jan', 'á', 'č', 'ek']
+3437 823 Name of father of x -1 Name of father of Leoš Janáček Jiří Janáček Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' Marie' ' V' 'ond' 'r' 'á' 'č' 'kov' 'á' '.']" , the Czech composer , and his wife , the actress Marie V ond r á č kov á . False 6 ['Le', 'o', 'š', ' Jan', 'á', 'č', 'ek']
+3438 824 Name of father of x -1 Name of father of George Soros Tivadar Soros George Soros "[',' ' the' ' billionaire' ' fin' 'anc' 'ier' ' and' ' philanthrop' 'ist'
+ ',' ' who' ' has' ' been' ' a' ' major' ' donor' ' to' ' the' ' Clinton'
+ ' Foundation']" , the billionaire fin anc ier and philanthrop ist , who has been a major donor to the Clinton Foundation False who represented George Soros as one of their 3 [' who', ' represented', ' George', ' Soros']
+3439 824 Name of father of x -1 Name of father of George Soros Tivadar Soros George Soros "[',' ' the' ' billionaire' ' fin' 'anc' 'ier' ' and' ' philanthrop' 'ist'
+ ',' ' who' ' has' ' been' ' a' ' major' ' donor' ' to' ' the' ' Clinton'
+ ' Foundation']" , the billionaire fin anc ier and philanthrop ist , who has been a major donor to the Clinton Foundation False represented George Soros as one of their sources 2 [' represented', ' George', ' Soros']
+3440 824 Name of father of x -1 Name of father of George Soros Tivadar Soros George Soros "[',' ' the' ' billionaire' ' fin' 'anc' 'ier' ' and' ' philanthrop' 'ist'
+ ',' ' who' ' has' ' been' ' a' ' major' ' donor' ' to' ' the' ' Clinton'
+ ' Foundation']" , the billionaire fin anc ier and philanthrop ist , who has been a major donor to the Clinton Foundation False investors who represented George Soros as one of their 4 [' investors', ' who', ' represented', ' George', ' Soros']
+3441 824 Name of father of x -1 Name of father of George Soros Tivadar Soros George Soros "[',' ' the' ' billionaire' ' fin' 'anc' 'ier' ' and' ' philanthrop' 'ist'
+ ',' ' who' ' has' ' been' ' a' ' major' ' donor' ' to' ' the' ' Clinton'
+ ' Foundation']" , the billionaire fin anc ier and philanthrop ist , who has been a major donor to the Clinton Foundation False who represented George Soros as one of their 3 [' who', ' represented', ' George', ' Soros']
+3442 824 Name of father of x -1 Name of father of George Soros Tivadar Soros George Soros "[',' ' the' ' billionaire' ' fin' 'anc' 'ier' ' and' ' philanthrop' 'ist'
+ ',' ' who' ' has' ' been' ' a' ' major' ' donor' ' to' ' the' ' Clinton'
+ ' Foundation']" , the billionaire fin anc ier and philanthrop ist , who has been a major donor to the Clinton Foundation False who represented George Soros as one of their sources 3 [' who', ' represented', ' George', ' Soros']
+3443 825 Name of father of x -1 Name of father of Friedrich Hayek August von Hayek Friedrich Hayek "[',' ' the' ' great' ' economist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Austrian' ' School' ' of' ' Economics' ',' ' and' ' the'
+ ' father' ' of']" , the great economist , who was a member of the Austrian School of Economics , and the father of False libertarian movement. Friedrich Hayek saw FEE as part 5 [' libertarian', ' movement', '.', ' Friedrich', ' Hay', 'ek']
+3444 825 Name of father of x -1 Name of father of Friedrich Hayek August von Hayek Friedrich Hayek "[',' ' the' ' great' ' economist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Austrian' ' School' ' of' ' Economics' ',' ' and' ' the'
+ ' father' ' of']" , the great economist , who was a member of the Austrian School of Economics , and the father of False libertarian movement. Friedrich Hayek saw FEE as part of 5 [' libertarian', ' movement', '.', ' Friedrich', ' Hay', 'ek']
+3445 825 Name of father of x -1 Name of father of Friedrich Hayek August von Hayek Friedrich Hayek "[',' ' the' ' great' ' economist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Austrian' ' School' ' of' ' Economics' ',' ' and' ' the'
+ ' father' ' of']" , the great economist , who was a member of the Austrian School of Economics , and the father of False 4 ['F', 'ried', 'rich', ' Hay', 'ek']
+3446 825 Name of father of x -1 Name of father of Friedrich Hayek August von Hayek Friedrich Hayek "[',' ' the' ' great' ' economist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Austrian' ' School' ' of' ' Economics' ',' ' and' ' the'
+ ' father' ' of']" , the great economist , who was a member of the Austrian School of Economics , and the father of False 4 ['F', 'ried', 'rich', ' Hay', 'ek']
+3447 825 Name of father of x -1 Name of father of Friedrich Hayek August von Hayek Friedrich Hayek "[',' ' the' ' great' ' economist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Austrian' ' School' ' of' ' Economics' ',' ' and' ' the'
+ ' father' ' of']" , the great economist , who was a member of the Austrian School of Economics , and the father of False libertarian movement. Friedrich Hayek saw FEE as part 5 [' libertarian', ' movement', '.', ' Friedrich', ' Hay', 'ek']
+3448 826 Name of father of x -1 Name of father of Louis XV of France Louis, Dauphin of France, Duke of Burgundy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of Louis XVI of France , and the" False Polish Succession. Louis XV of France demanded that 7 [' Polish', ' Success', 'ion', '.', ' Louis', ' XV', ' of', ' France']
+3449 826 Name of father of x -1 Name of father of Louis XV of France Louis, Dauphin of France, Duke of Burgundy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of Louis XVI of France , and the" False Polish Succession. Louis XV of France demanded that Maria 7 [' Polish', ' Success', 'ion', '.', ' Louis', ' XV', ' of', ' France']
+3450 826 Name of father of x -1 Name of father of Louis XV of France Louis, Dauphin of France, Duke of Burgundy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of Louis XVI of France , and the" False rarity made King Louis XV of France declare it the only 6 [' rarity', ' made', ' King', ' Louis', ' XV', ' of', ' France']
+3451 826 Name of father of x -1 Name of father of Louis XV of France Louis, Dauphin of France, Duke of Burgundy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of Louis XVI of France , and the" False had it and survived). Louis XV of France succeeded his great-grandfather 8 [' had', ' it', ' and', ' survived', ').', ' Louis', ' XV', ' of', ' France']
+3452 826 Name of father of x -1 Name of father of Louis XV of France Louis, Dauphin of France, Duke of Burgundy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of Louis XVI of France , and the" False rarity made King Louis XV of France declare it 6 [' rarity', ' made', ' King', ' Louis', ' XV', ' of', ' France']
+3453 827 Name of father of x -1 Name of father of Louis de Broglie Victor de Broglie Louis de Broglie "[',' ' the' ' French' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' French' ' Academy' ' of' ' Sciences' ',' ' and' ' a' ' member'
+ ' of']" , the French physicist , who was a member of the French Academy of Sciences , and a member of False 4 ['Louis', ' de', ' Bro', 'gl', 'ie']
+3454 827 Name of father of x -1 Name of father of Louis de Broglie Victor de Broglie Louis de Broglie "[',' ' the' ' French' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' French' ' Academy' ' of' ' Sciences' ',' ' and' ' a' ' member'
+ ' of']" , the French physicist , who was a member of the French Academy of Sciences , and a member of False 4 ['Louis', ' de', ' Bro', 'gl', 'ie']
+3455 827 Name of father of x -1 Name of father of Louis de Broglie Victor de Broglie Louis de Broglie "[',' ' the' ' French' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' French' ' Academy' ' of' ' Sciences' ',' ' and' ' a' ' member'
+ ' of']" , the French physicist , who was a member of the French Academy of Sciences , and a member of False French physicist Louis de Broglie hypothesized that 6 [' French', ' physicist', ' Louis', ' de', ' Bro', 'gl', 'ie']
+3456 828 Name of father of x -1 Name of father of Mickey Rooney Joe Yule Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the actor , and his wife , L illian , who was a singer .
+
+ The couple" False " as Flute and Mickey Rooney as Puck
+" 5 [' as', ' Fl', 'ute', ' and', ' Mickey', ' Rooney']
+3457 828 Name of father of x -1 Name of father of Mickey Rooney Joe Yule Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the actor , and his wife , L illian , who was a singer .
+
+ The couple" False that starred Mickey Rooney as Cap 'n Andy, 3 [' that', ' starred', ' Mickey', ' Rooney']
+3458 828 Name of father of x -1 Name of father of Mickey Rooney Joe Yule Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the actor , and his wife , L illian , who was a singer .
+
+ The couple" False " Stanton, and actors Mickey Rooney and Lash La Rue.
+" 5 [' Stanton', ',', ' and', ' actors', ' Mickey', ' Rooney']
+3459 828 Name of father of x -1 Name of father of Mickey Rooney Joe Yule Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the actor , and his wife , L illian , who was a singer .
+
+ The couple" False 2 ['Mic', 'key', ' Rooney']
+3460 828 Name of father of x -1 Name of father of Mickey Rooney Joe Yule Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the actor , and his wife , L illian , who was a singer .
+
+ The couple" False " custom-role for Mickey Rooney, ""which I did so" 5 [' custom', '-', 'role', ' for', ' Mickey', ' Rooney']
+3461 829 Name of father of x -1 Name of father of Anna Akhmatova Andrey Gorenko Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' her' ' husband' ',' ' the' ' poet'
+ ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' the' ' poet']" , the poet ess , and her husband , the poet , and the poet 's wife , the poet False explored at the Anna Akhmatova Literary and Memorial 7 [' explored', ' at', ' the', ' Anna', ' Ak', 'h', 'mat', 'ova']
+3462 829 Name of father of x -1 Name of father of Anna Akhmatova Andrey Gorenko Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' her' ' husband' ',' ' the' ' poet'
+ ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' the' ' poet']" , the poet ess , and her husband , the poet , and the poet 's wife , the poet False to greet us. Anna Akhmatova was immensely dignified, 8 [' to', ' greet', ' us', '.', ' Anna', ' Ak', 'h', 'mat', 'ova']
+3463 829 Name of father of x -1 Name of father of Anna Akhmatova Andrey Gorenko Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' her' ' husband' ',' ' the' ' poet'
+ ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' the' ' poet']" , the poet ess , and her husband , the poet , and the poet 's wife , the poet False explored at the Anna Akhmatova Literary and Memorial 7 [' explored', ' at', ' the', ' Anna', ' Ak', 'h', 'mat', 'ova']
+3464 829 Name of father of x -1 Name of father of Anna Akhmatova Andrey Gorenko Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' her' ' husband' ',' ' the' ' poet'
+ ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' the' ' poet']" , the poet ess , and her husband , the poet , and the poet 's wife , the poet False Twenty Poems of Anna Akhmatova (trans Jane Kenyon); 8 [' Twenty', ' Po', 'ems', ' of', ' Anna', ' Ak', 'h', 'mat', 'ova']
+3465 829 Name of father of x -1 Name of father of Anna Akhmatova Andrey Gorenko Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' her' ' husband' ',' ' the' ' poet'
+ ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' the' ' poet']" , the poet ess , and her husband , the poet , and the poet 's wife , the poet False Twenty Poems of Anna Akhmatova (trans Jane 8 [' Twenty', ' Po', 'ems', ' of', ' Anna', ' Ak', 'h', 'mat', 'ova']
+3466 830 Name of father of x -1 Name of father of Paracelsus Wilhelm Bombast von Hohenheim Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' father' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of father of Par ac els" False corpus, c. 1310), Paracelsus (De Natura 10 [' corpus', ',', ' c', '.', ' 13', '10', '),', ' Par', 'ac', 'els', 'us']
+3467 830 Name of father of x -1 Name of father of Paracelsus Wilhelm Bombast von Hohenheim Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' father' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of father of Par ac els" False gases, although Paracelsus around 1500, 6 [' gases', ',', ' although', ' Par', 'ac', 'els', 'us']
+3468 830 Name of father of x -1 Name of father of Paracelsus Wilhelm Bombast von Hohenheim Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' father' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of father of Par ac els" False known. In 1530, Paracelsus described a wasting 9 [' known', '.', ' In', ' 15', '30', ',', ' Par', 'ac', 'els', 'us']
+3469 830 Name of father of x -1 Name of father of Paracelsus Wilhelm Bombast von Hohenheim Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' father' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of father of Par ac els" False known. In 1530, Paracelsus described a wasting 9 [' known', '.', ' In', ' 15', '30', ',', ' Par', 'ac', 'els', 'us']
+3470 830 Name of father of x -1 Name of father of Paracelsus Wilhelm Bombast von Hohenheim Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' father' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of father of Par ac els" False corpus, c. 1310), Paracelsus (De Natura Rerum 10 [' corpus', ',', ' c', '.', ' 13', '10', '),', ' Par', 'ac', 'els', 'us']
+3471 831 Name of father of x -1 Name of father of Gustaf VI Adolf of Sweden Gustaf V of Sweden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' and' ' the' ' head']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , and the head" False 6 ['G', 'ust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+3472 831 Name of father of x -1 Name of father of Gustaf VI Adolf of Sweden Gustaf V of Sweden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' and' ' the' ' head']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , and the head" False 6 ['G', 'ust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+3473 831 Name of father of x -1 Name of father of Gustaf VI Adolf of Sweden Gustaf V of Sweden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' and' ' the' ' head']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , and the head" False July by King Gustaf VI Adolf of Sweden before returning 8 [' July', ' by', ' King', ' Gust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+3474 831 Name of father of x -1 Name of father of Gustaf VI Adolf of Sweden Gustaf V of Sweden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' and' ' the' ' head']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , and the head" False 6 ['G', 'ust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+3475 831 Name of father of x -1 Name of father of Gustaf VI Adolf of Sweden Gustaf V of Sweden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' and' ' the' ' head']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , and the head" False 11 July by King Gustaf VI Adolf of Sweden before returning 9 [' 11', ' July', ' by', ' King', ' Gust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+3476 832 Name of father of x -1 Name of father of Charles II of England Charles I of England Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False grant that Charles II of England awarded to 5 [' grant', ' that', ' Charles', ' II', ' of', ' England']
+3477 832 Name of father of x -1 Name of father of Charles II of England Charles I of England Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False marriage treaty of Charles II of England and Catherine 6 [' marriage', ' treaty', ' of', ' Charles', ' II', ' of', ' England']
+3478 832 Name of father of x -1 Name of father of Charles II of England Charles I of England Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False land grant that Charles II of England awarded to seven 6 [' land', ' grant', ' that', ' Charles', ' II', ' of', ' England']
+3479 832 Name of father of x -1 Name of father of Charles II of England Charles I of England Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False land grant that Charles II of England awarded to seven 6 [' land', ' grant', ' that', ' Charles', ' II', ' of', ' England']
+3480 832 Name of father of x -1 Name of father of Charles II of England Charles I of England Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False a land grant from Charles II of England to seven of his supporters 7 [' a', ' land', ' grant', ' from', ' Charles', ' II', ' of', ' England']
+3481 834 Name of father of x -1 Name of father of Maria Theresa of Austria Charles VI, Holy Roman Emperor Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' '\n' '\n' 'daughter' ' of' ' the' ' Emperor' ' Francis' ' II' ',']" ", the daughter of the Emperor Joseph II , and the
+
+ daughter of the Emperor Francis II ," False Frederick and Archduchess Maria Theresa of Austria that the balance 9 [' Frederick', ' and', ' Arch', 'du', 'che', 'ss', ' Maria', ' Theresa', ' of', ' Austria']
+3482 834 Name of father of x -1 Name of father of Maria Theresa of Austria Charles VI, Holy Roman Emperor Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' '\n' '\n' 'daughter' ' of' ' the' ' Emperor' ' Francis' ' II' ',']" ", the daughter of the Emperor Joseph II , and the
+
+ daughter of the Emperor Francis II ," False ceased when Empress Maria Theresa of Austria sent her personal 6 [' ceased', ' when', ' Empress', ' Maria', ' Theresa', ' of', ' Austria']
+3483 834 Name of father of x -1 Name of father of Maria Theresa of Austria Charles VI, Holy Roman Emperor Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' '\n' '\n' 'daughter' ' of' ' the' ' Emperor' ' Francis' ' II' ',']" ", the daughter of the Emperor Joseph II , and the
+
+ daughter of the Emperor Francis II ," False charter issued by Maria Theresa of Austria in 1741. The unit 6 [' charter', ' issued', ' by', ' Maria', ' Theresa', ' of', ' Austria']
+3484 834 Name of father of x -1 Name of father of Maria Theresa of Austria Charles VI, Holy Roman Emperor Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' '\n' '\n' 'daughter' ' of' ' the' ' Emperor' ' Francis' ' II' ',']" ", the daughter of the Emperor Joseph II , and the
+
+ daughter of the Emperor Francis II ," False Werbepatent) issued by Maria Theresa of Austria on 27 February 11 [' W', 'erb', 'ep', 'at', 'ent', ')', ' issued', ' by', ' Maria', ' Theresa', ' of', ' Austria']
+3485 834 Name of father of x -1 Name of father of Maria Theresa of Austria Charles VI, Holy Roman Emperor Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' '\n' '\n' 'daughter' ' of' ' the' ' Emperor' ' Francis' ' II' ',']" ", the daughter of the Emperor Joseph II , and the
+
+ daughter of the Emperor Francis II ," False and Archduchess Maria Theresa of Austria that the balance of 8 [' and', ' Arch', 'du', 'che', 'ss', ' Maria', ' Theresa', ' of', ' Austria']
+3486 835 Name of father of x -1 Name of father of Prince Philip, Duke of Edinburgh Prince Andrew of Greece and Denmark Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False engagement to the Prince Philip, Duke of Edinburgh on 10 July 1947. Prince 8 [' engagement', ' to', ' the', ' Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3487 835 Name of father of x -1 Name of father of Prince Philip, Duke of Edinburgh Prince Andrew of Greece and Denmark Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False 5 ['Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3488 835 Name of father of x -1 Name of father of Prince Philip, Duke of Edinburgh Prince Andrew of Greece and Denmark Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False 1970s, British Prince Philip, Duke of Edinburgh competed with a 9 [' 1970', 's', ',', ' British', ' Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3489 835 Name of father of x -1 Name of father of Prince Philip, Duke of Edinburgh Prince Andrew of Greece and Denmark Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False 1970s, British Prince Philip, Duke of Edinburgh competed with 9 [' 1970', 's', ',', ' British', ' Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3490 835 Name of father of x -1 Name of father of Prince Philip, Duke of Edinburgh Prince Andrew of Greece and Denmark Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False 5 ['Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3491 836 Name of father of x -1 Name of father of Hermann Göring Heinrich Ernst Göring Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False Reichsmarschall Hermann Göring and the Oberkommando 7 [' Reich', 'sm', 'ars', 'chall', ' Herman', 'n', ' Gö', 'ring']
+3492 836 Name of father of x -1 Name of father of Hermann Göring Heinrich Ernst Göring Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False " forces, reporting to Hermann Göring that ""Japanese" 7 [' forces', ',', ' reporting', ' to', ' Herman', 'n', ' Gö', 'ring']
+3493 836 Name of father of x -1 Name of father of Hermann Göring Heinrich Ernst Göring Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False Division and the Hermann Göring Panzer Division, 6 [' Division', ' and', ' the', ' Herman', 'n', ' Gö', 'ring']
+3494 836 Name of father of x -1 Name of father of Hermann Göring Heinrich Ernst Göring Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False the Interior, and Hermann Göring Minister of 7 [' the', ' Interior', ',', ' and', ' Herman', 'n', ' Gö', 'ring']
+3495 836 Name of father of x -1 Name of father of Hermann Göring Heinrich Ernst Göring Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False July 31, 1941, Hermann Göring gave written 8 [' July', ' 31', ',', ' 1941', ',', ' Herman', 'n', ' Gö', 'ring']
+3496 837 Name of father of x -1 Name of father of Galen Aelius Nicon Galen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gal' 'en' ','
+ ' and' ' the' ' name' ' of' ' the' ' mother' ' of' ' Gal']" ", and the
+
+ Name of mother of Gal en , and the name of the mother of Gal" False idea. First, Galen Clark and others 5 [' idea', '.', ' First', ',', ' Gal', 'en']
+3497 837 Name of father of x -1 Name of father of Galen Aelius Nicon Galen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gal' 'en' ','
+ ' and' ' the' ' name' ' of' ' the' ' mother' ' of' ' Gal']" ", and the
+
+ Name of mother of Gal en , and the name of the mother of Gal" False included a shot of Galen Tyrol (Douglas) 5 [' included', ' a', ' shot', ' of', ' Gal', 'en']
+3498 837 Name of father of x -1 Name of father of Galen Aelius Nicon Galen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gal' 'en' ','
+ ' and' ' the' ' name' ' of' ' the' ' mother' ' of' ' Gal']" ", and the
+
+ Name of mother of Gal en , and the name of the mother of Gal" False (Michael Trucco), Galen Tyrol (Aaron Douglas) 7 [' (', 'Michael', ' Tru', 'cc', 'o', '),', ' Gal', 'en']
+3499 837 Name of father of x -1 Name of father of Galen Aelius Nicon Galen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gal' 'en' ','
+ ' and' ' the' ' name' ' of' ' the' ' mother' ' of' ' Gal']" ", and the
+
+ Name of mother of Gal en , and the name of the mother of Gal" False the philosopher Galen Strawson and 3 [' the', ' philosopher', ' Gal', 'en']
+3500 837 Name of father of x -1 Name of father of Galen Aelius Nicon Galen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gal' 'en' ','
+ ' and' ' the' ' name' ' of' ' the' ' mother' ' of' ' Gal']" ", and the
+
+ Name of mother of Gal en , and the name of the mother of Gal" False the Bible, and Galen's writings, 5 [' the', ' Bible', ',', ' and', ' Gal', 'en']
+3501 838 Name of father of x -1 Name of father of Eugène Viollet-le-Duc Emmanuel Louis Nicolas Viollet-le-Duc Eugène Viollet-le-Duc "[',' ' the' ' architect' ' of' ' the' ' Cathedral' ' of' ' Notre' '-' 'D'
+ 'ame' ' de' ' Paris' ',' ' and' ' the' ' author' ' of' ' the' ' famous']" , the architect of the Cathedral of Notre - D ame de Paris , and the author of the famous False in France, Eugène Viollet-le-Duc reworked the 12 [' in', ' France', ',', ' Eug', 'è', 'ne', ' Viol', 'let', '-', 'le', '-', 'D', 'uc']
+3502 838 Name of father of x -1 Name of father of Eugène Viollet-le-Duc Emmanuel Louis Nicolas Viollet-le-Duc Eugène Viollet-le-Duc "[',' ' the' ' architect' ' of' ' the' ' Cathedral' ' of' ' Notre' '-' 'D'
+ 'ame' ' de' ' Paris' ',' ' and' ' the' ' author' ' of' ' the' ' famous']" , the architect of the Cathedral of Notre - D ame de Paris , and the author of the famous False Castle, while in France, Eugène Viollet-le-Duc reworked the keeps 15 [' Castle', ',', ' while', ' in', ' France', ',', ' Eug', 'è', 'ne', ' Viol', 'let', '-', 'le', '-', 'D', 'uc']
+3503 838 Name of father of x -1 Name of father of Eugène Viollet-le-Duc Emmanuel Louis Nicolas Viollet-le-Duc Eugène Viollet-le-Duc "[',' ' the' ' architect' ' of' ' the' ' Cathedral' ' of' ' Notre' '-' 'D'
+ 'ame' ' de' ' Paris' ',' ' and' ' the' ' author' ' of' ' the' ' famous']" , the architect of the Cathedral of Notre - D ame de Paris , and the author of the famous False staircase. Influenced by Eugène Viollet-le-Duc he decided 14 [' staircase', '.', ' Influ', 'enced', ' by', ' Eug', 'è', 'ne', ' Viol', 'let', '-', 'le', '-', 'D', 'uc']
+3504 838 Name of father of x -1 Name of father of Eugène Viollet-le-Duc Emmanuel Louis Nicolas Viollet-le-Duc Eugène Viollet-le-Duc "[',' ' the' ' architect' ' of' ' the' ' Cathedral' ' of' ' Notre' '-' 'D'
+ 'ame' ' de' ' Paris' ',' ' and' ' the' ' author' ' of' ' the' ' famous']" , the architect of the Cathedral of Notre - D ame de Paris , and the author of the famous False French architect Eugène Viollet-le-Duc who was undertaking 11 [' French', ' architect', ' Eug', 'è', 'ne', ' Viol', 'let', '-', 'le', '-', 'D', 'uc']
+3505 838 Name of father of x -1 Name of father of Eugène Viollet-le-Duc Emmanuel Louis Nicolas Viollet-le-Duc Eugène Viollet-le-Duc "[',' ' the' ' architect' ' of' ' the' ' Cathedral' ' of' ' Notre' '-' 'D'
+ 'ame' ' de' ' Paris' ',' ' and' ' the' ' author' ' of' ' the' ' famous']" , the architect of the Cathedral of Notre - D ame de Paris , and the author of the famous False staircase. Influenced by Eugène Viollet-le-Duc he decided 14 [' staircase', '.', ' Influ', 'enced', ' by', ' Eug', 'è', 'ne', ' Viol', 'let', '-', 'le', '-', 'D', 'uc']
+3506 839 Name of father of x -1 Name of father of Henry Fielding Edmund Fielding Henry Fielding "[',' ' the' ' author' ' of' ' _' 'Tom' ' Jones' '_' ',' ' and' ' _'
+ 'Joseph' ' Andrews' '_' ',' ' and' '\n' ' ' ' ' ' ']" ", the author of _ Tom Jones _ , and _ Joseph Andrews _ , and
+ " False he agreed with Henry Fielding that one of the 5 [' he', ' agreed', ' with', ' Henry', ' Field', 'ing']
+3507 839 Name of father of x -1 Name of father of Henry Fielding Edmund Fielding Henry Fielding "[',' ' the' ' author' ' of' ' _' 'Tom' ' Jones' '_' ',' ' and' ' _'
+ 'Joseph' ' Andrews' '_' ',' ' and' '\n' ' ' ' ' ' ']" ", the author of _ Tom Jones _ , and _ Joseph Andrews _ , and
+ " False in his Life of Henry Fielding (1855), connected 6 [' in', ' his', ' Life', ' of', ' Henry', ' Field', 'ing']
+3508 839 Name of father of x -1 Name of father of Henry Fielding Edmund Fielding Henry Fielding "[',' ' the' ' author' ' of' ' _' 'Tom' ' Jones' '_' ',' ' and' ' _'
+ 'Joseph' ' Andrews' '_' ',' ' and' '\n' ' ' ' ' ' ']" ", the author of _ Tom Jones _ , and _ Joseph Andrews _ , and
+ " False he agreed with Henry Fielding that one of 5 [' he', ' agreed', ' with', ' Henry', ' Field', 'ing']
+3509 839 Name of father of x -1 Name of father of Henry Fielding Edmund Fielding Henry Fielding "[',' ' the' ' author' ' of' ' _' 'Tom' ' Jones' '_' ',' ' and' ' _'
+ 'Joseph' ' Andrews' '_' ',' ' and' '\n' ' ' ' ' ' ']" ", the author of _ Tom Jones _ , and _ Joseph Andrews _ , and
+ " False experienced members, Henry Fielding joined the management's 5 [' experienced', ' members', ',', ' Henry', ' Field', 'ing']
+3510 839 Name of father of x -1 Name of father of Henry Fielding Edmund Fielding Henry Fielding "[',' ' the' ' author' ' of' ' _' 'Tom' ' Jones' '_' ',' ' and' ' _'
+ 'Joseph' ' Andrews' '_' ',' ' and' '\n' ' ' ' ' ' ']" ", the author of _ Tom Jones _ , and _ Joseph Andrews _ , and
+ " False in his Life of Henry Fielding (1855), connected the 6 [' in', ' his', ' Life', ' of', ' Henry', ' Field', 'ing']
+3511 840 Name of father of x -1 Name of father of Greta Garbo Karl Alfred Gustafsson Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Daughters (1928). Greta Garbo was the one non – native 9 [' D', 'aughters', ' (', '19', '28', ').', ' Gret', 'a', ' Gar', 'bo']
+3512 840 Name of father of x -1 Name of father of Greta Garbo Karl Alfred Gustafsson Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False stars from Europe, Greta Garbo and Marlene Dietrich. 7 [' stars', ' from', ' Europe', ',', ' Gret', 'a', ' Gar', 'bo']
+3513 840 Name of father of x -1 Name of father of Greta Garbo Karl Alfred Gustafsson Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False to appear opposite Greta Garbo in Queen Christina, 6 [' to', ' appear', ' opposite', ' Gret', 'a', ' Gar', 'bo']
+3514 840 Name of father of x -1 Name of father of Greta Garbo Karl Alfred Gustafsson Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Daughters (1928). Greta Garbo was the one non 9 [' D', 'aughters', ' (', '19', '28', ').', ' Gret', 'a', ' Gar', 'bo']
+3515 840 Name of father of x -1 Name of father of Greta Garbo Karl Alfred Gustafsson Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Daughters (1928). Greta Garbo was the one non 9 [' D', 'aughters', ' (', '19', '28', ').', ' Gret', 'a', ' Gar', 'bo']
+3516 841 Name of father of x -1 Name of father of Augustus John Edwin William John Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False capitals of Europe. Augustus John and his sister Gwen 5 [' capitals', ' of', ' Europe', '.', ' Augustus', ' John']
+3517 841 Name of father of x -1 Name of father of Augustus John Edwin William John Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False Woolf, Roger Fry, Augustus John and many other 7 [' Wool', 'f', ',', ' Roger', ' Fry', ',', ' Augustus', ' John']
+3518 841 Name of father of x -1 Name of father of Augustus John Edwin William John Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False Notable artists Augustus John and Walter Sickert. 4 [' Not', 'able', ' artists', ' Augustus', ' John']
+3519 841 Name of father of x -1 Name of father of Augustus John Edwin William John Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False capitals of Europe. Augustus John and his sister Gwen 5 [' capitals', ' of', ' Europe', '.', ' Augustus', ' John']
+3520 841 Name of father of x -1 Name of father of Augustus John Edwin William John Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False Woolf, Roger Fry, Augustus John and many other writers 7 [' Wool', 'f', ',', ' Roger', ' Fry', ',', ' Augustus', ' John']
+3521 842 Name of father of x -1 Name of father of Keanu Reeves Samuel Nowlin Reeves, Jr. Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False faint pencil marks. Keanu Reeves and John Cleese drew 6 [' faint', ' pencil', ' marks', '.', ' Ke', 'anu', ' Reeves']
+3522 842 Name of father of x -1 Name of father of Keanu Reeves Samuel Nowlin Reeves, Jr. Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False contest with Keanu Reeves to see how long 4 [' contest', ' with', ' Ke', 'anu', ' Reeves']
+3523 842 Name of father of x -1 Name of father of Keanu Reeves Samuel Nowlin Reeves, Jr. Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False which also starred Keanu Reeves and Morgan Freeman. 5 [' which', ' also', ' starred', ' Ke', 'anu', ' Reeves']
+3524 842 Name of father of x -1 Name of father of Keanu Reeves Samuel Nowlin Reeves, Jr. Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False 2 ['Ke', 'anu', ' Reeves']
+3525 842 Name of father of x -1 Name of father of Keanu Reeves Samuel Nowlin Reeves, Jr. Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False Canadian actor Keanu Reeves from the 1999 science 4 [' Canadian', ' actor', ' Ke', 'anu', ' Reeves']
+3526 843 Name of father of x -1 Name of father of Serena Williams Richard A Williams Jr Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and his wife , Alexis Oh anian , the co - founder of Reddit , False eventual champion Serena Williams in the second round. 4 [' eventual', ' champion', ' Sere', 'na', ' Williams']
+3527 843 Name of father of x -1 Name of father of Serena Williams Richard A Williams Jr Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and his wife , Alexis Oh anian , the co - founder of Reddit , False Raonic's fans. Serena Williams described the sleeve 7 "[' Ra', 'onic', ""'s"", ' fans', '.', ' Sere', 'na', ' Williams']"
+3528 843 Name of father of x -1 Name of father of Serena Williams Richard A Williams Jr Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and his wife , Alexis Oh anian , the co - founder of Reddit , False along with Dementieva, Serena Williams and Venus Williams. 9 [' along', ' with', ' D', 'ement', 'ie', 'va', ',', ' Sere', 'na', ' Williams']
+3529 843 Name of father of x -1 Name of father of Serena Williams Richard A Williams Jr Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and his wife , Alexis Oh anian , the co - founder of Reddit , False Kuznetsova defeated Serena Williams in three sets before 8 [' K', 'uz', 'net', 'so', 'va', ' defeated', ' Sere', 'na', ' Williams']
+3530 843 Name of father of x -1 Name of father of Serena Williams Richard A Williams Jr Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and his wife , Alexis Oh anian , the co - founder of Reddit , False first career win over Serena Williams before losing in the 6 [' first', ' career', ' win', ' over', ' Sere', 'na', ' Williams']
+3531 844 Name of father of x -1 Name of father of Heinrich Himmler Joseph Gebhard Himmler Heinrich Himmler "[',' ' the' ' Nazi' ' Party' ""'s"" ' chief' ' of' ' the' ' SS' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n']" ", the Nazi Party 's chief of the SS , was a member of the Nazi Party .
+
+" False reportedly because Heinrich Himmler intervened in 6 [' reportedly', ' because', ' Hein', 'rich', ' H', 'imm', 'ler']
+3532 844 Name of father of x -1 Name of father of Heinrich Himmler Joseph Gebhard Himmler Heinrich Himmler "[',' ' the' ' Nazi' ' Party' ""'s"" ' chief' ' of' ' the' ' SS' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n']" ", the Nazi Party 's chief of the SS , was a member of the Nazi Party .
+
+" False speeches made by Heinrich Himmler in October 1943 7 [' speeches', ' made', ' by', ' Hein', 'rich', ' H', 'imm', 'ler']
+3533 844 Name of father of x -1 Name of father of Heinrich Himmler Joseph Gebhard Himmler Heinrich Himmler "[',' ' the' ' Nazi' ' Party' ""'s"" ' chief' ' of' ' the' ' SS' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n']" ", the Nazi Party 's chief of the SS , was a member of the Nazi Party .
+
+" False 5 ['He', 'in', 'rich', ' H', 'imm', 'ler']
+3534 844 Name of father of x -1 Name of father of Heinrich Himmler Joseph Gebhard Himmler Heinrich Himmler "[',' ' the' ' Nazi' ' Party' ""'s"" ' chief' ' of' ' the' ' SS' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n']" ", the Nazi Party 's chief of the SS , was a member of the Nazi Party .
+
+" False Reichsführer-SS Heinrich Himmler and the supervision 11 [' Reich', 'sf', 'ü', 'h', 'rer', '-', 'SS', ' Hein', 'rich', ' H', 'imm', 'ler']
+3535 844 Name of father of x -1 Name of father of Heinrich Himmler Joseph Gebhard Himmler Heinrich Himmler "[',' ' the' ' Nazi' ' Party' ""'s"" ' chief' ' of' ' the' ' SS' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n']" ", the Nazi Party 's chief of the SS , was a member of the Nazi Party .
+
+" False including Adolf Eichmann, Heinrich Himmler and Adolf Hitler. 10 [' including', ' Adolf', ' E', 'ich', 'mann', ',', ' Hein', 'rich', ' H', 'imm', 'ler']
+3536 847 Name of father of x -1 Name of father of Samuel Johnson Michael Johnson Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False and enjoyment of Samuel Johnson's life and works. 4 [' and', ' enjoyment', ' of', ' Samuel', ' Johnson']
+3537 847 Name of father of x -1 Name of father of Samuel Johnson Michael Johnson Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False 2 ['Sam', 'uel', ' Johnson']
+3538 847 Name of father of x -1 Name of father of Samuel Johnson Michael Johnson Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False " lyricists labelled by Samuel Johnson as ""the Metaphysical" 5 [' lyric', 'ists', ' labelled', ' by', ' Samuel', ' Johnson']
+3539 847 Name of father of x -1 Name of father of Samuel Johnson Michael Johnson Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False Years since Doctor Samuel Johnson advised me to get 4 [' Years', ' since', ' Doctor', ' Samuel', ' Johnson']
+3540 847 Name of father of x -1 Name of father of Samuel Johnson Michael Johnson Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False " in the 18th century Samuel Johnson noted that:
+" 6 [' in', ' the', ' 18', 'th', ' century', ' Samuel', ' Johnson']
+3541 848 Name of father of x -1 Name of father of Dmitri Mendeleev Ivan Pavlovich Mendeleyev Dmitri Mendeleev "[',' ' the' ' father' ' of' ' the' ' periodic' ' table' '.' '\n' '\n'
+ 'The' ' periodic' ' table' ' of' ' the' ' elements' ' is' ' a' ' table'
+ ' that']" ", the father of the periodic table .
+
+ The periodic table of the elements is a table that" False 1869, when Dmitri Mendeleev published his periodic 8 [' 18', '69', ',', ' when', ' Dmit', 'ri', ' Mend', 'ele', 'ev']
+3542 848 Name of father of x -1 Name of father of Dmitri Mendeleev Ivan Pavlovich Mendeleyev Dmitri Mendeleev "[',' ' the' ' father' ' of' ' the' ' periodic' ' table' '.' '\n' '\n'
+ 'The' ' periodic' ' table' ' of' ' the' ' elements' ' is' ' a' ' table'
+ ' that']" ", the father of the periodic table .
+
+ The periodic table of the elements is a table that" False 5 ['D', 'mit', 'ri', ' Mend', 'ele', 'ev']
+3543 848 Name of father of x -1 Name of father of Dmitri Mendeleev Ivan Pavlovich Mendeleyev Dmitri Mendeleev "[',' ' the' ' father' ' of' ' the' ' periodic' ' table' '.' '\n' '\n'
+ 'The' ' periodic' ' table' ' of' ' the' ' elements' ' is' ' a' ' table'
+ ' that']" ", the father of the periodic table .
+
+ The periodic table of the elements is a table that" False 5 ['D', 'mit', 'ri', ' Mend', 'ele', 'ev']
+3544 848 Name of father of x -1 Name of father of Dmitri Mendeleev Ivan Pavlovich Mendeleyev Dmitri Mendeleev "[',' ' the' ' father' ' of' ' the' ' periodic' ' table' '.' '\n' '\n'
+ 'The' ' periodic' ' table' ' of' ' the' ' elements' ' is' ' a' ' table'
+ ' that']" ", the father of the periodic table .
+
+ The periodic table of the elements is a table that" False properties were predicted by Dmitri Mendeleev before the element 8 [' properties', ' were', ' predicted', ' by', ' Dmit', 'ri', ' Mend', 'ele', 'ev']
+3545 848 Name of father of x -1 Name of father of Dmitri Mendeleev Ivan Pavlovich Mendeleyev Dmitri Mendeleev "[',' ' the' ' father' ' of' ' the' ' periodic' ' table' '.' '\n' '\n'
+ 'The' ' periodic' ' table' ' of' ' the' ' elements' ' is' ' a' ' table'
+ ' that']" ", the father of the periodic table .
+
+ The periodic table of the elements is a table that" False helium and argon, Dmitri Mendeleev included these noble 9 [' helium', ' and', ' arg', 'on', ',', ' Dmit', 'ri', ' Mend', 'ele', 'ev']
+3546 849 Name of father of x -1 Name of father of Michael Faraday James Faraday Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False The scientist Michael Faraday described the situation 4 [' The', ' scientist', ' Michael', ' Far', 'aday']
+3547 849 Name of father of x -1 Name of father of Michael Faraday James Faraday Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False would later interest Michael Faraday and James Clerk 5 [' would', ' later', ' interest', ' Michael', ' Far', 'aday']
+3548 849 Name of father of x -1 Name of father of Michael Faraday James Faraday Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False expanded upon by Michael Faraday in 1833. Current 5 [' expanded', ' upon', ' by', ' Michael', ' Far', 'aday']
+3549 849 Name of father of x -1 Name of father of Michael Faraday James Faraday Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False 2 ['Michael', ' Far', 'aday']
+3550 849 Name of father of x -1 Name of father of Michael Faraday James Faraday Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False since 1834 when Michael Faraday published his works 6 [' since', ' 18', '34', ' when', ' Michael', ' Far', 'aday']
+3551 850 Name of father of x -1 Name of father of August von Kotzebue Levin Kotzebue August von Kotzebue "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False German dramatist August von Kotzebue in March 1819 8 [' German', ' dram', 'at', 'ist', ' August', ' von', ' Kot', 'zeb', 'ue']
+3552 850 Name of father of x -1 Name of father of August von Kotzebue Levin Kotzebue August von Kotzebue "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False German dramatist August von Kotzebue in March 1819 by a 8 [' German', ' dram', 'at', 'ist', ' August', ' von', ' Kot', 'zeb', 'ue']
+3553 850 Name of father of x -1 Name of father of August von Kotzebue Levin Kotzebue August von Kotzebue "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False German dramatist August von Kotzebue in March 1819 8 [' German', ' dram', 'at', 'ist', ' August', ' von', ' Kot', 'zeb', 'ue']
+3554 850 Name of father of x -1 Name of father of August von Kotzebue Levin Kotzebue August von Kotzebue "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False German dramatist August von Kotzebue in March 1819 by a 8 [' German', ' dram', 'at', 'ist', ' August', ' von', ' Kot', 'zeb', 'ue']
+3555 851 Name of father of x -1 Name of father of Paul Éluard Eugène Grindel Paul Éluard "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' painter' ','
+ ' and' ' his' ' wife' ',' ' the' ' painter' ""'s"" ' wife' ',']" , the poet , and his wife , the painter , and his wife , the painter 's wife , False Apollinaire, Max Jacob, Paul Éluard and Louis Aragon. 11 [' Ap', 'oll', 'ina', 'ire', ',', ' Max', ' Jacob', ',', ' Paul', ' É', 'lu', 'ard']
+3556 852 Name of father of x -1 Name of father of Bono Bob Hewson Bono "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False " ""The Electric Co."", Bono left the stage" 6 "[' ""', 'The', ' Electric', ' Co', '."",', ' Bon', 'o']"
+3557 852 Name of father of x -1 Name of father of Bono Bob Hewson Bono "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False sleepless night, Bono of the rock group 5 [' slee', 'pless', ' night', ',', ' Bon', 'o']
+3558 852 Name of father of x -1 Name of father of Bono Bob Hewson Bono "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False 1983's War Tour, Bono continued to reassure 6 "[' 1983', ""'s"", ' War', ' Tour', ',', ' Bon', 'o']"
+3559 852 Name of father of x -1 Name of father of Bono Bob Hewson Bono "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False Mirror Ball Man, Bono dressed in a 5 [' Mirror', ' Ball', ' Man', ',', ' Bon', 'o']
+3560 852 Name of father of x -1 Name of father of Bono Bob Hewson Bono "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the father of the groom .
+
+ The bride and groom" False Horizon's release, Bono said he was 5 "[' Horizon', ""'s"", ' release', ',', ' Bon', 'o']"
+3561 853 Name of father of x -1 Name of father of Alfred Russel Wallace Thomas Vere Wallace Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False " Wallace =
+" 7 [' Wallace', ' =', 'A', 'lf', 'red', ' Rus', 'sel', ' Wallace']
+3562 853 Name of father of x -1 Name of father of Alfred Russel Wallace Thomas Vere Wallace Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False first suggested by Alfred Russel Wallace before 1888. 6 [' first', ' suggested', ' by', ' Alfred', ' Rus', 'sel', ' Wallace']
+3563 853 Name of father of x -1 Name of father of Alfred Russel Wallace Thomas Vere Wallace Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False 5 ['A', 'lf', 'red', ' Rus', 'sel', ' Wallace']
+3564 853 Name of father of x -1 Name of father of Alfred Russel Wallace Thomas Vere Wallace Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False 5 ['A', 'lf', 'red', ' Rus', 'sel', ' Wallace']
+3565 853 Name of father of x -1 Name of father of Alfred Russel Wallace Thomas Vere Wallace Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False British naturalist Alfred Russel Wallace provided a much-quoted 6 [' British', ' natural', 'ist', ' Alfred', ' Rus', 'sel', ' Wallace']
+3566 854 Name of father of x -1 Name of father of Aung San Suu Kyi Aung San Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False and Burma's Aung San Suu Kyi (1991). As the laureate 9 "[' and', ' Burma', ""'s"", ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']"
+3567 854 Name of father of x -1 Name of father of Aung San Suu Kyi Aung San Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False factor in the refusal of Aung San Suu Kyi to leave Burma 11 [' factor', ' in', ' the', ' refusal', ' of', ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']
+3568 854 Name of father of x -1 Name of father of Aung San Suu Kyi Aung San Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False (1935) and Burma's Aung San Suu Kyi (1991). As the 13 "[' (', '19', '35', ')', ' and', ' Burma', ""'s"", ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']"
+3569 854 Name of father of x -1 Name of father of Aung San Suu Kyi Aung San Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False (1935) and Burma's Aung San Suu Kyi (1991). As 13 "[' (', '19', '35', ')', ' and', ' Burma', ""'s"", ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']"
+3570 854 Name of father of x -1 Name of father of Aung San Suu Kyi Aung San Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False and Burma's Aung San Suu Kyi (1991). As the 9 "[' and', ' Burma', ""'s"", ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']"
+3571 855 Name of father of x -1 Name of father of Klemens von Metternich Franz Georg Karl von Metternich Klemens von Metternich "[',' ' the' ' Austrian' ' chancellor' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' the' ' father' ' of' ' the' ' Austrian' ' chancellor' ',' ' who'
+ ' was']" ", the Austrian chancellor , and the
+
+ Name of the father of the Austrian chancellor , who was" False reactionary minister Klemens von Metternich and his secret 8 [' reactionary', ' minister', ' K', 'lem', 'ens', ' von', ' Met', 'tern', 'ich']
+3572 855 Name of father of x -1 Name of father of Klemens von Metternich Franz Georg Karl von Metternich Klemens von Metternich "[',' ' the' ' Austrian' ' chancellor' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' the' ' father' ' of' ' the' ' Austrian' ' chancellor' ',' ' who'
+ ' was']" ", the Austrian chancellor , and the
+
+ Name of the father of the Austrian chancellor , who was" False reactionary minister Klemens von Metternich and his secret police 8 [' reactionary', ' minister', ' K', 'lem', 'ens', ' von', ' Met', 'tern', 'ich']
+3573 857 Name of father of x -1 Name of father of Andrei Tarkovsky Arseny Tarkovsky Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False the director Andrei Tarkovsky had two motives 6 [' the', ' director', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3574 857 Name of father of x -1 Name of father of Andrei Tarkovsky Arseny Tarkovsky Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False Russian filmmaker Andrei Tarkovsky praised Chaplin 6 [' Russian', ' filmmaker', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3575 857 Name of father of x -1 Name of father of Andrei Tarkovsky Arseny Tarkovsky Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False filmmakers such as Andrei Tarkovsky and Robert 7 [' filmmakers', ' such', ' as', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3576 857 Name of father of x -1 Name of father of Andrei Tarkovsky Arseny Tarkovsky Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False filmmakers such as Andrei Tarkovsky and Robert Bresson. 7 [' filmmakers', ' such', ' as', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3577 857 Name of father of x -1 Name of father of Andrei Tarkovsky Arseny Tarkovsky Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False Russian filmmaker Andrei Tarkovsky praised Chaplin as 6 [' Russian', ' filmmaker', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3578 858 Name of father of x -1 Name of father of Drew Barrymore John Drew Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False plus stars Drew Barrymore and Jimmy Fallon, 4 [' plus', ' stars', ' Drew', ' Barry', 'more']
+3579 858 Name of father of x -1 Name of father of Drew Barrymore John Drew Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False 3 ['D', 'rew', ' Barry', 'more']
+3580 858 Name of father of x -1 Name of father of Drew Barrymore John Drew Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False critical response. Drew Barrymore read the script 5 [' critical', ' response', '.', ' Drew', ' Barry', 'more']
+3581 858 Name of father of x -1 Name of father of Drew Barrymore John Drew Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False Beetlejuice. Drew Barrymore previously auditioned 6 [' Beetle', 'ju', 'ice', '.', ' Drew', ' Barry', 'more']
+3582 858 Name of father of x -1 Name of father of Drew Barrymore John Drew Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False " as Sugar
+" 5 [' as', ' Sugar', 'D', 'rew', ' Barry', 'more']
+3583 861 Name of father of x -1 Name of father of John Singleton Copley Richard Copley John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Rev' '.' ' John' ' C']" , the artist , and his wife , Mary , who was the daughter of the Rev . John C False 5 ['John', ' Sing', 'leton', ' C', 'ople', 'y']
+3584 861 Name of father of x -1 Name of father of John Singleton Copley Richard Copley John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Rev' '.' ' John' ' C']" , the artist , and his wife , Mary , who was the daughter of the Rev . John C False colonial 7 [' colon', 'ia', 'John', ' Sing', 'leton', ' C', 'ople', 'y']
+3585 861 Name of father of x -1 Name of father of John Singleton Copley Richard Copley John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Rev' '.' ' John' ' C']" , the artist , and his wife , Mary , who was the daughter of the Rev . John C False appears in a John Singleton Copley portrait of 8 [' appears', ' in', ' a', ' John', ' Sing', 'leton', ' C', 'ople', 'y']
+3586 861 Name of father of x -1 Name of father of John Singleton Copley Richard Copley John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Rev' '.' ' John' ' C']" , the artist , and his wife , Mary , who was the daughter of the Rev . John C False and appears in a John Singleton Copley portrait of ca. 9 [' and', ' appears', ' in', ' a', ' John', ' Sing', 'leton', ' C', 'ople', 'y']
+3587 861 Name of father of x -1 Name of father of John Singleton Copley Richard Copley John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Rev' '.' ' John' ' C']" , the artist , and his wife , Mary , who was the daughter of the Rev . John C False 5 ['John', ' Sing', 'leton', ' C', 'ople', 'y']
+3588 862 Name of father of x -1 Name of father of Thomas Moore John Moore Thomas Moore "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' little' ' bit' ' of' ' a' ' fan'
+ ' of' ' the' ' show' ',' ' but' ' I' ' have' ' to']" ", the
+
+ I am a little bit of a fan of the show , but I have to" False on a poem by Thomas Moore with characters 5 [' on', ' a', ' poem', ' by', ' Thomas', ' Moore']
+3589 862 Name of father of x -1 Name of father of Thomas Moore John Moore Thomas Moore "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' little' ' bit' ' of' ' a' ' fan'
+ ' of' ' the' ' show' ',' ' but' ' I' ' have' ' to']" ", the
+
+ I am a little bit of a fan of the show , but I have to" False construction shipwright Thomas Moore tested that the hull 4 [' construction', ' ship', 'wright', ' Thomas', ' Moore']
+3590 862 Name of father of x -1 Name of father of Thomas Moore John Moore Thomas Moore "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' little' ' bit' ' of' ' a' ' fan'
+ ' of' ' the' ' show' ',' ' but' ' I' ' have' ' to']" ", the
+
+ I am a little bit of a fan of the show , but I have to" False construction shipwright Thomas Moore tested that the hull 4 [' construction', ' ship', 'wright', ' Thomas', ' Moore']
+3591 862 Name of father of x -1 Name of father of Thomas Moore John Moore Thomas Moore "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' little' ' bit' ' of' ' a' ' fan'
+ ' of' ' the' ' show' ',' ' but' ' I' ' have' ' to']" ", the
+
+ I am a little bit of a fan of the show , but I have to" False the Irish Melodies of Thomas Moore and ballads 6 [' the', ' Irish', ' Mel', 'odies', ' of', ' Thomas', ' Moore']
+3592 862 Name of father of x -1 Name of father of Thomas Moore John Moore Thomas Moore "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' little' ' bit' ' of' ' a' ' fan'
+ ' of' ' the' ' show' ',' ' but' ' I' ' have' ' to']" ", the
+
+ I am a little bit of a fan of the show , but I have to" False Melodies of Thomas Moore and ballads such 4 [' Mel', 'odies', ' of', ' Thomas', ' Moore']
+3593 863 Name of father of x -1 Name of father of Ed Sheeran John Sheeran Ed Sheeran "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' friend']" , who was a member of the band , and the band 's bass ist , who was a friend False 3 ['Ed', ' She', 'er', 'an']
+3594 863 Name of father of x -1 Name of father of Ed Sheeran John Sheeran Ed Sheeran "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' friend']" , who was a member of the band , and the band 's bass ist , who was a friend False singer-songwriter Ed Sheeran would appear 7 [' singer', '-', 'song', 'writer', ' Ed', ' She', 'er', 'an']
+3595 863 Name of father of x -1 Name of father of Ed Sheeran John Sheeran Ed Sheeran "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' friend']" , who was a member of the band , and the band 's bass ist , who was a friend False 3 ['Ed', ' She', 'er', 'an']
+3596 863 Name of father of x -1 Name of father of Ed Sheeran John Sheeran Ed Sheeran "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' friend']" , who was a member of the band , and the band 's bass ist , who was a friend False 3 ['Ed', ' She', 'er', 'an']
+3597 863 Name of father of x -1 Name of father of Ed Sheeran John Sheeran Ed Sheeran "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' band' ',' ' and' ' the'
+ ' band' ""'s"" ' bass' 'ist' ',' ' who' ' was' ' a' ' friend']" , who was a member of the band , and the band 's bass ist , who was a friend False recording artist Ed Sheeran for his second 5 [' recording', ' artist', ' Ed', ' She', 'er', 'an']
+3598 864 Name of father of x -1 Name of father of Peter the Great Alexei I of Russia Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False Court, and had Peter the Great (1672 – 1725) 6 [' Court', ',', ' and', ' had', ' Peter', ' the', ' Great']
+3599 864 Name of father of x -1 Name of father of Peter the Great Alexei I of Russia Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False Petersburg, Peter the Great ordered the 4 [' Petersburg', ',', ' Peter', ' the', ' Great']
+3600 864 Name of father of x -1 Name of father of Peter the Great Alexei I of Russia Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False artillery. In April, Peter the Great managed to 7 [' artillery', '.', ' In', ' April', ',', ' Peter', ' the', ' Great']
+3601 864 Name of father of x -1 Name of father of Peter the Great Alexei I of Russia Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False 2 ['Peter', ' the', ' Great']
+3602 864 Name of father of x -1 Name of father of Peter the Great Alexei I of Russia Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False patriotism related to Peter the Great resurfaced. He 5 [' patriotism', ' related', ' to', ' Peter', ' the', ' Great']
+3603 866 Name of father of x -1 Name of father of T. E. Lawrence Sir Thomas Chapman, 7th Baronet T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' his' ',' ' and' ' who']" , the author of the book , and the man who had been a friend of his , and who False was hopeful that T. E. Lawrence and the Sherifial 7 [' was', ' hopeful', ' that', ' T', '.', ' E', '.', ' Lawrence']
+3604 866 Name of father of x -1 Name of father of T. E. Lawrence Sir Thomas Chapman, 7th Baronet T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' his' ',' ' and' ' who']" , the author of the book , and the man who had been a friend of his , and who False " undergraduate thesis of T. E. Lawrence (""Lawrence of Arabia"")," 7 [' undergraduate', ' thesis', ' of', ' T', '.', ' E', '.', ' Lawrence']
+3605 866 Name of father of x -1 Name of father of T. E. Lawrence Sir Thomas Chapman, 7th Baronet T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' his' ',' ' and' ' who']" , the author of the book , and the man who had been a friend of his , and who False to join Colonel T. E. Lawrence and his irregular 7 [' to', ' join', ' Colonel', ' T', '.', ' E', '.', ' Lawrence']
+3606 866 Name of father of x -1 Name of father of T. E. Lawrence Sir Thomas Chapman, 7th Baronet T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' his' ',' ' and' ' who']" , the author of the book , and the man who had been a friend of his , and who False " undergraduate thesis of T. E. Lawrence (""Lawrence of Arabia"")," 7 [' undergraduate', ' thesis', ' of', ' T', '.', ' E', '.', ' Lawrence']
+3607 866 Name of father of x -1 Name of father of T. E. Lawrence Sir Thomas Chapman, 7th Baronet T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' man' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' his' ',' ' and' ' who']" , the author of the book , and the man who had been a friend of his , and who False " thesis of T. E. Lawrence (""Lawrence of" 6 [' thesis', ' of', ' T', '.', ' E', '.', ' Lawrence']
+3608 867 Name of father of x -1 Name of father of Wolfgang Pauli Wolfgang Joseph Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False 3 ['Wolf', 'gang', ' Paul', 'i']
+3609 867 Name of father of x -1 Name of father of Wolfgang Pauli Wolfgang Joseph Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False mechanics. After Wolfgang Pauli announced his 5 [' mechanics', '.', ' After', ' Wolfgang', ' Paul', 'i']
+3610 867 Name of father of x -1 Name of father of Wolfgang Pauli Wolfgang Joseph Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False complicated. When he met Wolfgang Pauli for the first time, 7 [' complicated', '.', ' When', ' he', ' met', ' Wolfgang', ' Paul', 'i']
+3611 867 Name of father of x -1 Name of father of Wolfgang Pauli Wolfgang Joseph Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False Austrian physicist Wolfgang Pauli observed that 4 [' Austrian', ' physicist', ' Wolfgang', ' Paul', 'i']
+3612 867 Name of father of x -1 Name of father of Wolfgang Pauli Wolfgang Joseph Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False Austrian physicist Wolfgang Pauli observed that 4 [' Austrian', ' physicist', ' Wolfgang', ' Paul', 'i']
+3613 868 Name of father of x -1 Name of father of Philippe Pétain Omer-Venant Pétain Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' at'
+ ' the' ' Battle' ' of' ' the' ' Mar' 'ne' ',' ' and' ' who' ' had'
+ ' been']" , the French general who had been defeated at the Battle of the Mar ne , and who had been False defensive-minded Philippe Pétain to the offensive-minded 6 [' defensive', '-', 'minded', ' Philippe', ' P', 'é', 'tain']
+3614 868 Name of father of x -1 Name of father of Philippe Pétain Omer-Venant Pétain Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' at'
+ ' the' ' Battle' ' of' ' the' ' Mar' 'ne' ',' ' and' ' who' ' had'
+ ' been']" , the French general who had been defeated at the Battle of the Mar ne , and who had been False tanks. Marshal Philippe Pétain described them 6 [' tanks', '.', ' Marshal', ' Philippe', ' P', 'é', 'tain']
+3615 868 Name of father of x -1 Name of father of Philippe Pétain Omer-Venant Pétain Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' at'
+ ' the' ' Battle' ' of' ' the' ' Mar' 'ne' ',' ' and' ' who' ' had'
+ ' been']" , the French general who had been defeated at the Battle of the Mar ne , and who had been False by Marshal Philippe Pétain of the Vichy 5 [' by', ' Marshal', ' Philippe', ' P', 'é', 'tain']
+3616 868 Name of father of x -1 Name of father of Philippe Pétain Omer-Venant Pétain Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' at'
+ ' the' ' Battle' ' of' ' the' ' Mar' 'ne' ',' ' and' ' who' ' had'
+ ' been']" , the French general who had been defeated at the Battle of the Mar ne , and who had been False replaced by General Philippe Pétain who immediately 6 [' replaced', ' by', ' General', ' Philippe', ' P', 'é', 'tain']
+3617 868 Name of father of x -1 Name of father of Philippe Pétain Omer-Venant Pétain Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' at'
+ ' the' ' Battle' ' of' ' the' ' Mar' 'ne' ',' ' and' ' who' ' had'
+ ' been']" , the French general who had been defeated at the Battle of the Mar ne , and who had been False replaced by General Philippe Pétain who immediately 6 [' replaced', ' by', ' General', ' Philippe', ' P', 'é', 'tain']
+3618 870 Name of father of x -1 Name of father of Guglielmo Marconi Giuseppe Marconi Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False submarine Guglielmo Marconi on 11 July after the 7 [' submarine', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3619 870 Name of father of x -1 Name of father of Guglielmo Marconi Giuseppe Marconi Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False the Italian submarine Guglielmo Marconi on 11 July after the 9 [' the', ' Italian', ' submarine', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3620 870 Name of father of x -1 Name of father of Guglielmo Marconi Giuseppe Marconi Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False Emmanuel invited Guglielmo Marconi to accompany him 8 [' Emmanuel', ' invited', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3621 870 Name of father of x -1 Name of father of Guglielmo Marconi Giuseppe Marconi Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False tests were conducted by Guglielmo Marconi and were supervised 10 [' tests', ' were', ' conducted', ' by', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3622 870 Name of father of x -1 Name of father of Guglielmo Marconi Giuseppe Marconi Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False as Bologna Guglielmo Marconi Airport, named 10 [' as', ' B', 'olog', 'na', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3623 871 Name of father of x -1 Name of father of Richard Dawkins Clinton John Dawkins Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False disagreed with Richard Dawkins about the importance 3 [' disagreed', ' with', ' Richard', ' Dawkins']
+3624 871 Name of father of x -1 Name of father of Richard Dawkins Clinton John Dawkins Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False 1 ['Richard', ' Dawkins']
+3625 871 Name of father of x -1 Name of father of Richard Dawkins Clinton John Dawkins Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False it is known as the Richard Dawkins Award, in honour 6 [' it', ' is', ' known', ' as', ' the', ' Richard', ' Dawkins']
+3626 871 Name of father of x -1 Name of father of Richard Dawkins Clinton John Dawkins Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False 1 ['Richard', ' Dawkins']
+3627 871 Name of father of x -1 Name of father of Richard Dawkins Clinton John Dawkins Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False article on atheist Richard Dawkins featured a picture 4 [' article', ' on', ' atheist', ' Richard', ' Dawkins']
+3628 872 Name of father of x -1 Name of father of Christo Vladimir Yavashev Christo "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' Son' ' of' ' God' ','
+ ' and' ' the' ' Holy' ' Spirit' ',' ' and' ' the' ' Father']" , the son of God , and the Son of God , and the Holy Spirit , and the Father False his assistant, Christo Wunderlich, 4 [' his', ' assistant', ',', ' Christ', 'o']
+3629 872 Name of father of x -1 Name of father of Christo Vladimir Yavashev Christo "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' Son' ' of' ' God' ','
+ ' and' ' the' ' Holy' ' Spirit' ',' ' and' ' the' ' Father']" , the son of God , and the Son of God , and the Holy Spirit , and the Father False them to Monte Christo in the Dominican Republic. 4 [' them', ' to', ' Monte', ' Christ', 'o']
+3630 872 Name of father of x -1 Name of father of Christo Vladimir Yavashev Christo "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' Son' ' of' ' God' ','
+ ' and' ' the' ' Holy' ' Spirit' ',' ' and' ' the' ' Father']" , the son of God , and the Son of God , and the Holy Spirit , and the Father False Maronem sine Christo possimus dicere 6 [' Mar', 'on', 'em', ' s', 'ine', ' Christ', 'o']
+3631 872 Name of father of x -1 Name of father of Christo Vladimir Yavashev Christo "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' Son' ' of' ' God' ','
+ ' and' ' the' ' Holy' ' Spirit' ',' ' and' ' the' ' Father']" , the son of God , and the Son of God , and the Holy Spirit , and the Father False Rauschenberg, Christo and Andy Warhol 6 [' Ra', 'us', 'chen', 'berg', ',', ' Christ', 'o']
+3632 872 Name of father of x -1 Name of father of Christo Vladimir Yavashev Christo "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' Son' ' of' ' God' ','
+ ' and' ' the' ' Holy' ' Spirit' ',' ' and' ' the' ' Father']" , the son of God , and the Son of God , and the Holy Spirit , and the Father False " ""Ich folge Christo nach"" (I follow after" 6 "[' ""', 'I', 'ch', ' fol', 'ge', ' Christ', 'o']"
+3633 873 Name of father of x -1 Name of father of Sergei Rachmaninoff Vasily Arkadyevich Rachmaninoff Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' ',' ' and']" , the composer , and his wife , the pian ist , Elena R ach man in off , and False Russian composer Sergei Rachmaninoff between October and 7 [' Russian', ' composer', ' Sergei', ' R', 'ach', 'man', 'in', 'off']
+3634 873 Name of father of x -1 Name of father of Sergei Rachmaninoff Vasily Arkadyevich Rachmaninoff Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' ',' ' and']" , the composer , and his wife , the pian ist , Elena R ach man in off , and False the pianist Sergei Rachmaninoff arranged for her to 8 [' the', ' pian', 'ist', ' Sergei', ' R', 'ach', 'man', 'in', 'off']
+3635 873 Name of father of x -1 Name of father of Sergei Rachmaninoff Vasily Arkadyevich Rachmaninoff Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' ',' ' and']" , the composer , and his wife , the pian ist , Elena R ach man in off , and False Russian composer Sergei Rachmaninoff between October and 7 [' Russian', ' composer', ' Sergei', ' R', 'ach', 'man', 'in', 'off']
+3636 873 Name of father of x -1 Name of father of Sergei Rachmaninoff Vasily Arkadyevich Rachmaninoff Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' ',' ' and']" , the composer , and his wife , the pian ist , Elena R ach man in off , and False speculation neither Sergei Rachmaninoff nor Arturo Toscanini 7 [' speculation', ' neither', ' Sergei', ' R', 'ach', 'man', 'in', 'off']
+3637 873 Name of father of x -1 Name of father of Sergei Rachmaninoff Vasily Arkadyevich Rachmaninoff Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' ',' ' and']" , the composer , and his wife , the pian ist , Elena R ach man in off , and False and the pianist Sergei Rachmaninoff arranged for 9 [' and', ' the', ' pian', 'ist', ' Sergei', ' R', 'ach', 'man', 'in', 'off']
+3638 875 Name of father of x -1 Name of father of Christiaan Huygens Constantijn Huygens Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False " clock was made by Christiaan Huygens in the 1650s.
+" 10 [' clock', ' was', ' made', ' by', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3639 875 Name of father of x -1 Name of father of Christiaan Huygens Constantijn Huygens Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False Isaac Newton, Christiaan Huygens and later scientists. 9 [' Isaac', ' Newton', ',', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3640 875 Name of father of x -1 Name of father of Christiaan Huygens Constantijn Huygens Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False book Cosmotheoros, Christiaan Huygens estimated the 12 [' book', ' Cos', 'm', 'othe', 'oros', ',', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3641 875 Name of father of x -1 Name of father of Christiaan Huygens Constantijn Huygens Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False Cosmotheoros, Christiaan Huygens estimated the distance 11 [' Cos', 'm', 'othe', 'oros', ',', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3642 875 Name of father of x -1 Name of father of Christiaan Huygens Constantijn Huygens Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False brothers Ludwig and Christiaan Huygens in 1667, where they 9 [' brothers', ' Ludwig', ' and', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3643 876 Name of father of x -1 Name of father of Louis XIII of France Henry IV of France Louis XIII of France "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' France' ',' ' and' ' the'
+ ' King' ' of' ' France' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of France , and the King of France , and the King of France" False Margaret (1553 – 1615), Louis XIII of France (1601 – 43), 11 [' Margaret', ' (', '15', '53', ' –', ' 16', '15', '),', ' Louis', ' XIII', ' of', ' France']
+3644 876 Name of father of x -1 Name of father of Louis XIII of France Henry IV of France Louis XIII of France "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' France' ',' ' and' ' the'
+ ' King' ' of' ' France' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of France , and the King of France , and the King of France" False (1553 – 1615), Louis XIII of France (1601 – 43), 10 [' (', '15', '53', ' –', ' 16', '15', '),', ' Louis', ' XIII', ' of', ' France']
+3645 876 Name of father of x -1 Name of father of Louis XIII of France Henry IV of France Louis XIII of France "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' France' ',' ' and' ' the'
+ ' King' ' of' ' France' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of France , and the King of France , and the King of France" False (1553 – 1615), Louis XIII of France (1601 – 43), Anne of 10 [' (', '15', '53', ' –', ' 16', '15', '),', ' Louis', ' XIII', ' of', ' France']
+3646 876 Name of father of x -1 Name of father of Louis XIII of France Henry IV of France Louis XIII of France "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' France' ',' ' and' ' the'
+ ' King' ' of' ' France' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of France , and the King of France , and the King of France" False Margaret (1553 – 1615), Louis XIII of France (1601 – 43), Anne 11 [' Margaret', ' (', '15', '53', ' –', ' 16', '15', '),', ' Louis', ' XIII', ' of', ' France']
+3647 876 Name of father of x -1 Name of father of Louis XIII of France Henry IV of France Louis XIII of France "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' France' ',' ' and' ' the'
+ ' King' ' of' ' France' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of France , and the King of France , and the King of France" False Margaret (1553 – 1615), Louis XIII of France (1601 – 43), 11 [' Margaret', ' (', '15', '53', ' –', ' 16', '15', '),', ' Louis', ' XIII', ' of', ' France']
+3648 877 Name of father of x -1 Name of father of Pete Seeger Charles Seeger Pete Seeger "[',' ' the' ' famous' ' folk' ' singer' ',' ' and' ' his' ' wife' ',' ' T'
+ 'oshi' ' See' 'ger' ',' ' a' ' folk' ' singer' ' and' ' activist']" , the famous folk singer , and his wife , T oshi See ger , a folk singer and activist False Kirk) joined Pete Seeger and the Sesame 5 [' Kirk', ')', ' joined', ' Pete', ' See', 'ger']
+3649 877 Name of father of x -1 Name of father of Pete Seeger Charles Seeger Pete Seeger "[',' ' the' ' famous' ' folk' ' singer' ',' ' and' ' his' ' wife' ',' ' T'
+ 'oshi' ' See' 'ger' ',' ' a' ' folk' ' singer' ' and' ' activist']" , the famous folk singer , and his wife , T oshi See ger , a folk singer and activist False Belafonte, Odetta, Pete Seeger and others. 9 [' Bel', 'af', 'onte', ',', ' Od', 'etta', ',', ' Pete', ' See', 'ger']
+3650 877 Name of father of x -1 Name of father of Pete Seeger Charles Seeger Pete Seeger "[',' ' the' ' famous' ' folk' ' singer' ',' ' and' ' his' ' wife' ',' ' T'
+ 'oshi' ' See' 'ger' ',' ' a' ' folk' ' singer' ' and' ' activist']" , the famous folk singer , and his wife , T oshi See ger , a folk singer and activist False by Woody Guthrie, Pete Seeger and Bo Diddley, 7 [' by', ' Woody', ' Guth', 'rie', ',', ' Pete', ' See', 'ger']
+3651 877 Name of father of x -1 Name of father of Pete Seeger Charles Seeger Pete Seeger "[',' ' the' ' famous' ' folk' ' singer' ',' ' and' ' his' ' wife' ',' ' T'
+ 'oshi' ' See' 'ger' ',' ' a' ' folk' ' singer' ' and' ' activist']" , the famous folk singer , and his wife , T oshi See ger , a folk singer and activist False by Woody Guthrie, Pete Seeger and Bo Diddley, 7 [' by', ' Woody', ' Guth', 'rie', ',', ' Pete', ' See', 'ger']
+3652 877 Name of father of x -1 Name of father of Pete Seeger Charles Seeger Pete Seeger "[',' ' the' ' famous' ' folk' ' singer' ',' ' and' ' his' ' wife' ',' ' T'
+ 'oshi' ' See' 'ger' ',' ' a' ' folk' ' singer' ' and' ' activist']" , the famous folk singer , and his wife , T oshi See ger , a folk singer and activist False (Brother Kirk) joined Pete Seeger and the Sesame 7 [' (', 'Brother', ' Kirk', ')', ' joined', ' Pete', ' See', 'ger']
+3653 878 Name of father of x -1 Name of father of Ron Howard Rance Howard Ron Howard "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' have'
+ ' a' ' lot' ' of' ' respect' ' for' ' him' '.' ' I' ' think']" , who is a friend of mine , and I have a lot of respect for him . I think False summer of 2002. Ron Howard had the original idea 5 [' summer', ' of', ' 2002', '.', ' Ron', ' Howard']
+3654 878 Name of father of x -1 Name of father of Ron Howard Rance Howard Ron Howard "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' have'
+ ' a' ' lot' ' of' ' respect' ' for' ' him' '.' ' I' ' think']" , who is a friend of mine , and I have a lot of respect for him . I think False be produced. Ron Howard had meant to 4 [' be', ' produced', '.', ' Ron', ' Howard']
+3655 878 Name of father of x -1 Name of father of Ron Howard Rance Howard Ron Howard "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' have'
+ ' a' ' lot' ' of' ' respect' ' for' ' him' '.' ' I' ' think']" , who is a friend of mine , and I have a lot of respect for him . I think False Smith (18) and Ron Howard (18) were the only 6 [' Smith', ' (', '18', ')', ' and', ' Ron', ' Howard']
+3656 878 Name of father of x -1 Name of father of Ron Howard Rance Howard Ron Howard "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' have'
+ ' a' ' lot' ' of' ' respect' ' for' ' him' '.' ' I' ' think']" , who is a friend of mine , and I have a lot of respect for him . I think False 2008, in an interview, Ron Howard commented that guest 7 [' 2008', ',', ' in', ' an', ' interview', ',', ' Ron', ' Howard']
+3657 878 Name of father of x -1 Name of father of Ron Howard Rance Howard Ron Howard "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' have'
+ ' a' ' lot' ' of' ' respect' ' for' ' him' '.' ' I' ' think']" , who is a friend of mine , and I have a lot of respect for him . I think False lifestyles. The director, Ron Howard explained the reason 6 [' lifestyles', '.', ' The', ' director', ',', ' Ron', ' Howard']
+3658 879 Name of father of x -1 Name of father of William James Henry James, Sr. William James "[',' ' the' ' son' ' of' ' the' ' late' ' William' ' James' ',' ' the'
+ ' father' ' of' ' the' ' famous' ' James' ',' ' the' ' father' ' of'
+ ' the']" , the son of the late William James , the father of the famous James , the father of the False conduct. Historian William James claims that opinion 5 [' conduct', '.', ' Hist', 'orian', ' William', ' James']
+3659 879 Name of father of x -1 Name of father of William James Henry James, Sr. William James "[',' ' the' ' son' ' of' ' the' ' late' ' William' ' James' ',' ' the'
+ ' father' ' of' ' the' ' famous' ' James' ',' ' the' ' father' ' of'
+ ' the']" , the son of the late William James , the father of the famous James , the father of the False unsuccessfully campaigned for William James (Tsakanikas) to 4 [' unsuccessfully', ' campaigned', ' for', ' William', ' James']
+3660 879 Name of father of x -1 Name of father of William James Henry James, Sr. William James "[',' ' the' ' son' ' of' ' the' ' late' ' William' ' James' ',' ' the'
+ ' father' ' of' ' the' ' famous' ' James' ',' ' the' ' father' ' of'
+ ' the']" , the son of the late William James , the father of the famous James , the father of the False coast. Historians William James and William Laird 5 [' coast', '.', ' Hist', 'orians', ' William', ' James']
+3661 879 Name of father of x -1 Name of father of William James Henry James, Sr. William James "[',' ' the' ' son' ' of' ' the' ' late' ' William' ' James' ',' ' the'
+ ' father' ' of' ' the' ' famous' ' James' ',' ' the' ' father' ' of'
+ ' the']" , the son of the late William James , the father of the famous James , the father of the False contemporary historian William James this conglomeration 3 [' contemporary', ' historian', ' William', ' James']
+3662 879 Name of father of x -1 Name of father of William James Henry James, Sr. William James "[',' ' the' ' son' ' of' ' the' ' late' ' William' ' James' ',' ' the'
+ ' father' ' of' ' the' ' famous' ' James' ',' ' the' ' father' ' of'
+ ' the']" , the son of the late William James , the father of the famous James , the father of the False summary, historian William James indicates that the 4 [' summary', ',', ' historian', ' William', ' James']
+3663 882 Name of father of x -1 Name of father of Chiang Kai-shek Chiang Chao-tsung Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ',' ' and' ' the' ' Japanese' ' leader'
+ ',' ' Hide' 'ki' ' To' 'jo' ',' ' were' ' both' ' born' ' in' ' the']" , the Chinese leader , and the Japanese leader , Hide ki To jo , were both born in the False political leader Chiang Kai-shek inviting Hergé 7 [' political', ' leader', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3664 882 Name of father of x -1 Name of father of Chiang Kai-shek Chiang Chao-tsung Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ',' ' and' ' the' ' Japanese' ' leader'
+ ',' ' Hide' 'ki' ' To' 'jo' ',' ' were' ' both' ' born' ' in' ' the']" , the Chinese leader , and the Japanese leader , Hide ki To jo , were both born in the False by Generalissimo Chiang Kai-shek retreated to Taiwan, 9 [' by', ' General', 'iss', 'imo', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3665 882 Name of father of x -1 Name of father of Chiang Kai-shek Chiang Chao-tsung Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ',' ' and' ' the' ' Japanese' ' leader'
+ ',' ' Hide' 'ki' ' To' 'jo' ',' ' were' ' both' ' born' ' in' ' the']" , the Chinese leader , and the Japanese leader , Hide ki To jo , were both born in the False military aid to Chiang Kai-shek and his Nationalist 8 [' military', ' aid', ' to', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3666 882 Name of father of x -1 Name of father of Chiang Kai-shek Chiang Chao-tsung Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ',' ' and' ' the' ' Japanese' ' leader'
+ ',' ' Hide' 'ki' ' To' 'jo' ',' ' were' ' both' ' born' ' in' ' the']" , the Chinese leader , and the Japanese leader , Hide ki To jo , were both born in the False diverted to Chiang Kai-shek International 7 [' diverted', ' to', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3667 882 Name of father of x -1 Name of father of Chiang Kai-shek Chiang Chao-tsung Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ',' ' and' ' the' ' Japanese' ' leader'
+ ',' ' Hide' 'ki' ' To' 'jo' ',' ' were' ' both' ' born' ' in' ' the']" , the Chinese leader , and the Japanese leader , Hide ki To jo , were both born in the False ordered by Chiang Kai-shek not to resist the 7 [' ordered', ' by', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3668 883 Name of father of x -1 Name of father of Michelle Obama Fraser Robinson III Michelle Obama "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the']" , the first lady of the United States , and the first lady of the United States , and the False compete with Michelle Obama and Oprah Winfrey, 3 [' compete', ' with', ' Michelle', ' Obama']
+3669 883 Name of father of x -1 Name of father of Michelle Obama Fraser Robinson III Michelle Obama "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the']" , the first lady of the United States , and the first lady of the United States , and the False a bit when Michelle Obama chides her husband 4 [' a', ' bit', ' when', ' Michelle', ' Obama']
+3670 883 Name of father of x -1 Name of father of Michelle Obama Fraser Robinson III Michelle Obama "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the']" , the first lady of the United States , and the first lady of the United States , and the False crisis. First Lady Michelle Obama deserves credit for 5 [' crisis', '.', ' First', ' Lady', ' Michelle', ' Obama']
+3671 883 Name of father of x -1 Name of father of Michelle Obama Fraser Robinson III Michelle Obama "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the']" , the first lady of the United States , and the first lady of the United States , and the False nor First Lady Michelle Obama were home at the 4 [' nor', ' First', ' Lady', ' Michelle', ' Obama']
+3672 883 Name of father of x -1 Name of father of Michelle Obama Fraser Robinson III Michelle Obama "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the']" , the first lady of the United States , and the first lady of the United States , and the False " often labeled Michelle Obama as an ""Angry" 3 [' often', ' labeled', ' Michelle', ' Obama']
+3673 884 Name of father of x -1 Name of father of Giacomo Meyerbeer Jacob Herz Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False " Felix Mendelssohn, Giacomo Meyerbeer and Henry Litolff.
+" 10 [' Felix', ' Mend', 'els', 'so', 'hn', ',', ' Gi', 'ac', 'omo', ' Meyer', 'beer']
+3674 884 Name of father of x -1 Name of father of Giacomo Meyerbeer Jacob Herz Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False " played the role of Giacomo Meyerbeer in the 1983 film Wagner.
+" 8 [' played', ' the', ' role', ' of', ' Gi', 'ac', 'omo', ' Meyer', 'beer']
+3675 884 Name of father of x -1 Name of father of Giacomo Meyerbeer Jacob Herz Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False the composer Giacomo Meyerbeer was so impressed 6 [' the', ' composer', ' Gi', 'ac', 'omo', ' Meyer', 'beer']
+3676 884 Name of father of x -1 Name of father of Giacomo Meyerbeer Jacob Herz Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False 4 ['G', 'iac', 'omo', ' Meyer', 'beer']
+3677 884 Name of father of x -1 Name of father of Giacomo Meyerbeer Jacob Herz Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False the role of Giacomo Meyerbeer in the 1983 film 7 [' the', ' role', ' of', ' Gi', 'ac', 'omo', ' Meyer', 'beer']
+3678 885 Name of father of x -1 Name of father of Ludwig Tieck Johann Ludwig Tieck Ludwig Tieck "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' a' ' friend' ' of' ' Go'
+ 'ethe' ""'s"" ',' ' and' ' who' ' had' ' been' '\n' '\n']" ", the German poet , who was a friend of Go ethe 's , and who had been
+
+" False over the centuries. Ludwig Tieck published a 6 [' over', ' the', ' centuries', '.', ' Ludwig', ' Tie', 'ck']
+3679 885 Name of father of x -1 Name of father of Ludwig Tieck Johann Ludwig Tieck Ludwig Tieck "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' a' ' friend' ' of' ' Go'
+ 'ethe' ""'s"" ',' ' and' ' who' ' had' ' been' '\n' '\n']" ", the German poet , who was a friend of Go ethe 's , and who had been
+
+" False by the works of Ludwig Tieck and E. T. A. Hoffmann. 6 [' by', ' the', ' works', ' of', ' Ludwig', ' Tie', 'ck']
+3680 885 Name of father of x -1 Name of father of Ludwig Tieck Johann Ludwig Tieck Ludwig Tieck "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' a' ' friend' ' of' ' Go'
+ 'ethe' ""'s"" ',' ' and' ' who' ' had' ' been' '\n' '\n']" ", the German poet , who was a friend of Go ethe 's , and who had been
+
+" False romantic travel story that Ludwig Tieck established 6 [' romantic', ' travel', ' story', ' that', ' Ludwig', ' Tie', 'ck']
+3681 885 Name of father of x -1 Name of father of Ludwig Tieck Johann Ludwig Tieck Ludwig Tieck "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' a' ' friend' ' of' ' Go'
+ 'ethe' ""'s"" ',' ' and' ' who' ' had' ' been' '\n' '\n']" ", the German poet , who was a friend of Go ethe 's , and who had been
+
+" False the works of Ludwig Tieck and E. T. A. Hoffmann. 5 [' the', ' works', ' of', ' Ludwig', ' Tie', 'ck']
+3682 885 Name of father of x -1 Name of father of Ludwig Tieck Johann Ludwig Tieck Ludwig Tieck "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' a' ' friend' ' of' ' Go'
+ 'ethe' ""'s"" ',' ' and' ' who' ' had' ' been' '\n' '\n']" ", the German poet , who was a friend of Go ethe 's , and who had been
+
+" False travel story that Ludwig Tieck established with his 5 [' travel', ' story', ' that', ' Ludwig', ' Tie', 'ck']
+3683 887 Name of father of x -1 Name of father of Lu Xun Zhou Boyi Lu Xun "[',' ' the' ' great' ' Chinese' ' poet' ',' ' was' ' born' ' in'
+ ' Beijing' ' in' ' 1895' '.' ' He' ' was' ' a' ' contemporary' ' of'
+ ' Mao' ' Zed']" , the great Chinese poet , was born in Beijing in 1895 . He was a contemporary of Mao Zed False 2 ['Lu', ' X', 'un']
+3684 887 Name of father of x -1 Name of father of Lu Xun Zhou Boyi Lu Xun "[',' ' the' ' great' ' Chinese' ' poet' ',' ' was' ' born' ' in'
+ ' Beijing' ' in' ' 1895' '.' ' He' ' was' ' a' ' contemporary' ' of'
+ ' Mao' ' Zed']" , the great Chinese poet , was born in Beijing in 1895 . He was a contemporary of Mao Zed False 2 ['Lu', ' X', 'un']
+3685 887 Name of father of x -1 Name of father of Lu Xun Zhou Boyi Lu Xun "[',' ' the' ' great' ' Chinese' ' poet' ',' ' was' ' born' ' in'
+ ' Beijing' ' in' ' 1895' '.' ' He' ' was' ' a' ' contemporary' ' of'
+ ' Mao' ' Zed']" , the great Chinese poet , was born in Beijing in 1895 . He was a contemporary of Mao Zed False 2 ['Lu', ' X', 'un']
+3686 887 Name of father of x -1 Name of father of Lu Xun Zhou Boyi Lu Xun "[',' ' the' ' great' ' Chinese' ' poet' ',' ' was' ' born' ' in'
+ ' Beijing' ' in' ' 1895' '.' ' He' ' was' ' a' ' contemporary' ' of'
+ ' Mao' ' Zed']" , the great Chinese poet , was born in Beijing in 1895 . He was a contemporary of Mao Zed False citizens. Hu Shih and Lu Xun were pioneers in modern 8 [' citizens', '.', ' Hu', ' Sh', 'ih', ' and', ' Lu', ' X', 'un']
+3687 887 Name of father of x -1 Name of father of Lu Xun Zhou Boyi Lu Xun "[',' ' the' ' great' ' Chinese' ' poet' ',' ' was' ' born' ' in'
+ ' Beijing' ' in' ' 1895' '.' ' He' ' was' ' a' ' contemporary' ' of'
+ ' Mao' ' Zed']" , the great Chinese poet , was born in Beijing in 1895 . He was a contemporary of Mao Zed False 2 ['Lu', ' X', 'un']
+3688 888 Name of father of x -1 Name of father of Rudolf Virchow Karl Christian Siegfried Virchow Rudolf Virchow "[',' ' the' ' great' ' German' ' path' 'ologist' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' path' 'ologist']" ", the great German path ologist , who was a great admire r of the
+
+ German path ologist" False 5 ['R', 'ud', 'olf', ' Vir', 'ch', 'ow']
+3689 888 Name of father of x -1 Name of father of Rudolf Virchow Karl Christian Siegfried Virchow Rudolf Virchow "[',' ' the' ' great' ' German' ' path' 'ologist' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' path' 'ologist']" ", the great German path ologist , who was a great admire r of the
+
+ German path ologist" False Theodor Schwann and Rudolf Virchow and was among 9 [' The', 'odor', ' Schw', 'ann', ' and', ' Rud', 'olf', ' Vir', 'ch', 'ow']
+3690 888 Name of father of x -1 Name of father of Rudolf Virchow Karl Christian Siegfried Virchow Rudolf Virchow "[',' ' the' ' great' ' German' ' path' 'ologist' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' path' 'ologist']" ", the great German path ologist , who was a great admire r of the
+
+ German path ologist" False Robert Remak (1852) and Rudolf Virchow (1855) who decisively 12 [' Robert', ' Rem', 'ak', ' (', '18', '52', ')', ' and', ' Rud', 'olf', ' Vir', 'ch', 'ow']
+3691 888 Name of father of x -1 Name of father of Rudolf Virchow Karl Christian Siegfried Virchow Rudolf Virchow "[',' ' the' ' great' ' German' ' path' 'ologist' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' path' 'ologist']" ", the great German path ologist , who was a great admire r of the
+
+ German path ologist" False Remak (1852) and Rudolf Virchow (1855) who 11 [' Rem', 'ak', ' (', '18', '52', ')', ' and', ' Rud', 'olf', ' Vir', 'ch', 'ow']
+3692 888 Name of father of x -1 Name of father of Rudolf Virchow Karl Christian Siegfried Virchow Rudolf Virchow "[',' ' the' ' great' ' German' ' path' 'ologist' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' path' 'ologist']" ", the great German path ologist , who was a great admire r of the
+
+ German path ologist" False German pathologist Rudolf Virchow postulated 7 [' German', ' path', 'ologist', ' Rud', 'olf', ' Vir', 'ch', 'ow']
+3693 889 Name of father of x -1 Name of father of Tove Lo Magnus Nilsson Tove Lo "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False co-directed by Tove Lo and Motellet 6 [' co', '-', 'directed', ' by', ' T', 'ove', ' Lo']
+3694 889 Name of father of x -1 Name of father of Tove Lo Magnus Nilsson Tove Lo "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Noel Gallagher, Tove Lo and Barack Obama. The 5 [' Noel', ' Gallagher', ',', ' T', 'ove', ' Lo']
+3695 889 Name of father of x -1 Name of father of Tove Lo Magnus Nilsson Tove Lo "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 2 ['T', 'ove', ' Lo']
+3696 889 Name of father of x -1 Name of father of Tove Lo Magnus Nilsson Tove Lo "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False was co-directed by Tove Lo and Motellet 7 [' was', ' co', '-', 'directed', ' by', ' T', 'ove', ' Lo']
+3697 889 Name of father of x -1 Name of father of Tove Lo Magnus Nilsson Tove Lo "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Gallagher, Tove Lo and Barack Obama. 4 [' Gallagher', ',', ' T', 'ove', ' Lo']
+3698 890 Name of father of x -1 Name of father of Tom Cruise Thomas Cruise Mapother III Tom Cruise "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", and the other is the name of the father of the bride .
+
+ The bride and groom" False Business where Tom Cruise dances to 3 [' Business', ' where', ' Tom', ' Cruise']
+3699 890 Name of father of x -1 Name of father of Tom Cruise Thomas Cruise Mapother III Tom Cruise "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", and the other is the name of the father of the bride .
+
+ The bride and groom" False " the jokes about Tom Cruise ""work splendidly" 4 [' the', ' jokes', ' about', ' Tom', ' Cruise']
+3700 890 Name of father of x -1 Name of father of Tom Cruise Thomas Cruise Mapother III Tom Cruise "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", and the other is the name of the father of the bride .
+
+ The bride and groom" False " get a video of Tom Cruise off the Internet.""." 5 [' get', ' a', ' video', ' of', ' Tom', ' Cruise']
+3701 890 Name of father of x -1 Name of father of Tom Cruise Thomas Cruise Mapother III Tom Cruise "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", and the other is the name of the father of the bride .
+
+ The bride and groom" False The cast includes Tom Cruise as Precrime 4 [' The', ' cast', ' includes', ' Tom', ' Cruise']
+3702 890 Name of father of x -1 Name of father of Tom Cruise Thomas Cruise Mapother III Tom Cruise "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' father'
+ ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", and the other is the name of the father of the bride .
+
+ The bride and groom" False a foam mock-up of Tom Cruise so the frame could 7 [' a', ' foam', ' mock', '-', 'up', ' of', ' Tom', ' Cruise']
+3703 892 Name of father of x -1 Name of father of Ferdinand II of Aragon John II of Aragon Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' England' ',' ' and' ' the']" , the King of Spain , and the King of France , and the King of England , and the False horses from Charles V, Ferdinand II of Aragon and the Duke 9 [' horses', ' from', ' Charles', ' V', ',', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3704 892 Name of father of x -1 Name of father of Ferdinand II of Aragon John II of Aragon Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' England' ',' ' and' ' the']" , the King of Spain , and the King of France , and the King of England , and the False Louis XII of France, Ferdinand II of Aragon and Maximilian I, the 9 [' Louis', ' XII', ' of', ' France', ',', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3705 892 Name of father of x -1 Name of father of Ferdinand II of Aragon John II of Aragon Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' England' ',' ' and' ' the']" , the King of Spain , and the King of France , and the King of England , and the False Monarchs of Spain, Ferdinand II of Aragon and Isabella I 9 [' Mon', 'archs', ' of', ' Spain', ',', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3706 892 Name of father of x -1 Name of father of Ferdinand II of Aragon John II of Aragon Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' England' ',' ' and' ' the']" , the King of Spain , and the King of France , and the King of England , and the False Monarchs of Spain, Ferdinand II of Aragon and Isabella I 9 [' Mon', 'archs', ' of', ' Spain', ',', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3707 892 Name of father of x -1 Name of father of Ferdinand II of Aragon John II of Aragon Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' England' ',' ' and' ' the']" , the King of Spain , and the King of France , and the King of England , and the False horses from Charles V, Ferdinand II of Aragon and the Duke of 9 [' horses', ' from', ' Charles', ' V', ',', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3708 893 Name of father of x -1 Name of father of Gregory of Nazianzus Gregory of Nazianzus the Elder Gregory of Nazianzus "[',' ' the' ' first' ' Christian' ' bishop' ' of' ' Constantinople' ','
+ ' and' ' the' ' first' ' to' ' be' ' called' ' a' ' saint' '.' '\n' '\n'
+ 'The']" ", the first Christian bishop of Constantinople , and the first to be called a saint .
+
+ The" False Church Father, Gregory of Nazianzus and by Jerome in 8 [' Church', ' Father', ',', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3709 893 Name of father of x -1 Name of father of Gregory of Nazianzus Gregory of Nazianzus the Elder Gregory of Nazianzus "[',' ' the' ' first' ' Christian' ' bishop' ' of' ' Constantinople' ','
+ ' and' ' the' ' first' ' to' ' be' ' called' ' a' ' saint' '.' '\n' '\n'
+ 'The']" ", the first Christian bishop of Constantinople , and the first to be called a saint .
+
+ The" False Church Father, Gregory of Nazianzus and by Jerome in 8 [' Church', ' Father', ',', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3710 893 Name of father of x -1 Name of father of Gregory of Nazianzus Gregory of Nazianzus the Elder Gregory of Nazianzus "[',' ' the' ' first' ' Christian' ' bishop' ' of' ' Constantinople' ','
+ ' and' ' the' ' first' ' to' ' be' ' called' ' a' ' saint' '.' '\n' '\n'
+ 'The']" ", the first Christian bishop of Constantinople , and the first to be called a saint .
+
+ The" False " June, with Macrina, Gregory of Nazianzus and Basil the Great.
+" 11 [' June', ',', ' with', ' Mac', 'rina', ',', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3711 893 Name of father of x -1 Name of father of Gregory of Nazianzus Gregory of Nazianzus the Elder Gregory of Nazianzus "[',' ' the' ' first' ' Christian' ' bishop' ' of' ' Constantinople' ','
+ ' and' ' the' ' first' ' to' ' be' ' called' ' a' ' saint' '.' '\n' '\n'
+ 'The']" ", the first Christian bishop of Constantinople , and the first to be called a saint .
+
+ The" False " with Macrina, Gregory of Nazianzus and Basil the Great.
+" 9 [' with', ' Mac', 'rina', ',', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3712 893 Name of father of x -1 Name of father of Gregory of Nazianzus Gregory of Nazianzus the Elder Gregory of Nazianzus "[',' ' the' ' first' ' Christian' ' bishop' ' of' ' Constantinople' ','
+ ' and' ' the' ' first' ' to' ' be' ' called' ' a' ' saint' '.' '\n' '\n'
+ 'The']" ", the first Christian bishop of Constantinople , and the first to be called a saint .
+
+ The" False Church Father, Gregory of Nazianzus and by Jerome in his 8 [' Church', ' Father', ',', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3713 894 Name of father of x -1 Name of father of Ansel Adams Charles Hitchcock Adams Ansel Adams "[',' ' the' ' famous' ' photographer' ',' ' was' ' a' ' great' ' admire'
+ 'r' ' of' ' the' ' work' ' of' ' the' ' artist' '.' '\n' '\n' 'The']" ", the famous photographer , was a great admire r of the work of the artist .
+
+ The" False adjacent to Yosemite: the Ansel Adams Wilderness to 7 [' adjacent', ' to', ' Yosemite', ':', ' the', ' An', 'sel', ' Adams']
+3714 894 Name of father of x -1 Name of father of Ansel Adams Charles Hitchcock Adams Ansel Adams "[',' ' the' ' famous' ' photographer' ',' ' was' ' a' ' great' ' admire'
+ 'r' ' of' ' the' ' work' ' of' ' the' ' artist' '.' '\n' '\n' 'The']" ", the famous photographer , was a great admire r of the work of the artist .
+
+ The" False 2 ['An', 'sel', ' Adams']
+3715 894 Name of father of x -1 Name of father of Ansel Adams Charles Hitchcock Adams Ansel Adams "[',' ' the' ' famous' ' photographer' ',' ' was' ' a' ' great' ' admire'
+ 'r' ' of' ' the' ' work' ' of' ' the' ' artist' '.' '\n' '\n' 'The']" ", the famous photographer , was a great admire r of the work of the artist .
+
+ The" False Wright first met Ansel Adams at a family 5 [' Wright', ' first', ' met', ' An', 'sel', ' Adams']
+3716 894 Name of father of x -1 Name of father of Ansel Adams Charles Hitchcock Adams Ansel Adams "[',' ' the' ' famous' ' photographer' ',' ' was' ' a' ' great' ' admire'
+ 'r' ' of' ' the' ' work' ' of' ' the' ' artist' '.' '\n' '\n' 'The']" ", the famous photographer , was a great admire r of the work of the artist .
+
+ The" False his autobiography, Ansel Adams called Cedric 5 [' his', ' autobiography', ',', ' An', 'sel', ' Adams']
+3717 894 Name of father of x -1 Name of father of Ansel Adams Charles Hitchcock Adams Ansel Adams "[',' ' the' ' famous' ' photographer' ',' ' was' ' a' ' great' ' admire'
+ 'r' ' of' ' the' ' work' ' of' ' the' ' artist' '.' '\n' '\n' 'The']" ", the famous photographer , was a great admire r of the work of the artist .
+
+ The" False best friends, Ansel Adams and his wife Virginia. 5 [' best', ' friends', ',', ' An', 'sel', ' Adams']
+3718 895 Name of father of x -1 Name of father of Ben Nicholson William Nicholson Ben Nicholson "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' and' ' the' ' groom' ""'s"" ' father' ',' ' and' ' the' ' groom']" , the father of the bride , and the groom , and the groom 's father , and the groom False second husband Ben Nicholson moved into a 3 [' second', ' husband', ' Ben', ' Nicholson']
+3719 896 Name of father of x -1 Name of father of Origen Leonides of Alexandria Origen "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False Christian commentator Origen responded, in 3 [' Christian', ' commentator', ' Orig', 'en']
+3720 896 Name of father of x -1 Name of father of Origen Leonides of Alexandria Origen "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False outside of the town. Origen of Alexandria, writing 6 [' outside', ' of', ' the', ' town', '.', ' Orig', 'en']
+3721 896 Name of father of x -1 Name of father of Origen Leonides of Alexandria Origen "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False the letters of Origen to Philip and his 4 [' the', ' letters', ' of', ' Orig', 'en']
+3722 896 Name of father of x -1 Name of father of Origen Leonides of Alexandria Origen "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False commentator Origen responded, in 2 [' commentator', ' Orig', 'en']
+3723 896 Name of father of x -1 Name of father of Origen Leonides of Alexandria Origen "[',' ' the' ' son' ' of' ' God' ',' ' and' ' the' ' son' ' of' ' God' ','
+ ' and' ' the' ' son' ' of' ' God' ',' ' and' ' the']" , the son of God , and the son of God , and the son of God , and the False 1 ['Orig', 'en']
+3724 897 Name of father of x -1 Name of father of Nicholas Roerich Konstantin Roerich Nicholas Roerich "[',' ' the' ' Russian' ' painter' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' '\n' '\n' 'Russian' ' Revolution' ',' ' and' ' who' ' was' ' a']" ", the Russian painter , who was a friend of the
+
+ Russian Revolution , and who was a" False Massine, with the Nicholas Roerich designs retained; 8 [' Mass', 'ine', ',', ' with', ' the', ' Nicholas', ' Ro', 'er', 'ich']
+3725 897 Name of father of x -1 Name of father of Nicholas Roerich Konstantin Roerich Nicholas Roerich "[',' ' the' ' Russian' ' painter' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' '\n' '\n' 'Russian' ' Revolution' ',' ' and' ' who' ' was' ' a']" ", the Russian painter , who was a friend of the
+
+ Russian Revolution , and who was a" False Massine, with the Nicholas Roerich designs retained; 8 [' Mass', 'ine', ',', ' with', ' the', ' Nicholas', ' Ro', 'er', 'ich']
+3726 897 Name of father of x -1 Name of father of Nicholas Roerich Konstantin Roerich Nicholas Roerich "[',' ' the' ' Russian' ' painter' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' '\n' '\n' 'Russian' ' Revolution' ',' ' and' ' who' ' was' ' a']" ", the Russian painter , who was a friend of the
+
+ Russian Revolution , and who was a" False Inspired by both Nicholas Roerich stories and the Kingdom 6 [' Inspired', ' by', ' both', ' Nicholas', ' Ro', 'er', 'ich']
+3727 897 Name of father of x -1 Name of father of Nicholas Roerich Konstantin Roerich Nicholas Roerich "[',' ' the' ' Russian' ' painter' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' '\n' '\n' 'Russian' ' Revolution' ',' ' and' ' who' ' was' ' a']" ", the Russian painter , who was a friend of the
+
+ Russian Revolution , and who was a" False Massine, with the Nicholas Roerich designs retained; 8 [' Mass', 'ine', ',', ' with', ' the', ' Nicholas', ' Ro', 'er', 'ich']
+3728 899 Name of father of x -1 Name of father of Alan Turing Julius Mathison Turing Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' computer' '.' '\n' '\n' 'The' ' Turing' ' Award'
+ ' is']" ", the father of modern computing , and the father of the computer .
+
+ The Turing Award is" False titled For Alan Turing in 2006, which 3 [' titled', ' For', ' Alan', ' Turing']
+3729 899 Name of father of x -1 Name of father of Alan Turing Julius Mathison Turing Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' computer' '.' '\n' '\n' 'The' ' Turing' ' Award'
+ ' is']" ", the father of modern computing , and the father of the computer .
+
+ The Turing Award is" False 1 ['Alan', ' Turing']
+3730 899 Name of father of x -1 Name of father of Alan Turing Julius Mathison Turing Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' computer' '.' '\n' '\n' 'The' ' Turing' ' Award'
+ ' is']" ", the father of modern computing , and the father of the computer .
+
+ The Turing Award is" False appropriate as Alan Turing was properly convicted 3 [' appropriate', ' as', ' Alan', ' Turing']
+3731 899 Name of father of x -1 Name of father of Alan Turing Julius Mathison Turing Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' computer' '.' '\n' '\n' 'The' ' Turing' ' Award'
+ ' is']" ", the father of modern computing , and the father of the computer .
+
+ The Turing Award is" False mathematician Alan Turing with Hugh Alexander 2 [' mathematician', ' Alan', ' Turing']
+3732 899 Name of father of x -1 Name of father of Alan Turing Julius Mathison Turing Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' computer' '.' '\n' '\n' 'The' ' Turing' ' Award'
+ ' is']" ", the father of modern computing , and the father of the computer .
+
+ The Turing Award is" False mathematician Alan Turing with Hugh Alexander 2 [' mathematician', ' Alan', ' Turing']
+3733 900 Name of father of x -1 Name of father of Sun Yat-sen Sun Dacheng Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Chinese' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.'
+ '\n' '\n']" ", the founder of the Chinese Republic , and the first president of the Republic of China .
+
+" False originally envisioned by Sun Yat-sen in The International 7 [' originally', ' envisioned', ' by', ' Sun', ' Y', 'at', '-', 'sen']
+3734 900 Name of father of x -1 Name of father of Sun Yat-sen Sun Dacheng Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Chinese' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.'
+ '\n' '\n']" ", the founder of the Chinese Republic , and the first president of the Republic of China .
+
+" False revolutionaries, headed by Sun Yat-sen, were concerned with 8 [' revolutionaries', ',', ' headed', ' by', ' Sun', ' Y', 'at', '-', 'sen']
+3735 900 Name of father of x -1 Name of father of Sun Yat-sen Sun Dacheng Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Chinese' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.'
+ '\n' '\n']" ", the founder of the Chinese Republic , and the first president of the Republic of China .
+
+" False Jawaharlal Nehru, Mme. Sun Yat-sen and Albert Einstein. 14 [' Jaw', 'ah', 'arl', 'al', ' Neh', 'ru', ',', ' M', 'me', '.', ' Sun', ' Y', 'at', '-', 'sen']
+3736 900 Name of father of x -1 Name of father of Sun Yat-sen Sun Dacheng Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Chinese' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.'
+ '\n' '\n']" ", the founder of the Chinese Republic , and the first president of the Republic of China .
+
+" False unveiled at Sun Yat-sen Memorial Hall 6 [' unveiled', ' at', ' Sun', ' Y', 'at', '-', 'sen']
+3737 900 Name of father of x -1 Name of father of Sun Yat-sen Sun Dacheng Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Chinese' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.'
+ '\n' '\n']" ", the founder of the Chinese Republic , and the first president of the Republic of China .
+
+" False revolutionary figure Sun Yat-sen visited southeast 6 [' revolutionary', ' figure', ' Sun', ' Y', 'at', '-', 'sen']
+3738 901 Name of father of x -1 Name of father of Henry Fonda William Brace Fonda Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' mother' ' of' ' the' ' actor' ',' ' and' ' the']" , the actor , and his wife , Mary , who was the mother of the actor , and the False Clementine (1946) – Stars Henry Fonda and directed 10 [' Clement', 'ine', ' (', '19', '46', ')', ' –', ' Stars', ' Henry', ' F', 'onda']
+3739 901 Name of father of x -1 Name of father of Henry Fonda William Brace Fonda Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' mother' ' of' ' the' ' actor' ',' ' and' ' the']" , the actor , and his wife , Mary , who was the mother of the actor , and the False " villainous role"" akin to Henry Fonda in Once Upon" 8 "[' villain', 'ous', ' role', '""', ' akin', ' to', ' Henry', ' F', 'onda']"
+3740 901 Name of father of x -1 Name of father of Henry Fonda William Brace Fonda Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' mother' ' of' ' the' ' actor' ',' ' and' ' the']" , the actor , and his wife , Mary , who was the mother of the actor , and the False concerts at Hollywood's Henry Fonda Theater, on January 6 "[' concerts', ' at', ' Hollywood', ""'s"", ' Henry', ' F', 'onda']"
+3741 901 Name of father of x -1 Name of father of Henry Fonda William Brace Fonda Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' mother' ' of' ' the' ' actor' ',' ' and' ' the']" , the actor , and his wife , Mary , who was the mother of the actor , and the False (1946) – Stars Henry Fonda and directed by John 8 [' (', '19', '46', ')', ' –', ' Stars', ' Henry', ' F', 'onda']
+3742 901 Name of father of x -1 Name of father of Henry Fonda William Brace Fonda Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' mother' ' of' ' the' ' actor' ',' ' and' ' the']" , the actor , and his wife , Mary , who was the mother of the actor , and the False Clementine (1946) – Stars Henry Fonda and directed 10 [' Clement', 'ine', ' (', '19', '46', ')', ' –', ' Stars', ' Henry', ' F', 'onda']
+3743 902 Name of father of x -1 Name of father of Ellen DeGeneres Elliott Everett DeGeneres Jr. Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure'
+ ' if' ' I' ""'m"" ' the' ' only' ' one' ' who' ' thinks']" ", who is a lesbian .
+
+ I 'm not sure if I 'm the only one who thinks" False the song on The Ellen DeGeneres Show on October 7 [' the', ' song', ' on', ' The', ' Ellen', ' De', 'Gene', 'res']
+3744 902 Name of father of x -1 Name of father of Ellen DeGeneres Elliott Everett DeGeneres Jr. Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure'
+ ' if' ' I' ""'m"" ' the' ' only' ' one' ' who' ' thinks']" ", who is a lesbian .
+
+ I 'm not sure if I 'm the only one who thinks" False by series star Ellen DeGeneres with Mark Driscoll, 6 [' by', ' series', ' star', ' Ellen', ' De', 'Gene', 'res']
+3745 902 Name of father of x -1 Name of father of Ellen DeGeneres Elliott Everett DeGeneres Jr. Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure'
+ ' if' ' I' ""'m"" ' the' ' only' ' one' ' who' ' thinks']" ", who is a lesbian .
+
+ I 'm not sure if I 'm the only one who thinks" False American comedian Ellen DeGeneres also pointed this 5 [' American', ' comedian', ' Ellen', ' De', 'Gene', 'res']
+3746 902 Name of father of x -1 Name of father of Ellen DeGeneres Elliott Everett DeGeneres Jr. Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure'
+ ' if' ' I' ""'m"" ' the' ' only' ' one' ' who' ' thinks']" ", who is a lesbian .
+
+ I 'm not sure if I 'm the only one who thinks" False Alexa Chung, The Ellen DeGeneres Show, Good Morning 7 [' Alexa', ' Chung', ',', ' The', ' Ellen', ' De', 'Gene', 'res']
+3747 902 Name of father of x -1 Name of father of Ellen DeGeneres Elliott Everett DeGeneres Jr. Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure'
+ ' if' ' I' ""'m"" ' the' ' only' ' one' ' who' ' thinks']" ", who is a lesbian .
+
+ I 'm not sure if I 'm the only one who thinks" False together on The Ellen DeGeneres Show's season 12 premiere 6 [' together', ' on', ' The', ' Ellen', ' De', 'Gene', 'res']
+3748 903 Name of father of x -1 Name of father of Jack Kerouac Léo Alcide Kerouac Jack Kerouac "[',' ' the' ' author' ' of' ' On' ' the' ' Road' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Beat' ' Generation' '.' '\n' '\n' 'The' ' Beat']" ", the author of On the Road , and the father of the Beat Generation .
+
+ The Beat" False Flaubert, Jack Kerouac and Henry James, 7 [' Fl', 'au', 'bert', ',', ' Jack', ' Ker', 'ou', 'ac']
+3749 903 Name of father of x -1 Name of father of Jack Kerouac Léo Alcide Kerouac Jack Kerouac "[',' ' the' ' author' ' of' ' On' ' the' ' Road' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Beat' ' Generation' '.' '\n' '\n' 'The' ' Beat']" ", the author of On the Road , and the father of the Beat Generation .
+
+ The Beat" False William Burroughs, Jack Kerouac and Allen Ginsberg 8 [' William', ' Bur', 'rough', 's', ',', ' Jack', ' Ker', 'ou', 'ac']
+3750 903 Name of father of x -1 Name of father of Jack Kerouac Léo Alcide Kerouac Jack Kerouac "[',' ' the' ' author' ' of' ' On' ' the' ' Road' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Beat' ' Generation' '.' '\n' '\n' 'The' ' Beat']" ", the author of On the Road , and the father of the Beat Generation .
+
+ The Beat" False adaptation of the Jack Kerouac novel of the 6 [' adaptation', ' of', ' the', ' Jack', ' Ker', 'ou', 'ac']
+3751 903 Name of father of x -1 Name of father of Jack Kerouac Léo Alcide Kerouac Jack Kerouac "[',' ' the' ' author' ' of' ' On' ' the' ' Road' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Beat' ' Generation' '.' '\n' '\n' 'The' ' Beat']" ", the author of On the Road , and the father of the Beat Generation .
+
+ The Beat" False Flaubert, Jack Kerouac and Henry James, 7 [' Fl', 'au', 'bert', ',', ' Jack', ' Ker', 'ou', 'ac']
+3752 903 Name of father of x -1 Name of father of Jack Kerouac Léo Alcide Kerouac Jack Kerouac "[',' ' the' ' author' ' of' ' On' ' the' ' Road' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Beat' ' Generation' '.' '\n' '\n' 'The' ' Beat']" ", the author of On the Road , and the father of the Beat Generation .
+
+ The Beat" False Dickens, Gustave Flaubert, Jack Kerouac and Henry James, 11 [' Dickens', ',', ' Gust', 'ave', ' Fl', 'au', 'bert', ',', ' Jack', ' Ker', 'ou', 'ac']
+3753 904 Name of father of x -1 Name of father of Kiefer Sutherland Donald Sutherland Kiefer Sutherland "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack'
+ ' Bauer' ' in' ' the' ' TV' ' series' ' 24' '.' '\n' '\n' 'The' ' actor']" ", the actor who played the role of Jack Bauer in the TV series 24 .
+
+ The actor" False Gainsbourg, Kiefer Sutherland and Charlotte Rampling 7 [' G', 'ains', 'bourg', ',', ' K', 'ief', 'er', ' Sutherland']
+3754 904 Name of father of x -1 Name of father of Kiefer Sutherland Donald Sutherland Kiefer Sutherland "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack'
+ ' Bauer' ' in' ' the' ' TV' ' series' ' 24' '.' '\n' '\n' 'The' ' actor']" ", the actor who played the role of Jack Bauer in the TV series 24 .
+
+ The actor" False Gordon, and Kiefer Sutherland were going to be 6 [' Gordon', ',', ' and', ' K', 'ief', 'er', ' Sutherland']
+3755 904 Name of father of x -1 Name of father of Kiefer Sutherland Donald Sutherland Kiefer Sutherland "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack'
+ ' Bauer' ' in' ' the' ' TV' ' series' ' 24' '.' '\n' '\n' 'The' ' actor']" ", the actor who played the role of Jack Bauer in the TV series 24 .
+
+ The actor" False Howard Gordon, and Kiefer Sutherland were going to be 7 [' Howard', ' Gordon', ',', ' and', ' K', 'ief', 'er', ' Sutherland']
+3756 904 Name of father of x -1 Name of father of Kiefer Sutherland Donald Sutherland Kiefer Sutherland "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack'
+ ' Bauer' ' in' ' the' ' TV' ' series' ' 24' '.' '\n' '\n' 'The' ' actor']" ", the actor who played the role of Jack Bauer in the TV series 24 .
+
+ The actor" False Phone Booth, with Kiefer Sutherland and Colin Farrell. 7 [' Phone', ' Booth', ',', ' with', ' K', 'ief', 'er', ' Sutherland']
+3757 904 Name of father of x -1 Name of father of Kiefer Sutherland Donald Sutherland Kiefer Sutherland "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' Jack'
+ ' Bauer' ' in' ' the' ' TV' ' series' ' 24' '.' '\n' '\n' 'The' ' actor']" ", the actor who played the role of Jack Bauer in the TV series 24 .
+
+ The actor" False " lover"". Actor Kiefer Sutherland has been cast" 6 "[' lover', '"".', ' Actor', ' K', 'ief', 'er', ' Sutherland']"
+3758 905 Name of father of x -1 Name of father of Jennifer Aniston John Aniston Jennifer Aniston "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' name' ' of' ' the' ' groom' '.' '\n' '\n']" ", and the name of the father of the bride , and the name of the groom .
+
+" False Wanderlust, with Jennifer Aniston and Paul Rudd. 6 [' Wander', 'lust', ',', ' with', ' Jennifer', ' An', 'iston']
+3759 905 Name of father of x -1 Name of father of Jennifer Aniston John Aniston Jennifer Aniston "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' name' ' of' ' the' ' groom' '.' '\n' '\n']" ", and the name of the father of the bride , and the name of the groom .
+
+" False chemistry between Jennifer Aniston and her then-husband 4 [' chemistry', ' between', ' Jennifer', ' An', 'iston']
+3760 905 Name of father of x -1 Name of father of Jennifer Aniston John Aniston Jennifer Aniston "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' name' ' of' ' the' ' groom' '.' '\n' '\n']" ", and the name of the father of the bride , and the name of the groom .
+
+" False co-starred with Jennifer Aniston in the romantic 7 [' co', '-', 'star', 'red', ' with', ' Jennifer', ' An', 'iston']
+3761 905 Name of father of x -1 Name of father of Jennifer Aniston John Aniston Jennifer Aniston "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' name' ' of' ' the' ' groom' '.' '\n' '\n']" ", and the name of the father of the bride , and the name of the groom .
+
+" False Wanderlust, with Jennifer Aniston and Paul Rudd. 6 [' Wander', 'lust', ',', ' with', ' Jennifer', ' An', 'iston']
+3762 905 Name of father of x -1 Name of father of Jennifer Aniston John Aniston Jennifer Aniston "[',' ' and' ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' name' ' of' ' the' ' groom' '.' '\n' '\n']" ", and the name of the father of the bride , and the name of the groom .
+
+" False the title role and Jennifer Aniston in her film debut. 6 [' the', ' title', ' role', ' and', ' Jennifer', ' An', 'iston']
+3763 906 Name of father of x -1 Name of father of Pietro Bembo Bernardo Bembo Pietro Bembo "[',' ' the' ' famous' ' Ven' 'et' 'ian' ' printer' ',' ' who' ' was'
+ ' the' ' first' ' to' ' print' ' the' ' first' ' edition' ' of' ' the'
+ ' _']" , the famous Ven et ian printer , who was the first to print the first edition of the _ False and cleric Pietro Bembo which was Manutius'first 6 [' and', ' cleric', ' Piet', 'ro', ' B', 'em', 'bo']
+3764 906 Name of father of x -1 Name of father of Pietro Bembo Bernardo Bembo Pietro Bembo "[',' ' the' ' famous' ' Ven' 'et' 'ian' ' printer' ',' ' who' ' was'
+ ' the' ' first' ' to' ' print' ' the' ' first' ' edition' ' of' ' the'
+ ' _']" , the famous Ven et ian printer , who was the first to print the first edition of the _ False first by Cardinal Pietro Bembo and then by 7 [' first', ' by', ' Cardinal', ' Piet', 'ro', ' B', 'em', 'bo']
+3765 906 Name of father of x -1 Name of father of Pietro Bembo Bernardo Bembo Pietro Bembo "[',' ' the' ' famous' ' Ven' 'et' 'ian' ' printer' ',' ' who' ' was'
+ ' the' ' first' ' to' ' print' ' the' ' first' ' edition' ' of' ' the'
+ ' _']" , the famous Ven et ian printer , who was the first to print the first edition of the _ False by Cardinal Pietro Bembo and then by the 6 [' by', ' Cardinal', ' Piet', 'ro', ' B', 'em', 'bo']
+3766 907 Name of father of x -1 Name of father of Roald Amundsen Jens Amundsen Roald Amundsen "[',' ' the' ' Norwegian' ' explorer' ' who' ' made' ' the' ' first'
+ ' successful' ' flight' ' over' ' the' ' South' ' Pole' '.' '\n' '\n'
+ 'The' ' first' ' flight']" ", the Norwegian explorer who made the first successful flight over the South Pole .
+
+ The first flight" False later used by Roald Amundsen in his successful 7 [' later', ' used', ' by', ' Ro', 'ald', ' Am', 'und', 'sen']
+3767 907 Name of father of x -1 Name of father of Roald Amundsen Jens Amundsen Roald Amundsen "[',' ' the' ' Norwegian' ' explorer' ' who' ' made' ' the' ' first'
+ ' successful' ' flight' ' over' ' the' ' South' ' Pole' '.' '\n' '\n'
+ 'The' ' first' ' flight']" ", the Norwegian explorer who made the first successful flight over the South Pole .
+
+ The first flight" False expressed their admiration; Roald Amundsen wrote, in a letter 8 [' expressed', ' their', ' admiration', ';', ' Ro', 'ald', ' Am', 'und', 'sen']
+3768 907 Name of father of x -1 Name of father of Roald Amundsen Jens Amundsen Roald Amundsen "[',' ' the' ' Norwegian' ' explorer' ' who' ' made' ' the' ' first'
+ ' successful' ' flight' ' over' ' the' ' South' ' Pole' '.' '\n' '\n'
+ 'The' ' first' ' flight']" ", the Norwegian explorer who made the first successful flight over the South Pole .
+
+ The first flight" False Pole lost to Roald Amundsen and ended in the 7 [' Pole', ' lost', ' to', ' Ro', 'ald', ' Am', 'und', 'sen']
+3769 907 Name of father of x -1 Name of father of Roald Amundsen Jens Amundsen Roald Amundsen "[',' ' the' ' Norwegian' ' explorer' ' who' ' made' ' the' ' first'
+ ' successful' ' flight' ' over' ' the' ' South' ' Pole' '.' '\n' '\n'
+ 'The' ' first' ' flight']" ", the Norwegian explorer who made the first successful flight over the South Pole .
+
+ The first flight" False polar elite: Roald Amundsen was a long-time 7 [' polar', ' elite', ':', ' Ro', 'ald', ' Am', 'und', 'sen']
+3770 907 Name of father of x -1 Name of father of Roald Amundsen Jens Amundsen Roald Amundsen "[',' ' the' ' Norwegian' ' explorer' ' who' ' made' ' the' ' first'
+ ' successful' ' flight' ' over' ' the' ' South' ' Pole' '.' '\n' '\n'
+ 'The' ' first' ' flight']" ", the Norwegian explorer who made the first successful flight over the South Pole .
+
+ The first flight" False 1890. In 1910 Roald Amundsen used skis on 8 [' 1890', '.', ' In', ' 1910', ' Ro', 'ald', ' Am', 'und', 'sen']
+3771 908 Name of father of x -1 Name of father of Paul Robeson William Drew Robeson I Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False prestigious production, with Paul Robeson in the title 6 [' prestigious', ' production', ',', ' with', ' Paul', ' Rob', 'eson']
+3772 908 Name of father of x -1 Name of father of Paul Robeson William Drew Robeson I Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False the meeting, actor Paul Robeson did. Robeson, though 6 [' the', ' meeting', ',', ' actor', ' Paul', ' Rob', 'eson']
+3773 908 Name of father of x -1 Name of father of Paul Robeson William Drew Robeson I Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False performances of singer Paul Robeson and to Langston Hughes' 5 [' performances', ' of', ' singer', ' Paul', ' Rob', 'eson']
+3774 908 Name of father of x -1 Name of father of Paul Robeson William Drew Robeson I Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False guest, Black artist Paul Robeson back to Portland, 6 [' guest', ',', ' Black', ' artist', ' Paul', ' Rob', 'eson']
+3775 908 Name of father of x -1 Name of father of Paul Robeson William Drew Robeson I Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False " Robeson ===
+" 5 [' Rob', 'eson', ' ===', 'Paul', ' Rob', 'eson']
+3776 909 Name of father of x -1 Name of father of Oliver Cromwell Robert Cromwell Oliver Cromwell "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' Oliver' ' Crom' 'well' ','
+ ' the' ' great' '-' 'great' '-' 'grand' 'son' ' of' ' Oliver']" , the great - grand son of Oliver Crom well , the great - great - grand son of Oliver False forces under Oliver Cromwell attacked in early 4 [' forces', ' under', ' Oliver', ' Crom', 'well']
+3777 909 Name of father of x -1 Name of father of Oliver Cromwell Robert Cromwell Oliver Cromwell "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' Oliver' ' Crom' 'well' ','
+ ' the' ' great' '-' 'great' '-' 'grand' 'son' ' of' ' Oliver']" , the great - grand son of Oliver Crom well , the great - great - grand son of Oliver False 3 ['Ol', 'iver', ' Crom', 'well']
+3778 909 Name of father of x -1 Name of father of Oliver Cromwell Robert Cromwell Oliver Cromwell "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' Oliver' ' Crom' 'well' ','
+ ' the' ' great' '-' 'great' '-' 'grand' 'son' ' of' ' Oliver']" , the great - grand son of Oliver Crom well , the great - great - grand son of Oliver False After the death of Oliver Cromwell and the fall 6 [' After', ' the', ' death', ' of', ' Oliver', ' Crom', 'well']
+3779 909 Name of father of x -1 Name of father of Oliver Cromwell Robert Cromwell Oliver Cromwell "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' Oliver' ' Crom' 'well' ','
+ ' the' ' great' '-' 'great' '-' 'grand' 'son' ' of' ' Oliver']" , the great - grand son of Oliver Crom well , the great - great - grand son of Oliver False Scotland prompted Oliver Cromwell to lead a force 4 [' Scotland', ' prompted', ' Oliver', ' Crom', 'well']
+3780 909 Name of father of x -1 Name of father of Oliver Cromwell Robert Cromwell Oliver Cromwell "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' Oliver' ' Crom' 'well' ','
+ ' the' ' great' '-' 'great' '-' 'grand' 'son' ' of' ' Oliver']" , the great - grand son of Oliver Crom well , the great - great - grand son of Oliver False success of Oliver Cromwell and his New Model 4 [' success', ' of', ' Oliver', ' Crom', 'well']
+3781 910 Name of father of x -1 Name of father of John Donne John Donne John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ""'s"" ' father' ',' ' and' ' the']" , the poet , and the poet 's father , and the poet 's father 's father , and the False Alexander Pope, John Donne and Jonathan Swift. 5 [' Alexander', ' Pope', ',', ' John', ' Don', 'ne']
+3782 910 Name of father of x -1 Name of father of John Donne John Donne John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ""'s"" ' father' ',' ' and' ' the']" , the poet , and the poet 's father , and the poet 's father 's father , and the False " even earlier by John Donne in his Meditations.
+" 5 [' even', ' earlier', ' by', ' John', ' Don', 'ne']
+3783 910 Name of father of x -1 Name of father of John Donne John Donne John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ""'s"" ' father' ',' ' and' ' the']" , the poet , and the poet 's father , and the poet 's father 's father , and the False Holy Sonnets of John Donne and the Second 6 [' Holy', ' Son', 'nets', ' of', ' John', ' Don', 'ne']
+3784 910 Name of father of x -1 Name of father of John Donne John Donne John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ""'s"" ' father' ',' ' and' ' the']" , the poet , and the poet 's father , and the poet 's father 's father , and the False " the Poet from John Donne to Julia Alvarez""." 6 [' the', ' Po', 'et', ' from', ' John', ' Don', 'ne']
+3785 910 Name of father of x -1 Name of father of John Donne John Donne John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' father' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ""'s"" ' father' ',' ' and' ' the']" , the poet , and the poet 's father , and the poet 's father 's father , and the False Genji), Alexander Pope, John Donne and Jonathan Swift. 8 [' Gen', 'ji', '),', ' Alexander', ' Pope', ',', ' John', ' Don', 'ne']
+3786 911 Name of father of x -1 Name of father of Ginger Rogers Eddins McMath Ginger Rogers "[',' ' the' ' famous' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' famous' ' dancer' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw']" ", the famous dancer , and the mother of the famous dancer .
+
+ The first time I saw" False Astaire and Ginger Rogers danced across in 4 [' Ast', 'aire', ' and', ' Ginger', ' Rogers']
+3787 911 Name of father of x -1 Name of father of Ginger Rogers Eddins McMath Ginger Rogers "[',' ' the' ' famous' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' famous' ' dancer' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw']" ", the famous dancer , and the mother of the famous dancer .
+
+ The first time I saw" False Nancy danced like Ginger Rogers and could administer 4 [' Nancy', ' danced', ' like', ' Ginger', ' Rogers']
+3788 911 Name of father of x -1 Name of father of Ginger Rogers Eddins McMath Ginger Rogers "[',' ' the' ' famous' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' famous' ' dancer' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw']" ", the famous dancer , and the mother of the famous dancer .
+
+ The first time I saw" False paired her with Ginger Rogers in a role which 4 [' paired', ' her', ' with', ' Ginger', ' Rogers']
+3789 911 Name of father of x -1 Name of father of Ginger Rogers Eddins McMath Ginger Rogers "[',' ' the' ' famous' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' famous' ' dancer' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw']" ", the famous dancer , and the mother of the famous dancer .
+
+ The first time I saw" False actress and dancer Ginger Rogers owned the 1,000-acre 4 [' actress', ' and', ' dancer', ' Ginger', ' Rogers']
+3790 911 Name of father of x -1 Name of father of Ginger Rogers Eddins McMath Ginger Rogers "[',' ' the' ' famous' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' famous' ' dancer' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw']" ", the famous dancer , and the mother of the famous dancer .
+
+ The first time I saw" False co-starring with Ginger Rogers for the eighth time 6 [' co', '-', 'star', 'ring', ' with', ' Ginger', ' Rogers']
+3791 912 Name of father of x -1 Name of father of Tim Burton Bill Burton Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' director' ' of' ' ""' 'The'
+ ' Nightmare']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the director of "" The Nightmare" False dark feel similar to Tim Burton's Alice in Wonderland: 5 [' dark', ' feel', ' similar', ' to', ' Tim', ' Burton']
+3792 912 Name of father of x -1 Name of father of Tim Burton Bill Burton Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' director' ' of' ' ""' 'The'
+ ' Nightmare']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the director of "" The Nightmare" False 1 ['Tim', ' Burton']
+3793 912 Name of father of x -1 Name of father of Tim Burton Bill Burton Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' director' ' of' ' ""' 'The'
+ ' Nightmare']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the director of "" The Nightmare" False essentially made a Tim Burton movie without 4 [' essentially', ' made', ' a', ' Tim', ' Burton']
+3794 912 Name of father of x -1 Name of father of Tim Burton Bill Burton Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' director' ' of' ' ""' 'The'
+ ' Nightmare']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the director of "" The Nightmare" False Knight Returns, Tim Burton has mentioned 4 [' Knight', ' Returns', ',', ' Tim', ' Burton']
+3795 912 Name of father of x -1 Name of father of Tim Burton Bill Burton Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' director' ' of' ' ""' 'The'
+ ' Nightmare']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the director of "" The Nightmare" False referred to as Tim Burton's Corpse Bride, 4 [' referred', ' to', ' as', ' Tim', ' Burton']
+3796 913 Name of father of x -1 Name of father of Paul Bourget Justin Bourget Paul Bourget "[',' ' the' ' author' ' of' ' _' 'The' ' Dawn' ' of' ' a' ' New' ' Era'
+ '_' ',' ' and' ' _' 'The' ' Dawn' ' of' ' a' ' New']" , the author of _ The Dawn of a New Era _ , and _ The Dawn of a New False Rudyard Kipling, Paul Bourget and Sarah Bernhardt. 7 [' Rud', 'yard', ' Ki', 'pling', ',', ' Paul', ' Bour', 'get']
+3797 915 Name of father of x -1 Name of father of Mstislav Rostropovich Leopold Rostropovich Mstislav Rostropovich "[',' ' the' ' famous' ' cell' 'ist' ',' ' who' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oy' 'ars' 'k' ',']" , the famous cell ist , who was born in the village of K ras n oy ars k , False Cello Concerto with Mstislav Rostropovich (1958), and a pioneering 11 [' C', 'ello', ' Conc', 'erto', ' with', ' M', 'st', 'isl', 'av', ' Rost', 'rop', 'ovich']
+3798 915 Name of father of x -1 Name of father of Mstislav Rostropovich Leopold Rostropovich Mstislav Rostropovich "[',' ' the' ' famous' ' cell' 'ist' ',' ' who' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oy' 'ars' 'k' ',']" , the famous cell ist , who was born in the village of K ras n oy ars k , False from clear: Mstislav Rostropovich described 9 [' from', ' clear', ':', ' M', 'st', 'isl', 'av', ' Rost', 'rop', 'ovich']
+3799 915 Name of father of x -1 Name of father of Mstislav Rostropovich Leopold Rostropovich Mstislav Rostropovich "[',' ' the' ' famous' ' cell' 'ist' ',' ' who' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oy' 'ars' 'k' ',']" , the famous cell ist , who was born in the village of K ras n oy ars k , False Concerto with Mstislav Rostropovich (1958), and a pioneering 9 [' Conc', 'erto', ' with', ' M', 'st', 'isl', 'av', ' Rost', 'rop', 'ovich']
+3800 915 Name of father of x -1 Name of father of Mstislav Rostropovich Leopold Rostropovich Mstislav Rostropovich "[',' ' the' ' famous' ' cell' 'ist' ',' ' who' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oy' 'ars' 'k' ',']" , the famous cell ist , who was born in the village of K ras n oy ars k , False " far from clear: Mstislav Rostropovich described it as ""tender""." 10 [' far', ' from', ' clear', ':', ' M', 'st', 'isl', 'av', ' Rost', 'rop', 'ovich']
+3801 915 Name of father of x -1 Name of father of Mstislav Rostropovich Leopold Rostropovich Mstislav Rostropovich "[',' ' the' ' famous' ' cell' 'ist' ',' ' who' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oy' 'ars' 'k' ',']" , the famous cell ist , who was born in the village of K ras n oy ars k , False including cellist Mstislav Rostropovich — have vowed not 9 [' including', ' cell', 'ist', ' M', 'st', 'isl', 'av', ' Rost', 'rop', 'ovich']
+3802 916 Name of father of x -1 Name of father of Richard Attenborough Frederick Attenborough Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' the' ' same'
+ ' year' ' as' ' the' ' film' ',' ' and' ' who' ' died' ' in' ' 2014']" , the actor , who was born in the same year as the film , and who died in 2014 False Chelsea chairman Richard Attenborough asking for a loan of 5 [' Chelsea', ' chairman', ' Richard', ' Att', 'en', 'borough']
+3803 916 Name of father of x -1 Name of father of Richard Attenborough Frederick Attenborough Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' the' ' same'
+ ' year' ' as' ' the' ' film' ',' ' and' ' who' ' died' ' in' ' 2014']" , the actor , who was born in the same year as the film , and who died in 2014 False " play M and Richard Attenborough as director.
+" 6 [' play', ' M', ' and', ' Richard', ' Att', 'en', 'borough']
+3804 916 Name of father of x -1 Name of father of Richard Attenborough Frederick Attenborough Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' the' ' same'
+ ' year' ' as' ' the' ' film' ',' ' and' ' who' ' died' ' in' ' 2014']" , the actor , who was born in the same year as the film , and who died in 2014 False " to play M and Richard Attenborough as director.
+" 7 [' to', ' play', ' M', ' and', ' Richard', ' Att', 'en', 'borough']
+3805 916 Name of father of x -1 Name of father of Richard Attenborough Frederick Attenborough Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' the' ' same'
+ ' year' ' as' ' the' ' film' ',' ' and' ' who' ' died' ' in' ' 2014']" , the actor , who was born in the same year as the film , and who died in 2014 False directors, namely Richard Attenborough and Shyam Benegal; 6 [' directors', ',', ' namely', ' Richard', ' Att', 'en', 'borough']
+3806 916 Name of father of x -1 Name of father of Richard Attenborough Frederick Attenborough Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' the' ' same'
+ ' year' ' as' ' the' ' film' ',' ' and' ' who' ' died' ' in' ' 2014']" , the actor , who was born in the same year as the film , and who died in 2014 False Martin Scorsese and Richard Attenborough also pointed 8 [' Martin', ' Sc', 'ors', 'ese', ' and', ' Richard', ' Att', 'en', 'borough']
+3807 917 Name of father of x -1 Name of father of Lars von Trier Fritz Michael Hartmann Lars von Trier "[""'s"" ' ""' 'Ant' 'ich' 'rist' '""' '\n' '\n' 'The' ' Danish' ' director'
+ ""'s"" ' latest' ' film' ',' ' ""' 'Ant' 'ich' 'rist' ',""']" "'s "" Ant ich rist ""
+
+ The Danish director 's latest film , "" Ant ich rist ,""" False 2003 documentary by Lars von Trier and Jørgen Leth, which 6 [' 2003', ' documentary', ' by', ' Lars', ' von', ' T', 'rier']
+3808 917 Name of father of x -1 Name of father of Lars von Trier Fritz Michael Hartmann Lars von Trier "[""'s"" ' ""' 'Ant' 'ich' 'rist' '""' '\n' '\n' 'The' ' Danish' ' director'
+ ""'s"" ' latest' ' film' ',' ' ""' 'Ant' 'ich' 'rist' ',""']" "'s "" Ant ich rist ""
+
+ The Danish director 's latest film , "" Ant ich rist ,""" False documentary by Lars von Trier and Jørgen Leth, which 5 [' documentary', ' by', ' Lars', ' von', ' T', 'rier']
+3809 917 Name of father of x -1 Name of father of Lars von Trier Fritz Michael Hartmann Lars von Trier "[""'s"" ' ""' 'Ant' 'ich' 'rist' '""' '\n' '\n' 'The' ' Danish' ' director'
+ ""'s"" ' latest' ' film' ',' ' ""' 'Ant' 'ich' 'rist' ',""']" "'s "" Ant ich rist ""
+
+ The Danish director 's latest film , "" Ant ich rist ,""" False Danish filmmaker Lars von Trier drew inspiration 5 [' Danish', ' filmmaker', ' Lars', ' von', ' T', 'rier']
+3810 917 Name of father of x -1 Name of father of Lars von Trier Fritz Michael Hartmann Lars von Trier "[""'s"" ' ""' 'Ant' 'ich' 'rist' '""' '\n' '\n' 'The' ' Danish' ' director'
+ ""'s"" ' latest' ' film' ',' ' ""' 'Ant' 'ich' 'rist' ',""']" "'s "" Ant ich rist ""
+
+ The Danish director 's latest film , "" Ant ich rist ,""" False Danish filmmaker Lars von Trier drew inspiration 5 [' Danish', ' filmmaker', ' Lars', ' von', ' T', 'rier']
+3811 917 Name of father of x -1 Name of father of Lars von Trier Fritz Michael Hartmann Lars von Trier "[""'s"" ' ""' 'Ant' 'ich' 'rist' '""' '\n' '\n' 'The' ' Danish' ' director'
+ ""'s"" ' latest' ' film' ',' ' ""' 'Ant' 'ich' 'rist' ',""']" "'s "" Ant ich rist ""
+
+ The Danish director 's latest film , "" Ant ich rist ,""" False Danish filmmaker Lars von Trier drew inspiration 5 [' Danish', ' filmmaker', ' Lars', ' von', ' T', 'rier']
+3812 918 Name of father of x -1 Name of father of Steve Jobs John Abdulfattah Jandali Steve Jobs "[',' ' the' ' founder' ' of' ' Apple' ',' ' and' ' the' ' man' ' who'
+ ' invented' ' the' ' iPod' ',' ' the' ' iPhone' ',' ' and' ' the' ' iPad']" , the founder of Apple , and the man who invented the iPod , the iPhone , and the iPad False Flynn to us was Steve Jobs and Bill Gates 5 [' Flynn', ' to', ' us', ' was', ' Steve', ' Jobs']
+3813 918 Name of father of x -1 Name of father of Steve Jobs John Abdulfattah Jandali Steve Jobs "[',' ' the' ' founder' ' of' ' Apple' ',' ' and' ' the' ' man' ' who'
+ ' invented' ' the' ' iPod' ',' ' the' ' iPhone' ',' ' and' ' the' ' iPad']" , the founder of Apple , and the man who invented the iPod , the iPhone , and the iPad False executive producers were Steve Jobs and Edwin Catmull. 4 [' executive', ' producers', ' were', ' Steve', ' Jobs']
+3814 918 Name of father of x -1 Name of father of Steve Jobs John Abdulfattah Jandali Steve Jobs "[',' ' the' ' founder' ' of' ' Apple' ',' ' and' ' the' ' man' ' who'
+ ' invented' ' the' ' iPod' ',' ' the' ' iPhone' ',' ' and' ' the' ' iPad']" , the founder of Apple , and the man who invented the iPod , the iPhone , and the iPad False played it for Steve Jobs 10 years ago, Jobs 4 [' played', ' it', ' for', ' Steve', ' Jobs']
+3815 918 Name of father of x -1 Name of father of Steve Jobs John Abdulfattah Jandali Steve Jobs "[',' ' the' ' founder' ' of' ' Apple' ',' ' and' ' the' ' man' ' who'
+ ' invented' ' the' ' iPod' ',' ' the' ' iPhone' ',' ' and' ' the' ' iPad']" , the founder of Apple , and the man who invented the iPod , the iPhone , and the iPad False Kingdom and that Steve Jobs had been directed 4 [' Kingdom', ' and', ' that', ' Steve', ' Jobs']
+3816 918 Name of father of x -1 Name of father of Steve Jobs John Abdulfattah Jandali Steve Jobs "[',' ' the' ' founder' ' of' ' Apple' ',' ' and' ' the' ' man' ' who'
+ ' invented' ' the' ' iPod' ',' ' the' ' iPhone' ',' ' and' ' the' ' iPad']" , the founder of Apple , and the man who invented the iPod , the iPhone , and the iPad False and Pixar's Steve Jobs and John Lasseter. 4 "[' and', ' Pixar', ""'s"", ' Steve', ' Jobs']"
+3817 921 Name of father of x -1 Name of father of Ralph Vaughan Williams Arthur Charles Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False " accompaniment"". Ralph Vaughan Williams was delighted" 5 "[' accompan', 'iment', '"".', ' Ralph', ' Vaughan', ' Williams']"
+3818 921 Name of father of x -1 Name of father of Ralph Vaughan Williams Arthur Charles Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False Delius, Gustav Holst, Ralph Vaughan Williams and others. Present-day 9 [' Del', 'ius', ',', ' Gustav', ' Hol', 'st', ',', ' Ralph', ' Vaughan', ' Williams']
+3819 921 Name of father of x -1 Name of father of Ralph Vaughan Williams Arthur Charles Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False Gustav Holst, Ralph Vaughan Williams and others. Present-day 6 [' Gustav', ' Hol', 'st', ',', ' Ralph', ' Vaughan', ' Williams']
+3820 921 Name of father of x -1 Name of father of Ralph Vaughan Williams Arthur Charles Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False series of portraits of Ralph Vaughan Williams in 1952 – 61. 6 [' series', ' of', ' portraits', ' of', ' Ralph', ' Vaughan', ' Williams']
+3821 921 Name of father of x -1 Name of father of Ralph Vaughan Williams Arthur Charles Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False which English composer Ralph Vaughan Williams based his Sea Songs, 5 [' which', ' English', ' composer', ' Ralph', ' Vaughan', ' Williams']
+3822 922 Name of father of x -1 Name of father of Pius IX Conte Girolamo Mastai Ferretti Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' bishops' ' of'
+ ' the' ' world' ',' ' dated' ' from' ' the' ' Vatican' ',' '\n']" ".
+
+ The Pope 's letter to the bishops of the world , dated from the Vatican ,
+" False protection of Pope Pius IX — who refused 5 [' protection', ' of', ' Pope', ' P', 'ius', ' IX']
+3823 922 Name of father of x -1 Name of father of Pius IX Conte Girolamo Mastai Ferretti Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' bishops' ' of'
+ ' the' ' world' ',' ' dated' ' from' ' the' ' Vatican' ',' '\n']" ".
+
+ The Pope 's letter to the bishops of the world , dated from the Vatican ,
+" False audience with Pope Pius IX in Rome. He 5 [' audience', ' with', ' Pope', ' P', 'ius', ' IX']
+3824 922 Name of father of x -1 Name of father of Pius IX Conte Girolamo Mastai Ferretti Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' bishops' ' of'
+ ' the' ' world' ',' ' dated' ' from' ' the' ' Vatican' ',' '\n']" ".
+
+ The Pope 's letter to the bishops of the world , dated from the Vatican ,
+" False at the time Pope Pius IX proclaimed the 6 [' at', ' the', ' time', ' Pope', ' P', 'ius', ' IX']
+3825 922 Name of father of x -1 Name of father of Pius IX Conte Girolamo Mastai Ferretti Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' bishops' ' of'
+ ' the' ' world' ',' ' dated' ' from' ' the' ' Vatican' ',' '\n']" ".
+
+ The Pope 's letter to the bishops of the world , dated from the Vatican ,
+" False audience with Pope Pius IX in Rome. He eagerly 5 [' audience', ' with', ' Pope', ' P', 'ius', ' IX']
+3826 922 Name of father of x -1 Name of father of Pius IX Conte Girolamo Mastai Ferretti Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' bishops' ' of'
+ ' the' ' world' ',' ' dated' ' from' ' the' ' Vatican' ',' '\n']" ".
+
+ The Pope 's letter to the bishops of the world , dated from the Vatican ,
+" False in 1841, Pope Pius IX ruled that, 7 [' in', ' 18', '41', ',', ' Pope', ' P', 'ius', ' IX']
+3827 923 Name of father of x -1 Name of father of Robert Capa Dezsö Friedmann Robert Capa "[',' ' the' ' famous' ' photographer' ',' ' who' ' died' ' in' ' the'
+ ' Spanish' ' Civil' ' War' '.' '\n' '\n' 'The' ' exhibition' ' is' ' a'
+ ' tribute']" ", the famous photographer , who died in the Spanish Civil War .
+
+ The exhibition is a tribute" False Falling Soldier by Robert Capa is the third 5 [' Falling', ' Soldier', ' by', ' Robert', ' Cap', 'a']
+3828 924 Name of father of x -1 Name of father of Xi Jinping Xi Zhongxun Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Li Changchun, Xi Jinping and Zhou Yongkang, 6 [' Li', ' Chang', 'ch', 'un', ',', ' Xi', ' Jinping']
+3829 924 Name of father of x -1 Name of father of Xi Jinping Xi Zhongxun Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Chinese president Xi Jinping in Beijing, Key 3 [' Chinese', ' president', ' Xi', ' Jinping']
+3830 924 Name of father of x -1 Name of father of Xi Jinping Xi Zhongxun Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Qinglin, Li Changchun, Xi Jinping and Zhou Yongkang, 9 [' Qing', 'lin', ',', ' Li', ' Chang', 'ch', 'un', ',', ' Xi', ' Jinping']
+3831 924 Name of father of x -1 Name of father of Xi Jinping Xi Zhongxun Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Qinglin, Li Changchun, Xi Jinping and Zhou Yongkang, 9 [' Qing', 'lin', ',', ' Li', ' Chang', 'ch', 'un', ',', ' Xi', ' Jinping']
+3832 924 Name of father of x -1 Name of father of Xi Jinping Xi Zhongxun Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Chinese president Xi Jinping sent a special 3 [' Chinese', ' president', ' Xi', ' Jinping']
+3833 925 Name of father of x -1 Name of father of Marquis de Sade Jean-Baptiste-François-Joseph de Sade Marquis de Sade "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Marqu' 'is'
+ ' de' ' S' 'ade' ',' ' and' ' the' '\n' '\n' 'Name']" ", and the
+
+ Name of mother of Marqu is de S ade , and the
+
+ Name" False Caylus, Besenval, the Marquis de Sade and Baudelaire. 12 [' Cay', 'lus', ',', ' Bes', 'en', 'val', ',', ' the', ' Marqu', 'is', ' de', ' S', 'ade']
+3834 925 Name of father of x -1 Name of father of Marquis de Sade Jean-Baptiste-François-Joseph de Sade Marquis de Sade "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Marqu' 'is'
+ ' de' ' S' 'ade' ',' ' and' ' the' '\n' '\n' 'Name']" ", and the
+
+ Name of mother of Marqu is de S ade , and the
+
+ Name" False Besenval, the Marquis de Sade and Baudelaire. 9 [' Bes', 'en', 'val', ',', ' the', ' Marqu', 'is', ' de', ' S', 'ade']
+3835 925 Name of father of x -1 Name of father of Marquis de Sade Jean-Baptiste-François-Joseph de Sade Marquis de Sade "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Marqu' 'is'
+ ' de' ' S' 'ade' ',' ' and' ' the' '\n' '\n' 'Name']" ", and the
+
+ Name of mother of Marqu is de S ade , and the
+
+ Name" False Caylus, Besenval, the Marquis de Sade and Baudelaire. 12 [' Cay', 'lus', ',', ' Bes', 'en', 'val', ',', ' the', ' Marqu', 'is', ' de', ' S', 'ade']
+3836 926 Name of father of x -1 Name of father of Paulus Potter Pieter Symonsz Potter Paulus Potter "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' friend' ' of' ' the'
+ ' family' ' for' ' years' '.' '\n' '\n' '""' 'I' ""'m"" ' sorry']" ", the man who had been a friend of the family for years .
+
+ "" I 'm sorry" False Dutch merchant ship Paulus Potter abandoned 5 [' Dutch', ' merchant', ' ship', ' Paul', 'us', ' Potter']
+3837 926 Name of father of x -1 Name of father of Paulus Potter Pieter Symonsz Potter Paulus Potter "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' friend' ' of' ' the'
+ ' family' ' for' ' years' '.' '\n' '\n' '""' 'I' ""'m"" ' sorry']" ", the man who had been a friend of the family for years .
+
+ "" I 'm sorry" False merchant ship Paulus Potter abandoned and drifting, 4 [' merchant', ' ship', ' Paul', 'us', ' Potter']
+3838 926 Name of father of x -1 Name of father of Paulus Potter Pieter Symonsz Potter Paulus Potter "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' friend' ' of' ' the'
+ ' family' ' for' ' years' '.' '\n' '\n' '""' 'I' ""'m"" ' sorry']" ", the man who had been a friend of the family for years .
+
+ "" I 'm sorry" False merchant ship Paulus Potter abandoned and drifting, 4 [' merchant', ' ship', ' Paul', 'us', ' Potter']
+3839 927 Name of father of x -1 Name of father of Gilbert du Motier, Marquis de Lafayette Michel du Motier, Marquis de La Fayette Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False a Freemason. When Gilbert du Motier, Marquis de Lafayette laid the cornerstone 13 [' a', ' Freem', 'ason', '.', ' When', ' Gilbert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3840 927 Name of father of x -1 Name of father of Gilbert du Motier, Marquis de Lafayette Michel du Motier, Marquis de La Fayette Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False Paul Yves Roch Gilbert du Motier, Marquis de Lafayette (French pronunciation: 13 [' Paul', ' Y', 'ves', ' R', 'och', ' Gilbert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3841 927 Name of father of x -1 Name of father of Gilbert du Motier, Marquis de Lafayette Michel du Motier, Marquis de La Fayette Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False " Marquis de Lafayette =
+" 14 [' Marqu', 'is', ' de', ' Lafayette', ' =', 'Gil', 'bert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3842 927 Name of father of x -1 Name of father of Gilbert du Motier, Marquis de Lafayette Michel du Motier, Marquis de La Fayette Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False was a Freemason. When Gilbert du Motier, Marquis de Lafayette laid the cornerstone 14 [' was', ' a', ' Freem', 'ason', '.', ' When', ' Gilbert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3843 927 Name of father of x -1 Name of father of Gilbert du Motier, Marquis de Lafayette Michel du Motier, Marquis de La Fayette Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False a Freemason. When Gilbert du Motier, Marquis de Lafayette laid the cornerstone 13 [' a', ' Freem', 'ason', '.', ' When', ' Gilbert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3844 928 Name of father of x -1 Name of father of François Hollande Georges Hollande François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False " state."" French president François Hollande highlighted Weizsäcker's" 5 "[' state', '.""', ' French', ' president', ' François', ' Hollande']"
+3845 928 Name of father of x -1 Name of father of François Hollande Georges Hollande François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False French president François Hollande highlighted Weizsäcker's 3 [' French', ' president', ' François', ' Hollande']
+3846 928 Name of father of x -1 Name of father of François Hollande Georges Hollande François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False French President François Hollande praised the 3 [' French', ' President', ' François', ' Hollande']
+3847 928 Name of father of x -1 Name of father of François Hollande Georges Hollande François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False French President François Hollande that Qatar will 3 [' French', ' President', ' François', ' Hollande']
+3848 928 Name of father of x -1 Name of father of François Hollande Georges Hollande François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False " state."" French president François Hollande highlighted Weizsäcker's" 5 "[' state', '.""', ' French', ' president', ' François', ' Hollande']"
+3849 930 Name of father of x -1 Name of father of John of Damascus Sarjun ibn Mansur John of Damascus "[',' ' the' ' great' ' Syrian' ' monk' ',' ' who' ' was' ' a'
+ ' contemporary' ' of' ' St' '.' ' John' ' the' ' Baptist' ',' ' and'
+ ' who' ' was']" , the great Syrian monk , who was a contemporary of St . John the Baptist , and who was False " Daniel J (1972), John of Damascus on Islam: The ""Heresy" 7 [' Daniel', ' J', ' (', '1972', '),', ' John', ' of', ' Damascus']
+3850 930 Name of father of x -1 Name of father of John of Damascus Sarjun ibn Mansur John of Damascus "[',' ' the' ' great' ' Syrian' ' monk' ',' ' who' ' was' ' a'
+ ' contemporary' ' of' ' St' '.' ' John' ' the' ' Baptist' ',' ' and'
+ ' who' ' was']" , the great Syrian monk , who was a contemporary of St . John the Baptist , and who was False " Daniel J (1972), John of Damascus on Islam: The ""Heresy" 7 [' Daniel', ' J', ' (', '1972', '),', ' John', ' of', ' Damascus']
+3851 931 Name of father of x -1 Name of father of Helen Keller Arthur Henley Keller Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' who' ' was'
+ ' born' ' in' ' 18' '87' ',' ' and' ' who' ' was' ' the']" , the blind and deaf - m ute , who was born in 18 87 , and who was the False professionals at the Helen Keller Institute in Mumbai. 4 [' professionals', ' at', ' the', ' Helen', ' Keller']
+3852 931 Name of father of x -1 Name of father of Helen Keller Arthur Henley Keller Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' who' ' was'
+ ' born' ' in' ' 18' '87' ',' ' and' ' who' ' was' ' the']" , the blind and deaf - m ute , who was born in 18 87 , and who was the False for the blind, Helen Keller sat on the jury. 5 [' for', ' the', ' blind', ',', ' Helen', ' Keller']
+3853 931 Name of father of x -1 Name of father of Helen Keller Arthur Henley Keller Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' who' ' was'
+ ' born' ' in' ' 18' '87' ',' ' and' ' who' ' was' ' the']" , the blind and deaf - m ute , who was born in 18 87 , and who was the False professionals at the Helen Keller Institute 4 [' professionals', ' at', ' the', ' Helen', ' Keller']
+3854 931 Name of father of x -1 Name of father of Helen Keller Arthur Henley Keller Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' who' ' was'
+ ' born' ' in' ' 18' '87' ',' ' and' ' who' ' was' ' the']" , the blind and deaf - m ute , who was born in 18 87 , and who was the False children of the Helen Keller Institute; she 4 [' children', ' of', ' the', ' Helen', ' Keller']
+3855 931 Name of father of x -1 Name of father of Helen Keller Arthur Henley Keller Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' who' ' was'
+ ' born' ' in' ' 18' '87' ',' ' and' ' who' ' was' ' the']" , the blind and deaf - m ute , who was born in 18 87 , and who was the False professionals at the Helen Keller Institute in 4 [' professionals', ' at', ' the', ' Helen', ' Keller']
+3856 932 Name of father of x -1 Name of father of Jeremy Bentham Jeremiah Bentham Jeremy Bentham "[',' ' the' ' father' ' of' ' the' ' modern' ' school' ' of' ' philosophy'
+ ',' ' and' ' the' ' founder' ' of' ' the' '\n' '\n' 'The' ' first' ' of']" ", the father of the modern school of philosophy , and the founder of the
+
+ The first of" False King's Bench; Jeremy Bentham asserted that 6 "[' King', ""'s"", ' Bench', ';', ' Jeremy', ' Bent', 'ham']"
+3857 932 Name of father of x -1 Name of father of Jeremy Bentham Jeremiah Bentham Jeremy Bentham "[',' ' the' ' father' ' of' ' the' ' modern' ' school' ' of' ' philosophy'
+ ',' ' and' ' the' ' founder' ' of' ' the' '\n' '\n' 'The' ' first' ' of']" ", the father of the modern school of philosophy , and the founder of the
+
+ The first of" False " and Death of Jeremy Bentham ""are the two installments" 5 [' and', ' Death', ' of', ' Jeremy', ' Bent', 'ham']
+3858 932 Name of father of x -1 Name of father of Jeremy Bentham Jeremiah Bentham Jeremy Bentham "[',' ' the' ' father' ' of' ' the' ' modern' ' school' ' of' ' philosophy'
+ ',' ' and' ' the' ' founder' ' of' ' the' '\n' '\n' 'The' ' first' ' of']" ", the father of the modern school of philosophy , and the founder of the
+
+ The first of" False " 'herd-conscience.' Jeremy Bentham noted that:""" 9 "["" '"", 'her', 'd', '-', 'cons', 'cience', "".'"", ' Jeremy', ' Bent', 'ham']"
+3859 932 Name of father of x -1 Name of father of Jeremy Bentham Jeremiah Bentham Jeremy Bentham "[',' ' the' ' father' ' of' ' the' ' modern' ' school' ' of' ' philosophy'
+ ',' ' and' ' the' ' founder' ' of' ' the' '\n' '\n' 'The' ' first' ' of']" ", the father of the modern school of philosophy , and the founder of the
+
+ The first of" False public, quoting Jeremy Bentham when he said 5 [' public', ',', ' quoting', ' Jeremy', ' Bent', 'ham']
+3860 932 Name of father of x -1 Name of father of Jeremy Bentham Jeremiah Bentham Jeremy Bentham "[',' ' the' ' father' ' of' ' the' ' modern' ' school' ' of' ' philosophy'
+ ',' ' and' ' the' ' founder' ' of' ' the' '\n' '\n' 'The' ' first' ' of']" ", the father of the modern school of philosophy , and the founder of the
+
+ The first of" False the King's Bench; Jeremy Bentham asserted that 7 "[' the', ' King', ""'s"", ' Bench', ';', ' Jeremy', ' Bent', 'ham']"
+3861 933 Name of father of x -1 Name of father of Paul Newman Arthur Newman Paul Newman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' board' ' of' ' directors'
+ ' of' ' the' ' American' ' Cancer' ' Society' ',' ' and' ' a' ' member'
+ ' of']" , who was a member of the board of directors of the American Cancer Society , and a member of False Brothers loaned Paul Newman to 20th Century 4 [' Brothers', ' loan', 'ed', ' Paul', ' Newman']
+3862 933 Name of father of x -1 Name of father of Paul Newman Arthur Newman Paul Newman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' board' ' of' ' directors'
+ ' of' ' the' ' American' ' Cancer' ' Society' ',' ' and' ' a' ' member'
+ ' of']" , who was a member of the board of directors of the American Cancer Society , and a member of False 1 ['Paul', ' Newman']
+3863 933 Name of father of x -1 Name of father of Paul Newman Arthur Newman Paul Newman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' board' ' of' ' directors'
+ ' of' ' the' ' American' ' Cancer' ' Society' ',' ' and' ' a' ' member'
+ ' of']" , who was a member of the board of directors of the American Cancer Society , and a member of False Sting, which starred Paul Newman and Robert 5 [' Sting', ',', ' which', ' starred', ' Paul', ' Newman']
+3864 933 Name of father of x -1 Name of father of Paul Newman Arthur Newman Paul Newman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' board' ' of' ' directors'
+ ' of' ' the' ' American' ' Cancer' ' Society' ',' ' and' ' a' ' member'
+ ' of']" , who was a member of the board of directors of the American Cancer Society , and a member of False (non-documentary) movie for Paul Newman before his retirement 9 [' (', 'non', '-', 'document', 'ary', ')', ' movie', ' for', ' Paul', ' Newman']
+3865 933 Name of father of x -1 Name of father of Paul Newman Arthur Newman Paul Newman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' board' ' of' ' directors'
+ ' of' ' the' ' American' ' Cancer' ' Society' ',' ' and' ' a' ' member'
+ ' of']" , who was a member of the board of directors of the American Cancer Society , and a member of False Money (with Paul Newman and Tom Cruise), 4 [' Money', ' (', 'with', ' Paul', ' Newman']
+3866 935 Name of father of x -1 Name of father of Burt Bacharach Bert Bacharach Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Walk' ' on' ' the'
+ ' Wild' ' Side' '""' ' and' ' the' ' song' ' ""' 'I' ""'ll"" ' Never']" ", the composer of the song "" Walk on the Wild Side "" and the song "" I 'll Never" False the music of Burt Bacharach and The Beach Boys'Smiley 7 [' the', ' music', ' of', ' B', 'urt', ' Bach', 'ar', 'ach']
+3867 935 Name of father of x -1 Name of father of Burt Bacharach Bert Bacharach Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Walk' ' on' ' the'
+ ' Wild' ' Side' '""' ' and' ' the' ' song' ' ""' 'I' ""'ll"" ' Never']" ", the composer of the song "" Walk on the Wild Side "" and the song "" I 'll Never" False " to You"", written by Burt Bacharach and Hal David." 9 "[' to', ' You', '"",', ' written', ' by', ' B', 'urt', ' Bach', 'ar', 'ach']"
+3868 935 Name of father of x -1 Name of father of Burt Bacharach Bert Bacharach Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Walk' ' on' ' the'
+ ' Wild' ' Side' '""' ' and' ' the' ' song' ' ""' 'I' ""'ll"" ' Never']" ", the composer of the song "" Walk on the Wild Side "" and the song "" I 'll Never" False 1966, contains the Burt Bacharach and Hal David 8 [' 1966', ',', ' contains', ' the', ' B', 'urt', ' Bach', 'ar', 'ach']
+3869 935 Name of father of x -1 Name of father of Burt Bacharach Bert Bacharach Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Walk' ' on' ' the'
+ ' Wild' ' Side' '""' ' and' ' the' ' song' ' ""' 'I' ""'ll"" ' Never']" ", the composer of the song "" Walk on the Wild Side "" and the song "" I 'll Never" False " Home"", written by Burt Bacharach and Hal David," 8 "[' Home', '"",', ' written', ' by', ' B', 'urt', ' Bach', 'ar', 'ach']"
+3870 935 Name of father of x -1 Name of father of Burt Bacharach Bert Bacharach Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Walk' ' on' ' the'
+ ' Wild' ' Side' '""' ' and' ' the' ' song' ' ""' 'I' ""'ll"" ' Never']" ", the composer of the song "" Walk on the Wild Side "" and the song "" I 'll Never" False " Bacharach – songwriter
+" 10 [' Bach', 'ar', 'ach', ' –', ' song', 'writer', 'B', 'urt', ' Bach', 'ar', 'ach']
+3871 937 Name of father of x -1 Name of father of Nikita Mikhalkov Sergey Mikhalkov Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ' ""' 'The'
+ ' Return' ' of' ' the' ' Living' ' Dead' '""' ' (' '1993' ')' ' and' ' ""']" ", the Russian director of the film "" The Return of the Living Dead "" ( 1993 ) and """ False 6 – 18, director Nikita Mikhalkov observes that to be 10 [' 6', ' –', ' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3872 937 Name of father of x -1 Name of father of Nikita Mikhalkov Sergey Mikhalkov Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ' ""' 'The'
+ ' Return' ' of' ' the' ' Living' ' Dead' '""' ' (' '1993' ')' ' and' ' ""']" ", the Russian director of the film "" The Return of the Living Dead "" ( 1993 ) and """ False – 18, director Nikita Mikhalkov observes that to be 9 [' –', ' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3873 937 Name of father of x -1 Name of father of Nikita Mikhalkov Sergey Mikhalkov Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ' ""' 'The'
+ ' Return' ' of' ' the' ' Living' ' Dead' '""' ' (' '1993' ')' ' and' ' ""']" ", the Russian director of the film "" The Return of the Living Dead "" ( 1993 ) and """ False 18, director Nikita Mikhalkov observes that to 8 [' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3874 937 Name of father of x -1 Name of father of Nikita Mikhalkov Sergey Mikhalkov Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ' ""' 'The'
+ ' Return' ' of' ' the' ' Living' ' Dead' '""' ' (' '1993' ')' ' and' ' ""']" ", the Russian director of the film "" The Return of the Living Dead "" ( 1993 ) and """ False – 18, director Nikita Mikhalkov observes that 9 [' –', ' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3875 937 Name of father of x -1 Name of father of Nikita Mikhalkov Sergey Mikhalkov Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ' ""' 'The'
+ ' Return' ' of' ' the' ' Living' ' Dead' '""' ' (' '1993' ')' ' and' ' ""']" ", the Russian director of the film "" The Return of the Living Dead "" ( 1993 ) and """ False 18, director Nikita Mikhalkov observes that to 8 [' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3876 938 Name of father of x -1 Name of father of Lorenzo de' Medici Piero di Cosimo de' Medici Lorenzo de' Medici "[',' ' the' ' Magn' 'ificent' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Magn ificent , and the
+ " False horse's head. Lorenzo de' Medici sent Leonardo 8 "[' horse', ""'s"", ' head', '.', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3877 938 Name of father of x -1 Name of father of Lorenzo de' Medici Piero di Cosimo de' Medici Lorenzo de' Medici "[',' ' the' ' Magn' 'ificent' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Magn ificent , and the
+ " False marriage to Piero di Lorenzo de' Medici was arranged by his 9 "[' marriage', ' to', ' Pier', 'o', ' di', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3878 938 Name of father of x -1 Name of father of Lorenzo de' Medici Piero di Cosimo de' Medici Lorenzo de' Medici "[',' ' the' ' Magn' 'ificent' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Magn ificent , and the
+ " False marriage to Piero di Lorenzo de' Medici was arranged 9 "[' marriage', ' to', ' Pier', 'o', ' di', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3879 938 Name of father of x -1 Name of father of Lorenzo de' Medici Piero di Cosimo de' Medici Lorenzo de' Medici "[',' ' the' ' Magn' 'ificent' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Magn ificent , and the
+ " False horse's head. Lorenzo de' Medici sent Leonardo 8 "[' horse', ""'s"", ' head', '.', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3880 939 Name of father of x -1 Name of father of Moses Amram Moses "[',' ' the' ' son' ' of' ' Am' 'ram' ',' ' the' ' son' ' of' ' Y' 'oche'
+ 'ved' ',' ' the' ' daughter' ' of' ' Levi' ',' ' the']" , the son of Am ram , the son of Y oche ved , the daughter of Levi , the True Ugandan runner Moses Kipsiro. He competed 3 [' Ug', 'andan', ' runner', ' Moses']
+3881 939 Name of father of x -1 Name of father of Moses Amram Moses "[',' ' the' ' son' ' of' ' Am' 'ram' ',' ' the' ' son' ' of' ' Y' 'oche'
+ 'ved' ',' ' the' ' daughter' ' of' ' Levi' ',' ' the']" , the son of Am ram , the son of Y oche ved , the daughter of Levi , the True the Robert Moses Niagara Power 2 [' the', ' Robert', ' Moses']
+3882 939 Name of father of x -1 Name of father of Moses Amram Moses "[',' ' the' ' son' ' of' ' Am' 'ram' ',' ' the' ' son' ' of' ' Y' 'oche'
+ 'ved' ',' ' the' ' daughter' ' of' ' Levi' ',' ' the']" , the son of Am ram , the son of Y oche ved , the daughter of Levi , the True 1 ['M', 'oses']
+3883 939 Name of father of x -1 Name of father of Moses Amram Moses "[',' ' the' ' son' ' of' ' Am' 'ram' ',' ' the' ' son' ' of' ' Y' 'oche'
+ 'ved' ',' ' the' ' daughter' ' of' ' Levi' ',' ' the']" , the son of Am ram , the son of Y oche ved , the daughter of Levi , the True Thomas Girtin, Moses Griffith, 5 [' Thomas', ' G', 'irt', 'in', ',', ' Moses']
+3884 939 Name of father of x -1 Name of father of Moses Amram Moses "[',' ' the' ' son' ' of' ' Am' 'ram' ',' ' the' ' son' ' of' ' Y' 'oche'
+ 'ved' ',' ' the' ' daughter' ' of' ' Levi' ',' ' the']" , the son of Am ram , the son of Y oche ved , the daughter of Levi , the True tapped Robert Moses to head a newly 2 [' tapped', ' Robert', ' Moses']
+3885 941 Name of father of x -1 Name of father of Jean-Paul Belmondo Paul Belmondo Jean-Paul Belmondo "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' French' ' actor' ' in' ' the' ' film' ' ""' 'B' 'reath' 'less' '""'
+ ' and']" ", the actor who played the role of the French actor in the film "" B reath less "" and" False owner opposite Jean-Paul Belmondo in Bolognini's 7 [' owner', ' opposite', ' Jean', '-', 'Paul', ' Bel', 'mond', 'o']
+3886 941 Name of father of x -1 Name of father of Jean-Paul Belmondo Paul Belmondo Jean-Paul Belmondo "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' French' ' actor' ' in' ' the' ' film' ' ""' 'B' 'reath' 'less' '""'
+ ' and']" ", the actor who played the role of the French actor in the film "" B reath less "" and" False Scoumoune with Jean-Paul Belmondo and Michel 11 [' Sc', 'ou', 'm', 'ou', 'ne', ' with', ' Jean', '-', 'Paul', ' Bel', 'mond', 'o']
+3887 941 Name of father of x -1 Name of father of Jean-Paul Belmondo Paul Belmondo Jean-Paul Belmondo "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' French' ' actor' ' in' ' the' ' film' ' ""' 'B' 'reath' 'less' '""'
+ ' and']" ", the actor who played the role of the French actor in the film "" B reath less "" and" False George Raft and Jean-Paul Belmondo were given top billing 9 [' George', ' Ra', 'ft', ' and', ' Jean', '-', 'Paul', ' Bel', 'mond', 'o']
+3888 941 Name of father of x -1 Name of father of Jean-Paul Belmondo Paul Belmondo Jean-Paul Belmondo "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' French' ' actor' ' in' ' the' ' film' ' ""' 'B' 'reath' 'less' '""'
+ ' and']" ", the actor who played the role of the French actor in the film "" B reath less "" and" False starred opposite Jean-Paul Belmondo as Vénus in the 7 [' starred', ' opposite', ' Jean', '-', 'Paul', ' Bel', 'mond', 'o']
+3889 941 Name of father of x -1 Name of father of Jean-Paul Belmondo Paul Belmondo Jean-Paul Belmondo "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' French' ' actor' ' in' ' the' ' film' ' ""' 'B' 'reath' 'less' '""'
+ ' and']" ", the actor who played the role of the French actor in the film "" B reath less "" and" False 5 ['Jean', '-', 'Paul', ' Bel', 'mond', 'o']
+3890 942 Name of father of x -1 Name of father of Thomas Babington Macaulay, 1st Baron Macaulay Zachary Macaulay Thomas Babington Macaulay, 1st Baron Macaulay "['\n' '\n' 'Thomas' ' Bab' 'ington' ' Mac' 'aul' 'ay' ',' ' 1' 'st'
+ ' Baron' ' Mac' 'aul' 'ay' ' (' '17' '41' '–' '18']" "
+
+ Thomas Bab ington Mac aul ay , 1 st Baron Mac aul ay ( 17 41 – 18" False 19th-century historian Thomas Babington Macaulay, 1st Baron Macaulay was granted (English) 17 [' 19', 'th', '-', 'century', ' historian', ' Thomas', ' Bab', 'ington', ' Mac', 'aul', 'ay', ',', ' 1', 'st', ' Baron', ' Mac', 'aul', 'ay']
+3891 945 Name of father of x -1 Name of father of Boris Johnson Stanley Johnson Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False after her. In 2005, Boris Johnson (former Mayor 7 [' after', ' her', '.', ' In', ' 2005', ',', ' Boris', ' Johnson']
+3892 945 Name of father of x -1 Name of father of Boris Johnson Stanley Johnson Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False Hall overseen by Boris Johnson at 3.5 pounds 4 [' Hall', ' overseen', ' by', ' Boris', ' Johnson']
+3893 945 Name of father of x -1 Name of father of Boris Johnson Stanley Johnson Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False " meeting, declaring Mayor Boris Johnson to be a ""jester-despot""," 5 [' meeting', ',', ' declaring', ' Mayor', ' Boris', ' Johnson']
+3894 945 Name of father of x -1 Name of father of Boris Johnson Stanley Johnson Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False military, while Mayor Boris Johnson arranged for 5 [' military', ',', ' while', ' Mayor', ' Boris', ' Johnson']
+3895 945 Name of father of x -1 Name of father of Boris Johnson Stanley Johnson Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False military, while Mayor Boris Johnson arranged for the 5 [' military', ',', ' while', ' Mayor', ' Boris', ' Johnson']
+3896 946 Name of father of x -1 Name of father of Bjørnstjerne Bjørnson Peder Bjørnson Bjørnstjerne Bjørnson "['\n' '\n' 'B' 'j' 'ø' 'rn' 'st' 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ B j ø rn st jer ne Bj ø rn son ( ; ; ; ; ; ;" False Norwegian poet Bjørnstjerne Bjørnson and Icelandic sagas. 11 [' Norwegian', ' poet', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3897 946 Name of father of x -1 Name of father of Bjørnstjerne Bjørnson Peder Bjørnson Bjørnstjerne Bjørnson "['\n' '\n' 'B' 'j' 'ø' 'rn' 'st' 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ B j ø rn st jer ne Bj ø rn son ( ; ; ; ; ; ;" False and novelist Bjørnstjerne Bjørnson was the guest 11 [' and', ' novelist', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3898 946 Name of father of x -1 Name of father of Bjørnstjerne Bjørnson Peder Bjørnson Bjørnstjerne Bjørnson "['\n' '\n' 'B' 'j' 'ø' 'rn' 'st' 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ B j ø rn st jer ne Bj ø rn son ( ; ; ; ; ; ;" False for lèse majesté. Bjørnstjerne Bjørnson and Lars Holst were 17 [' for', ' l', 'è', 'se', ' maj', 'est', 'é', '.', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3899 946 Name of father of x -1 Name of father of Bjørnstjerne Bjørnson Peder Bjørnson Bjørnstjerne Bjørnson "['\n' '\n' 'B' 'j' 'ø' 'rn' 'st' 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ B j ø rn st jer ne Bj ø rn son ( ; ; ; ; ; ;" False for lèse majesté. Bjørnstjerne Bjørnson and Lars Holst were 17 [' for', ' l', 'è', 'se', ' maj', 'est', 'é', '.', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3900 946 Name of father of x -1 Name of father of Bjørnstjerne Bjørnson Peder Bjørnson Bjørnstjerne Bjørnson "['\n' '\n' 'B' 'j' 'ø' 'rn' 'st' 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' ';'
+ ' ;' ' ;' ' ;' ' ;' ' ;']" "
+
+ B j ø rn st jer ne Bj ø rn son ( ; ; ; ; ; ;" False literature. Later Bjørnstjerne Bjørnson wrote a well-known 12 [' literature', '.', ' Later', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3901 947 Name of father of x -1 Name of father of Franz Lehár Franz Lehár Sr. Franz Lehár "[',' ' the' ' composer' ' of' ' the' ' o' 'pe' 'rett' 'a' ' _' 'The'
+ ' Merry' ' Widow' '_' ',' ' and' ' the' ' lib' 'rett' 'ist']" , the composer of the o pe rett a _ The Merry Widow _ , and the lib rett ist False composers, including Franz Lehár and Johann Strauss, 7 [' compos', 'ers', ',', ' including', ' Franz', ' Leh', 'á', 'r']
+3902 947 Name of father of x -1 Name of father of Franz Lehár Franz Lehár Sr. Franz Lehár "[',' ' the' ' composer' ' of' ' the' ' o' 'pe' 'rett' 'a' ' _' 'The'
+ ' Merry' ' Widow' '_' ',' ' and' ' the' ' lib' 'rett' 'ist']" , the composer of the o pe rett a _ The Merry Widow _ , and the lib rett ist False composers, including Franz Lehár and Johann Strauss, 7 [' compos', 'ers', ',', ' including', ' Franz', ' Leh', 'á', 'r']
+3903 947 Name of father of x -1 Name of father of Franz Lehár Franz Lehár Sr. Franz Lehár "[',' ' the' ' composer' ' of' ' the' ' o' 'pe' 'rett' 'a' ' _' 'The'
+ ' Merry' ' Widow' '_' ',' ' and' ' the' ' lib' 'rett' 'ist']" , the composer of the o pe rett a _ The Merry Widow _ , and the lib rett ist False composers, including Franz Lehár and Johann 7 [' compos', 'ers', ',', ' including', ' Franz', ' Leh', 'á', 'r']
+3904 947 Name of father of x -1 Name of father of Franz Lehár Franz Lehár Sr. Franz Lehár "[',' ' the' ' composer' ' of' ' the' ' o' 'pe' 'rett' 'a' ' _' 'The'
+ ' Merry' ' Widow' '_' ',' ' and' ' the' ' lib' 'rett' 'ist']" , the composer of the o pe rett a _ The Merry Widow _ , and the lib rett ist False composers, including Franz Lehár and Johann Strauss, 7 [' compos', 'ers', ',', ' including', ' Franz', ' Leh', 'á', 'r']
+3905 948 Name of father of x -1 Name of father of Olivia de Havilland Walter Augustus de Havilland Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' house' ' was' ' a' ' large' ',' ' r' 'ambling' ',']" ", who was a friend of the family .
+
+ The house was a large , r ambling ," False involving actress Olivia de Havilland decades before. Jared 6 [' involving', ' actress', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3906 948 Name of father of x -1 Name of father of Olivia de Havilland Walter Augustus de Havilland Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' house' ' was' ' a' ' large' ',' ' r' 'ambling' ',']" ", who was a friend of the family .
+
+ The house was a large , r ambling ," False " Havilland =
+" 10 [' Hav', 'ill', 'and', ' =', 'O', 'liv', 'ia', ' de', ' Hav', 'ill', 'and']
+3907 948 Name of father of x -1 Name of father of Olivia de Havilland Walter Augustus de Havilland Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' house' ' was' ' a' ' large' ',' ' r' 'ambling' ',']" ", who was a friend of the family .
+
+ The house was a large , r ambling ," False as happened with Olivia de Havilland and Bette Davis. 7 [' as', ' happened', ' with', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3908 948 Name of father of x -1 Name of father of Olivia de Havilland Walter Augustus de Havilland Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' house' ' was' ' a' ' large' ',' ' r' 'ambling' ',']" ", who was a friend of the family .
+
+ The house was a large , r ambling ," False " Flinn identifying Olivia de Havilland as a ""standout""" 7 [' Fl', 'inn', ' identifying', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3909 948 Name of father of x -1 Name of father of Olivia de Havilland Walter Augustus de Havilland Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' house' ' was' ' a' ' large' ',' ' r' 'ambling' ',']" ", who was a friend of the family .
+
+ The house was a large , r ambling ," False and later, Olivia de Havilland over Leigh, 7 [' and', ' later', ',', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3910 950 Name of father of x -1 Name of father of August Wilhelm Schlegel Johann Adolf Schlegel August Wilhelm Schlegel "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' August' ' Wilhelm'
+ ' Sch' 'leg' 'el' ',' ' the' '\n' '\n' 'Name' ' of' ' father']" ", the
+
+ Name of father of August Wilhelm Sch leg el , the
+
+ Name of father" False the brothers August Wilhelm Schlegel (1767 – 1845) and 6 [' the', ' brothers', ' August', ' Wilhelm', ' Sch', 'leg', 'el']
+3911 950 Name of father of x -1 Name of father of August Wilhelm Schlegel Johann Adolf Schlegel August Wilhelm Schlegel "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' August' ' Wilhelm'
+ ' Sch' 'leg' 'el' ',' ' the' '\n' '\n' 'Name' ' of' ' father']" ", the
+
+ Name of father of August Wilhelm Sch leg el , the
+
+ Name of father" False 1831), the brothers August Wilhelm Schlegel (1767 – 1845) and Karl 9 [' 18', '31', '),', ' the', ' brothers', ' August', ' Wilhelm', ' Sch', 'leg', 'el']
+3912 951 Name of father of x -1 Name of father of Zbigniew Brzezinski Tadeusz Brzeziński Zbigniew Brzezinski "[',' ' the' ' former' ' national' ' security' ' adviser' ' to'
+ ' President' ' Jimmy' ' Carter' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Tr' 'ilateral' ' Commission']" , the former national security adviser to President Jimmy Carter , who was a member of the Tr ilateral Commission False scientists, such as Zbigniew Brzezinski and Frank Gibney, 10 [' scientists', ',', ' such', ' as', ' Z', 'b', 'ign', 'iew', ' Br', 'ze', 'zinski']
+3913 951 Name of father of x -1 Name of father of Zbigniew Brzezinski Tadeusz Brzeziński Zbigniew Brzezinski "[',' ' the' ' former' ' national' ' security' ' adviser' ' to'
+ ' President' ' Jimmy' ' Carter' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Tr' 'ilateral' ' Commission']" , the former national security adviser to President Jimmy Carter , who was a member of the Tr ilateral Commission False scientists, such as Zbigniew Brzezinski and Frank Gibney, 10 [' scientists', ',', ' such', ' as', ' Z', 'b', 'ign', 'iew', ' Br', 'ze', 'zinski']
+3914 951 Name of father of x -1 Name of father of Zbigniew Brzezinski Tadeusz Brzeziński Zbigniew Brzezinski "[',' ' the' ' former' ' national' ' security' ' adviser' ' to'
+ ' President' ' Jimmy' ' Carter' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Tr' 'ilateral' ' Commission']" , the former national security adviser to President Jimmy Carter , who was a member of the Tr ilateral Commission False scientists, such as Zbigniew Brzezinski and Frank Gibney, 10 [' scientists', ',', ' such', ' as', ' Z', 'b', 'ign', 'iew', ' Br', 'ze', 'zinski']
+3915 952 Name of father of x -1 Name of father of Arthur Balfour James Maitland Balfour Arthur Balfour "[',' ' the' '\n' '\n' '1' '.' '\n' '\n' '2' '.' '\n' '\n' '3' '.' '\n'
+ '\n' '4' '.' '\n' '\n']" ", the
+
+ 1 .
+
+ 2 .
+
+ 3 .
+
+ 4 .
+
+" False Foreign Secretary Arthur Balfour contacted Baron 5 [' Foreign', ' Secretary', ' Arthur', ' B', 'alf', 'our']
+3916 952 Name of father of x -1 Name of father of Arthur Balfour James Maitland Balfour Arthur Balfour "[',' ' the' '\n' '\n' '1' '.' '\n' '\n' '2' '.' '\n' '\n' '3' '.' '\n'
+ '\n' '4' '.' '\n' '\n']" ", the
+
+ 1 .
+
+ 2 .
+
+ 3 .
+
+ 4 .
+
+" False Foreign Secretary Arthur Balfour contacted Baron 5 [' Foreign', ' Secretary', ' Arthur', ' B', 'alf', 'our']
+3917 952 Name of father of x -1 Name of father of Arthur Balfour James Maitland Balfour Arthur Balfour "[',' ' the' '\n' '\n' '1' '.' '\n' '\n' '2' '.' '\n' '\n' '3' '.' '\n'
+ '\n' '4' '.' '\n' '\n']" ", the
+
+ 1 .
+
+ 2 .
+
+ 3 .
+
+ 4 .
+
+" False Foreign Secretary Arthur Balfour contacted Baron Rothschild, 5 [' Foreign', ' Secretary', ' Arthur', ' B', 'alf', 'our']
+3918 952 Name of father of x -1 Name of father of Arthur Balfour James Maitland Balfour Arthur Balfour "[',' ' the' '\n' '\n' '1' '.' '\n' '\n' '2' '.' '\n' '\n' '3' '.' '\n'
+ '\n' '4' '.' '\n' '\n']" ", the
+
+ 1 .
+
+ 2 .
+
+ 3 .
+
+ 4 .
+
+" False Conservative leaders Arthur Balfour and Lord Lansdowne 5 [' Conservative', ' leaders', ' Arthur', ' B', 'alf', 'our']
+3919 952 Name of father of x -1 Name of father of Arthur Balfour James Maitland Balfour Arthur Balfour "[',' ' the' '\n' '\n' '1' '.' '\n' '\n' '2' '.' '\n' '\n' '3' '.' '\n'
+ '\n' '4' '.' '\n' '\n']" ", the
+
+ 1 .
+
+ 2 .
+
+ 3 .
+
+ 4 .
+
+" False Conservative leaders Arthur Balfour and Lord Lansdowne 5 [' Conservative', ' leaders', ' Arthur', ' B', 'alf', 'our']
+3920 953 Name of father of x -1 Name of father of Jacques Cousteau Daniel Cousteau Jacques Cousteau "[',' ' the' ' famous' ' French' ' ocean' 'ographer' ',' ' who' ' died'
+ ' in' ' 1997' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the']" ", the famous French ocean ographer , who died in 1997 .
+
+ The first time I saw the" False oceanographic researcher Jacques Cousteau described the 6 [' ocean', 'ographic', ' researcher', ' Jacques', ' Cou', 'ste', 'au']
+3921 953 Name of father of x -1 Name of father of Jacques Cousteau Daniel Cousteau Jacques Cousteau "[',' ' the' ' famous' ' French' ' ocean' 'ographer' ',' ' who' ' died'
+ ' in' ' 1997' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the']" ", the famous French ocean ographer , who died in 1997 .
+
+ The first time I saw the" False and, in the 1940s, Jacques Cousteau helped develop the 10 [' and', ',', ' in', ' the', ' 1940', 's', ',', ' Jacques', ' Cou', 'ste', 'au']
+3922 953 Name of father of x -1 Name of father of Jacques Cousteau Daniel Cousteau Jacques Cousteau "[',' ' the' ' famous' ' French' ' ocean' 'ographer' ',' ' who' ' died'
+ ' in' ' 1997' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the']" ", the famous French ocean ographer , who died in 1997 .
+
+ The first time I saw the" False underwater explorer Jacques Cousteau began diving 5 [' underwater', ' explorer', ' Jacques', ' Cou', 'ste', 'au']
+3923 953 Name of father of x -1 Name of father of Jacques Cousteau Daniel Cousteau Jacques Cousteau "[',' ' the' ' famous' ' French' ' ocean' 'ographer' ',' ' who' ' died'
+ ' in' ' 1997' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the']" ", the famous French ocean ographer , who died in 1997 .
+
+ The first time I saw the" False and, in the 1940s, Jacques Cousteau helped develop 10 [' and', ',', ' in', ' the', ' 1940', 's', ',', ' Jacques', ' Cou', 'ste', 'au']
+3924 953 Name of father of x -1 Name of father of Jacques Cousteau Daniel Cousteau Jacques Cousteau "[',' ' the' ' famous' ' French' ' ocean' 'ographer' ',' ' who' ' died'
+ ' in' ' 1997' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the']" ", the famous French ocean ographer , who died in 1997 .
+
+ The first time I saw the" False scuba dive with Jacques Cousteau in 1953 provided 7 [' sc', 'uba', ' dive', ' with', ' Jacques', ' Cou', 'ste', 'au']
+3925 955 Name of father of x -1 Name of father of Taras Shevchenko Hryhoriy Ivanovych Shevchenko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False 4 ['Tar', 'as', ' She', 'v', 'chenko']
+3926 955 Name of father of x -1 Name of father of Taras Shevchenko Hryhoriy Ivanovych Shevchenko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False of Ukraine Taras Shevchenko (posthumously, 1976 6 [' of', ' Ukraine', ' Tar', 'as', ' She', 'v', 'chenko']
+3927 955 Name of father of x -1 Name of father of Taras Shevchenko Hryhoriy Ivanovych Shevchenko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False on a poem by Taras Shevchenko titled “ Зоре моя вечірняя 8 [' on', ' a', ' poem', ' by', ' Tar', 'as', ' She', 'v', 'chenko']
+3928 955 Name of father of x -1 Name of father of Taras Shevchenko Hryhoriy Ivanovych Shevchenko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False National Prize of Ukraine Taras Shevchenko (posthumously, 8 [' National', ' Prize', ' of', ' Ukraine', ' Tar', 'as', ' She', 'v', 'chenko']
+3929 955 Name of father of x -1 Name of father of Taras Shevchenko Hryhoriy Ivanovych Shevchenko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False of Ukraine Taras Shevchenko (posthumously, 1976 6 [' of', ' Ukraine', ' Tar', 'as', ' She', 'v', 'chenko']
+3930 956 Name of father of x -1 Name of father of Jayne Mansfield Herbert William Palmer Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False Wild, Wild World of Jayne Mansfield (1968) included nude 8 [' Wild', ',', ' Wild', ' World', ' of', ' Jay', 'ne', ' Mans', 'field']
+3931 956 Name of father of x -1 Name of father of Jayne Mansfield Herbert William Palmer Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False on the planet Jayne Mansfield. In the episode of 6 [' on', ' the', ' planet', ' Jay', 'ne', ' Mans', 'field']
+3932 956 Name of father of x -1 Name of father of Jayne Mansfield Herbert William Palmer Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False and names it Jayne Mansfield because she had 6 [' and', ' names', ' it', ' Jay', 'ne', ' Mans', 'field']
+3933 956 Name of father of x -1 Name of father of Jayne Mansfield Herbert William Palmer Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False " says, ""You did the Jayne Mansfield crash without me?""" 9 "[' says', ',', ' ""', 'You', ' did', ' the', ' Jay', 'ne', ' Mans', 'field']"
+3934 956 Name of father of x -1 Name of father of Jayne Mansfield Herbert William Palmer Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False 3 ['Jay', 'ne', ' Mans', 'field']
+3935 957 Name of father of x -1 Name of father of Hans Bethe Albrecht Bethe Hans Bethe "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Manhattan' ' Project' ',' ' the' ' team' ' that' ' developed'
+ ' the' ' atomic']" , the famous physicist , who was a member of the Manhattan Project , the team that developed the atomic False Scientists like Hans Bethe and George Gamow 4 [' Scientists', ' like', ' Hans', ' Bet', 'he']
+3936 957 Name of father of x -1 Name of father of Hans Bethe Albrecht Bethe Hans Bethe "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Manhattan' ' Project' ',' ' the' ' team' ' that' ' developed'
+ ' the' ' atomic']" , the famous physicist , who was a member of the Manhattan Project , the team that developed the atomic False core, including Hans Bethe who had worked 5 [' core', ',', ' including', ' Hans', ' Bet', 'he']
+3937 957 Name of father of x -1 Name of father of Hans Bethe Albrecht Bethe Hans Bethe "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Manhattan' ' Project' ',' ' the' ' team' ' that' ' developed'
+ ' the' ' atomic']" , the famous physicist , who was a member of the Manhattan Project , the team that developed the atomic False accepted an offer from Hans Bethe to work with 6 [' accepted', ' an', ' offer', ' from', ' Hans', ' Bet', 'he']
+3938 957 Name of father of x -1 Name of father of Hans Bethe Albrecht Bethe Hans Bethe "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Manhattan' ' Project' ',' ' the' ' team' ' that' ' developed'
+ ' the' ' atomic']" , the famous physicist , who was a member of the Manhattan Project , the team that developed the atomic False with his friend Hans Bethe in developing 5 [' with', ' his', ' friend', ' Hans', ' Bet', 'he']
+3939 957 Name of father of x -1 Name of father of Hans Bethe Albrecht Bethe Hans Bethe "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Manhattan' ' Project' ',' ' the' ' team' ' that' ' developed'
+ ' the' ' atomic']" , the famous physicist , who was a member of the Manhattan Project , the team that developed the atomic False " Hans Bethe =
+" 2 [' Hans', ' Bet', 'he']
+3940 958 Name of father of x -1 Name of father of Genghis Khan Yesugei Genghis Khan "[',' ' the' ' great' ' conquer' 'or' ' of' ' the' ' world' ',' ' and'
+ ' the' ' founder' ' of' ' the' ' Mongol' ' Empire' '.' '\n' '\n' 'The']" ", the great conquer or of the world , and the founder of the Mongol Empire .
+
+ The" False " Siberia: ""While Genghis Khan was holding" 7 "[' Siberia', ':', ' ""', 'While', ' Gen', 'gh', 'is', ' Khan']"
+3941 958 Name of father of x -1 Name of father of Genghis Khan Yesugei Genghis Khan "[',' ' the' ' great' ' conquer' 'or' ' of' ' the' ' world' ',' ' and'
+ ' the' ' founder' ' of' ' the' ' Mongol' ' Empire' '.' '\n' '\n' 'The']" ", the great conquer or of the world , and the founder of the Mongol Empire .
+
+ The" False into Kievan Rus'; Genghis Khan defeated Turkic 9 "[' into', ' K', 'ie', 'van', ' Rus', ""';"", ' Gen', 'gh', 'is', ' Khan']"
+3942 958 Name of father of x -1 Name of father of Genghis Khan Yesugei Genghis Khan "[',' ' the' ' great' ' conquer' 'or' ' of' ' the' ' world' ',' ' and'
+ ' the' ' founder' ' of' ' the' ' Mongol' ' Empire' '.' '\n' '\n' 'The']" ", the great conquer or of the world , and the founder of the Mongol Empire .
+
+ The" False Empire beginning with Genghis Khan in the early 6 [' Empire', ' beginning', ' with', ' Gen', 'gh', 'is', ' Khan']
+3943 958 Name of father of x -1 Name of father of Genghis Khan Yesugei Genghis Khan "[',' ' the' ' great' ' conquer' 'or' ' of' ' the' ' world' ',' ' and'
+ ' the' ' founder' ' of' ' the' ' Mongol' ' Empire' '.' '\n' '\n' 'The']" ", the great conquer or of the world , and the founder of the Mongol Empire .
+
+ The" False between Prester John and Genghis Khan was elaborated 8 [' between', ' Pre', 'ster', ' John', ' and', ' Gen', 'gh', 'is', ' Khan']
+3944 958 Name of father of x -1 Name of father of Genghis Khan Yesugei Genghis Khan "[',' ' the' ' great' ' conquer' 'or' ' of' ' the' ' world' ',' ' and'
+ ' the' ' founder' ' of' ' the' ' Mongol' ' Empire' '.' '\n' '\n' 'The']" ", the great conquer or of the world , and the founder of the Mongol Empire .
+
+ The" False and Song. Genghis Khan had died in 1227 during 6 [' and', ' Song', '.', ' Gen', 'gh', 'is', ' Khan']
+3945 959 Name of father of x -1 Name of father of Trajan Marcus Ulpius Traianus Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False 1 ['Tra', 'jan']
+3946 959 Name of father of x -1 Name of father of Trajan Marcus Ulpius Traianus Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False 198 AD, failing as Trajan once did to capture 6 [' 198', ' AD', ',', ' failing', ' as', ' Tra', 'jan']
+3947 959 Name of father of x -1 Name of father of Trajan Marcus Ulpius Traianus Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False lost Justice of Trajan and Herkinbald, 4 [' lost', ' Justice', ' of', ' Tra', 'jan']
+3948 959 Name of father of x -1 Name of father of Trajan Marcus Ulpius Traianus Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False months of 116 AD, Trajan captured the Persian 6 [' months', ' of', ' 116', ' AD', ',', ' Tra', 'jan']
+3949 959 Name of father of x -1 Name of father of Trajan Marcus Ulpius Traianus Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False Further east, Trajan turned his 4 [' Further', ' east', ',', ' Tra', 'jan']
+3950 960 Name of father of x -1 Name of father of Johan Huizinga Dirk Huizinga Johan Huizinga "[',' ' the' ' Dutch' ' historian' ' of' ' the' ' Renaissance' ',' ' who'
+ ' was' ' born' ' in' ' the' ' Netherlands' ' in' ' 1919' '.' '\n' '\n'
+ 'The']" ", the Dutch historian of the Renaissance , who was born in the Netherlands in 1919 .
+
+ The" False 20th century. Johan Huizinga was the first 8 [' 20', 'th', ' century', '.', ' Joh', 'an', ' Hu', 'izing', 'a']
+3951 960 Name of father of x -1 Name of father of Johan Huizinga Dirk Huizinga Johan Huizinga "[',' ' the' ' Dutch' ' historian' ' of' ' the' ' Renaissance' ',' ' who'
+ ' was' ' born' ' in' ' the' ' Netherlands' ' in' ' 1919' '.' '\n' '\n'
+ 'The']" ", the Dutch historian of the Renaissance , who was born in the Netherlands in 1919 .
+
+ The" False evocative and alive. Johan Huizinga said that art 9 [' ev', 'ocative', ' and', ' alive', '.', ' Joh', 'an', ' Hu', 'izing', 'a']
+3952 960 Name of father of x -1 Name of father of Johan Huizinga Dirk Huizinga Johan Huizinga "[',' ' the' ' Dutch' ' historian' ' of' ' the' ' Renaissance' ',' ' who'
+ ' was' ' born' ' in' ' the' ' Netherlands' ' in' ' 1919' '.' '\n' '\n'
+ 'The']" ", the Dutch historian of the Renaissance , who was born in the Netherlands in 1919 .
+
+ The" False and alive. Johan Huizinga said that art of 7 [' and', ' alive', '.', ' Joh', 'an', ' Hu', 'izing', 'a']
+3953 960 Name of father of x -1 Name of father of Johan Huizinga Dirk Huizinga Johan Huizinga "[',' ' the' ' Dutch' ' historian' ' of' ' the' ' Renaissance' ',' ' who'
+ ' was' ' born' ' in' ' the' ' Netherlands' ' in' ' 1919' '.' '\n' '\n'
+ 'The']" ", the Dutch historian of the Renaissance , who was born in the Netherlands in 1919 .
+
+ The" False evocative and alive. Johan Huizinga said that art of the 9 [' ev', 'ocative', ' and', ' alive', '.', ' Joh', 'an', ' Hu', 'izing', 'a']
+3954 960 Name of father of x -1 Name of father of Johan Huizinga Dirk Huizinga Johan Huizinga "[',' ' the' ' Dutch' ' historian' ' of' ' the' ' Renaissance' ',' ' who'
+ ' was' ' born' ' in' ' the' ' Netherlands' ' in' ' 1919' '.' '\n' '\n'
+ 'The']" ", the Dutch historian of the Renaissance , who was born in the Netherlands in 1919 .
+
+ The" False 4 ['J', 'ohan', ' Hu', 'izing', 'a']
+3955 961 Name of father of x -1 Name of father of Arrigo Boito Silvestro Boito Arrigo Boito "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ',' ' and' ' died'
+ ' in' ' 18' '91' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 44 , and died in 18 91 .
+
+ The first" False libretto was adapted by Arrigo Boito from Shakespeare's 10 [' lib', 'rett', 'o', ' was', ' adapted', ' by', ' Ar', 'rig', 'o', ' Bo', 'ito']
+3956 961 Name of father of x -1 Name of father of Arrigo Boito Silvestro Boito Arrigo Boito "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ',' ' and' ' died'
+ ' in' ' 18' '91' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 44 , and died in 18 91 .
+
+ The first" False Falstaff of Arrigo Boito and Giuseppe 7 [' Fal', 'staff', ' of', ' Ar', 'rig', 'o', ' Bo', 'ito']
+3957 961 Name of father of x -1 Name of father of Arrigo Boito Silvestro Boito Arrigo Boito "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ',' ' and' ' died'
+ ' in' ' 18' '91' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 44 , and died in 18 91 .
+
+ The first" False the Falstaff of Arrigo Boito and Giuseppe Verdi 8 [' the', ' Fal', 'staff', ' of', ' Ar', 'rig', 'o', ' Bo', 'ito']
+3958 961 Name of father of x -1 Name of father of Arrigo Boito Silvestro Boito Arrigo Boito "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ',' ' and' ' died'
+ ' in' ' 18' '91' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 44 , and died in 18 91 .
+
+ The first" False was adapted by Arrigo Boito from Shakespeare's 7 [' was', ' adapted', ' by', ' Ar', 'rig', 'o', ' Bo', 'ito']
+3959 961 Name of father of x -1 Name of father of Arrigo Boito Silvestro Boito Arrigo Boito "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ',' ' and' ' died'
+ ' in' ' 18' '91' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 44 , and died in 18 91 .
+
+ The first" False adapted by Arrigo Boito from Shakespeare's 6 [' adapted', ' by', ' Ar', 'rig', 'o', ' Bo', 'ito']
+3960 963 Name of father of x -1 Name of father of Sylvia Plath Otto Plath Sylvia Plath "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' daughter' ',' ' and'
+ ' the' ' poet' ""'s"" ' daughter' ""'s"" ' daughter' ',' ' and' ' the']" , the poet , and the poet 's daughter , and the poet 's daughter 's daughter , and the False writings of Sylvia Plath at the time. The 4 [' writings', ' of', ' Sylvia', ' Pl', 'ath']
+3961 963 Name of father of x -1 Name of father of Sylvia Plath Otto Plath Sylvia Plath "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' daughter' ',' ' and'
+ ' the' ' poet' ""'s"" ' daughter' ""'s"" ' daughter' ',' ' and' ' the']" , the poet , and the poet 's daughter , and the poet 's daughter 's daughter , and the False depression and interest in Sylvia Plath while recording 6 [' depression', ' and', ' interest', ' in', ' Sylvia', ' Pl', 'ath']
+3962 963 Name of father of x -1 Name of father of Sylvia Plath Otto Plath Sylvia Plath "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' daughter' ',' ' and'
+ ' the' ' poet' ""'s"" ' daughter' ""'s"" ' daughter' ',' ' and' ' the']" , the poet , and the poet 's daughter , and the poet 's daughter 's daughter , and the False in poetry, with Sylvia Plath being her favourite. 6 [' in', ' poetry', ',', ' with', ' Sylvia', ' Pl', 'ath']
+3963 963 Name of father of x -1 Name of father of Sylvia Plath Otto Plath Sylvia Plath "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' daughter' ',' ' and'
+ ' the' ' poet' ""'s"" ' daughter' ""'s"" ' daughter' ',' ' and' ' the']" , the poet , and the poet 's daughter , and the poet 's daughter 's daughter , and the False authors such as Sylvia Plath and Ted Hughes 5 [' authors', ' such', ' as', ' Sylvia', ' Pl', 'ath']
+3964 963 Name of father of x -1 Name of father of Sylvia Plath Otto Plath Sylvia Plath "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' daughter' ',' ' and'
+ ' the' ' poet' ""'s"" ' daughter' ""'s"" ' daughter' ',' ' and' ' the']" , the poet , and the poet 's daughter , and the poet 's daughter 's daughter , and the False Roethke and Sylvia Plath wrote villanelles 6 [' Ro', 'eth', 'ke', ' and', ' Sylvia', ' Pl', 'ath']
+3965 964 Name of father of x -1 Name of father of Saddam Hussein Hussein 'Abid al-Majid Saddam Hussein "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' the' ' father' ' of'
+ ' the' ' nation' ',' ' the' ' father' ' of' ' the' ' nation' ',' ' the']" , the father of the nation , the father of the nation , the father of the nation , the False " necessary because Saddam Hussein ""had not abandoned" 3 [' necessary', ' because', ' Saddam', ' Hussein']
+3966 964 Name of father of x -1 Name of father of Saddam Hussein Hussein 'Abid al-Majid Saddam Hussein "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' the' ' father' ' of'
+ ' the' ' nation' ',' ' the' ' father' ' of' ' the' ' nation' ',' ' the']" , the father of the nation , the father of the nation , the father of the nation , the False " discovery of dictator Saddam Hussein in a ""spider hole""" 4 [' discovery', ' of', ' dictator', ' Saddam', ' Hussein']
+3967 964 Name of father of x -1 Name of father of Saddam Hussein Hussein 'Abid al-Majid Saddam Hussein "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' the' ' father' ' of'
+ ' the' ' nation' ',' ' the' ' father' ' of' ' the' ' nation' ',' ' the']" , the father of the nation , the father of the nation , the father of the nation , the False East until Saddam Hussein was toppled. 3 [' East', ' until', ' Saddam', ' Hussein']
+3968 964 Name of father of x -1 Name of father of Saddam Hussein Hussein 'Abid al-Majid Saddam Hussein "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' the' ' father' ' of'
+ ' the' ' nation' ',' ' the' ' father' ' of' ' the' ' nation' ',' ' the']" , the father of the nation , the father of the nation , the father of the nation , the False rejoice when Saddam Hussein is sentenced, 3 [' rejoice', ' when', ' Saddam', ' Hussein']
+3969 964 Name of father of x -1 Name of father of Saddam Hussein Hussein 'Abid al-Majid Saddam Hussein "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' the' ' father' ' of'
+ ' the' ' nation' ',' ' the' ' father' ' of' ' the' ' nation' ',' ' the']" , the father of the nation , the father of the nation , the father of the nation , the False McCormick, as well as Saddam Hussein (even though during 7 [' McCorm', 'ick', ',', ' as', ' well', ' as', ' Saddam', ' Hussein']
+3970 966 Name of father of x -1 Name of father of André Previn Jack Priwin André Previn "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' mother' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's mother , the pian False with pianist André Previn and with members 6 [' with', ' pian', 'ist', ' And', 'ré', ' Pre', 'vin']
+3971 966 Name of father of x -1 Name of father of André Previn Jack Priwin André Previn "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' mother' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's mother , the pian False Ormandy, Prêtre, André Previn and Leopold Stokowski. 11 [' Or', 'm', 'andy', ',', ' Pr', 'ê', 'tre', ',', ' And', 'ré', ' Pre', 'vin']
+3972 966 Name of father of x -1 Name of father of André Previn Jack Priwin André Previn "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' mother' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's mother , the pian False Weller 1980 – 85, André Previn 1985 – 92, Vladimir 9 [' Well', 'er', ' 1980', ' –', ' 85', ',', ' And', 'ré', ' Pre', 'vin']
+3973 966 Name of father of x -1 Name of father of André Previn Jack Priwin André Previn "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' mother' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's mother , the pian False Eugene Ormandy, Prêtre, André Previn and Leopold Stokowski. 12 [' Eugene', ' Or', 'm', 'andy', ',', ' Pr', 'ê', 'tre', ',', ' And', 'ré', ' Pre', 'vin']
+3974 966 Name of father of x -1 Name of father of André Previn Jack Priwin André Previn "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' and' ' the' ' composer' ""'s"" ' mother' ',' ' the' ' pian']" , the composer , and his wife , the pian ist , and the composer 's mother , the pian False was performed with André Previn as conductor 6 [' was', ' performed', ' with', ' And', 'ré', ' Pre', 'vin']
+3975 968 Name of father of x -1 Name of father of Nastassja Kinski Klaus Kinski Nastassja Kinski "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' woman' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was']" ", the actress who played the role of the young woman in the film .
+
+ The film was" False talked to actress Nastassja Kinski about playing 8 [' talked', ' to', ' actress', ' N', 'ast', 'ass', 'ja', ' K', 'inski']
+3976 968 Name of father of x -1 Name of father of Nastassja Kinski Klaus Kinski Nastassja Kinski "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' woman' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was']" ", the actress who played the role of the young woman in the film .
+
+ The film was" False The Border, and Nastassja Kinski in Cat People (all 9 [' The', ' Border', ',', ' and', ' N', 'ast', 'ass', 'ja', ' K', 'inski']
+3977 968 Name of father of x -1 Name of father of Nastassja Kinski Klaus Kinski Nastassja Kinski "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' woman' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was']" ", the actress who played the role of the young woman in the film .
+
+ The film was" False in The Border, and Nastassja Kinski in Cat People (all 10 [' in', ' The', ' Border', ',', ' and', ' N', 'ast', 'ass', 'ja', ' K', 'inski']
+3978 968 Name of father of x -1 Name of father of Nastassja Kinski Klaus Kinski Nastassja Kinski "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' young' ' woman' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' was']" ", the actress who played the role of the young woman in the film .
+
+ The film was" False talked to actress Nastassja Kinski about playing 8 [' talked', ' to', ' actress', ' N', 'ast', 'ass', 'ja', ' K', 'inski']
+3979 969 Name of father of x -1 Name of father of Robert F. Kennedy Joseph P. Kennedy Sr. Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' Kennedy' ' family' ' has' ' been' ' a'
+ ' part' ' of' ' the' ' American' ' political' ' landscape' ' for'
+ ' generations' '.']" ", Jr .
+
+ The Kennedy family has been a part of the American political landscape for generations ." False national leaders, such as Robert F. Kennedy who was elected 8 [' national', ' leaders', ',', ' such', ' as', ' Robert', ' F', '.', ' Kennedy']
+3980 969 Name of father of x -1 Name of father of Robert F. Kennedy Joseph P. Kennedy Sr. Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' Kennedy' ' family' ' has' ' been' ' a'
+ ' part' ' of' ' the' ' American' ' political' ' landscape' ' for'
+ ' generations' '.']" ", Jr .
+
+ The Kennedy family has been a part of the American political landscape for generations ." False and Attorney General Robert F. Kennedy condemned the decision 6 [' and', ' Attorney', ' General', ' Robert', ' F', '.', ' Kennedy']
+3981 969 Name of father of x -1 Name of father of Robert F. Kennedy Joseph P. Kennedy Sr. Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' Kennedy' ' family' ' has' ' been' ' a'
+ ' part' ' of' ' the' ' American' ' political' ' landscape' ' for'
+ ' generations' '.']" ", Jr .
+
+ The Kennedy family has been a part of the American political landscape for generations ." False morning at Robert F. Kennedy Stadium where 5 [' morning', ' at', ' Robert', ' F', '.', ' Kennedy']
+3982 969 Name of father of x -1 Name of father of Robert F. Kennedy Joseph P. Kennedy Sr. Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' Kennedy' ' family' ' has' ' been' ' a'
+ ' part' ' of' ' the' ' American' ' political' ' landscape' ' for'
+ ' generations' '.']" ", Jr .
+
+ The Kennedy family has been a part of the American political landscape for generations ." False I Give was held at Robert F. Kennedy Memorial Stadium 8 [' I', ' Give', ' was', ' held', ' at', ' Robert', ' F', '.', ' Kennedy']
+3983 969 Name of father of x -1 Name of father of Robert F. Kennedy Joseph P. Kennedy Sr. Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'The' ' Kennedy' ' family' ' has' ' been' ' a'
+ ' part' ' of' ' the' ' American' ' political' ' landscape' ' for'
+ ' generations' '.']" ", Jr .
+
+ The Kennedy family has been a part of the American political landscape for generations ." False attention paid to Robert F. Kennedy during the Select 6 [' attention', ' paid', ' to', ' Robert', ' F', '.', ' Kennedy']
+3984 970 Name of father of x -1 Name of father of Francis I of France Charles Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False service of King Francis I of France – was the first European 6 [' service', ' of', ' King', ' Francis', ' I', ' of', ' France']
+3985 970 Name of father of x -1 Name of father of Francis I of France Charles Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False enlist the support of Francis I of France for his intended 7 [' enlist', ' the', ' support', ' of', ' Francis', ' I', ' of', ' France']
+3986 970 Name of father of x -1 Name of father of Francis I of France Charles Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " Henry VIII wrote to Francis I of France that ""Divine" 7 [' Henry', ' VIII', ' wrote', ' to', ' Francis', ' I', ' of', ' France']
+3987 970 Name of father of x -1 Name of father of Francis I of France Charles Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False comuneros, and King Francis I of France used the opportunity 9 [' com', 'un', 'eros', ',', ' and', ' King', ' Francis', ' I', ' of', ' France']
+3988 970 Name of father of x -1 Name of father of Francis I of France Charles Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False meeting with Francis I of France at the Field of the 5 [' meeting', ' with', ' Francis', ' I', ' of', ' France']
+3989 972 Name of father of x -1 Name of father of Jules Michelet Jean-François Furcy-Michelet Jules Michelet "[',' ' the' ' French' ' historian' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'French' ' Revolution' ',' ' and'
+ ' who']" ", the French historian , who was a great admire r of the
+
+ French Revolution , and who" False Ernest Renan and Jules Michelet were mobilized 7 [' Ernest', ' Ren', 'an', ' and', ' J', 'ules', ' Miche', 'let']
+3990 972 Name of father of x -1 Name of father of Jules Michelet Jean-François Furcy-Michelet Jules Michelet "[',' ' the' ' French' ' historian' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'French' ' Revolution' ',' ' and'
+ ' who']" ", the French historian , who was a great admire r of the
+
+ French Revolution , and who" False French historian Jules Michelet (1798 – 1874) defined 5 [' French', ' historian', ' J', 'ules', ' Miche', 'let']
+3991 972 Name of father of x -1 Name of father of Jules Michelet Jean-François Furcy-Michelet Jules Michelet "[',' ' the' ' French' ' historian' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'French' ' Revolution' ',' ' and'
+ ' who']" ", the French historian , who was a great admire r of the
+
+ French Revolution , and who" False replaced the cruiser Jules Michelet as the flagship. 6 [' replaced', ' the', ' cruiser', ' J', 'ules', ' Miche', 'let']
+3992 972 Name of father of x -1 Name of father of Jules Michelet Jean-François Furcy-Michelet Jules Michelet "[',' ' the' ' French' ' historian' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'French' ' Revolution' ',' ' and'
+ ' who']" ", the French historian , who was a great admire r of the
+
+ French Revolution , and who" False French historian Jules Michelet (1798 – 1874) in 5 [' French', ' historian', ' J', 'ules', ' Miche', 'let']
+3993 972 Name of father of x -1 Name of father of Jules Michelet Jean-François Furcy-Michelet Jules Michelet "[',' ' the' ' French' ' historian' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'French' ' Revolution' ',' ' and'
+ ' who']" ", the French historian , who was a great admire r of the
+
+ French Revolution , and who" False replaced the cruiser Jules Michelet as the flagship. 6 [' replaced', ' the', ' cruiser', ' J', 'ules', ' Miche', 'let']
+3994 973 Name of father of x -1 Name of father of Edward Gibbon Edward Gibbon Edward Gibbon "[',' ' the' ' historian' ',' ' and' ' the' ' author' ' of' ' the' ' _'
+ 'Decl' 'ine' ' and' ' Fall' ' of' ' the' ' Roman' ' Empire' '_' ',']" , the historian , and the author of the _ Decl ine and Fall of the Roman Empire _ , False of the Roman Empire, Edward Gibbon notes that the disciples 7 [' of', ' the', ' Roman', ' Empire', ',', ' Edward', ' Gib', 'bon']
+3995 973 Name of father of x -1 Name of father of Edward Gibbon Edward Gibbon Edward Gibbon "[',' ' the' ' historian' ',' ' and' ' the' ' author' ' of' ' the' ' _'
+ 'Decl' 'ine' ' and' ' Fall' ' of' ' the' ' Roman' ' Empire' '_' ',']" , the historian , and the author of the _ Decl ine and Fall of the Roman Empire _ , False historians and scholars. Edward Gibbon suggested its founding 6 [' historians', ' and', ' scholars', '.', ' Edward', ' Gib', 'bon']
+3996 973 Name of father of x -1 Name of father of Edward Gibbon Edward Gibbon Edward Gibbon "[',' ' the' ' historian' ',' ' and' ' the' ' author' ' of' ' the' ' _'
+ 'Decl' 'ine' ' and' ' Fall' ' of' ' the' ' Roman' ' Empire' '_' ',']" , the historian , and the author of the _ Decl ine and Fall of the Roman Empire _ , False 2 ['Edward', ' Gib', 'bon']
+3997 973 Name of father of x -1 Name of father of Edward Gibbon Edward Gibbon Edward Gibbon "[',' ' the' ' historian' ',' ' and' ' the' ' author' ' of' ' the' ' _'
+ 'Decl' 'ine' ' and' ' Fall' ' of' ' the' ' Roman' ' Empire' '_' ',']" , the historian , and the author of the _ Decl ine and Fall of the Roman Empire _ , False the Roman Empire by Edward Gibbon (1737 – 94) further 6 [' the', ' Roman', ' Empire', ' by', ' Edward', ' Gib', 'bon']
+3998 973 Name of father of x -1 Name of father of Edward Gibbon Edward Gibbon Edward Gibbon "[',' ' the' ' historian' ',' ' and' ' the' ' author' ' of' ' the' ' _'
+ 'Decl' 'ine' ' and' ' Fall' ' of' ' the' ' Roman' ' Empire' '_' ',']" , the historian , and the author of the _ Decl ine and Fall of the Roman Empire _ , False 2 ['Edward', ' Gib', 'bon']
+3999 974 Name of father of x -1 Name of father of Elizabeth Warren Donald Jones Herring Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' president' ',' ' and'
+ ' the' ' daughter' ' of' ' a' ' former' ' vice' ' president' '.' '\n'
+ '\n' 'The']" ", the daughter of a former president , and the daughter of a former vice president .
+
+ The" False that listed Elizabeth Warren ’ s great-great-great 3 [' that', ' listed', ' Elizabeth', ' Warren']
+4000 974 Name of father of x -1 Name of father of Elizabeth Warren Donald Jones Herring Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' president' ',' ' and'
+ ' the' ' daughter' ' of' ' a' ' former' ' vice' ' president' '.' '\n'
+ '\n' 'The']" ", the daughter of a former president , and the daughter of a former vice president .
+
+ The" False application that listed Elizabeth Warren ’ s great-great-great 4 [' application', ' that', ' listed', ' Elizabeth', ' Warren']
+4001 974 Name of father of x -1 Name of father of Elizabeth Warren Donald Jones Herring Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' president' ',' ' and'
+ ' the' ' daughter' ' of' ' a' ' former' ' vice' ' president' '.' '\n'
+ '\n' 'The']" ", the daughter of a former president , and the daughter of a former vice president .
+
+ The" False Capuano joined Senator Elizabeth Warren and Boston Mayor 6 [' Cap', 'u', 'ano', ' joined', ' Senator', ' Elizabeth', ' Warren']
+4002 974 Name of father of x -1 Name of father of Elizabeth Warren Donald Jones Herring Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' president' ',' ' and'
+ ' the' ' daughter' ' of' ' a' ' former' ' vice' ' president' '.' '\n'
+ '\n' 'The']" ", the daughter of a former president , and the daughter of a former vice president .
+
+ The" False joined Senator Elizabeth Warren and Boston 3 [' joined', ' Senator', ' Elizabeth', ' Warren']
+4003 974 Name of father of x -1 Name of father of Elizabeth Warren Donald Jones Herring Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' president' ',' ' and'
+ ' the' ' daughter' ' of' ' a' ' former' ' vice' ' president' '.' '\n'
+ '\n' 'The']" ", the daughter of a former president , and the daughter of a former vice president .
+
+ The" False calling for Elizabeth Warren to run for President 3 [' calling', ' for', ' Elizabeth', ' Warren']
+4004 975 Name of father of x -1 Name of father of Alexander Hamilton James A. Hamilton Alexander Hamilton "[',' ' the' ' first' ' U' '.' 'S' '.' ' Secretary' ' of' ' the'
+ ' Treasury' ',' ' and' ' the' ' first' ' U' '.' 'S' '.' ' Secretary']" , the first U . S . Secretary of the Treasury , and the first U . S . Secretary False and surgeon Alexander Hamilton Rice, Jr., a 3 [' and', ' surgeon', ' Alexander', ' Hamilton']
+4005 975 Name of father of x -1 Name of father of Alexander Hamilton James A. Hamilton Alexander Hamilton "[',' ' the' ' first' ' U' '.' 'S' '.' ' Secretary' ' of' ' the'
+ ' Treasury' ',' ' and' ' the' ' first' ' U' '.' 'S' '.' ' Secretary']" , the first U . S . Secretary of the Treasury , and the first U . S . Secretary False One statue honoring Alexander Hamilton in Chicago was 4 [' One', ' statue', ' honoring', ' Alexander', ' Hamilton']
+4006 975 Name of father of x -1 Name of father of Alexander Hamilton James A. Hamilton Alexander Hamilton "[',' ' the' ' first' ' U' '.' 'S' '.' ' Secretary' ' of' ' the'
+ ' Treasury' ',' ' and' ' the' ' first' ' U' '.' 'S' '.' ' Secretary']" , the first U . S . Secretary of the Treasury , and the first U . S . Secretary False Johnson (Connecticut), Alexander Hamilton (New York), Gouverneur 6 [' Johnson', ' (', 'Connect', 'icut', '),', ' Alexander', ' Hamilton']
+4007 975 Name of father of x -1 Name of father of Alexander Hamilton James A. Hamilton Alexander Hamilton "[',' ' the' ' first' ' U' '.' 'S' '.' ' Secretary' ' of' ' the'
+ ' Treasury' ',' ' and' ' the' ' first' ' U' '.' 'S' '.' ' Secretary']" , the first U . S . Secretary of the Treasury , and the first U . S . Secretary False Captain and supported Alexander Hamilton in the Siege of Yorktown. 4 [' Captain', ' and', ' supported', ' Alexander', ' Hamilton']
+4008 975 Name of father of x -1 Name of father of Alexander Hamilton James A. Hamilton Alexander Hamilton "[',' ' the' ' first' ' U' '.' 'S' '.' ' Secretary' ' of' ' the'
+ ' Treasury' ',' ' and' ' the' ' first' ' U' '.' 'S' '.' ' Secretary']" , the first U . S . Secretary of the Treasury , and the first U . S . Secretary False approach to the Alexander Hamilton Bridge. The Deegan 4 [' approach', ' to', ' the', ' Alexander', ' Hamilton']
+4009 977 Name of father of x -1 Name of father of Philip IV of Spain Philip III of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False December 15, 1637, Philip IV of Spain imposed an obligatory 9 [' December', ' 15', ',', ' 16', '37', ',', ' Philip', ' IV', ' of', ' Spain']
+4010 977 Name of father of x -1 Name of father of Philip IV of Spain Philip III of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False December 15, 1637, Philip IV of Spain imposed an obligatory 9 [' December', ' 15', ',', ' 16', '37', ',', ' Philip', ' IV', ' of', ' Spain']
+4011 977 Name of father of x -1 Name of father of Philip IV of Spain Philip III of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False control of this trade. Philip IV of Spain (reigned 1621 8 [' control', ' of', ' this', ' trade', '.', ' Philip', ' IV', ' of', ' Spain']
+4012 977 Name of father of x -1 Name of father of Philip IV of Spain Philip III of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False December 15, 1637, Philip IV of Spain imposed an obligatory 9 [' December', ' 15', ',', ' 16', '37', ',', ' Philip', ' IV', ' of', ' Spain']
+4013 977 Name of father of x -1 Name of father of Philip IV of Spain Philip III of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' King' ' of' ' France' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the King of France ,
+" False control of this trade. Philip IV of Spain (reigned 1621 – 8 [' control', ' of', ' this', ' trade', '.', ' Philip', ' IV', ' of', ' Spain']
+4014 978 Name of father of x -1 Name of father of John Huston Walter Huston John Huston "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the' '\n'
+ '\n' '1' '\n' '\n' '2' '\n' '\n' '3' '\n']" ", who was a member of the family of the
+
+ 1
+
+ 2
+
+ 3
+" False mid-1950s director John Huston came to town with 7 [' mid', '-', '1950', 's', ' director', ' John', ' Hust', 'on']
+4015 978 Name of father of x -1 Name of father of John Huston Walter Huston John Huston "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the' '\n'
+ '\n' '1' '\n' '\n' '2' '\n' '\n' '3' '\n']" ", who was a member of the family of the
+
+ 1
+
+ 2
+
+ 3
+" False American director John Huston visited India for 4 [' American', ' director', ' John', ' Hust', 'on']
+4016 978 Name of father of x -1 Name of father of John Huston Walter Huston John Huston "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the' '\n'
+ '\n' '1' '\n' '\n' '2' '\n' '\n' '3' '\n']" ", who was a member of the family of the
+
+ 1
+
+ 2
+
+ 3
+" False Madre, which John Huston eventually directed 6 [' Mad', 're', ',', ' which', ' John', ' Hust', 'on']
+4017 978 Name of father of x -1 Name of father of John Huston Walter Huston John Huston "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the' '\n'
+ '\n' '1' '\n' '\n' '2' '\n' '\n' '3' '\n']" ", who was a member of the family of the
+
+ 1
+
+ 2
+
+ 3
+" False Courage directed by John Huston and adapted 5 [' Courage', ' directed', ' by', ' John', ' Hust', 'on']
+4018 978 Name of father of x -1 Name of father of John Huston Walter Huston John Huston "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the' '\n'
+ '\n' '1' '\n' '\n' '2' '\n' '\n' '3' '\n']" ", who was a member of the family of the
+
+ 1
+
+ 2
+
+ 3
+" False Paul Henreid and John Huston joined other 7 [' Paul', ' Hen', 're', 'id', ' and', ' John', ' Hust', 'on']
+4019 979 Name of father of x -1 Name of father of Leo X Lorenzo de' Medici Leo X "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' a' ' portrait'
+ ' of' ' the' ' pope' ',' ' and' ' the' ' other' ' is' ' a']" ".
+
+ The first of the two is a portrait of the pope , and the other is a" False from Pope Leo X to Henry VIII of England 3 [' from', ' Pope', ' Leo', ' X']
+4020 979 Name of father of x -1 Name of father of Leo X Lorenzo de' Medici Leo X "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' a' ' portrait'
+ ' of' ' the' ' pope' ',' ' and' ' the' ' other' ' is' ' a']" ".
+
+ The first of the two is a portrait of the pope , and the other is a" False a grant from Pope Leo X rewarding Henry 5 [' a', ' grant', ' from', ' Pope', ' Leo', ' X']
+4021 979 Name of father of x -1 Name of father of Leo X Lorenzo de' Medici Leo X "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' a' ' portrait'
+ ' of' ' the' ' pope' ',' ' and' ' the' ' other' ' is' ' a']" ".
+
+ The first of the two is a portrait of the pope , and the other is a" False Traspontina, and in 1513 Pope Leo X (r. 1513 – 11 [' Tr', 'asp', 'ont', 'ina', ',', ' and', ' in', ' 15', '13', ' Pope', ' Leo', ' X']
+4022 979 Name of father of x -1 Name of father of Leo X Lorenzo de' Medici Leo X "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' a' ' portrait'
+ ' of' ' the' ' pope' ',' ' and' ' the' ' other' ' is' ' a']" ".
+
+ The first of the two is a portrait of the pope , and the other is a" False Medici popes, Leo X and Clement VII. 6 [' Medic', 'i', ' pop', 'es', ',', ' Leo', ' X']
+4023 979 Name of father of x -1 Name of father of Leo X Lorenzo de' Medici Leo X "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' a' ' portrait'
+ ' of' ' the' ' pope' ',' ' and' ' the' ' other' ' is' ' a']" ".
+
+ The first of the two is a portrait of the pope , and the other is a" False Portrait of Pope Leo X (c. 1519) by 5 [' Port', 'rait', ' of', ' Pope', ' Leo', ' X']
+4024 980 Name of father of x -1 Name of father of Artemisia Gentileschi Orazio Gentileschi Artemisia Gentileschi "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '93' ' in'
+ ' Florence' ',' ' Italy' '.' ' She' ' was' ' the' ' daughter' ' of' ' O']" , the painter , was born in 15 93 in Florence , Italy . She was the daughter of O False " and paintings by Artemisia Gentileschi (17th century).
+" 7 [' and', ' paintings', ' by', ' Artemis', 'ia', ' Gent', 'iles', 'chi']
+4025 980 Name of father of x -1 Name of father of Artemisia Gentileschi Orazio Gentileschi Artemisia Gentileschi "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '93' ' in'
+ ' Florence' ',' ' Italy' '.' ' She' ' was' ' the' ' daughter' ' of' ' O']" , the painter , was born in 15 93 in Florence , Italy . She was the daughter of O False " and paintings by Artemisia Gentileschi (17th century).
+" 7 [' and', ' paintings', ' by', ' Artemis', 'ia', ' Gent', 'iles', 'chi']
+4026 980 Name of father of x -1 Name of father of Artemisia Gentileschi Orazio Gentileschi Artemisia Gentileschi "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '93' ' in'
+ ' Florence' ',' ' Italy' '.' ' She' ' was' ' the' ' daughter' ' of' ' O']" , the painter , was born in 15 93 in Florence , Italy . She was the daughter of O False " and paintings by Artemisia Gentileschi (17th century).
+" 7 [' and', ' paintings', ' by', ' Artemis', 'ia', ' Gent', 'iles', 'chi']
+4027 980 Name of father of x -1 Name of father of Artemisia Gentileschi Orazio Gentileschi Artemisia Gentileschi "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '93' ' in'
+ ' Florence' ',' ' Italy' '.' ' She' ' was' ' the' ' daughter' ' of' ' O']" , the painter , was born in 15 93 in Florence , Italy . She was the daughter of O False " Sarto, and paintings by Artemisia Gentileschi (17th century).
+" 11 [' S', 'art', 'o', ',', ' and', ' paintings', ' by', ' Artemis', 'ia', ' Gent', 'iles', 'chi']
+4028 981 Name of father of x -1 Name of father of Miloš Forman Otto Kohn Miloš Forman "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' Mil' 'ena' ' D' 'rav' 'i' 'ć' '.']" , the director of the film , and his wife , the actress Mil ena D rav i ć . False insistence of director Miloš Forman when she was cast 6 [' insistence', ' of', ' director', ' Milo', 'š', ' Form', 'an']
+4029 981 Name of father of x -1 Name of father of Miloš Forman Otto Kohn Miloš Forman "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' Mil' 'ena' ' D' 'rav' 'i' 'ć' '.']" , the director of the film , and his wife , the actress Mil ena D rav i ć . False " were playing Miloš Forman playing their roles.
+" 5 [' were', ' playing', ' Milo', 'š', ' Form', 'an']
+4030 981 Name of father of x -1 Name of father of Miloš Forman Otto Kohn Miloš Forman "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' Mil' 'ena' ' D' 'rav' 'i' 'ć' '.']" , the director of the film , and his wife , the actress Mil ena D rav i ć . False insistence of director Miloš Forman when she was cast in 6 [' insistence', ' of', ' director', ' Milo', 'š', ' Form', 'an']
+4031 981 Name of father of x -1 Name of father of Miloš Forman Otto Kohn Miloš Forman "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' Mil' 'ena' ' D' 'rav' 'i' 'ć' '.']" , the director of the film , and his wife , the actress Mil ena D rav i ć . False from the director Miloš Forman to star in a small 6 [' from', ' the', ' director', ' Milo', 'š', ' Form', 'an']
+4032 981 Name of father of x -1 Name of father of Miloš Forman Otto Kohn Miloš Forman "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' his' ' wife' ','
+ ' the' ' actress' ' Mil' 'ena' ' D' 'rav' 'i' 'ć' '.']" , the director of the film , and his wife , the actress Mil ena D rav i ć . False the director Miloš Forman to star in a small 5 [' the', ' director', ' Milo', 'š', ' Form', 'an']
+4033 982 Name of father of x -1 Name of father of Jacob Burckhardt Jakob Burckhardt Jacob Burckhardt "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' Jacob' ' Bur' 'ck'
+ 'hardt' ',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", the
+
+ Name of father of Jacob Bur ck hardt , the
+
+ Name of father of" False The Swiss historian Jacob Burckhardt (1818 – 1897) 6 [' The', ' Swiss', ' historian', ' Jacob', ' Bur', 'ck', 'hardt']
+4034 982 Name of father of x -1 Name of father of Jacob Burckhardt Jakob Burckhardt Jacob Burckhardt "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' Jacob' ' Bur' 'ck'
+ 'hardt' ',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", the
+
+ Name of father of Jacob Bur ck hardt , the
+
+ Name of father of" False Swiss historian Jacob Burckhardt (1818 – 1897) 5 [' Swiss', ' historian', ' Jacob', ' Bur', 'ck', 'hardt']
+4035 982 Name of father of x -1 Name of father of Jacob Burckhardt Jakob Burckhardt Jacob Burckhardt "[',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of' ' Jacob' ' Bur' 'ck'
+ 'hardt' ',' ' the' '\n' '\n' 'Name' ' of' ' father' ' of']" ", the
+
+ Name of father of Jacob Bur ck hardt , the
+
+ Name of father of" False Swiss historian Jacob Burckhardt (1818 – 1897) in 5 [' Swiss', ' historian', ' Jacob', ' Bur', 'ck', 'hardt']
+4036 983 Name of father of x -1 Name of father of Herbert Hoover Jesse Hoover Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Lou' ' Henry' ' Hoover' ',' ' who' ' was'
+ ' the']" , the former president of the United States , and his wife , Lou Henry Hoover , who was the False " Stanford classmate) Herbert Hoover wrote, ""Judge George" 4 [' Stanford', ' classmate', ')', ' Herbert', ' Hoover']
+4037 983 Name of father of x -1 Name of father of Herbert Hoover Jesse Hoover Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Lou' ' Henry' ' Hoover' ',' ' who' ' was'
+ ' the']" , the former president of the United States , and his wife , Lou Henry Hoover , who was the False Commerce Secretary Herbert Hoover campaigned 3 [' Commerce', ' Secretary', ' Herbert', ' Hoover']
+4038 983 Name of father of x -1 Name of father of Herbert Hoover Jesse Hoover Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Lou' ' Henry' ' Hoover' ',' ' who' ' was'
+ ' the']" , the former president of the United States , and his wife , Lou Henry Hoover , who was the False Smith was defeated by Herbert Hoover both nationally and 5 [' Smith', ' was', ' defeated', ' by', ' Herbert', ' Hoover']
+4039 983 Name of father of x -1 Name of father of Herbert Hoover Jesse Hoover Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Lou' ' Henry' ' Hoover' ',' ' who' ' was'
+ ' the']" , the former president of the United States , and his wife , Lou Henry Hoover , who was the False was the residence of Herbert Hoover from his retirement 5 [' was', ' the', ' residence', ' of', ' Herbert', ' Hoover']
+4040 983 Name of father of x -1 Name of father of Herbert Hoover Jesse Hoover Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' his' ' wife' ',' ' Lou' ' Henry' ' Hoover' ',' ' who' ' was'
+ ' the']" , the former president of the United States , and his wife , Lou Henry Hoover , who was the False 1933. President Herbert Hoover feared that too much 4 [' 1933', '.', ' President', ' Herbert', ' Hoover']
+4041 987 Name of father of x -1 Name of father of Pontormo Bartolommeo di Jacopo di Martino Pontormo "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False and through them Pontormo and Correggio. The 5 [' and', ' through', ' them', ' Pont', 'orm', 'o']
+4042 987 Name of father of x -1 Name of father of Pontormo Bartolommeo di Jacopo di Martino Pontormo "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False " and paintings by Pontormo and Domenico Beccafumi.
+" 5 [' and', ' paintings', ' by', ' Pont', 'orm', 'o']
+4043 987 Name of father of x -1 Name of father of Pontormo Bartolommeo di Jacopo di Martino Pontormo "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False through them Pontormo and Correggio. 4 [' through', ' them', ' Pont', 'orm', 'o']
+4044 987 Name of father of x -1 Name of father of Pontormo Bartolommeo di Jacopo di Martino Pontormo "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False and through them Pontormo and Correggio. The 5 [' and', ' through', ' them', ' Pont', 'orm', 'o']
+4045 987 Name of father of x -1 Name of father of Pontormo Bartolommeo di Jacopo di Martino Pontormo "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False " Botticelli and paintings by Pontormo and Domenico Beccafumi.
+" 8 [' Bott', 'ice', 'lli', ' and', ' paintings', ' by', ' Pont', 'orm', 'o']
+4046 988 Name of father of x -1 Name of father of Dean Martin Gaetano Crocetti Dean Martin "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' '/' 's'
+ 'inger' '/' 'act' 'ress' ',' ' D' 'ixie' ' Carter' '.']" , the actor , and his wife , actress / s inger / act ress , D ixie Carter . False and to co-star Dean Martin and Cyd Charisse. 6 [' and', ' to', ' co', '-', 'star', ' Dean', ' Martin']
+4047 988 Name of father of x -1 Name of father of Dean Martin Gaetano Crocetti Dean Martin "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' '/' 's'
+ 'inger' '/' 'act' 'ress' ',' ' D' 'ixie' ' Carter' '.']" , the actor , and his wife , actress / s inger / act ress , D ixie Carter . False Stupid, opposite Dean Martin and Kim Novak. 4 [' Stupid', ',', ' opposite', ' Dean', ' Martin']
+4048 988 Name of father of x -1 Name of father of Dean Martin Gaetano Crocetti Dean Martin "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' '/' 's'
+ 'inger' '/' 'act' 'ress' ',' ' D' 'ixie' ' Carter' '.']" , the actor , and his wife , actress / s inger / act ress , D ixie Carter . False Frank Sinatra and Dean Martin acquired three 5 [' Frank', ' Sin', 'atra', ' and', ' Dean', ' Martin']
+4049 988 Name of father of x -1 Name of father of Dean Martin Gaetano Crocetti Dean Martin "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' '/' 's'
+ 'inger' '/' 'act' 'ress' ',' ' D' 'ixie' ' Carter' '.']" , the actor , and his wife , actress / s inger / act ress , D ixie Carter . False album, as did Dean Martin on his 1970 5 [' album', ',', ' as', ' did', ' Dean', ' Martin']
+4050 988 Name of father of x -1 Name of father of Dean Martin Gaetano Crocetti Dean Martin "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' '/' 's'
+ 'inger' '/' 'act' 'ress' ',' ' D' 'ixie' ' Carter' '.']" , the actor , and his wife , actress / s inger / act ress , D ixie Carter . False comedy partners Dean Martin and Jerry Lewis for 3 [' comedy', ' partners', ' Dean', ' Martin']
+4051 989 Name of father of x -1 Name of father of Antoine Lavoisier Jean-Antoine de Lavoisier Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False In 1772, Antoine Lavoisier used a lens to 8 [' In', ' 17', '72', ',', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+4052 989 Name of father of x -1 Name of father of Antoine Lavoisier Jean-Antoine de Lavoisier Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False structures. In 1780, Antoine Lavoisier used a guinea 10 [' structures', '.', ' In', ' 17', '80', ',', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+4053 989 Name of father of x -1 Name of father of Antoine Lavoisier Jean-Antoine de Lavoisier Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False carbon. In 1772, Antoine Lavoisier showed that diamonds 10 [' carbon', '.', ' In', ' 17', '72', ',', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+4054 989 Name of father of x -1 Name of father of Antoine Lavoisier Jean-Antoine de Lavoisier Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False " element by Antoine Lavoisier in 1777.
+" 6 [' element', ' by', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+4055 989 Name of father of x -1 Name of father of Antoine Lavoisier Jean-Antoine de Lavoisier Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False element. In 1783, Antoine Lavoisier gave the element 10 [' element', '.', ' In', ' 17', '83', ',', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+4056 990 Name of father of x -1 Name of father of Marianne Faithfull Major Prof. Robert Glynn Faithfull Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Williamson as Hamlet and Marianne Faithfull as Ophelia.
+" 8 [' Williamson', ' as', ' Ham', 'let', ' and', ' Marian', 'ne', ' Faith', 'full']
+4057 990 Name of father of x -1 Name of father of Marianne Faithfull Major Prof. Robert Glynn Faithfull Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Five"" were recorded by Marianne Faithfull for Come My" 8 "[' Five', '""', ' were', ' recorded', ' by', ' Marian', 'ne', ' Faith', 'full']"
+4058 990 Name of father of x -1 Name of father of Marianne Faithfull Major Prof. Robert Glynn Faithfull Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Memory Remains"" with Marianne Faithfull on NBC's Saturday" 8 "[' Memory', ' Rem', 'ains', '""', ' with', ' Marian', 'ne', ' Faith', 'full']"
+4059 990 Name of father of x -1 Name of father of Marianne Faithfull Major Prof. Robert Glynn Faithfull Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Grateful Dead, Marianne Faithfull and Robyn 6 [' Grateful', ' Dead', ',', ' Marian', 'ne', ' Faith', 'full']
+4060 990 Name of father of x -1 Name of father of Marianne Faithfull Major Prof. Robert Glynn Faithfull Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Hamlet and Marianne Faithfull as Ophelia.
+" 6 [' Ham', 'let', ' and', ' Marian', 'ne', ' Faith', 'full']
diff --git a/patchscopes/code/preprocessed_data/factual/person_mother.tsv b/patchscopes/code/preprocessed_data/factual/person_mother.tsv
new file mode 100644
index 00000000..40e5a9a1
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/person_mother.tsv
@@ -0,0 +1,14495 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 Name of mother of x -1 Name of mother of Drake Sandi Graham Drake "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake']" ", and the
+
+ Name of mother of Drake , and the
+
+ Name of mother of Drake" False " and two each from Drake and Bastin.
+" 4 [' and', ' two', ' each', ' from', ' Drake']
+1 0 Name of mother of x -1 Name of mother of Drake Sandi Graham Drake "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake']" ", and the
+
+ Name of mother of Drake , and the
+
+ Name of mother of Drake" False " as Bobby ""Iceman"" Drake in the X-Men film" 6 "[' as', ' Bobby', ' ""', 'Ice', 'man', '""', ' Drake']"
+2 0 Name of mother of x -1 Name of mother of Drake Sandi Graham Drake "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake']" ", and the
+
+ Name of mother of Drake , and the
+
+ Name of mother of Drake" False " rhythm guitarist Matt Drake said that he was ""really" 3 [' rhythm', ' guitarist', ' Matt', ' Drake']
+3 0 Name of mother of x -1 Name of mother of Drake Sandi Graham Drake "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake']" ", and the
+
+ Name of mother of Drake , and the
+
+ Name of mother of Drake" False speculation about Drake's sexuality. Boyd 2 [' speculation', ' about', ' Drake']
+4 0 Name of mother of x -1 Name of mother of Drake Sandi Graham Drake "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Drake']" ", and the
+
+ Name of mother of Drake , and the
+
+ Name of mother of Drake" False first released as Drake's follow-up track 3 [' first', ' released', ' as', ' Drake']
+5 1 Name of mother of x -1 Name of mother of Ariana Grande Joan Grande Ariana Grande "[',' ' who' ' is' ' a' ' huge' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ '�' '�' 'I' '�' '�' 'm' ' a' ' huge']" ", who is a huge fan of the show .
+
+ � � I � � m a huge" False row, Ed Sheeran, Ariana Grande, and Hozier. Taylor 9 [' row', ',', ' Ed', ' She', 'er', 'an', ',', ' Ari', 'ana', ' Grande']
+6 1 Name of mother of x -1 Name of mother of Ariana Grande Joan Grande Ariana Grande "[',' ' who' ' is' ' a' ' huge' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ '�' '�' 'I' '�' '�' 'm' ' a' ' huge']" ", who is a huge fan of the show .
+
+ � � I � � m a huge" False recording artist Ariana Grande that features 4 [' recording', ' artist', ' Ari', 'ana', ' Grande']
+7 1 Name of mother of x -1 Name of mother of Ariana Grande Joan Grande Ariana Grande "[',' ' who' ' is' ' a' ' huge' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ '�' '�' 'I' '�' '�' 'm' ' a' ' huge']" ", who is a huge fan of the show .
+
+ � � I � � m a huge" False recording artist Ariana Grande that features 4 [' recording', ' artist', ' Ari', 'ana', ' Grande']
+8 1 Name of mother of x -1 Name of mother of Ariana Grande Joan Grande Ariana Grande "[',' ' who' ' is' ' a' ' huge' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ '�' '�' 'I' '�' '�' 'm' ' a' ' huge']" ", who is a huge fan of the show .
+
+ � � I � � m a huge" False Ed Sheeran, Ariana Grande, and Hozier. Taylor 7 [' Ed', ' She', 'er', 'an', ',', ' Ari', 'ana', ' Grande']
+9 1 Name of mother of x -1 Name of mother of Ariana Grande Joan Grande Ariana Grande "[',' ' who' ' is' ' a' ' huge' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ '�' '�' 'I' '�' '�' 'm' ' a' ' huge']" ", who is a huge fan of the show .
+
+ � � I � � m a huge" False recording artist Ariana Grande that features 4 [' recording', ' artist', ' Ari', 'ana', ' Grande']
+10 2 Name of mother of x -1 Name of mother of Rihanna Monica Braithwaite Rihanna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " Hobson noted that Rihanna ""rejects the" 5 [' Hob', 'son', ' noted', ' that', ' Rih', 'anna']
+11 2 Name of mother of x -1 Name of mother of Rihanna Monica Braithwaite Rihanna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Gallucci also felt that Rihanna's over-sings on 6 [' Gall', 'ucci', ' also', ' felt', ' that', ' Rih', 'anna']
+12 2 Name of mother of x -1 Name of mother of Rihanna Monica Braithwaite Rihanna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " the composition. ""[Rihanna] and panting" 6 "[' the', ' composition', '.', ' ""[', 'R', 'ih', 'anna']"
+13 2 Name of mother of x -1 Name of mother of Rihanna Monica Braithwaite Rihanna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False between Guetta and Rihanna as well as its 5 [' between', ' Gu', 'etta', ' and', ' Rih', 'anna']
+14 2 Name of mother of x -1 Name of mother of Rihanna Monica Braithwaite Rihanna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False of the song and Rihanna's musical direction 5 [' of', ' the', ' song', ' and', ' Rih', 'anna']
+15 3 Name of mother of x -1 Name of mother of Will Smith Caroline Bright Will Smith "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the' ' groom']" , the father of the bride , and the groom , and the groom 's mother , and the groom False " with New Yorkers. Will Smith said, ""I don" 5 [' with', ' New', ' Yorkers', '.', ' Will', ' Smith']
+16 3 Name of mother of x -1 Name of mother of Will Smith Caroline Bright Will Smith "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the' ' groom']" , the father of the bride , and the groom , and the groom 's mother , and the groom False starred alongside Will Smith in Alex Proyas'science 3 [' starred', ' alongside', ' Will', ' Smith']
+17 3 Name of mother of x -1 Name of mother of Will Smith Caroline Bright Will Smith "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the' ' groom']" , the father of the bride , and the groom , and the groom 's mother , and the groom False million. Peters offered Will Smith the role of Superman, 5 [' million', '.', ' Peters', ' offered', ' Will', ' Smith']
+18 3 Name of mother of x -1 Name of mother of Will Smith Caroline Bright Will Smith "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the' ' groom']" , the father of the bride , and the groom , and the groom 's mother , and the groom False He credited actor Will Smith for helping him 4 [' He', ' credited', ' actor', ' Will', ' Smith']
+19 3 Name of mother of x -1 Name of mother of Will Smith Caroline Bright Will Smith "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the' ' groom']" , the father of the bride , and the groom , and the groom 's mother , and the groom False briefly in the Will Smith music video for 4 [' briefly', ' in', ' the', ' Will', ' Smith']
+20 4 Name of mother of x -1 Name of mother of Prince William Princess Diana Prince William "[' and' ' Prince' ' George' ',' ' and' ' the' ' Duke' ' of' ' Cambridge'
+ ',' ' and' ' the' ' Duke' ' of' ' Cambridge' ',' ' Prince' ' George' ','
+ ' Prince']" and Prince George , and the Duke of Cambridge , and the Duke of Cambridge , Prince George , Prince False corridor in Fairfax and Prince William counties. According 5 [' corridor', ' in', ' Fairfax', ' and', ' Prince', ' William']
+21 4 Name of mother of x -1 Name of mother of Prince William Princess Diana Prince William "[' and' ' Prince' ' George' ',' ' and' ' the' ' Duke' ' of' ' Cambridge'
+ ',' ' and' ' the' ' Duke' ' of' ' Cambridge' ',' ' Prince' ' George' ','
+ ' Prince']" and Prince George , and the Duke of Cambridge , and the Duke of Cambridge , Prince George , Prince False Chopawamsic Creek (today in Prince William County, Virginia) 9 [' Chop', 'aw', 'ams', 'ic', ' Creek', ' (', 'today', ' in', ' Prince', ' William']
+22 4 Name of mother of x -1 Name of mother of Prince William Princess Diana Prince William "[' and' ' Prince' ' George' ',' ' and' ' the' ' Duke' ' of' ' Cambridge'
+ ',' ' and' ' the' ' Duke' ' of' ' Cambridge' ',' ' Prince' ' George' ','
+ ' Prince']" and Prince George , and the Duke of Cambridge , and the Duke of Cambridge , Prince George , Prince False Australia at the Wedding of Prince William and Catherine 6 [' Australia', ' at', ' the', ' Wedding', ' of', ' Prince', ' William']
+23 4 Name of mother of x -1 Name of mother of Prince William Princess Diana Prince William "[' and' ' Prince' ' George' ',' ' and' ' the' ' Duke' ' of' ' Cambridge'
+ ',' ' and' ' the' ' Duke' ' of' ' Cambridge' ',' ' Prince' ' George' ','
+ ' Prince']" and Prince George , and the Duke of Cambridge , and the Duke of Cambridge , Prince George , Prince False the wife of Prince William III of Orange. 4 [' the', ' wife', ' of', ' Prince', ' William']
+24 4 Name of mother of x -1 Name of mother of Prince William Princess Diana Prince William "[' and' ' Prince' ' George' ',' ' and' ' the' ' Duke' ' of' ' Cambridge'
+ ',' ' and' ' the' ' Duke' ' of' ' Cambridge' ',' ' Prince' ' George' ','
+ ' Prince']" and Prince George , and the Duke of Cambridge , and the Duke of Cambridge , Prince George , Prince False younger brother Prince William of Denmark. Ernest 3 [' younger', ' brother', ' Prince', ' William']
+25 5 Name of mother of x -1 Name of mother of Prince Harry Princess Diana Prince Harry "[' and' ' Princess' ' Diana' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of'
+ ' two' ',' ' a' ' wife' ',' ' a' ' daughter' ',' ' a']" " and Princess Diana .
+
+ I am a mother of two , a wife , a daughter , a" True 1 ['Prince', ' Harry']
+26 5 Name of mother of x -1 Name of mother of Prince Harry Princess Diana Prince Harry "[' and' ' Princess' ' Diana' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of'
+ ' two' ',' ' a' ' wife' ',' ' a' ' daughter' ',' ' a']" " and Princess Diana .
+
+ I am a mother of two , a wife , a daughter , a" True Jill Biden and Prince Harry in visiting 4 [' Jill', ' Biden', ' and', ' Prince', ' Harry']
+27 5 Name of mother of x -1 Name of mother of Prince Harry Princess Diana Prince Harry "[' and' ' Princess' ' Diana' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of'
+ ' two' ',' ' a' ' wife' ',' ' a' ' daughter' ',' ' a']" " and Princess Diana .
+
+ I am a mother of two , a wife , a daughter , a" True 1 ['Prince', ' Harry']
+28 5 Name of mother of x -1 Name of mother of Prince Harry Princess Diana Prince Harry "[' and' ' Princess' ' Diana' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of'
+ ' two' ',' ' a' ' wife' ',' ' a' ' daughter' ',' ' a']" " and Princess Diana .
+
+ I am a mother of two , a wife , a daughter , a" True to interview Prince Harry as he prepared 3 [' to', ' interview', ' Prince', ' Harry']
+29 5 Name of mother of x -1 Name of mother of Prince Harry Princess Diana Prince Harry "[' and' ' Princess' ' Diana' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of'
+ ' two' ',' ' a' ' wife' ',' ' a' ' daughter' ',' ' a']" " and Princess Diana .
+
+ I am a mother of two , a wife , a daughter , a" True press that allowed Prince Harry to serve in Afghanistan. 4 [' press', ' that', ' allowed', ' Prince', ' Harry']
+30 6 Name of mother of x -1 Name of mother of Franklin D. Roosevelt Sara Ann Delano Franklin D. Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False U.S. President Franklin D. Roosevelt ’ s attitude towards 8 [' U', '.', 'S', '.', ' President', ' Franklin', ' D', '.', ' Roosevelt']
+31 6 Name of mother of x -1 Name of mother of Franklin D. Roosevelt Sara Ann Delano Franklin D. Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False D.C., President Franklin D. Roosevelt invited Skelton to 8 [' D', '.', 'C', '.,', ' President', ' Franklin', ' D', '.', ' Roosevelt']
+32 6 Name of mother of x -1 Name of mother of Franklin D. Roosevelt Sara Ann Delano Franklin D. Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False friendship with Franklin D. Roosevelt and his wife, beginning 5 [' friendship', ' with', ' Franklin', ' D', '.', ' Roosevelt']
+33 6 Name of mother of x -1 Name of mother of Franklin D. Roosevelt Sara Ann Delano Franklin D. Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False canyon and nearing Franklin D. Roosevelt Lake; at 80,000 acres 6 [' canyon', ' and', ' nearing', ' Franklin', ' D', '.', ' Roosevelt']
+34 6 Name of mother of x -1 Name of mother of Franklin D. Roosevelt Sara Ann Delano Franklin D. Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False Secretary of the Navy Franklin D. Roosevelt and former Lane assistant 7 [' Secretary', ' of', ' the', ' Navy', ' Franklin', ' D', '.', ' Roosevelt']
+35 7 Name of mother of x -1 Name of mother of Queen Elizabeth II Queen Elizabeth The Queen Mother Queen Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' '.' '\n' '\n' 'The' ' Queen' ' of' ' England' ' is' ' the']" ", the Queen of England , and the Queen of England .
+
+ The Queen of England is the" False three months later by Queen Elizabeth II in a ceremony 6 [' three', ' months', ' later', ' by', ' Queen', ' Elizabeth', ' II']
+36 7 Name of mother of x -1 Name of mother of Queen Elizabeth II Queen Elizabeth The Queen Mother Queen Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' '.' '\n' '\n' 'The' ' Queen' ' of' ' England' ' is' ' the']" ", the Queen of England , and the Queen of England .
+
+ The Queen of England is the" False 2 ['Queen', ' Elizabeth', ' II']
+37 7 Name of mother of x -1 Name of mother of Queen Elizabeth II Queen Elizabeth The Queen Mother Queen Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' '.' '\n' '\n' 'The' ' Queen' ' of' ' England' ' is' ' the']" ", the Queen of England , and the Queen of England .
+
+ The Queen of England is the" False General (ADC Gen) to Queen Elizabeth II on 5 June 2006. He 9 [' General', ' (', 'AD', 'C', ' Gen', ')', ' to', ' Queen', ' Elizabeth', ' II']
+38 7 Name of mother of x -1 Name of mother of Queen Elizabeth II Queen Elizabeth The Queen Mother Queen Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' '.' '\n' '\n' 'The' ' Queen' ' of' ' England' ' is' ' the']" ", the Queen of England , and the Queen of England .
+
+ The Queen of England is the" False they performed for Queen Elizabeth II at the Royal Variety 5 [' they', ' performed', ' for', ' Queen', ' Elizabeth', ' II']
+39 7 Name of mother of x -1 Name of mother of Queen Elizabeth II Queen Elizabeth The Queen Mother Queen Elizabeth II "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' '.' '\n' '\n' 'The' ' Queen' ' of' ' England' ' is' ' the']" ", the Queen of England , and the Queen of England .
+
+ The Queen of England is the" False foreign visits. Queen Elizabeth II and Prime Ministers 5 [' foreign', ' visits', '.', ' Queen', ' Elizabeth', ' II']
+40 8 Name of mother of x -1 Name of mother of Charles, Prince of Wales Queen Elizabeth II Charles, Prince of Wales "[',' ' and' ' the' ' Duke' ' of' ' York' ',' ' and' ' the' ' Duke' ' of'
+ ' York' ',' ' and' ' the' ' Duke' ' of' ' York' ',' ' and']" , and the Duke of York , and the Duke of York , and the Duke of York , and False cabaret piece when Charles, Prince of Wales and Camilla, Duchess 8 [' cab', 'aret', ' piece', ' when', ' Charles', ',', ' Prince', ' of', ' Wales']
+41 9 Name of mother of x -1 Name of mother of Princess Margaret Queen Elizabeth The Queen Mother Princess Margaret "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False for the coronation, Princess Margaret informed her 6 [' for', ' the', ' coron', 'ation', ',', ' Princess', ' Margaret']
+42 9 Name of mother of x -1 Name of mother of Princess Margaret Queen Elizabeth The Queen Mother Princess Margaret "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False promised marriage of Princess Margaret to the Dauphin. In 4 [' promised', ' marriage', ' of', ' Princess', ' Margaret']
+43 9 Name of mother of x -1 Name of mother of Princess Margaret Queen Elizabeth The Queen Mother Princess Margaret "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False second daughter Princess Margaret died. On 13 3 [' second', ' daughter', ' Princess', ' Margaret']
+44 9 Name of mother of x -1 Name of mother of Princess Margaret Queen Elizabeth The Queen Mother Princess Margaret "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False " Denis Compton and Princess Margaret"". He had gained" 4 [' Denis', ' Compton', ' and', ' Princess', ' Margaret']
+45 9 Name of mother of x -1 Name of mother of Princess Margaret Queen Elizabeth The Queen Mother Princess Margaret "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False instructions that when Princess Margaret has breakfast 4 [' instructions', ' that', ' when', ' Princess', ' Margaret']
+46 10 Name of mother of x -1 Name of mother of Prince Philip Princess Alice of Battenberg Prince Philip "[',' ' Duke' ' of' ' Edinburgh' ',' ' and' ' the' ' Duke' ' of'
+ ' Edinburgh' ',' ' Prince' ' Philip' ',' ' Duke' ' of' ' Edinburgh' ','
+ ' and' ' the']" , Duke of Edinburgh , and the Duke of Edinburgh , Prince Philip , Duke of Edinburgh , and the False apparent and consort. Prince Philip is at present the most 6 [' apparent', ' and', ' cons', 'ort', '.', ' Prince', ' Philip']
+47 10 Name of mother of x -1 Name of mother of Prince Philip Princess Alice of Battenberg Prince Philip "[',' ' Duke' ' of' ' Edinburgh' ',' ' and' ' the' ' Duke' ' of'
+ ' Edinburgh' ',' ' Prince' ' Philip' ',' ' Duke' ' of' ' Edinburgh' ','
+ ' and' ' the']" , Duke of Edinburgh , and the Duke of Edinburgh , Prince Philip , Duke of Edinburgh , and the False so the Queen and Prince Philip took them that morning. 5 [' so', ' the', ' Queen', ' and', ' Prince', ' Philip']
+48 10 Name of mother of x -1 Name of mother of Prince Philip Princess Alice of Battenberg Prince Philip "[',' ' Duke' ' of' ' Edinburgh' ',' ' and' ' the' ' Duke' ' of'
+ ' Edinburgh' ',' ' Prince' ' Philip' ',' ' Duke' ' of' ' Edinburgh' ','
+ ' and' ' the']" , Duke of Edinburgh , and the Duke of Edinburgh , Prince Philip , Duke of Edinburgh , and the False Queen Elizabeth II and Prince Philip during their inaugural 5 [' Queen', ' Elizabeth', ' II', ' and', ' Prince', ' Philip']
+49 10 Name of mother of x -1 Name of mother of Prince Philip Princess Alice of Battenberg Prince Philip "[',' ' Duke' ' of' ' Edinburgh' ',' ' and' ' the' ' Duke' ' of'
+ ' Edinburgh' ',' ' Prince' ' Philip' ',' ' Duke' ' of' ' Edinburgh' ','
+ ' and' ' the']" , Duke of Edinburgh , and the Duke of Edinburgh , Prince Philip , Duke of Edinburgh , and the False " family, among them Prince Philip and Princess Margaret.
+" 5 [' family', ',', ' among', ' them', ' Prince', ' Philip']
+50 10 Name of mother of x -1 Name of mother of Prince Philip Princess Alice of Battenberg Prince Philip "[',' ' Duke' ' of' ' Edinburgh' ',' ' and' ' the' ' Duke' ' of'
+ ' Edinburgh' ',' ' Prince' ' Philip' ',' ' Duke' ' of' ' Edinburgh' ','
+ ' and' ' the']" , Duke of Edinburgh , and the Duke of Edinburgh , Prince Philip , Duke of Edinburgh , and the False veterans, by Prince Philip and officials 4 [' veterans', ',', ' by', ' Prince', ' Philip']
+51 12 Name of mother of x -1 Name of mother of King George VI Queen Mary King George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False weight-for-age race, the King George VI and Queen Elizabeth 10 [' weight', '-', 'for', '-', 'age', ' race', ',', ' the', ' King', ' George', ' VI']
+52 12 Name of mother of x -1 Name of mother of King George VI Queen Mary King George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False the coronation of King George VI and Queen Elizabeth 6 [' the', ' coron', 'ation', ' of', ' King', ' George', ' VI']
+53 12 Name of mother of x -1 Name of mother of King George VI Queen Mary King George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False of the song. King George VI presented Formby 6 [' of', ' the', ' song', '.', ' King', ' George', ' VI']
+54 12 Name of mother of x -1 Name of mother of King George VI Queen Mary King George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False were escorts for King George VI and his wife 6 [' were', ' esc', 'orts', ' for', ' King', ' George', ' VI']
+55 12 Name of mother of x -1 Name of mother of King George VI Queen Mary King George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False aide-de-camp to King George VI in 1941 – a 8 [' aide', '-', 'de', '-', 'camp', ' to', ' King', ' George', ' VI']
+56 18 Name of mother of x -1 Name of mother of J.K. Rowling Anne Volant Rowling J.K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False the nom de plume J.K. Rowling just before 9 [' the', ' nom', ' de', ' pl', 'ume', ' J', '.', 'K', '.', ' Rowling']
+57 18 Name of mother of x -1 Name of mother of J.K. Rowling Anne Volant Rowling J.K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False platform on which J.K. Rowling updates the series 7 [' platform', ' on', ' which', ' J', '.', 'K', '.', ' Rowling']
+58 18 Name of mother of x -1 Name of mother of J.K. Rowling Anne Volant Rowling J.K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False Harry Potter author J.K. Rowling revealed that Albus 7 [' Harry', ' Potter', ' author', ' J', '.', 'K', '.', ' Rowling']
+59 18 Name of mother of x -1 Name of mother of J.K. Rowling Anne Volant Rowling J.K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False illustrations by J.K. Rowling not included in the 6 [' illustrations', ' by', ' J', '.', 'K', '.', ' Rowling']
+60 18 Name of mother of x -1 Name of mother of J.K. Rowling Anne Volant Rowling J.K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False 4 ['J', '.', 'K', '.', ' Rowling']
+61 19 Name of mother of x -1 Name of mother of Peter Paul Rubens Maria Pypelinckx Peter Paul Rubens "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Peter' ' Paul' ' Rub' 'ens' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Peter Paul Rub ens , the painter ," False Renaissance, Peter Paul Rubens and Johann Baptist 5 [' Renaissance', ',', ' Peter', ' Paul', ' Rub', 'ens']
+62 19 Name of mother of x -1 Name of mother of Peter Paul Rubens Maria Pypelinckx Peter Paul Rubens "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Peter' ' Paul' ' Rub' 'ens' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Peter Paul Rub ens , the painter ," False present-day Peter Paul Rubens Street (Maltese: 6 [' present', '-', 'day', ' Peter', ' Paul', ' Rub', 'ens']
+63 19 Name of mother of x -1 Name of mother of Peter Paul Rubens Maria Pypelinckx Peter Paul Rubens "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Peter' ' Paul' ' Rub' 'ens' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Peter Paul Rub ens , the painter ," False studying the work of Peter Paul Rubens — and broadened 7 [' studying', ' the', ' work', ' of', ' Peter', ' Paul', ' Rub', 'ens']
+64 19 Name of mother of x -1 Name of mother of Peter Paul Rubens Maria Pypelinckx Peter Paul Rubens "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Peter' ' Paul' ' Rub' 'ens' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Peter Paul Rub ens , the painter ," False begin at present-day Peter Paul Rubens Street (Maltese: 8 [' begin', ' at', ' present', '-', 'day', ' Peter', ' Paul', ' Rub', 'ens']
+65 19 Name of mother of x -1 Name of mother of Peter Paul Rubens Maria Pypelinckx Peter Paul Rubens "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Peter' ' Paul' ' Rub' 'ens' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Peter Paul Rub ens , the painter ," False studying the work of Peter Paul Rubens — and broadened 7 [' studying', ' the', ' work', ' of', ' Peter', ' Paul', ' Rub', 'ens']
+66 20 Name of mother of x -1 Name of mother of Madonna Madonna Louise Fortin Madonna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False samples from Madonna herself. Almost 2 [' samples', ' from', ' Madonna']
+67 20 Name of mother of x -1 Name of mother of Madonna Madonna Louise Fortin Madonna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False years later, Madonna performed an energetic, 3 [' years', ' later', ',', ' Madonna']
+68 20 Name of mother of x -1 Name of mother of Madonna Madonna Louise Fortin Madonna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Morton, the album made Madonna a household name, 5 [' Morton', ',', ' the', ' album', ' made', ' Madonna']
+69 20 Name of mother of x -1 Name of mother of Madonna Madonna Louise Fortin Madonna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False For the costumes, Madonna collaborated 4 [' For', ' the', ' costumes', ',', ' Madonna']
+70 20 Name of mother of x -1 Name of mother of Madonna Madonna Louise Fortin Madonna "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " his list of ""Top 15 Madonna Singles of All Time""." 6 "[' his', ' list', ' of', ' ""', 'Top', ' 15', ' Madonna']"
+71 21 Name of mother of x -1 Name of mother of Pablo Picasso Maria Picasso y Lopez Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False Three Musicians by Pablo Picasso in which their instruments 6 [' Three', ' Mus', 'icians', ' by', ' Pablo', ' Pic', 'asso']
+72 21 Name of mother of x -1 Name of mother of Pablo Picasso Maria Picasso y Lopez Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False Beach and Pablo Picasso with Mary 4 [' Beach', ' and', ' Pablo', ' Pic', 'asso']
+73 21 Name of mother of x -1 Name of mother of Pablo Picasso Maria Picasso y Lopez Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False included Isadora Duncan, Pablo Picasso and Aleister Crowley. 8 [' included', ' Is', 'ad', 'ora', ' Duncan', ',', ' Pablo', ' Pic', 'asso']
+74 21 Name of mother of x -1 Name of mother of Pablo Picasso Maria Picasso y Lopez Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False Cubists, including Pablo Picasso and Jean Metzinger. 6 [' Cub', 'ists', ',', ' including', ' Pablo', ' Pic', 'asso']
+75 21 Name of mother of x -1 Name of mother of Pablo Picasso Maria Picasso y Lopez Pablo Picasso "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False poetry much as Pablo Picasso did in painting 5 [' poetry', ' much', ' as', ' Pablo', ' Pic', 'asso']
+76 22 Name of mother of x -1 Name of mother of Lady Gaga Cynthia Germanotta Lady Gaga "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False both Beyoncé and Lady Gaga trading verses with 5 [' both', ' Beyon', 'cé', ' and', ' Lady', ' Gaga']
+77 22 Name of mother of x -1 Name of mother of Lady Gaga Cynthia Germanotta Lady Gaga "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Credits adapted from Lady Gaga Presents the Monster 4 [' Credits', ' adapted', ' from', ' Lady', ' Gaga']
+78 22 Name of mother of x -1 Name of mother of Lady Gaga Cynthia Germanotta Lady Gaga "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Bennett and Lady Gaga first met backstage 3 [' Bennett', ' and', ' Lady', ' Gaga']
+79 22 Name of mother of x -1 Name of mother of Lady Gaga Cynthia Germanotta Lady Gaga "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " born, to the character Lady Gaga. He explained: ""What's" 6 [' born', ',', ' to', ' the', ' character', ' Lady', ' Gaga']
+80 22 Name of mother of x -1 Name of mother of Lady Gaga Cynthia Germanotta Lady Gaga "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False 1 ['Lady', ' Gaga']
+81 23 Name of mother of x -1 Name of mother of Rembrandt Neeltje Willemsdr. Zuytbrouck Rembrandt "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Rem' 'brand'
+ 't' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Rem brand t , and the
+
+ Name of mother" False 2 ['Rem', 'brand', 't']
+82 23 Name of mother of x -1 Name of mother of Rembrandt Neeltje Willemsdr. Zuytbrouck Rembrandt "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Rem' 'brand'
+ 't' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Rem brand t , and the
+
+ Name of mother" False the 1650s, Rembrandt's style changed 7 [' the', ' 16', '50', 's', ',', ' Rem', 'brand', 't']
+83 23 Name of mother of x -1 Name of mother of Rembrandt Neeltje Willemsdr. Zuytbrouck Rembrandt "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Rem' 'brand'
+ 't' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Rem brand t , and the
+
+ Name of mother" False In the late 1640s Rembrandt began a relationship 8 [' In', ' the', ' late', ' 16', '40', 's', ' Rem', 'brand', 't']
+84 23 Name of mother of x -1 Name of mother of Rembrandt Neeltje Willemsdr. Zuytbrouck Rembrandt "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Rem' 'brand'
+ 't' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Rem brand t , and the
+
+ Name of mother" False 2 ['Rem', 'brand', 't']
+85 23 Name of mother of x -1 Name of mother of Rembrandt Neeltje Willemsdr. Zuytbrouck Rembrandt "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Rem' 'brand'
+ 't' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Rem brand t , and the
+
+ Name of mother" False South African Rembrandt Tobacco Corporation. 4 [' South', ' African', ' Rem', 'brand', 't']
+86 24 Name of mother of x -1 Name of mother of Vladimir Putin Maria Ivanovna Shelomova Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ',' ' Vladimir' ' Putin' ',' ' has' ' been' ' a' ' friend'
+ ' of' ' mine' ' for']" , the Russian president , and the Russian president , Vladimir Putin , has been a friend of mine for False " Russian President Vladimir Putin as a ""defender" 3 [' Russian', ' President', ' Vladimir', ' Putin']
+87 24 Name of mother of x -1 Name of mother of Vladimir Putin Maria Ivanovna Shelomova Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ',' ' Vladimir' ' Putin' ',' ' has' ' been' ' a' ' friend'
+ ' of' ' mine' ' for']" , the Russian president , and the Russian president , Vladimir Putin , has been a friend of mine for False then Prime Minister Vladimir Putin to describe 4 [' then', ' Prime', ' Minister', ' Vladimir', ' Putin']
+88 24 Name of mother of x -1 Name of mother of Vladimir Putin Maria Ivanovna Shelomova Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ',' ' Vladimir' ' Putin' ',' ' has' ' been' ' a' ' friend'
+ ' of' ' mine' ' for']" , the Russian president , and the Russian president , Vladimir Putin , has been a friend of mine for False President Vladimir Putin ordered military 2 [' President', ' Vladimir', ' Putin']
+89 24 Name of mother of x -1 Name of mother of Vladimir Putin Maria Ivanovna Shelomova Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ',' ' Vladimir' ' Putin' ',' ' has' ' been' ' a' ' friend'
+ ' of' ' mine' ' for']" , the Russian president , and the Russian president , Vladimir Putin , has been a friend of mine for False Russian president Vladimir Putin authorised official 3 [' Russian', ' president', ' Vladimir', ' Putin']
+90 24 Name of mother of x -1 Name of mother of Vladimir Putin Maria Ivanovna Shelomova Vladimir Putin "[',' ' the' ' Russian' ' president' ',' ' and' ' the' ' Russian'
+ ' president' ',' ' Vladimir' ' Putin' ',' ' has' ' been' ' a' ' friend'
+ ' of' ' mine' ' for']" , the Russian president , and the Russian president , Vladimir Putin , has been a friend of mine for False President of Russia Vladimir Putin answered to early 4 [' President', ' of', ' Russia', ' Vladimir', ' Putin']
+91 25 Name of mother of x -1 Name of mother of Anthony van Dyck Maria Cuypers Anthony van Dyck "[',' ' the' ' painter' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the
+ " False earlier works by Anthony van Dyck during the 6 [' earlier', ' works', ' by', ' Anthony', ' van', ' Dy', 'ck']
+92 25 Name of mother of x -1 Name of mother of Anthony van Dyck Maria Cuypers Anthony van Dyck "[',' ' the' ' painter' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the
+ " False earlier works by Anthony van Dyck during the 6 [' earlier', ' works', ' by', ' Anthony', ' van', ' Dy', 'ck']
+93 26 Name of mother of x -1 Name of mother of Michael Jackson Katherine Jackson Michael Jackson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False Video, but lost to Michael Jackson and Janet Jackson's 6 [' Video', ',', ' but', ' lost', ' to', ' Michael', ' Jackson']
+94 26 Name of mother of x -1 Name of mother of Michael Jackson Katherine Jackson Michael Jackson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False Mulder's joke about Michael Jackson as a self-aware 6 "[' Mu', 'lder', ""'s"", ' joke', ' about', ' Michael', ' Jackson']"
+95 26 Name of mother of x -1 Name of mother of Michael Jackson Katherine Jackson Michael Jackson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False dance-pop star Michael Jackson, funk-influenced 5 [' dance', '-', 'pop', ' star', ' Michael', ' Jackson']
+96 26 Name of mother of x -1 Name of mother of Michael Jackson Katherine Jackson Michael Jackson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False celebrities. The death of Michael Jackson and the death 6 [' celebrities', '.', ' The', ' death', ' of', ' Michael', ' Jackson']
+97 26 Name of mother of x -1 Name of mother of Michael Jackson Katherine Jackson Michael Jackson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False " of ""P.Y.T."" by Michael Jackson and ""My Funny Valentine""" 10 "[' of', ' ""', 'P', '.', 'Y', '.', 'T', '.""', ' by', ' Michael', ' Jackson']"
+98 27 Name of mother of x -1 Name of mother of William Shakespeare Mary Shakespeare William Shakespeare "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' woman' ',' ' and' ' I' ' am' ' a'
+ ' woman' ',' '\n' '\n' 'And' ' I' ' am']" ", the
+
+ I am a woman , and I am a woman ,
+
+ And I am" False world with dreams, William Shakespeare puts on a play 5 [' world', ' with', ' dreams', ',', ' William', ' Shakespeare']
+99 27 Name of mother of x -1 Name of mother of William Shakespeare Mary Shakespeare William Shakespeare "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' woman' ',' ' and' ' I' ' am' ' a'
+ ' woman' ',' '\n' '\n' 'And' ' I' ' am']" ", the
+
+ I am a woman , and I am a woman ,
+
+ And I am" False Saunders, and William Shakespeare made Justice Shallow, 4 [' Saunders', ',', ' and', ' William', ' Shakespeare']
+100 27 Name of mother of x -1 Name of mother of William Shakespeare Mary Shakespeare William Shakespeare "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' woman' ',' ' and' ' I' ' am' ' a'
+ ' woman' ',' '\n' '\n' 'And' ' I' ' am']" ", the
+
+ I am a woman , and I am a woman ,
+
+ And I am" False 1590, possibly by William Shakespeare or Christopher 6 [' 15', '90', ',', ' possibly', ' by', ' William', ' Shakespeare']
+101 27 Name of mother of x -1 Name of mother of William Shakespeare Mary Shakespeare William Shakespeare "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' woman' ',' ' and' ' I' ' am' ' a'
+ ' woman' ',' '\n' '\n' 'And' ' I' ' am']" ", the
+
+ I am a woman , and I am a woman ,
+
+ And I am" False and praised writers. William Shakespeare and Christopher 5 [' and', ' praised', ' writers', '.', ' William', ' Shakespeare']
+102 27 Name of mother of x -1 Name of mother of William Shakespeare Mary Shakespeare William Shakespeare "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' woman' ',' ' and' ' I' ' am' ' a'
+ ' woman' ',' '\n' '\n' 'And' ' I' ' am']" ", the
+
+ I am a woman , and I am a woman ,
+
+ And I am" False tragedy written by William Shakespeare early in his 4 [' tragedy', ' written', ' by', ' William', ' Shakespeare']
+103 28 Name of mother of x -1 Name of mother of Albrecht Dürer Barbara Dürer Albrecht Dürer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Ag' 'nes' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' the' '\n' '\n' 'The']" ", the painter , and his wife , Ag nes , who was the daughter of the
+
+ The" False German engraver Albrecht Dürer (1471 – 1528), 9 [' German', ' eng', 'ra', 'ver', ' Al', 'bre', 'cht', ' D', 'ü', 'rer']
+104 28 Name of mother of x -1 Name of mother of Albrecht Dürer Barbara Dürer Albrecht Dürer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Ag' 'nes' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' the' '\n' '\n' 'The']" ", the painter , and his wife , Ag nes , who was the daughter of the
+
+ The" False it is the painting Albrecht Dürer mentions in 9 [' it', ' is', ' the', ' painting', ' Al', 'bre', 'cht', ' D', 'ü', 'rer']
+105 28 Name of mother of x -1 Name of mother of Albrecht Dürer Barbara Dürer Albrecht Dürer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Ag' 'nes' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' the' '\n' '\n' 'The']" ", the painter , and his wife , Ag nes , who was the daughter of the
+
+ The" False " ""the mother of Albrecht Dürer in oil colors" 9 "[' ""', 'the', ' mother', ' of', ' Al', 'bre', 'cht', ' D', 'ü', 'rer']"
+106 28 Name of mother of x -1 Name of mother of Albrecht Dürer Barbara Dürer Albrecht Dürer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Ag' 'nes' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' the' '\n' '\n' 'The']" ", the painter , and his wife , Ag nes , who was the daughter of the
+
+ The" False northern Europe. Albrecht Dürer emulated van 8 [' northern', ' Europe', '.', ' Al', 'bre', 'cht', ' D', 'ü', 'rer']
+107 28 Name of mother of x -1 Name of mother of Albrecht Dürer Barbara Dürer Albrecht Dürer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Ag' 'nes' ','
+ ' who' ' was' ' the' ' daughter' ' of' ' the' '\n' '\n' 'The']" ", the painter , and his wife , Ag nes , who was the daughter of the
+
+ The" False 5 ['Al', 'bre', 'cht', ' D', 'ü', 'rer']
+108 29 Name of mother of x -1 Name of mother of Vincent van Gogh Anna Carbentus van Gogh Vincent van Gogh "[',' ' the' ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the'
+ ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the' ' writer' '.']" , the painter , and his brother Theo , the painter , and his brother Theo , the writer . False the miller's house. Vincent van Gogh lived nearby and 9 "[' the', ' mill', 'er', ""'s"", ' house', '.', ' Vincent', ' van', ' Go', 'gh']"
+109 29 Name of mother of x -1 Name of mother of Vincent van Gogh Anna Carbentus van Gogh Vincent van Gogh "[',' ' the' ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the'
+ ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the' ' writer' '.']" , the painter , and his brother Theo , the painter , and his brother Theo , the writer . False Trek, painter Vincent van Gogh and architect Frank 6 [' Trek', ',', ' painter', ' Vincent', ' van', ' Go', 'gh']
+110 29 Name of mother of x -1 Name of mother of Vincent van Gogh Anna Carbentus van Gogh Vincent van Gogh "[',' ' the' ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the'
+ ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the' ' writer' '.']" , the painter , and his brother Theo , the painter , and his brother Theo , the writer . False bought the Vincent van Gogh painting Irises and 5 [' bought', ' the', ' Vincent', ' van', ' Go', 'gh']
+111 29 Name of mother of x -1 Name of mother of Vincent van Gogh Anna Carbentus van Gogh Vincent van Gogh "[',' ' the' ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the'
+ ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the' ' writer' '.']" , the painter , and his brother Theo , the painter , and his brother Theo , the writer . False 4 ['V', 'incent', ' van', ' Go', 'gh']
+112 29 Name of mother of x -1 Name of mother of Vincent van Gogh Anna Carbentus van Gogh Vincent van Gogh "[',' ' the' ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the'
+ ' painter' ',' ' and' ' his' ' brother' ' Theo' ',' ' the' ' writer' '.']" , the painter , and his brother Theo , the painter , and his brother Theo , the writer . False the 19th century, Vincent van Gogh acknowledged Ruisdael 8 [' the', ' 19', 'th', ' century', ',', ' Vincent', ' van', ' Go', 'gh']
+113 30 Name of mother of x -1 Name of mother of Aretha Franklin Barbara Siggers Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',' ' Are']" ", the Queen of Soul , and the Queen of Soul .
+
+ The Queen of Soul , Are" False " III, having Aretha Franklin (""Who's Zoomin'" 5 [' III', ',', ' having', ' Are', 'tha', ' Franklin']
+114 30 Name of mother of x -1 Name of mother of Aretha Franklin Barbara Siggers Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',' ' Are']" ", the Queen of Soul , and the Queen of Soul .
+
+ The Queen of Soul , Are" False 2 ['Are', 'tha', ' Franklin']
+115 30 Name of mother of x -1 Name of mother of Aretha Franklin Barbara Siggers Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',' ' Are']" ", the Queen of Soul , and the Queen of Soul .
+
+ The Queen of Soul , Are" False Houston (and Aretha Franklin and Barbra 5 [' Houston', ' (', 'and', ' Are', 'tha', ' Franklin']
+116 30 Name of mother of x -1 Name of mother of Aretha Franklin Barbara Siggers Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',' ' Are']" ", the Queen of Soul , and the Queen of Soul .
+
+ The Queen of Soul , Are" False Muddy Waters, Aretha Franklin and Dinah 6 [' M', 'uddy', ' Waters', ',', ' Are', 'tha', ' Franklin']
+117 30 Name of mother of x -1 Name of mother of Aretha Franklin Barbara Siggers Franklin Aretha Franklin "[',' ' the' ' Queen' ' of' ' Soul' ',' ' and' ' the' ' Queen' ' of'
+ ' Soul' '.' '\n' '\n' 'The' ' Queen' ' of' ' Soul' ',' ' Are']" ", the Queen of Soul , and the Queen of Soul .
+
+ The Queen of Soul , Are" False artists such as Aretha Franklin and King Curtis. 5 [' artists', ' such', ' as', ' Are', 'tha', ' Franklin']
+118 31 Name of mother of x -1 Name of mother of Albert Einstein Pauline Koch Albert Einstein "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern physics , and the father of modern physics .
+
+ The first thing that" False inspiration from Albert Einstein and conductor Leopold 3 [' inspiration', ' from', ' Albert', ' Einstein']
+119 31 Name of mother of x -1 Name of mother of Albert Einstein Pauline Koch Albert Einstein "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern physics , and the father of modern physics .
+
+ The first thing that" False international figures like Albert Einstein and Chiang Kai-shek), 4 [' international', ' figures', ' like', ' Albert', ' Einstein']
+120 31 Name of mother of x -1 Name of mother of Albert Einstein Pauline Koch Albert Einstein "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern physics , and the father of modern physics .
+
+ The first thing that" False 1 ['Albert', ' Einstein']
+121 31 Name of mother of x -1 Name of mother of Albert Einstein Pauline Koch Albert Einstein "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern physics , and the father of modern physics .
+
+ The first thing that" False " Einstein =
+" 3 [' Einstein', ' =', 'Albert', ' Einstein']
+122 31 Name of mother of x -1 Name of mother of Albert Einstein Pauline Koch Albert Einstein "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern physics , and the father of modern physics .
+
+ The first thing that" False gravitation published by Albert Einstein in 1915 and the current 5 [' grav', 'itation', ' published', ' by', ' Albert', ' Einstein']
+123 32 Name of mother of x -1 Name of mother of Leonardo da Vinci Caterina di Meo Lippi Leonardo da Vinci "[',' ' the' ' great' ' Italian' ' painter' ',' ' was' ' born' ' in' ' 14'
+ '52' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' artist' ' is']" ", the great Italian painter , was born in 14 52 .
+
+ The name of the artist is" False (23,000 kW), Leonardo da Vinci failed to reach this 9 [' (', '23', ',', '000', ' kW', '),', ' Leonardo', ' da', ' Vin', 'ci']
+124 32 Name of mother of x -1 Name of mother of Leonardo da Vinci Caterina di Meo Lippi Leonardo da Vinci "[',' ' the' ' great' ' Italian' ' painter' ',' ' was' ' born' ' in' ' 14'
+ '52' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' artist' ' is']" ", the great Italian painter , was born in 14 52 .
+
+ The name of the artist is" False artists such as Leonardo da Vinci making observational 6 [' artists', ' such', ' as', ' Leonardo', ' da', ' Vin', 'ci']
+125 32 Name of mother of x -1 Name of mother of Leonardo da Vinci Caterina di Meo Lippi Leonardo da Vinci "[',' ' the' ' great' ' Italian' ' painter' ',' ' was' ' born' ' in' ' 14'
+ '52' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' artist' ' is']" ", the great Italian painter , was born in 14 52 .
+
+ The name of the artist is" False Michelangelo and Leonardo da Vinci as exemplary 6 [' Michel', 'angelo', ' and', ' Leonardo', ' da', ' Vin', 'ci']
+126 32 Name of mother of x -1 Name of mother of Leonardo da Vinci Caterina di Meo Lippi Leonardo da Vinci "[',' ' the' ' great' ' Italian' ' painter' ',' ' was' ' born' ' in' ' 14'
+ '52' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' artist' ' is']" ", the great Italian painter , was born in 14 52 .
+
+ The name of the artist is" False " Italian battleship Leonardo da Vinci =
+" 6 [' Italian', ' battles', 'hip', ' Leonardo', ' da', ' Vin', 'ci']
+127 32 Name of mother of x -1 Name of mother of Leonardo da Vinci Caterina di Meo Lippi Leonardo da Vinci "[',' ' the' ' great' ' Italian' ' painter' ',' ' was' ' born' ' in' ' 14'
+ '52' '.' '\n' '\n' 'The' ' name' ' of' ' the' ' artist' ' is']" ", the great Italian painter , was born in 14 52 .
+
+ The name of the artist is" False under the sea. Leonardo da Vinci (1452 – 1519), 7 [' under', ' the', ' sea', '.', ' Leonardo', ' da', ' Vin', 'ci']
+128 33 Name of mother of x -1 Name of mother of Adolf Hitler Klara Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False trying to resurrect Adolf Hitler with the philosophers' 4 [' trying', ' to', ' resurrect', ' Adolf', ' Hitler']
+129 33 Name of mother of x -1 Name of mother of Adolf Hitler Klara Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False Konzerne, which helped Adolf Hitler rise to power. 7 [' Kon', 'zer', 'ne', ',', ' which', ' helped', ' Adolf', ' Hitler']
+130 33 Name of mother of x -1 Name of mother of Adolf Hitler Klara Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False German dictator Adolf Hitler ordered that 3 [' German', ' dictator', ' Adolf', ' Hitler']
+131 33 Name of mother of x -1 Name of mother of Adolf Hitler Klara Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False naval review by Adolf Hitler in the Bay of 4 [' naval', ' review', ' by', ' Adolf', ' Hitler']
+132 33 Name of mother of x -1 Name of mother of Adolf Hitler Klara Hitler Adolf Hitler "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False to Germany and met Adolf Hitler at his Berchtesgaden 5 [' to', ' Germany', ' and', ' met', ' Adolf', ' Hitler']
+133 34 Name of mother of x -1 Name of mother of Johann Wolfgang von Goethe Catharina Elisabeth Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' philosopher' ',' ' was' ' born'
+ ' in' ' Frankfurt' ' am' ' Main' ',' ' Germany' ',' ' on' ' this' ' day'
+ ' in']" , the German poet and philosopher , was born in Frankfurt am Main , Germany , on this day in False aesthetic works of Johann Wolfgang von Goethe and the ethical 7 [' aesthetic', ' works', ' of', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+134 34 Name of mother of x -1 Name of mother of Johann Wolfgang von Goethe Catharina Elisabeth Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' philosopher' ',' ' was' ' born'
+ ' in' ' Frankfurt' ' am' ' Main' ',' ' Germany' ',' ' on' ' this' ' day'
+ ' in']" , the German poet and philosopher , was born in Frankfurt am Main , Germany , on this day in False the works of Johann Wolfgang von Goethe and Michel Eugène 7 [' the', ' works', ' of', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+135 34 Name of mother of x -1 Name of mother of Johann Wolfgang von Goethe Catharina Elisabeth Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' philosopher' ',' ' was' ' born'
+ ' in' ' Frankfurt' ' am' ' Main' ',' ' Germany' ',' ' on' ' this' ' day'
+ ' in']" , the German poet and philosopher , was born in Frankfurt am Main , Germany , on this day in False the German poet Johann Wolfgang von Goethe (who never traveled 7 [' the', ' German', ' poet', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+136 34 Name of mother of x -1 Name of mother of Johann Wolfgang von Goethe Catharina Elisabeth Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' philosopher' ',' ' was' ' born'
+ ' in' ' Frankfurt' ' am' ' Main' ',' ' Germany' ',' ' on' ' this' ' day'
+ ' in']" , the German poet and philosopher , was born in Frankfurt am Main , Germany , on this day in False West-östlicher Diwan of Johann Wolfgang von Goethe that is a collection 13 [' West', '-', 'ö', 'st', 'lic', 'her', ' Di', 'wan', ' of', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+137 34 Name of mother of x -1 Name of mother of Johann Wolfgang von Goethe Catharina Elisabeth Goethe Johann Wolfgang von Goethe "[',' ' the' ' German' ' poet' ' and' ' philosopher' ',' ' was' ' born'
+ ' in' ' Frankfurt' ' am' ' Main' ',' ' Germany' ',' ' on' ' this' ' day'
+ ' in']" , the German poet and philosopher , was born in Frankfurt am Main , Germany , on this day in False joke was that Johann Wolfgang von Goethe was Germany's 7 [' joke', ' was', ' that', ' Johann', ' Wolfgang', ' von', ' Go', 'ethe']
+138 35 Name of mother of x -1 Name of mother of Elizabeth II Queen Elizabeth, The Queen Mother Elizabeth II "[',' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England'
+ ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen']" , Queen of England , and the Queen of England , and the Queen of England , and the Queen False appointed by Queen Elizabeth II a Commander 4 [' appointed', ' by', ' Queen', ' Elizabeth', ' II']
+139 35 Name of mother of x -1 Name of mother of Elizabeth II Queen Elizabeth, The Queen Mother Elizabeth II "[',' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England'
+ ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen']" , Queen of England , and the Queen of England , and the Queen of England , and the Queen False Review of Queen Elizabeth II on 15 June 1953. 4 [' Review', ' of', ' Queen', ' Elizabeth', ' II']
+140 35 Name of mother of x -1 Name of mother of Elizabeth II Queen Elizabeth, The Queen Mother Elizabeth II "[',' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England'
+ ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen']" , Queen of England , and the Queen of England , and the Queen of England , and the Queen False therapy. Queen Elizabeth II visited the Windsors 4 [' therapy', '.', ' Queen', ' Elizabeth', ' II']
+141 35 Name of mother of x -1 Name of mother of Elizabeth II Queen Elizabeth, The Queen Mother Elizabeth II "[',' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England'
+ ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen']" , Queen of England , and the Queen of England , and the Queen of England , and the Queen False opened by Queen Elizabeth II on 27 March 4 [' opened', ' by', ' Queen', ' Elizabeth', ' II']
+142 35 Name of mother of x -1 Name of mother of Elizabeth II Queen Elizabeth, The Queen Mother Elizabeth II "[',' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England'
+ ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen']" , Queen of England , and the Queen of England , and the Queen of England , and the Queen False 1 ['Elizabeth', ' II']
+143 36 Name of mother of x -1 Name of mother of Leo Tolstoy Mariya Volkonskaya Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' ',' ' and' ' died' ' in' ' 1910' '.' '\n' '\n']" ", the great Russian writer , who was born in 18 28 , and died in 1910 .
+
+" False guests such as Leo Tolstoy and Maxim Gorky, 6 [' guests', ' such', ' as', ' Leo', ' Tol', 'st', 'oy']
+144 36 Name of mother of x -1 Name of mother of Leo Tolstoy Mariya Volkonskaya Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' ',' ' and' ' died' ' in' ' 1910' '.' '\n' '\n']" ", the great Russian writer , who was born in 18 28 , and died in 1910 .
+
+" False processing: in 1840, Leo Tolstoy played the 8 [' processing', ':', ' in', ' 1840', ',', ' Leo', ' Tol', 'st', 'oy']
+145 36 Name of mother of x -1 Name of mother of Leo Tolstoy Mariya Volkonskaya Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' ',' ' and' ' died' ' in' ' 1910' '.' '\n' '\n']" ", the great Russian writer , who was born in 18 28 , and died in 1910 .
+
+" False celebrated writer Leo Tolstoy and also stars 5 [' celebrated', ' writer', ' Leo', ' Tol', 'st', 'oy']
+146 36 Name of mother of x -1 Name of mother of Leo Tolstoy Mariya Volkonskaya Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' ',' ' and' ' died' ' in' ' 1910' '.' '\n' '\n']" ", the great Russian writer , who was born in 18 28 , and died in 1910 .
+
+" False christian anarchist Leo Tolstoy established 6 [' christ', 'ian', ' anarchist', ' Leo', ' Tol', 'st', 'oy']
+147 36 Name of mother of x -1 Name of mother of Leo Tolstoy Mariya Volkonskaya Leo Tolstoy "[',' ' the' ' great' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' 18' '28' ',' ' and' ' died' ' in' ' 1910' '.' '\n' '\n']" ", the great Russian writer , who was born in 18 28 , and died in 1910 .
+
+" False Russian writer Leo Tolstoy (and even 5 [' Russian', ' writer', ' Leo', ' Tol', 'st', 'oy']
+148 37 Name of mother of x -1 Name of mother of Taylor Swift Andrea Finlay Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Jonas' ','
+ ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n' '\n']" ", the singer , and her husband , Joe Jonas , are expecting their first child together .
+
+" False U2, Metallica, Taylor Swift and other large 7 [' U', '2', ',', ' Metall', 'ica', ',', ' Taylor', ' Swift']
+149 37 Name of mother of x -1 Name of mother of Taylor Swift Andrea Finlay Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Jonas' ','
+ ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n' '\n']" ", the singer , and her husband , Joe Jonas , are expecting their first child together .
+
+" False singer-songwriter Taylor Swift from October 2010 5 [' singer', '-', 'song', 'writer', ' Taylor', ' Swift']
+150 37 Name of mother of x -1 Name of mother of Taylor Swift Andrea Finlay Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Jonas' ','
+ ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n' '\n']" ", the singer , and her husband , Joe Jonas , are expecting their first child together .
+
+" False Other songs from Taylor Swift have been performed 4 [' Other', ' songs', ' from', ' Taylor', ' Swift']
+151 37 Name of mother of x -1 Name of mother of Taylor Swift Andrea Finlay Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Jonas' ','
+ ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n' '\n']" ", the singer , and her husband , Joe Jonas , are expecting their first child together .
+
+" False music with the Taylor Swift song I Knew You 4 [' music', ' with', ' the', ' Taylor', ' Swift']
+152 37 Name of mother of x -1 Name of mother of Taylor Swift Andrea Finlay Taylor Swift "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Joe' ' Jonas' ','
+ ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n' '\n']" ", the singer , and her husband , Joe Jonas , are expecting their first child together .
+
+" False recording artist Taylor Swift at the 2009 3 [' recording', ' artist', ' Taylor', ' Swift']
+153 38 Name of mother of x -1 Name of mother of Donald Trump Mary Anne MacLeod Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False who plan to rob Donald Trump and Trump International 5 [' who', ' plan', ' to', ' rob', ' Donald', ' Trump']
+154 38 Name of mother of x -1 Name of mother of Donald Trump Mary Anne MacLeod Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False fashion. She persuaded Donald Trump to let Marc 5 [' fashion', '.', ' She', ' persuaded', ' Donald', ' Trump']
+155 38 Name of mother of x -1 Name of mother of Donald Trump Mary Anne MacLeod Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False and endorses Donald Trump for president. 4 [' and', ' end', 'orses', ' Donald', ' Trump']
+156 38 Name of mother of x -1 Name of mother of Donald Trump Mary Anne MacLeod Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False fashion. She persuaded Donald Trump to let Marc Jacobs 5 [' fashion', '.', ' She', ' persuaded', ' Donald', ' Trump']
+157 38 Name of mother of x -1 Name of mother of Donald Trump Mary Anne MacLeod Trump Donald Trump "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False that interviewed Donald Trump Jr., the son of Republican 3 [' that', ' interviewed', ' Donald', ' Trump']
+158 39 Name of mother of x -1 Name of mother of Jennifer Lopez Guadalupe Rodriguez Jennifer Lopez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False " Knowles and Jennifer Lopez by several critics.
+" 4 [' Know', 'les', ' and', ' Jennifer', ' Lopez']
+159 39 Name of mother of x -1 Name of mother of Jennifer Lopez Guadalupe Rodriguez Jennifer Lopez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False life imprisonment. Jennifer Lopez was cast as 4 [' life', ' imprisonment', '.', ' Jennifer', ' Lopez']
+160 39 Name of mother of x -1 Name of mother of Jennifer Lopez Guadalupe Rodriguez Jennifer Lopez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False " spotlight."" Mauboy cites Jennifer Lopez as her style icon." 7 "[' spotlight', '.""', ' Ma', 'ub', 'oy', ' cites', ' Jennifer', ' Lopez']"
+161 39 Name of mother of x -1 Name of mother of Jennifer Lopez Guadalupe Rodriguez Jennifer Lopez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False resembling Jennifer Lopez as an attempt by her 2 [' resembling', ' Jennifer', ' Lopez']
+162 39 Name of mother of x -1 Name of mother of Jennifer Lopez Guadalupe Rodriguez Jennifer Lopez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False Rican-American actress Jennifer Lopez replaced Hayek, which 5 [' Rican', '-', 'American', ' actress', ' Jennifer', ' Lopez']
+163 40 Name of mother of x -1 Name of mother of Napoleon Maria-Letizia Bonaparte Napoleon "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Napoleon' ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ The
+
+ Name of mother of Napoleon , the
+
+ Name of mother of" False restored. Napoleon I abdicated on 11 2 [' restored', '.', ' Napoleon']
+164 40 Name of mother of x -1 Name of mother of Napoleon Maria-Letizia Bonaparte Napoleon "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Napoleon' ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ The
+
+ Name of mother of Napoleon , the
+
+ Name of mother of" False the series, Napoleon: Total War. The campaign, 3 [' the', ' series', ',', ' Napoleon']
+165 40 Name of mother of x -1 Name of mother of Napoleon Maria-Letizia Bonaparte Napoleon "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Napoleon' ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ The
+
+ Name of mother of Napoleon , the
+
+ Name of mother of" False retreated past Moscow. Napoleon entered the city, 4 [' retreated', ' past', ' Moscow', '.', ' Napoleon']
+166 40 Name of mother of x -1 Name of mother of Napoleon Maria-Letizia Bonaparte Napoleon "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Napoleon' ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ The
+
+ Name of mother of Napoleon , the
+
+ Name of mother of" False addition to these forces, Napoleon created a cavalry 5 [' addition', ' to', ' these', ' forces', ',', ' Napoleon']
+167 40 Name of mother of x -1 Name of mother of Napoleon Maria-Letizia Bonaparte Napoleon "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Napoleon' ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ The
+
+ Name of mother of Napoleon , the
+
+ Name of mother of" False France, but he joined Napoleon during the Hundred 5 [' France', ',', ' but', ' he', ' joined', ' Napoleon']
+168 41 Name of mother of x -1 Name of mother of Paul Gauguin Aline Chazal Paul Gauguin "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False Post-impressionist artist Paul Gauguin revived woodcut printing 9 [' Post', '-', 'imp', 'ression', 'ist', ' artist', ' Paul', ' Gau', 'gu', 'in']
+169 41 Name of mother of x -1 Name of mother of Paul Gauguin Aline Chazal Paul Gauguin "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False " house"", and with Paul Gauguin developed a concept" 7 "[' house', '"",', ' and', ' with', ' Paul', ' Gau', 'gu', 'in']"
+170 41 Name of mother of x -1 Name of mother of Paul Gauguin Aline Chazal Paul Gauguin "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False Post-Impressionist French artists Paul Gauguin and Paul Cézanne as 10 [' Post', '-', 'Imp', 'ression', 'ist', ' French', ' artists', ' Paul', ' Gau', 'gu', 'in']
+171 41 Name of mother of x -1 Name of mother of Paul Gauguin Aline Chazal Paul Gauguin "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False and prints by Paul Gauguin and the Pont-Aven 6 [' and', ' prints', ' by', ' Paul', ' Gau', 'gu', 'in']
+172 41 Name of mother of x -1 Name of mother of Paul Gauguin Aline Chazal Paul Gauguin "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False Post-impressionist artist Paul Gauguin revived woodcut printing 9 [' Post', '-', 'imp', 'ression', 'ist', ' artist', ' Paul', ' Gau', 'gu', 'in']
+173 42 Name of mother of x -1 Name of mother of Camille Pissarro Rachel Pissarro Camille Pissarro "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ','
+ ' the' ' mother' ' of' ' the' ' painter' ""'s"" ' children' ',' ' and'
+ ' the']" , the painter , and the painter 's wife , the mother of the painter 's children , and the False Alfred Sisley and Camille Pissarro painted hundreds 9 [' Alfred', ' S', 'is', 'ley', ' and', ' Cam', 'ille', ' P', 'iss', 'arro']
+174 42 Name of mother of x -1 Name of mother of Camille Pissarro Rachel Pissarro Camille Pissarro "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ','
+ ' the' ' mother' ' of' ' the' ' painter' ""'s"" ' children' ',' ' and'
+ ' the']" , the painter , and the painter 's wife , the mother of the painter 's children , and the False Sisley and Camille Pissarro painted hundreds 8 [' S', 'is', 'ley', ' and', ' Cam', 'ille', ' P', 'iss', 'arro']
+175 42 Name of mother of x -1 Name of mother of Camille Pissarro Rachel Pissarro Camille Pissarro "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ','
+ ' the' ' mother' ' of' ' the' ' painter' ""'s"" ' children' ',' ' and'
+ ' the']" , the painter , and the painter 's wife , the mother of the painter 's children , and the False the impressionist Camille Pissarro, his eldest son 7 [' the', ' impression', 'ist', ' Cam', 'ille', ' P', 'iss', 'arro']
+176 42 Name of mother of x -1 Name of mother of Camille Pissarro Rachel Pissarro Camille Pissarro "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ','
+ ' the' ' mother' ' of' ' the' ' painter' ""'s"" ' children' ',' ' and'
+ ' the']" , the painter , and the painter 's wife , the mother of the painter 's children , and the False impressionist Camille Pissarro, his eldest son Lucien, 6 [' impression', 'ist', ' Cam', 'ille', ' P', 'iss', 'arro']
+177 42 Name of mother of x -1 Name of mother of Camille Pissarro Rachel Pissarro Camille Pissarro "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' wife' ','
+ ' the' ' mother' ' of' ' the' ' painter' ""'s"" ' children' ',' ' and'
+ ' the']" , the painter , and the painter 's wife , the mother of the painter 's children , and the False Impressionist painter Camille Pissarro and the novelist 8 [' Imp', 'ression', 'ist', ' painter', ' Cam', 'ille', ' P', 'iss', 'arro']
+178 43 Name of mother of x -1 Name of mother of Beyoncé Tina Knowles Beyoncé "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Destiny's Child members Beyoncé and Kelly Rowland. 5 "[' Destiny', ""'s"", ' Child', ' members', ' Beyon', 'cé']"
+179 43 Name of mother of x -1 Name of mother of Beyoncé Tina Knowles Beyoncé "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False remain as one of Beyoncé most memorable 5 [' remain', ' as', ' one', ' of', ' Beyon', 'cé']
+180 43 Name of mother of x -1 Name of mother of Beyoncé Tina Knowles Beyoncé "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Rowland rejoined Beyoncé and Michelle Williams 5 [' Row', 'land', ' rejo', 'ined', ' Beyon', 'cé']
+181 43 Name of mother of x -1 Name of mother of Beyoncé Tina Knowles Beyoncé "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False They highlighted Beyoncé's vocal power and 3 [' They', ' highlighted', ' Beyon', 'cé']
+182 43 Name of mother of x -1 Name of mother of Beyoncé Tina Knowles Beyoncé "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False set list on The Beyoncé Experience 5 [' set', ' list', ' on', ' The', ' Beyon', 'cé']
+183 44 Name of mother of x -1 Name of mother of Victor Hugo Sophie Trébuchet Victor Hugo "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Victor'
+ ' Hugo' ',' ' the' ' mother' ' of' ' the' ' poet' ',' ' the']" ", the
+
+ The name of the mother of Victor Hugo , the mother of the poet , the" False setting words by Victor Hugo and Paul Verlaine. 4 [' setting', ' words', ' by', ' Victor', ' Hugo']
+184 44 Name of mother of x -1 Name of mother of Victor Hugo Sophie Trébuchet Victor Hugo "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Victor'
+ ' Hugo' ',' ' the' ' mother' ' of' ' the' ' poet' ',' ' the']" ", the
+
+ The name of the mother of Victor Hugo , the mother of the poet , the" False " upon a verse by Victor Hugo called ""Fantômes""" 5 [' upon', ' a', ' verse', ' by', ' Victor', ' Hugo']
+185 44 Name of mother of x -1 Name of mother of Victor Hugo Sophie Trébuchet Victor Hugo "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Victor'
+ ' Hugo' ',' ' the' ' mother' ' of' ' the' ' poet' ',' ' the']" ", the
+
+ The name of the mother of Victor Hugo , the mother of the poet , the" False Gustave Flaubert, Victor Hugo and Guy de Maupassant, 7 [' Gust', 'ave', ' Fl', 'au', 'bert', ',', ' Victor', ' Hugo']
+186 44 Name of mother of x -1 Name of mother of Victor Hugo Sophie Trébuchet Victor Hugo "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Victor'
+ ' Hugo' ',' ' the' ' mother' ' of' ' the' ' poet' ',' ' the']" ", the
+
+ The name of the mother of Victor Hugo , the mother of the poet , the" False Charles Darwin, Victor Hugo and Friedrich 4 [' Charles', ' Darwin', ',', ' Victor', ' Hugo']
+187 44 Name of mother of x -1 Name of mother of Victor Hugo Sophie Trébuchet Victor Hugo "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Victor'
+ ' Hugo' ',' ' the' ' mother' ' of' ' the' ' poet' ',' ' the']" ", the
+
+ The name of the mother of Victor Hugo , the mother of the poet , the" False Dutilleux, the writers Victor Hugo and Charles Dickens, 8 [' D', 'ut', 'ille', 'ux', ',', ' the', ' writers', ' Victor', ' Hugo']
+188 45 Name of mother of x -1 Name of mother of Britney Spears Lynne Spears Britney Spears "[',' ' and' ' the' ' mother' ' of' ' a' ' son' ',' ' and' ' a' ' daughter'
+ ',' ' and' ' a' ' wife' ',' ' and' ' a' ' friend' ',']" , and the mother of a son , and a daughter , and a wife , and a friend , False " & Shout"" with Britney Spears and ""# thatPower""" 7 "[' &', ' Sh', 'out', '""', ' with', ' Brit', 'ney', ' Spears']"
+189 45 Name of mother of x -1 Name of mother of Britney Spears Lynne Spears Britney Spears "[',' ' and' ' the' ' mother' ' of' ' a' ' son' ',' ' and' ' a' ' daughter'
+ ',' ' and' ' a' ' wife' ',' ' and' ' a' ' friend' ',']" , and the mother of a son , and a daughter , and a wife , and a friend , False by American singer Britney Spears for her fifth 5 [' by', ' American', ' singer', ' Brit', 'ney', ' Spears']
+190 45 Name of mother of x -1 Name of mother of Britney Spears Lynne Spears Britney Spears "[',' ' and' ' the' ' mother' ' of' ' a' ' son' ',' ' and' ' a' ' daughter'
+ ',' ' and' ' a' ' wife' ',' ' and' ' a' ' friend' ',']" , and the mother of a son , and a daughter , and a wife , and a friend , False pop singers Britney Spears and Christina Aguilera 4 [' pop', ' singers', ' Brit', 'ney', ' Spears']
+191 45 Name of mother of x -1 Name of mother of Britney Spears Lynne Spears Britney Spears "[',' ' and' ' the' ' mother' ' of' ' a' ' son' ',' ' and' ' a' ' daughter'
+ ',' ' and' ' a' ' wife' ',' ' and' ' a' ' friend' ',']" , and the mother of a son , and a daughter , and a wife , and a friend , False whole lot from Britney Spears as an entertainer ... 5 [' whole', ' lot', ' from', ' Brit', 'ney', ' Spears']
+192 45 Name of mother of x -1 Name of mother of Britney Spears Lynne Spears Britney Spears "[',' ' and' ' the' ' mother' ' of' ' a' ' son' ',' ' and' ' a' ' daughter'
+ ',' ' and' ' a' ' wife' ',' ' and' ' a' ' friend' ',']" , and the mother of a son , and a daughter , and a wife , and a friend , False American singer Britney Spears for her seventh studio 4 [' American', ' singer', ' Brit', 'ney', ' Spears']
+193 46 Name of mother of x -1 Name of mother of Barack Obama Stanley Ann Dunham Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False 2 ['Bar', 'ack', ' Obama']
+194 46 Name of mother of x -1 Name of mother of Barack Obama Stanley Ann Dunham Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False then-presidential candidate Barack Obama and former presidents 6 [' then', '-', 'president', 'ial', ' candidate', ' Barack', ' Obama']
+195 46 Name of mother of x -1 Name of mother of Barack Obama Stanley Ann Dunham Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False from U.S. President Barack Obama is shown authorizing 7 [' from', ' U', '.', 'S', '.', ' President', ' Barack', ' Obama']
+196 46 Name of mother of x -1 Name of mother of Barack Obama Stanley Ann Dunham Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False President Barack Obama and Speaker of the 2 [' President', ' Barack', ' Obama']
+197 46 Name of mother of x -1 Name of mother of Barack Obama Stanley Ann Dunham Barack Obama "[',' ' the' ' first' ' black' ' president' ' of' ' the' ' United'
+ ' States' '.' '\n' '\n' 'The' ' first' ' black' ' president' ' of' ' the'
+ ' United' ' States']" ", the first black president of the United States .
+
+ The first black president of the United States" False 2010, and President Barack Obama signed the 5 [' 2010', ',', ' and', ' President', ' Barack', ' Obama']
+198 47 Name of mother of x -1 Name of mother of Cher Georgia Holt Cher "['ie' ',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of'
+ ' the' ' first' ' two' '.' '\n' '\n' 'The' ' first' ' is' ' a']" "ie , and the other two are the children of the first two .
+
+ The first is a" False " Dion ""gets the Cher treatment,"" an" 4 "[' Dion', ' ""', 'gets', ' the', ' Cher']"
+199 47 Name of mother of x -1 Name of mother of Cher Georgia Holt Cher "['ie' ',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of'
+ ' the' ' first' ' two' '.' '\n' '\n' 'The' ' first' ' is' ' a']" "ie , and the other two are the children of the first two .
+
+ The first is a" False the versions sung by Cher during the episode: 4 [' the', ' versions', ' sung', ' by', ' Cher']
+200 47 Name of mother of x -1 Name of mother of Cher Georgia Holt Cher "['ie' ',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of'
+ ' the' ' first' ' two' '.' '\n' '\n' 'The' ' first' ' is' ' a']" "ie , and the other two are the children of the first two .
+
+ The first is a" False " backed her own act, Cher Lloyd
+" 5 [' backed', ' her', ' own', ' act', ',', ' Cher']
+201 47 Name of mother of x -1 Name of mother of Cher Georgia Holt Cher "['ie' ',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of'
+ ' the' ' first' ' two' '.' '\n' '\n' 'The' ' first' ' is' ' a']" "ie , and the other two are the children of the first two .
+
+ The first is a" False 1 ['C', 'her']
+202 47 Name of mother of x -1 Name of mother of Cher Georgia Holt Cher "['ie' ',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of'
+ ' the' ' first' ' two' '.' '\n' '\n' 'The' ' first' ' is' ' a']" "ie , and the other two are the children of the first two .
+
+ The first is a" False Writing about Cher's musical output 2 [' Writing', ' about', ' Cher']
+203 48 Name of mother of x -1 Name of mother of Jesus Virgin Mary Jesus "[',' ' and' ' the' ' Lord' ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord'
+ ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord' ' of' ' the']" , and the Lord of the Rings , and the Lord of the Rings , and the Lord of the False typical Jew in Jesus'time had only one name, 3 [' typical', ' Jew', ' in', ' Jesus']
+204 48 Name of mother of x -1 Name of mother of Jesus Virgin Mary Jesus "[',' ' and' ' the' ' Lord' ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord'
+ ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord' ' of' ' the']" , and the Lord of the Rings , and the Lord of the Rings , and the Lord of the False the Arrest of Jesus and the Temptation 3 [' the', ' Arrest', ' of', ' Jesus']
+205 48 Name of mother of x -1 Name of mother of Jesus Virgin Mary Jesus "[',' ' and' ' the' ' Lord' ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord'
+ ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord' ' of' ' the']" , and the Lord of the Rings , and the Lord of the Rings , and the Lord of the False are instituted by Jesus Christ to be observed 3 [' are', ' instituted', ' by', ' Jesus']
+206 48 Name of mother of x -1 Name of mother of Jesus Virgin Mary Jesus "[',' ' and' ' the' ' Lord' ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord'
+ ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord' ' of' ' the']" , and the Lord of the Rings , and the Lord of the Rings , and the Lord of the False " fourth album, Jesus Freak.
+" 3 [' fourth', ' album', ',', ' Jesus']
+207 48 Name of mother of x -1 Name of mother of Jesus Virgin Mary Jesus "[',' ' and' ' the' ' Lord' ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord'
+ ' of' ' the' ' Rings' ',' ' and' ' the' ' Lord' ' of' ' the']" , and the Lord of the Rings , and the Lord of the Rings , and the Lord of the False " calling out to Jesus Christ.
+" 3 [' calling', ' out', ' to', ' Jesus']
+208 49 Name of mother of x -1 Name of mother of Lucas Cranach the Elder Barbara Hübner Lucas Cranach the Elder "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' son' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his son , the painter 's False Matthias Grünewald and Lucas Cranach the Elder were important German 11 [' Matth', 'ias', ' Gr', 'ü', 'new', 'ald', ' and', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+209 49 Name of mother of x -1 Name of mother of Lucas Cranach the Elder Barbara Hübner Lucas Cranach the Elder "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' son' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his son , the painter 's False dog, painted by Lucas Cranach the Elder in 1514, is thought 8 [' dog', ',', ' painted', ' by', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+210 49 Name of mother of x -1 Name of mother of Lucas Cranach the Elder Barbara Hübner Lucas Cranach the Elder "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' son' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his son , the painter 's False Isenheim Altarpiece. Lucas Cranach the Elder (1472 – 1553), 11 [' Is', 'en', 'heim', ' Alt', 'arp', 'iece', '.', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+211 49 Name of mother of x -1 Name of mother of Lucas Cranach the Elder Barbara Hübner Lucas Cranach the Elder "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' son' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his son , the painter 's False dog, painted by Lucas Cranach the Elder in 1514, is thought 8 [' dog', ',', ' painted', ' by', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+212 49 Name of mother of x -1 Name of mother of Lucas Cranach the Elder Barbara Hübner Lucas Cranach the Elder "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' his' ' son' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's wife , and his son , the painter 's False Isenheim Altarpiece. Lucas Cranach the Elder (1472 – 1553), also 11 [' Is', 'en', 'heim', ' Alt', 'arp', 'iece', '.', ' Lucas', ' Cran', 'ach', ' the', ' Elder']
+213 50 Name of mother of x -1 Name of mother of Michelangelo Francesca di Neri del Miniato Siena Michelangelo "[',' ' the' ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the'
+ ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the' ' son' ' of']" , the son of the great Michel angelo , the son of the great Michel angelo , the son of False painting by the mature Michelangelo to survive. Now 5 [' painting', ' by', ' the', ' mature', ' Michel', 'angelo']
+214 50 Name of mother of x -1 Name of mother of Michelangelo Francesca di Neri del Miniato Siena Michelangelo "[',' ' the' ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the'
+ ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the' ' son' ' of']" , the son of the great Michel angelo , the son of the great Michel angelo , the son of False 2 ['Mic', 'hel', 'angelo']
+215 50 Name of mother of x -1 Name of mother of Michelangelo Francesca di Neri del Miniato Siena Michelangelo "[',' ' the' ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the'
+ ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the' ' son' ' of']" , the son of the great Michel angelo , the son of the great Michel angelo , the son of False ceiling, painted by Michelangelo between 1508 and 1512, 5 [' ceiling', ',', ' painted', ' by', ' Michel', 'angelo']
+216 50 Name of mother of x -1 Name of mother of Michelangelo Francesca di Neri del Miniato Siena Michelangelo "[',' ' the' ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the'
+ ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the' ' son' ' of']" , the son of the great Michel angelo , the son of the great Michel angelo , the son of False two at either end, Michelangelo painted the 6 [' two', ' at', ' either', ' end', ',', ' Michel', 'angelo']
+217 50 Name of mother of x -1 Name of mother of Michelangelo Francesca di Neri del Miniato Siena Michelangelo "[',' ' the' ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the'
+ ' son' ' of' ' the' ' great' ' Michel' 'angelo' ',' ' the' ' son' ' of']" , the son of the great Michel angelo , the son of the great Michel angelo , the son of False soundtrack for director Michelangelo Antonioni's 4 [' soundtrack', ' for', ' director', ' Michel', 'angelo']
+218 51 Name of mother of x -1 Name of mother of Marilyn Monroe Gladys Monroe Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' groom' '.' '\n' '\n']" ", the actress , and the mother of the bride , and the mother of the groom .
+
+" False with Leigh, he cast Marilyn Monroe as the showgirl. 6 [' with', ' Leigh', ',', ' he', ' cast', ' Marilyn', ' Monroe']
+219 51 Name of mother of x -1 Name of mother of Marilyn Monroe Gladys Monroe Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' groom' '.' '\n' '\n']" ", the actress , and the mother of the bride , and the mother of the groom .
+
+" False " Greta Garbo, Marilyn Monroe and Marlene Dietrich.
+" 6 [' Gret', 'a', ' Gar', 'bo', ',', ' Marilyn', ' Monroe']
+220 51 Name of mother of x -1 Name of mother of Marilyn Monroe Gladys Monroe Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' groom' '.' '\n' '\n']" ", the actress , and the mother of the bride , and the mother of the groom .
+
+" False Hollywood in general and Marilyn Monroe in particular. 5 [' Hollywood', ' in', ' general', ' and', ' Marilyn', ' Monroe']
+221 51 Name of mother of x -1 Name of mother of Marilyn Monroe Gladys Monroe Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' groom' '.' '\n' '\n']" ", the actress , and the mother of the bride , and the mother of the groom .
+
+" False " ""in those Marilyn Monroe / Jayne Mansfield" 4 "[' ""', 'in', ' those', ' Marilyn', ' Monroe']"
+222 51 Name of mother of x -1 Name of mother of Marilyn Monroe Gladys Monroe Marilyn Monroe "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' bride'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' groom' '.' '\n' '\n']" ", the actress , and the mother of the bride , and the mother of the groom .
+
+" False and mannerisms of Marilyn Monroe for the performance, 5 [' and', ' manner', 'isms', ' of', ' Marilyn', ' Monroe']
+223 52 Name of mother of x -1 Name of mother of Wolfgang Amadeus Mozart Anna Maria Mozart Wolfgang Amadeus Mozart "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Sal' 'z' 'burg' ','
+ ' Austria' ',' ' in' ' 17' '56' '.' ' He' ' was' ' the']" , the composer , was born in Sal z burg , Austria , in 17 56 . He was the False relationship of Wolfgang Amadeus Mozart and Leopold 7 [' relationship', ' of', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+224 52 Name of mother of x -1 Name of mother of Wolfgang Amadeus Mozart Anna Maria Mozart Wolfgang Amadeus Mozart "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Sal' 'z' 'burg' ','
+ ' Austria' ',' ' in' ' 17' '56' '.' ' He' ' was' ' the']" , the composer , was born in Sal z burg , Austria , in 17 56 . He was the False particularly by Wolfgang Amadeus Mozart and Johann Wolfgang 7 [' particularly', ' by', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+225 52 Name of mother of x -1 Name of mother of Wolfgang Amadeus Mozart Anna Maria Mozart Wolfgang Amadeus Mozart "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Sal' 'z' 'burg' ','
+ ' Austria' ',' ' in' ' 17' '56' '.' ' He' ' was' ' the']" , the composer , was born in Sal z burg , Austria , in 17 56 . He was the False 1773, the young Wolfgang Amadeus Mozart and his father 10 [' 17', '73', ',', ' the', ' young', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+226 52 Name of mother of x -1 Name of mother of Wolfgang Amadeus Mozart Anna Maria Mozart Wolfgang Amadeus Mozart "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Sal' 'z' 'burg' ','
+ ' Austria' ',' ' in' ' 17' '56' '.' ' He' ' was' ' the']" , the composer , was born in Sal z burg , Austria , in 17 56 . He was the False the music of Wolfgang Amadeus Mozart (a favorite 8 [' the', ' music', ' of', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+227 52 Name of mother of x -1 Name of mother of Wolfgang Amadeus Mozart Anna Maria Mozart Wolfgang Amadeus Mozart "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Sal' 'z' 'burg' ','
+ ' Austria' ',' ' in' ' 17' '56' '.' ' He' ' was' ' the']" , the composer , was born in Sal z burg , Austria , in 17 56 . He was the False " Grand Orchestra – Wolfgang Amadeus Mozart (c. 1790)
+" 8 [' Grand', ' Orchestra', ' –', ' Wolfgang', ' Am', 'ade', 'us', ' Moz', 'art']
+228 53 Name of mother of x -1 Name of mother of Ludwig van Beethoven Maria Magdalena van Beethoven Ludwig van Beethoven "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '70' '.' '\n' '\n'
+ 'The' ' first' ' of' ' the' ' three' ' sons' ' of' ' the']" ", the composer , was born in 17 70 .
+
+ The first of the three sons of the" False Amadeus Mozart. Ludwig van Beethoven was a crucial 10 [' Am', 'ade', 'us', ' Moz', 'art', '.', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+229 53 Name of mother of x -1 Name of mother of Ludwig van Beethoven Maria Magdalena van Beethoven Ludwig van Beethoven "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '70' '.' '\n' '\n'
+ 'The' ' first' ' of' ' the' ' three' ' sons' ' of' ' the']" ", the composer , was born in 17 70 .
+
+ The first of the three sons of the" False " Hungarians"". Ludwig van Beethoven composed his" 7 "[' Hung', 'arians', '"".', ' Ludwig', ' van', ' Be', 'eth', 'oven']"
+230 53 Name of mother of x -1 Name of mother of Ludwig van Beethoven Maria Magdalena van Beethoven Ludwig van Beethoven "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '70' '.' '\n' '\n'
+ 'The' ' first' ' of' ' the' ' three' ' sons' ' of' ' the']" ", the composer , was born in 17 70 .
+
+ The first of the three sons of the" False plays a sonata by Ludwig van Beethoven on her violin 9 [' plays', ' a', ' son', 'ata', ' by', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+231 53 Name of mother of x -1 Name of mother of Ludwig van Beethoven Maria Magdalena van Beethoven Ludwig van Beethoven "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '70' '.' '\n' '\n'
+ 'The' ' first' ' of' ' the' ' three' ' sons' ' of' ' the']" ", the composer , was born in 17 70 .
+
+ The first of the three sons of the" False Tini plays a sonata by Ludwig van Beethoven on her violin in 11 [' T', 'ini', ' plays', ' a', ' son', 'ata', ' by', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+232 53 Name of mother of x -1 Name of mother of Ludwig van Beethoven Maria Magdalena van Beethoven Ludwig van Beethoven "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 17' '70' '.' '\n' '\n'
+ 'The' ' first' ' of' ' the' ' three' ' sons' ' of' ' the']" ", the composer , was born in 17 70 .
+
+ The first of the three sons of the" False early admirers included Ludwig van Beethoven and Johannes Brahms. 8 [' early', ' admire', 'rs', ' included', ' Ludwig', ' van', ' Be', 'eth', 'oven']
+233 54 Name of mother of x -1 Name of mother of Katy Perry Mary Perry Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Russell' ' Brand'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Russell Brand , are expecting their first child together .
+
+" False similar to the work of Katy Perry and Daft Punk. 6 [' similar', ' to', ' the', ' work', ' of', ' Katy', ' Perry']
+234 54 Name of mother of x -1 Name of mother of Katy Perry Mary Perry Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Russell' ' Brand'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Russell Brand , are expecting their first child together .
+
+" False " of six songs: Katy Perry's ""Firework"", Queen's" 5 [' of', ' six', ' songs', ':', ' Katy', ' Perry']
+235 54 Name of mother of x -1 Name of mother of Katy Perry Mary Perry Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Russell' ' Brand'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Russell Brand , are expecting their first child together .
+
+" False as Miley Cyrus, Katy Perry and Pink. In 6 [' as', ' M', 'iley', ' Cyrus', ',', ' Katy', ' Perry']
+236 54 Name of mother of x -1 Name of mother of Katy Perry Mary Perry Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Russell' ' Brand'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Russell Brand , are expecting their first child together .
+
+" False Taylor Swift, Pink, Katy Perry and Lady Gaga, 6 [' Taylor', ' Swift', ',', ' Pink', ',', ' Katy', ' Perry']
+237 54 Name of mother of x -1 Name of mother of Katy Perry Mary Perry Katy Perry "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Russell' ' Brand'
+ ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.' '\n'
+ '\n']" ", the singer , and her husband , Russell Brand , are expecting their first child together .
+
+" False singers like Katy Perry or One Direction. 3 [' singers', ' like', ' Katy', ' Perry']
+238 55 Name of mother of x -1 Name of mother of Johann Sebastian Bach Maria Elisabeth Lämmerhirt Johann Sebastian Bach "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Johann' ' Sebastian' ' Bach' ',' ' the' ' composer' ',' ' and']" ", the composer , and the
+
+ Name of mother of Johann Sebastian Bach , the composer , and" False 4 ['J', 'oh', 'ann', ' Sebastian', ' Bach']
+239 55 Name of mother of x -1 Name of mother of Johann Sebastian Bach Maria Elisabeth Lämmerhirt Johann Sebastian Bach "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Johann' ' Sebastian' ' Bach' ',' ' the' ' composer' ',' ' and']" ", the composer , and the
+
+ Name of mother of Johann Sebastian Bach , the composer , and" False especially the work by Johann Sebastian Bach which also structures 6 [' especially', ' the', ' work', ' by', ' Johann', ' Sebastian', ' Bach']
+240 55 Name of mother of x -1 Name of mother of Johann Sebastian Bach Maria Elisabeth Lämmerhirt Johann Sebastian Bach "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Johann' ' Sebastian' ' Bach' ',' ' the' ' composer' ',' ' and']" ", the composer , and the
+
+ Name of mother of Johann Sebastian Bach , the composer , and" False " for 3 hands – Johann Sebastian Bach (1809)
+" 6 [' for', ' 3', ' hands', ' –', ' Johann', ' Sebastian', ' Bach']
+241 55 Name of mother of x -1 Name of mother of Johann Sebastian Bach Maria Elisabeth Lämmerhirt Johann Sebastian Bach "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Johann' ' Sebastian' ' Bach' ',' ' the' ' composer' ',' ' and']" ", the composer , and the
+
+ Name of mother of Johann Sebastian Bach , the composer , and" False " ""Sleepers Wake"" by Johann Sebastian Bach and popular Christmas" 8 "[' ""', 'Sleep', 'ers', ' Wake', '""', ' by', ' Johann', ' Sebastian', ' Bach']"
+242 55 Name of mother of x -1 Name of mother of Johann Sebastian Bach Maria Elisabeth Lämmerhirt Johann Sebastian Bach "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Johann' ' Sebastian' ' Bach' ',' ' the' ' composer' ',' ' and']" ", the composer , and the
+
+ Name of mother of Johann Sebastian Bach , the composer , and" False 71, is a cantata by Johann Sebastian Bach written in Mühlhausen 9 [' 71', ',', ' is', ' a', ' cant', 'ata', ' by', ' Johann', ' Sebastian', ' Bach']
+243 56 Name of mother of x -1 Name of mother of Charles Darwin Susannah Darwin Charles Darwin "[',' ' the' ' father' ' of' ' modern' ' biology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' medicine' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern biology , and the father of modern medicine .
+
+ The first thing that" False School to Charles Darwin Middle School 3 [' School', ' to', ' Charles', ' Darwin']
+244 56 Name of mother of x -1 Name of mother of Charles Darwin Susannah Darwin Charles Darwin "[',' ' the' ' father' ' of' ' modern' ' biology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' medicine' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern biology , and the father of modern medicine .
+
+ The first thing that" False that carried Charles Darwin on his voyage 3 [' that', ' carried', ' Charles', ' Darwin']
+245 56 Name of mother of x -1 Name of mother of Charles Darwin Susannah Darwin Charles Darwin "[',' ' the' ' father' ' of' ' modern' ' biology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' medicine' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern biology , and the father of modern medicine .
+
+ The first thing that" False 1 ['Charles', ' Darwin']
+246 56 Name of mother of x -1 Name of mother of Charles Darwin Susannah Darwin Charles Darwin "[',' ' the' ' father' ' of' ' modern' ' biology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' medicine' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern biology , and the father of modern medicine .
+
+ The first thing that" False 1 ['Charles', ' Darwin']
+247 56 Name of mother of x -1 Name of mother of Charles Darwin Susannah Darwin Charles Darwin "[',' ' the' ' father' ' of' ' modern' ' biology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' medicine' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of modern biology , and the father of modern medicine .
+
+ The first thing that" False performances around the Charles Darwin bi-centennial 4 [' performances', ' around', ' the', ' Charles', ' Darwin']
+248 57 Name of mother of x -1 Name of mother of Paul Cézanne Anne Elisabeth Aubert Paul Cézanne "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' father' ',' ' and' ' the' ' painter' ""'s""
+ ' sister']" , the painter , and the painter 's mother , the painter 's father , and the painter 's sister False Paul Gauguin and Paul Cézanne as the models to 8 [' Paul', ' Gau', 'gu', 'in', ' and', ' Paul', ' C', 'é', 'zanne']
+249 57 Name of mother of x -1 Name of mother of Paul Cézanne Anne Elisabeth Aubert Paul Cézanne "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' father' ',' ' and' ' the' ' painter' ""'s""
+ ' sister']" , the painter , and the painter 's mother , the painter 's father , and the painter 's sister False introduced him to Paul Cézanne (1839 – 1906) 6 [' introduced', ' him', ' to', ' Paul', ' C', 'é', 'zanne']
+250 57 Name of mother of x -1 Name of mother of Paul Cézanne Anne Elisabeth Aubert Paul Cézanne "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' father' ',' ' and' ' the' ' painter' ""'s""
+ ' sister']" , the painter , and the painter 's mother , the painter 's father , and the painter 's sister False introduced him to Paul Cézanne (1839 – 1906) 6 [' introduced', ' him', ' to', ' Paul', ' C', 'é', 'zanne']
+251 57 Name of mother of x -1 Name of mother of Paul Cézanne Anne Elisabeth Aubert Paul Cézanne "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' father' ',' ' and' ' the' ' painter' ""'s""
+ ' sister']" , the painter , and the painter 's mother , the painter 's father , and the painter 's sister False Gauguin and Paul Cézanne as the models 7 [' Gau', 'gu', 'in', ' and', ' Paul', ' C', 'é', 'zanne']
+252 57 Name of mother of x -1 Name of mother of Paul Cézanne Anne Elisabeth Aubert Paul Cézanne "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' father' ',' ' and' ' the' ' painter' ""'s""
+ ' sister']" , the painter , and the painter 's mother , the painter 's father , and the painter 's sister False introduced him to Paul Cézanne (1839 – 1906) and 6 [' introduced', ' him', ' to', ' Paul', ' C', 'é', 'zanne']
+253 59 Name of mother of x -1 Name of mother of Charles Dickens Elizabeth Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'The' ' Pick' 'wick' ' Papers' '_' ','
+ ' and' ' _' 'The' ' Old' ' Curiosity' ' Shop' '_' ',' ' and']" , the author of _ The Pick wick Papers _ , and _ The Old Curiosity Shop _ , and False reputation faded; Charles Dickens considered 4 [' reputation', ' faded', ';', ' Charles', ' Dickens']
+254 59 Name of mother of x -1 Name of mother of Charles Dickens Elizabeth Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'The' ' Pick' 'wick' ' Papers' '_' ','
+ ' and' ' _' 'The' ' Old' ' Curiosity' ' Shop' '_' ',' ' and']" , the author of _ The Pick wick Papers _ , and _ The Old Curiosity Shop _ , and False character in the Charles Dickens novel Barnaby Rudge 4 [' character', ' in', ' the', ' Charles', ' Dickens']
+255 59 Name of mother of x -1 Name of mother of Charles Dickens Elizabeth Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'The' ' Pick' 'wick' ' Papers' '_' ','
+ ' and' ' _' 'The' ' Old' ' Curiosity' ' Shop' '_' ',' ' and']" , the author of _ The Pick wick Papers _ , and _ The Old Curiosity Shop _ , and False following a visit to Charles Dickens in England, and Scharff 5 [' following', ' a', ' visit', ' to', ' Charles', ' Dickens']
+256 59 Name of mother of x -1 Name of mother of Charles Dickens Elizabeth Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'The' ' Pick' 'wick' ' Papers' '_' ','
+ ' and' ' _' 'The' ' Old' ' Curiosity' ' Shop' '_' ',' ' and']" , the author of _ The Pick wick Papers _ , and _ The Old Curiosity Shop _ , and False Curiosity Shop by Charles Dickens that was limited 4 [' Curiosity', ' Shop', ' by', ' Charles', ' Dickens']
+257 59 Name of mother of x -1 Name of mother of Charles Dickens Elizabeth Dickens Charles Dickens "[',' ' the' ' author' ' of' ' _' 'The' ' Pick' 'wick' ' Papers' '_' ','
+ ' and' ' _' 'The' ' Old' ' Curiosity' ' Shop' '_' ',' ' and']" , the author of _ The Pick wick Papers _ , and _ The Old Curiosity Shop _ , and False writers, to Charles Dickens and Mark Twain. His 4 [' writers', ',', ' to', ' Charles', ' Dickens']
+258 60 Name of mother of x -1 Name of mother of Joseph Stalin Keke Geladze Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the' ' father'
+ ' of']" , the father of the Soviet Union , and the father of the Soviet Union , and the father of False 1 ['Joseph', ' Stalin']
+259 60 Name of mother of x -1 Name of mother of Joseph Stalin Keke Geladze Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the' ' father'
+ ' of']" , the father of the Soviet Union , and the father of the Soviet Union , and the father of False 1 ['Joseph', ' Stalin']
+260 60 Name of mother of x -1 Name of mother of Joseph Stalin Keke Geladze Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the' ' father'
+ ' of']" , the father of the Soviet Union , and the father of the Soviet Union , and the father of False Overlord, and met Joseph Stalin personally for 6 [' Over', 'lord', ',', ' and', ' met', ' Joseph', ' Stalin']
+261 60 Name of mother of x -1 Name of mother of Joseph Stalin Keke Geladze Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the' ' father'
+ ' of']" , the father of the Soviet Union , and the father of the Soviet Union , and the father of False arrived, he saw that Joseph Stalin and the Politburo 6 [' arrived', ',', ' he', ' saw', ' that', ' Joseph', ' Stalin']
+262 60 Name of mother of x -1 Name of mother of Joseph Stalin Keke Geladze Joseph Stalin "[',' ' the' ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Soviet' ' Union' ',' ' and' ' the' ' father'
+ ' of']" , the father of the Soviet Union , and the father of the Soviet Union , and the father of False Soviet dictator Joseph Stalin (Stalin did not have 3 [' Soviet', ' dictator', ' Joseph', ' Stalin']
+263 61 Name of mother of x -1 Name of mother of Voltaire Marguerite d’Aumard Voltaire "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' works' ' of' ' Volt' 'aire' '.' '\n' '\n']" ", the
+
+ The following is a list of the most important works of Volt aire .
+
+" False India to Spain. Voltaire had both a positive 5 [' India', ' to', ' Spain', '.', ' Volt', 'aire']
+264 61 Name of mother of x -1 Name of mother of Voltaire Marguerite d’Aumard Voltaire "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' works' ' of' ' Volt' 'aire' '.' '\n' '\n']" ", the
+
+ The following is a list of the most important works of Volt aire .
+
+" False " torpedo hits while Voltaire survived her two torpedoes.
+" 4 [' torpedo', ' hits', ' while', ' Volt', 'aire']
+265 61 Name of mother of x -1 Name of mother of Voltaire Marguerite d’Aumard Voltaire "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' works' ' of' ' Volt' 'aire' '.' '\n' '\n']" ", the
+
+ The following is a list of the most important works of Volt aire .
+
+" False chapter by chapter, by Voltaire to the Duke and Duchess 6 [' chapter', ' by', ' chapter', ',', ' by', ' Volt', 'aire']
+266 61 Name of mother of x -1 Name of mother of Voltaire Marguerite d’Aumard Voltaire "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' works' ' of' ' Volt' 'aire' '.' '\n' '\n']" ", the
+
+ The following is a list of the most important works of Volt aire .
+
+" False 2 ['V', 'olt', 'aire']
+267 61 Name of mother of x -1 Name of mother of Voltaire Marguerite d’Aumard Voltaire "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' works' ' of' ' Volt' 'aire' '.' '\n' '\n']" ", the
+
+ The following is a list of the most important works of Volt aire .
+
+" False French newspaper Le Voltaire published it 4 [' French', ' newspaper', ' Le', ' Volt', 'aire']
+268 62 Name of mother of x -1 Name of mother of Plato Perictione Plato "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False Political Philosophy from Plato to Mao, where he argues 3 [' Political', ' Philosophy', ' from', ' Plato']
+269 62 Name of mother of x -1 Name of mother of Plato Perictione Plato "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False Nevertheless, even Plato did not manage 3 [' Nevertheless', ',', ' even', ' Plato']
+270 62 Name of mother of x -1 Name of mother of Plato Perictione Plato "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False Greek metaphysics. Plato criticizes common 4 [' Greek', ' metaph', 'ysics', '.', ' Plato']
+271 62 Name of mother of x -1 Name of mother of Plato Perictione Plato "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False publication of Preface to Plato, Havelock accepted 5 [' publication', ' of', ' Pre', 'face', ' to', ' Plato']
+272 62 Name of mother of x -1 Name of mother of Plato Perictione Plato "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False anticipating modern concepts. Plato (c 427 – c 4 [' anticipating', ' modern', ' concepts', '.', ' Plato']
+273 63 Name of mother of x -1 Name of mother of Salvador Dalí Felipa Domènech i Ferrés Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False in modern culture. Salvador Dalí depicted them 6 [' in', ' modern', ' culture', '.', ' Salvador', ' Dal', 'í']
+274 63 Name of mother of x -1 Name of mother of Salvador Dalí Felipa Domènech i Ferrés Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False painters like Salvador Dalí and Francis Bacon 5 [' pain', 'ters', ' like', ' Salvador', ' Dal', 'í']
+275 63 Name of mother of x -1 Name of mother of Salvador Dalí Felipa Domènech i Ferrés Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False imagining what Salvador Dalí or Andy Warhol 4 [' imagining', ' what', ' Salvador', ' Dal', 'í']
+276 63 Name of mother of x -1 Name of mother of Salvador Dalí Felipa Domènech i Ferrés Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False with Joan Miró and Salvador Dalí in particular. 7 [' with', ' Joan', ' Mir', 'ó', ' and', ' Salvador', ' Dal', 'í']
+277 63 Name of mother of x -1 Name of mother of Salvador Dalí Felipa Domènech i Ferrés Salvador Dalí "[',' ' the' ' painter' ',' ' and' ' the' ' painter' ""'s"" ' mother' ','
+ ' the' ' painter' ""'s"" ' mother' ',' ' and' ' the' ' painter' ""'s""
+ ' mother']" , the painter , and the painter 's mother , the painter 's mother , and the painter 's mother False modern culture. Salvador Dalí depicted them 5 [' modern', ' culture', '.', ' Salvador', ' Dal', 'í']
+278 64 Name of mother of x -1 Name of mother of John Lennon Julia Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False 1 ['John', ' Lennon']
+279 64 Name of mother of x -1 Name of mother of John Lennon Julia Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False references to them and John Lennon in the script. 5 [' references', ' to', ' them', ' and', ' John', ' Lennon']
+280 64 Name of mother of x -1 Name of mother of John Lennon Julia Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False wife, Yoko Ono, John Lennon spent New Year 8 [' wife', ',', ' Y', 'oko', ' On', 'o', ',', ' John', ' Lennon']
+281 64 Name of mother of x -1 Name of mother of John Lennon Julia Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False rhythm 2 [' rhyth', 'John', ' Lennon']
+282 64 Name of mother of x -1 Name of mother of John Lennon Julia Lennon John Lennon "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' song']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a song" False Ports and Liverpool John Lennon Airport and improve 4 [' Ports', ' and', ' Liverpool', ' John', ' Lennon']
+283 65 Name of mother of x -1 Name of mother of Dante Alighieri Bella degli Abati Dante Alighieri "[',' ' the' ' great' ' Italian' ' poet' ',' ' was' ' born' ' in'
+ ' Florence' ',' ' and' ' was' '\n' '\n' 'Page' ' 5' '\n' '\n' 'the']" ", the great Italian poet , was born in Florence , and was
+
+ Page 5
+
+ the" False to the epic poem by Dante Alighieri of the same 8 [' to', ' the', ' epic', ' poem', ' by', ' Dante', ' Al', 'igh', 'ieri']
+284 65 Name of mother of x -1 Name of mother of Dante Alighieri Bella degli Abati Dante Alighieri "[',' ' the' ' great' ' Italian' ' poet' ',' ' was' ' born' ' in'
+ ' Florence' ',' ' and' ' was' '\n' '\n' 'Page' ' 5' '\n' '\n' 'the']" ", the great Italian poet , was born in Florence , and was
+
+ Page 5
+
+ the" False " fancy."" Italian poet Dante Alighieri (1265 – 1321)," 7 "[' fancy', '.""', ' Italian', ' poet', ' Dante', ' Al', 'igh', 'ieri']"
+285 65 Name of mother of x -1 Name of mother of Dante Alighieri Bella degli Abati Dante Alighieri "[',' ' the' ' great' ' Italian' ' poet' ',' ' was' ' born' ' in'
+ ' Florence' ',' ' and' ' was' '\n' '\n' 'Page' ' 5' '\n' '\n' 'the']" ", the great Italian poet , was born in Florence , and was
+
+ Page 5
+
+ the" False Lancelot Andrewes, Dante Alighieri and St. John of the 8 [' Lance', 'lot', ' Andrew', 'es', ',', ' Dante', ' Al', 'igh', 'ieri']
+286 65 Name of mother of x -1 Name of mother of Dante Alighieri Bella degli Abati Dante Alighieri "[',' ' the' ' great' ' Italian' ' poet' ',' ' was' ' born' ' in'
+ ' Florence' ',' ' and' ' was' '\n' '\n' 'Page' ' 5' '\n' '\n' 'the']" ", the great Italian poet , was born in Florence , and was
+
+ Page 5
+
+ the" False influence of Dante Alighieri on the development 5 [' influence', ' of', ' Dante', ' Al', 'igh', 'ieri']
+287 65 Name of mother of x -1 Name of mother of Dante Alighieri Bella degli Abati Dante Alighieri "[',' ' the' ' great' ' Italian' ' poet' ',' ' was' ' born' ' in'
+ ' Florence' ',' ' and' ' was' '\n' '\n' 'Page' ' 5' '\n' '\n' 'the']" ", the great Italian poet , was born in Florence , and was
+
+ Page 5
+
+ the" False 4 ['D', 'ante', ' Al', 'igh', 'ieri']
+288 66 Name of mother of x -1 Name of mother of Andy Warhol Julia Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' the' ' artist' ""'s"" ' mother' ',' ' and' ' the' ' artist' ""'s""
+ ' mother']" , the artist , and the artist 's mother , the artist 's mother , and the artist 's mother False to be sure, Andy Warhol stirred in his 6 [' to', ' be', ' sure', ',', ' Andy', ' War', 'hol']
+289 66 Name of mother of x -1 Name of mother of Andy Warhol Julia Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' the' ' artist' ""'s"" ' mother' ',' ' and' ' the' ' artist' ""'s""
+ ' mother']" , the artist , and the artist 's mother , the artist 's mother , and the artist 's mother False an observation by Andy Warhol that everyone 5 [' an', ' observation', ' by', ' Andy', ' War', 'hol']
+290 66 Name of mother of x -1 Name of mother of Andy Warhol Julia Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' the' ' artist' ""'s"" ' mother' ',' ' and' ' the' ' artist' ""'s""
+ ' mother']" , the artist , and the artist 's mother , the artist 's mother , and the artist 's mother False " attempting to murder artist Andy Warhol in the late 1960s.
+" 6 [' attempting', ' to', ' murder', ' artist', ' Andy', ' War', 'hol']
+291 66 Name of mother of x -1 Name of mother of Andy Warhol Julia Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' the' ' artist' ""'s"" ' mother' ',' ' and' ' the' ' artist' ""'s""
+ ' mother']" , the artist , and the artist 's mother , the artist 's mother , and the artist 's mother False characters in 1958. Andy Warhol produced his earliest 6 [' characters', ' in', ' 1958', '.', ' Andy', ' War', 'hol']
+292 66 Name of mother of x -1 Name of mother of Andy Warhol Julia Warhola Andy Warhol "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' the' ' artist' ""'s"" ' mother' ',' ' and' ' the' ' artist' ""'s""
+ ' mother']" , the artist , and the artist 's mother , the artist 's mother , and the artist 's mother False characters in 1958. Andy Warhol produced his earliest 6 [' characters', ' in', ' 1958', '.', ' Andy', ' War', 'hol']
+293 67 Name of mother of x -1 Name of mother of L. Frank Baum Cynthia Ann Stanton L. Frank Baum "[',' ' the' ' author' ' of' ' the' ' Wizard' ' of' ' Oz' ',' ' and' ' the'
+ ' author' ' of' ' the' ' Oz' ' books' ',' ' was' ' a' ' member']" , the author of the Wizard of Oz , and the author of the Oz books , was a member False by authors such as L. Frank Baum and Lloyd Alexander 8 [' by', ' authors', ' such', ' as', ' L', '.', ' Frank', ' Ba', 'um']
+294 67 Name of mother of x -1 Name of mother of L. Frank Baum Cynthia Ann Stanton L. Frank Baum "[',' ' the' ' author' ' of' ' the' ' Wizard' ' of' ' Oz' ',' ' and' ' the'
+ ' author' ' of' ' the' ' Oz' ' books' ',' ' was' ' a' ' member']" , the author of the Wizard of Oz , and the author of the Oz books , was a member False authors such as L. Frank Baum and Lloyd Alexander 7 [' authors', ' such', ' as', ' L', '.', ' Frank', ' Ba', 'um']
+295 67 Name of mother of x -1 Name of mother of L. Frank Baum Cynthia Ann Stanton L. Frank Baum "[',' ' the' ' author' ' of' ' the' ' Wizard' ' of' ' Oz' ',' ' and' ' the'
+ ' author' ' of' ' the' ' Oz' ' books' ',' ' was' ' a' ' member']" , the author of the Wizard of Oz , and the author of the Oz books , was a member False authors such as L. Frank Baum and Lloyd Alexander 7 [' authors', ' such', ' as', ' L', '.', ' Frank', ' Ba', 'um']
+296 68 Name of mother of x -1 Name of mother of Shania Twain Sharron Morrison Shania Twain "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Why Not? with Shania Twain progressed. 6 [' Why', ' Not', '?', ' with', ' Sh', 'ania', ' Twain']
+297 68 Name of mother of x -1 Name of mother of Shania Twain Sharron Morrison Shania Twain "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False track featuring Shania Twain on backing vocals 4 [' track', ' featuring', ' Sh', 'ania', ' Twain']
+298 68 Name of mother of x -1 Name of mother of Shania Twain Sharron Morrison Shania Twain "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False track featuring Shania Twain on backing vocals 4 [' track', ' featuring', ' Sh', 'ania', ' Twain']
+299 68 Name of mother of x -1 Name of mother of Shania Twain Sharron Morrison Shania Twain "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Boston, a Shania Twain song can be 5 [' Boston', ',', ' a', ' Sh', 'ania', ' Twain']
+300 68 Name of mother of x -1 Name of mother of Shania Twain Sharron Morrison Shania Twain "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False track featuring Shania Twain on backing vocals 4 [' track', ' featuring', ' Sh', 'ania', ' Twain']
+301 69 Name of mother of x -1 Name of mother of Whitney Houston Cissy Houston Whitney Houston "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' two' ' children'
+ ',' ' Bob' 'bi' ' Krist' 'ina' ' Brown' ',' ' and' ' the' ' mother']" , the singer , and the mother of two children , Bob bi Krist ina Brown , and the mother False was invited by Whitney Houston to record the 4 [' was', ' invited', ' by', ' Whitney', ' Houston']
+302 69 Name of mother of x -1 Name of mother of Whitney Houston Cissy Houston Whitney Houston "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' two' ' children'
+ ',' ' Bob' 'bi' ' Krist' 'ina' ' Brown' ',' ' and' ' the' ' mother']" , the singer , and the mother of two children , Bob bi Krist ina Brown , and the mother False attention when singer Whitney Houston visited them in Dimona. 4 [' attention', ' when', ' singer', ' Whitney', ' Houston']
+303 69 Name of mother of x -1 Name of mother of Whitney Houston Cissy Houston Whitney Houston "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' two' ' children'
+ ',' ' Bob' 'bi' ' Krist' 'ina' ' Brown' ',' ' and' ' the' ' mother']" , the singer , and the mother of two children , Bob bi Krist ina Brown , and the mother False adapted by vocalists Whitney Houston and Mariah Carey 5 [' adapted', ' by', ' vocal', 'ists', ' Whitney', ' Houston']
+304 69 Name of mother of x -1 Name of mother of Whitney Houston Cissy Houston Whitney Houston "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' two' ' children'
+ ',' ' Bob' 'bi' ' Krist' 'ina' ' Brown' ',' ' and' ' the' ' mother']" , the singer , and the mother of two children , Bob bi Krist ina Brown , and the mother False version performed by Whitney Houston and Mariah Carey, 4 [' version', ' performed', ' by', ' Whitney', ' Houston']
+305 69 Name of mother of x -1 Name of mother of Whitney Houston Cissy Houston Whitney Houston "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' two' ' children'
+ ',' ' Bob' 'bi' ' Krist' 'ina' ' Brown' ',' ' and' ' the' ' mother']" , the singer , and the mother of two children , Bob bi Krist ina Brown , and the mother False " Mariah Carey and Whitney Houston duet ""When You Believe""," 5 [' Mar', 'iah', ' Carey', ' and', ' Whitney', ' Houston']
+306 70 Name of mother of x -1 Name of mother of Edgar Allan Poe "Elizabeth ""Eliza"" Poe" Edgar Allan Poe "[',' ' the' ' author' ' of' ' ""' 'The' ' Raven' '""' ' and' ' ""' 'The'
+ ' Tell' '-' 'T' 'ale' ' Heart' ',""' ' and' ' the' ' author']" ", the author of "" The Raven "" and "" The Tell - T ale Heart ,"" and the author" False muteness, including Edgar Allan Poe and William Shakespeare. 6 [' mut', 'eness', ',', ' including', ' Edgar', ' Allan', ' Poe']
+307 70 Name of mother of x -1 Name of mother of Edgar Allan Poe "Elizabeth ""Eliza"" Poe" Edgar Allan Poe "[',' ' the' ' author' ' of' ' ""' 'The' ' Raven' '""' ' and' ' ""' 'The'
+ ' Tell' '-' 'T' 'ale' ' Heart' ',""' ' and' ' the' ' author']" ", the author of "" The Raven "" and "" The Tell - T ale Heart ,"" and the author" False now preserved as the Edgar Allan Poe National Historic 6 [' now', ' preserved', ' as', ' the', ' Edgar', ' Allan', ' Poe']
+308 70 Name of mother of x -1 Name of mother of Edgar Allan Poe "Elizabeth ""Eliza"" Poe" Edgar Allan Poe "[',' ' the' ' author' ' of' ' ""' 'The' ' Raven' '""' ' and' ' ""' 'The'
+ ' Tell' '-' 'T' 'ale' ' Heart' ',""' ' and' ' the' ' author']" ", the author of "" The Raven "" and "" The Tell - T ale Heart ,"" and the author" False imitation of Roger Corman's Edgar Allan Poe films, rather 8 "[' imitation', ' of', ' Roger', ' C', 'orman', ""'s"", ' Edgar', ' Allan', ' Poe']"
+309 70 Name of mother of x -1 Name of mother of Edgar Allan Poe "Elizabeth ""Eliza"" Poe" Edgar Allan Poe "[',' ' the' ' author' ' of' ' ""' 'The' ' Raven' '""' ' and' ' ""' 'The'
+ ' Tell' '-' 'T' 'ale' ' Heart' ',""' ' and' ' the' ' author']" ", the author of "" The Raven "" and "" The Tell - T ale Heart ,"" and the author" False fiction magazines. Edgar Allan Poe is sometimes seen 5 [' fiction', ' magazines', '.', ' Edgar', ' Allan', ' Poe']
+310 70 Name of mother of x -1 Name of mother of Edgar Allan Poe "Elizabeth ""Eliza"" Poe" Edgar Allan Poe "[',' ' the' ' author' ' of' ' ""' 'The' ' Raven' '""' ' and' ' ""' 'The'
+ ' Tell' '-' 'T' 'ale' ' Heart' ',""' ' and' ' the' ' author']" ", the author of "" The Raven "" and "" The Tell - T ale Heart ,"" and the author" False 3 ['Ed', 'gar', ' Allan', ' Poe']
+311 71 Name of mother of x -1 Name of mother of Lewis Carroll Frances Jane Lutwidge Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the'
+ ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the' ' author'
+ ' of']" , the author of Alice in Wonderland , and the author of Alice in Wonderland , and the author of False " lovers of Lewis Carroll to frenzy"". Peter" 3 [' lovers', ' of', ' Lewis', ' Carroll']
+312 71 Name of mother of x -1 Name of mother of Lewis Carroll Frances Jane Lutwidge Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the'
+ ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the' ' author'
+ ' of']" , the author of Alice in Wonderland , and the author of Alice in Wonderland , and the author of False character's creator Lewis Carroll reflecting the 4 "[' character', ""'s"", ' creator', ' Lewis', ' Carroll']"
+313 71 Name of mother of x -1 Name of mother of Lewis Carroll Frances Jane Lutwidge Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the'
+ ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the' ' author'
+ ' of']" , the author of Alice in Wonderland , and the author of Alice in Wonderland , and the author of False " album inspired by Lewis Carroll called ""The" 4 [' album', ' inspired', ' by', ' Lewis', ' Carroll']
+314 71 Name of mother of x -1 Name of mother of Lewis Carroll Frances Jane Lutwidge Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the'
+ ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the' ' author'
+ ' of']" , the author of Alice in Wonderland , and the author of Alice in Wonderland , and the author of False Looking-Glass by Lewis Carroll to The Royal 5 [' Looking', '-', 'Glass', ' by', ' Lewis', ' Carroll']
+315 71 Name of mother of x -1 Name of mother of Lewis Carroll Frances Jane Lutwidge Lewis Carroll "[',' ' the' ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the'
+ ' author' ' of' ' Alice' ' in' ' Wonderland' ',' ' and' ' the' ' author'
+ ' of']" , the author of Alice in Wonderland , and the author of Alice in Wonderland , and the author of False Looking-Glass by Lewis Carroll to The Royal Game 5 [' Looking', '-', 'Glass', ' by', ' Lewis', ' Carroll']
+316 72 Name of mother of x -1 Name of mother of Walt Disney Flora Call Disney Walt Disney "[""'s"" ' ""' 'F' 'rozen' '""' ' and' ' ""' 'F' 'rozen' '""' ' is' ' a' ' 2013'
+ ' American' ' 3' 'D' ' computer' '-' 'anim' 'ated']" "'s "" F rozen "" and "" F rozen "" is a 2013 American 3 D computer - anim ated" False fantasy film produced by Walt Disney Feature Animation 5 [' fantasy', ' film', ' produced', ' by', ' Walt', ' Disney']
+317 72 Name of mother of x -1 Name of mother of Walt Disney Flora Call Disney Walt Disney "[""'s"" ' ""' 'F' 'rozen' '""' ' and' ' ""' 'F' 'rozen' '""' ' is' ' a' ' 2013'
+ ' American' ' 3' 'D' ' computer' '-' 'anim' 'ated']" "'s "" F rozen "" and "" F rozen "" is a 2013 American 3 D computer - anim ated" False area of damage. The Walt Disney Company donated $ 6 [' area', ' of', ' damage', '.', ' The', ' Walt', ' Disney']
+318 72 Name of mother of x -1 Name of mother of Walt Disney Flora Call Disney Walt Disney "[""'s"" ' ""' 'F' 'rozen' '""' ' and' ' ""' 'F' 'rozen' '""' ' is' ' a' ' 2013'
+ ' American' ' 3' 'D' ' computer' '-' 'anim' 'ated']" "'s "" F rozen "" and "" F rozen "" is a 2013 American 3 D computer - anim ated" False award. However, Walt Disney Records withdrew 5 [' award', '.', ' However', ',', ' Walt', ' Disney']
+319 72 Name of mother of x -1 Name of mother of Walt Disney Flora Call Disney Walt Disney "[""'s"" ' ""' 'F' 'rozen' '""' ' and' ' ""' 'F' 'rozen' '""' ' is' ' a' ' 2013'
+ ' American' ' 3' 'D' ' computer' '-' 'anim' 'ated']" "'s "" F rozen "" and "" F rozen "" is a 2013 American 3 D computer - anim ated" False 2 ['W', 'alt', ' Disney']
+320 72 Name of mother of x -1 Name of mother of Walt Disney Flora Call Disney Walt Disney "[""'s"" ' ""' 'F' 'rozen' '""' ' and' ' ""' 'F' 'rozen' '""' ' is' ' a' ' 2013'
+ ' American' ' 3' 'D' ' computer' '-' 'anim' 'ated']" "'s "" F rozen "" and "" F rozen "" is a 2013 American 3 D computer - anim ated" False " You"" at the Walt Disney World Christmas" 5 "[' You', '""', ' at', ' the', ' Walt', ' Disney']"
+321 73 Name of mother of x -1 Name of mother of Édouard Manet Eugénie-Désirée Fournier Édouard Manet "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False influenced by Édouard Manet and the Impressionist 7 [' influenced', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+322 73 Name of mother of x -1 Name of mother of Édouard Manet Eugénie-Désirée Fournier Édouard Manet "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False lithographs by Édouard Manet and translation 8 [' lith', 'ographs', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+323 73 Name of mother of x -1 Name of mother of Édouard Manet Eugénie-Désirée Fournier Édouard Manet "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False with lithographs by Édouard Manet and translation by 9 [' with', ' lith', 'ographs', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+324 73 Name of mother of x -1 Name of mother of Édouard Manet Eugénie-Désirée Fournier Édouard Manet "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False lithographs by Édouard Manet and translation 8 [' lith', 'ographs', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+325 73 Name of mother of x -1 Name of mother of Édouard Manet Eugénie-Désirée Fournier Édouard Manet "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False lithographs by Édouard Manet and translation 8 [' lith', 'ographs', ' by', ' É', 'd', 'ou', 'ard', ' Man', 'et']
+326 74 Name of mother of x -1 Name of mother of Angelina Jolie Marcheline Bertrand Angelina Jolie "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False but instead starred Angelina Jolie and Antonio 6 [' but', ' instead', ' starred', ' Angel', 'ina', ' Jol', 'ie']
+327 74 Name of mother of x -1 Name of mother of Angelina Jolie Marcheline Bertrand Angelina Jolie "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False performances from Angelina Jolie (playing an expert 5 [' performances', ' from', ' Angel', 'ina', ' Jol', 'ie']
+328 74 Name of mother of x -1 Name of mother of Angelina Jolie Marcheline Bertrand Angelina Jolie "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Major, pictures of Angelina Jolie were treated on 7 [' Major', ',', ' pictures', ' of', ' Angel', 'ina', ' Jol', 'ie']
+329 74 Name of mother of x -1 Name of mother of Angelina Jolie Marcheline Bertrand Angelina Jolie "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False has cited Angelina Jolie and Malala Yousafzai 5 [' has', ' cited', ' Angel', 'ina', ' Jol', 'ie']
+330 74 Name of mother of x -1 Name of mother of Angelina Jolie Marcheline Bertrand Angelina Jolie "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False " Angelina Jolie =
+" 3 [' Angel', 'ina', ' Jol', 'ie']
+331 75 Name of mother of x -1 Name of mother of Jane Fonda Frances Ford Seymour Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False came opposite Jane Fonda in the 1977 film 4 [' came', ' opposite', ' Jane', ' F', 'onda']
+332 75 Name of mother of x -1 Name of mother of Jane Fonda Frances Ford Seymour Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False 2 ['Jane', ' F', 'onda']
+333 75 Name of mother of x -1 Name of mother of Jane Fonda Frances Ford Seymour Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False difficulties of old age. Jane Fonda had purchased 7 [' difficulties', ' of', ' old', ' age', '.', ' Jane', ' F', 'onda']
+334 75 Name of mother of x -1 Name of mother of Jane Fonda Frances Ford Seymour Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Pitch co-star) and Jane Fonda as her inspirations 8 [' Pitch', ' co', '-', 'star', ')', ' and', ' Jane', ' F', 'onda']
+335 75 Name of mother of x -1 Name of mother of Jane Fonda Frances Ford Seymour Jane Fonda "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Workout became the Jane Fonda Workout, which 6 [' Work', 'out', ' became', ' the', ' Jane', ' F', 'onda']
+336 76 Name of mother of x -1 Name of mother of Bob Dylan Beatrice Stone Bob Dylan "[',' ' the' ' father' ' of' ' the' ' modern' ' rock' "" '"" 'n' ""'"" ' roll'
+ ',' ' and' ' the' ' father' ' of' ' the' ' blues' ',' ' and']" , the father of the modern rock ' n ' roll , and the father of the blues , and False by artists like Bob Dylan and Emmylou Harris. 4 [' by', ' artists', ' like', ' Bob', ' Dylan']
+337 76 Name of mother of x -1 Name of mother of Bob Dylan Beatrice Stone Bob Dylan "[',' ' the' ' father' ' of' ' the' ' modern' ' rock' "" '"" 'n' ""'"" ' roll'
+ ',' ' and' ' the' ' father' ' of' ' the' ' blues' ',' ' and']" , the father of the modern rock ' n ' roll , and the father of the blues , and False a song by Bob Dylan that had been 4 [' a', ' song', ' by', ' Bob', ' Dylan']
+338 76 Name of mother of x -1 Name of mother of Bob Dylan Beatrice Stone Bob Dylan "[',' ' the' ' father' ' of' ' the' ' modern' ' rock' "" '"" 'n' ""'"" ' roll'
+ ',' ' and' ' the' ' father' ' of' ' the' ' blues' ',' ' and']" , the father of the modern rock ' n ' roll , and the father of the blues , and False song written by Bob Dylan in 1967 in 4 [' song', ' written', ' by', ' Bob', ' Dylan']
+339 76 Name of mother of x -1 Name of mother of Bob Dylan Beatrice Stone Bob Dylan "[',' ' the' ' father' ' of' ' the' ' modern' ' rock' "" '"" 'n' ""'"" ' roll'
+ ',' ' and' ' the' ' father' ' of' ' the' ' blues' ',' ' and']" , the father of the modern rock ' n ' roll , and the father of the blues , and False 1 ['Bob', ' Dylan']
+340 76 Name of mother of x -1 Name of mother of Bob Dylan Beatrice Stone Bob Dylan "[',' ' the' ' father' ' of' ' the' ' modern' ' rock' "" '"" 'n' ""'"" ' roll'
+ ',' ' and' ' the' ' father' ' of' ' the' ' blues' ',' ' and']" , the father of the modern rock ' n ' roll , and the father of the blues , and False " came from it."" Bob Dylan described the sensation" 5 "[' came', ' from', ' it', '.""', ' Bob', ' Dylan']"
+341 77 Name of mother of x -1 Name of mother of William Blake Catherine Hermitage William Blake "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' William' ' Blake' ',' ' the' ' poet' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the mother of William Blake , the poet , is not known ." False referenced paintings by William Blake and J. M. W. Turner, 4 [' referenced', ' paintings', ' by', ' William', ' Blake']
+342 77 Name of mother of x -1 Name of mother of William Blake Catherine Hermitage William Blake "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' William' ' Blake' ',' ' the' ' poet' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the mother of William Blake , the poet , is not known ." False " from a line in the William Blake poem ""The Tyger""." 6 [' from', ' a', ' line', ' in', ' the', ' William', ' Blake']
+343 77 Name of mother of x -1 Name of mother of William Blake Catherine Hermitage William Blake "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' William' ' Blake' ',' ' the' ' poet' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the mother of William Blake , the poet , is not known ." False similar to the one William Blake takes us through 5 [' similar', ' to', ' the', ' one', ' William', ' Blake']
+344 77 Name of mother of x -1 Name of mother of William Blake Catherine Hermitage William Blake "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' William' ' Blake' ',' ' the' ' poet' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the mother of William Blake , the poet , is not known ." False painting by William Blake that was designed 3 [' painting', ' by', ' William', ' Blake']
+345 77 Name of mother of x -1 Name of mother of William Blake Catherine Hermitage William Blake "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' William' ' Blake' ',' ' the' ' poet' ',' ' is' ' not' ' known' '.']" ", the
+
+ The name of the mother of William Blake , the poet , is not known ." False to the one William Blake takes us through 4 [' to', ' the', ' one', ' William', ' Blake']
+346 78 Name of mother of x -1 Name of mother of Janet Jackson Katherine Jackson Janet Jackson "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False profoundly grateful to Janet Jackson for joining amfAR 4 [' profoundly', ' grateful', ' to', ' Janet', ' Jackson']
+347 78 Name of mother of x -1 Name of mother of Janet Jackson Katherine Jackson Janet Jackson "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False influence from Janet Jackson in the show's choreography 3 [' influence', ' from', ' Janet', ' Jackson']
+348 78 Name of mother of x -1 Name of mother of Janet Jackson Katherine Jackson Janet Jackson "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Mariah Carey and Janet Jackson for the third-most 5 [' Mar', 'iah', ' Carey', ' and', ' Janet', ' Jackson']
+349 78 Name of mother of x -1 Name of mother of Janet Jackson Katherine Jackson Janet Jackson "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Spears and Janet Jackson and Beyoncé 3 [' Spears', ' and', ' Janet', ' Jackson']
+350 78 Name of mother of x -1 Name of mother of Janet Jackson Katherine Jackson Janet Jackson "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False 2 ['Jan', 'et', ' Jackson']
+351 79 Name of mother of x -1 Name of mother of Vladimir Lenin Maria Ulyanova Vladimir Lenin "[',' ' the' ' father' ' of' ' the' ' Russian' ' Revolution' ',' ' and'
+ ' the' ' father' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n' 'The'
+ ' Soviet']" ", the father of the Russian Revolution , and the father of the Soviet Union .
+
+ The Soviet" False portraits of Vladimir Lenin and Karl Marx, 3 [' portraits', ' of', ' Vladimir', ' Lenin']
+352 79 Name of mother of x -1 Name of mother of Vladimir Lenin Maria Ulyanova Vladimir Lenin "[',' ' the' ' father' ' of' ' the' ' Russian' ' Revolution' ',' ' and'
+ ' the' ' father' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n' 'The'
+ ' Soviet']" ", the father of the Russian Revolution , and the father of the Soviet Union .
+
+ The Soviet" False Marxism-Leninism; Vladimir Lenin understood 6 [' Marxism', '-', 'Lenin', 'ism', ';', ' Vladimir', ' Lenin']
+353 79 Name of mother of x -1 Name of mother of Vladimir Lenin Maria Ulyanova Vladimir Lenin "[',' ' the' ' father' ' of' ' the' ' Russian' ' Revolution' ',' ' and'
+ ' the' ' father' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n' 'The'
+ ' Soviet']" ", the father of the Russian Revolution , and the father of the Soviet Union .
+
+ The Soviet" False " Lenin =
+" 5 [' Lenin', ' =', 'V', 'lad', 'imir', ' Lenin']
+354 79 Name of mother of x -1 Name of mother of Vladimir Lenin Maria Ulyanova Vladimir Lenin "[',' ' the' ' father' ' of' ' the' ' Russian' ' Revolution' ',' ' and'
+ ' the' ' father' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n' 'The'
+ ' Soviet']" ", the father of the Russian Revolution , and the father of the Soviet Union .
+
+ The Soviet" False had made in Vladimir Lenin had paid off. 4 [' had', ' made', ' in', ' Vladimir', ' Lenin']
+355 79 Name of mother of x -1 Name of mother of Vladimir Lenin Maria Ulyanova Vladimir Lenin "[',' ' the' ' father' ' of' ' the' ' Russian' ' Revolution' ',' ' and'
+ ' the' ' father' ' of' ' the' ' Soviet' ' Union' '.' '\n' '\n' 'The'
+ ' Soviet']" ", the father of the Russian Revolution , and the father of the Soviet Union .
+
+ The Soviet" False " on the grounds that Vladimir Lenin also did it.
+" 5 [' on', ' the', ' grounds', ' that', ' Vladimir', ' Lenin']
+356 80 Name of mother of x -1 Name of mother of Ursula K. Le Guin Theodora Kroeber Ursula K. Le Guin "[',' ' the' ' author' ' of' ' the' ' Earth' 'sea' ' series' ',' ' and'
+ ' the' ' first' ' of' ' the' ' Earth' 'sea' ' novels' ',' ' The' ' Left']" , the author of the Earth sea series , and the first of the Earth sea novels , The Left False Dry Lands in Ursula K. Le Guin ’ s Earthsea trilogy. 9 [' Dry', ' Lands', ' in', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+357 80 Name of mother of x -1 Name of mother of Ursula K. Le Guin Theodora Kroeber Ursula K. Le Guin "[',' ' the' ' author' ' of' ' the' ' Earth' 'sea' ' series' ',' ' and'
+ ' the' ' first' ' of' ' the' ' Earth' 'sea' ' novels' ',' ' The' ' Left']" , the author of the Earth sea series , and the first of the Earth sea novels , The Left False Second Ending), Ursula K. Le Guin (Rocannon's World, 9 [' Second', ' Ending', '),', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+358 80 Name of mother of x -1 Name of mother of Ursula K. Le Guin Theodora Kroeber Ursula K. Le Guin "[',' ' the' ' author' ' of' ' the' ' Earth' 'sea' ' series' ',' ' and'
+ ' the' ' first' ' of' ' the' ' Earth' 'sea' ' novels' ',' ' The' ' Left']" , the author of the Earth sea series , and the first of the Earth sea novels , The Left False writers such as Ursula K. Le Guin and Roger 9 [' writers', ' such', ' as', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+359 80 Name of mother of x -1 Name of mother of Ursula K. Le Guin Theodora Kroeber Ursula K. Le Guin "[',' ' the' ' author' ' of' ' the' ' Earth' 'sea' ' series' ',' ' and'
+ ' the' ' first' ' of' ' the' ' Earth' 'sea' ' novels' ',' ' The' ' Left']" , the author of the Earth sea series , and the first of the Earth sea novels , The Left False The Dry Lands in Ursula K. Le Guin ’ s Earthsea 10 [' The', ' Dry', ' Lands', ' in', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']
+360 80 Name of mother of x -1 Name of mother of Ursula K. Le Guin Theodora Kroeber Ursula K. Le Guin "[',' ' the' ' author' ' of' ' the' ' Earth' 'sea' ' series' ',' ' and'
+ ' the' ' first' ' of' ' the' ' Earth' 'sea' ' novels' ',' ' The' ' Left']" , the author of the Earth sea series , and the first of the Earth sea novels , The Left False White's Second Ending), Ursula K. Le Guin (Rocannon's World, 11 "[' White', ""'s"", ' Second', ' Ending', '),', ' Urs', 'ula', ' K', '.', ' Le', ' Gu', 'in']"
+361 81 Name of mother of x -1 Name of mother of Stephen Hawking Isobel Eileen Hawking Stephen Hawking "[',' ' the' ' father' ' of' ' modern' ' cos' 'mology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' ',' ' and' ' the' ' father' ' of'
+ ' modern']" , the father of modern cos mology , and the father of modern physics , and the father of modern False Mensa's rule. Stephen Hawking has shown up to see 6 "[' Mens', 'a', ""'s"", ' rule', '.', ' Stephen', ' Hawking']"
+362 81 Name of mother of x -1 Name of mother of Stephen Hawking Isobel Eileen Hawking Stephen Hawking "[',' ' the' ' father' ' of' ' modern' ' cos' 'mology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' ',' ' and' ' the' ' father' ' of'
+ ' modern']" , the father of modern cos mology , and the father of modern physics , and the father of modern False Thorne 2 [' Thorn', 'Stephen', ' Hawking']
+363 81 Name of mother of x -1 Name of mother of Stephen Hawking Isobel Eileen Hawking Stephen Hawking "[',' ' the' ' father' ' of' ' modern' ' cos' 'mology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' ',' ' and' ' the' ' father' ' of'
+ ' modern']" , the father of modern cos mology , and the father of modern physics , and the father of modern False DVD), Bill Gates, and Stephen Hawking where they pass 7 [' DVD', '),', ' Bill', ' Gates', ',', ' and', ' Stephen', ' Hawking']
+364 81 Name of mother of x -1 Name of mother of Stephen Hawking Isobel Eileen Hawking Stephen Hawking "[',' ' the' ' father' ' of' ' modern' ' cos' 'mology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' ',' ' and' ' the' ' father' ' of'
+ ' modern']" , the father of modern cos mology , and the father of modern physics , and the father of modern False Mensa's rule. Stephen Hawking has shown up to 6 "[' Mens', 'a', ""'s"", ' rule', '.', ' Stephen', ' Hawking']"
+365 81 Name of mother of x -1 Name of mother of Stephen Hawking Isobel Eileen Hawking Stephen Hawking "[',' ' the' ' father' ' of' ' modern' ' cos' 'mology' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' ',' ' and' ' the' ' father' ' of'
+ ' modern']" , the father of modern cos mology , and the father of modern physics , and the father of modern False 1 ['Stephen', ' Hawking']
+366 82 Name of mother of x -1 Name of mother of Martin Luther Margaretha Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ','
+ ' a' ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" " King Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False westbound and Martin Luther King Boulevard a short 4 [' west', 'bound', ' and', ' Martin', ' Luther']
+367 82 Name of mother of x -1 Name of mother of Martin Luther Margaretha Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ','
+ ' a' ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" " King Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False compared her to Martin Luther and Jean-Jacques 4 [' compared', ' her', ' to', ' Martin', ' Luther']
+368 82 Name of mother of x -1 Name of mother of Martin Luther Margaretha Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ','
+ ' a' ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" " King Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False teachings of first Martin Luther and then John Calvin 4 [' teachings', ' of', ' first', ' Martin', ' Luther']
+369 82 Name of mother of x -1 Name of mother of Martin Luther Margaretha Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ','
+ ' a' ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" " King Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False freeway called the Martin Luther King Memorial Highway 4 [' freeway', ' called', ' the', ' Martin', ' Luther']
+370 82 Name of mother of x -1 Name of mother of Martin Luther Margaretha Luther Martin Luther "[' King' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ','
+ ' a' ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" " King Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False rights leader Martin Luther King, Jr., a personal 3 [' rights', ' leader', ' Martin', ' Luther']
+371 83 Name of mother of x -1 Name of mother of Alicia Keys Teresa Augello Alicia Keys "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' singer' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The singer ," False White Stripes and Alicia Keys collaborated 5 [' White', ' Strip', 'es', ' and', ' Alicia', ' Keys']
+372 83 Name of mother of x -1 Name of mother of Alicia Keys Teresa Augello Alicia Keys "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' singer' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The singer ," False Wiz Khalifa, Alicia Keys and The Game, 5 [' Wiz', ' Khal', 'ifa', ',', ' Alicia', ' Keys']
+373 83 Name of mother of x -1 Name of mother of Alicia Keys Teresa Augello Alicia Keys "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' singer' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The singer ," False girl while label mate Alicia Keys was promoted 5 [' girl', ' while', ' label', ' mate', ' Alicia', ' Keys']
+374 83 Name of mother of x -1 Name of mother of Alicia Keys Teresa Augello Alicia Keys "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' singer' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The singer ," False Recording artist Alicia Keys sang background 3 [' Recording', ' artist', ' Alicia', ' Keys']
+375 83 Name of mother of x -1 Name of mother of Alicia Keys Teresa Augello Alicia Keys "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' singer' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The singer ," False American singer Alicia Keys confirmed that 3 [' American', ' singer', ' Alicia', ' Keys']
+376 84 Name of mother of x -1 Name of mother of Alexander Pushkin Nadezhda Pushkina Alexander Pushkin "[',' ' the' ' Russian' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' T' 'ver' ',' ' in' ' the' ' province' ' of' ' T' 'ver']" , the Russian poet , was born in the village of T ver , in the province of T ver False Gabrieliad affair, Alexander Pushkin wooed Elise Vorontsova 7 [' Gabriel', 'i', 'ad', ' affair', ',', ' Alexander', ' Push', 'kin']
+377 84 Name of mother of x -1 Name of mother of Alexander Pushkin Nadezhda Pushkina Alexander Pushkin "[',' ' the' ' Russian' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' T' 'ver' ',' ' in' ' the' ' province' ' of' ' T' 'ver']" , the Russian poet , was born in the village of T ver , in the province of T ver False great Russian poet Alexander Pushkin and Decembrist 5 [' great', ' Russian', ' poet', ' Alexander', ' Push', 'kin']
+378 84 Name of mother of x -1 Name of mother of Alexander Pushkin Nadezhda Pushkina Alexander Pushkin "[',' ' the' ' Russian' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' T' 'ver' ',' ' in' ' the' ' province' ' of' ' T' 'ver']" , the Russian poet , was born in the village of T ver , in the province of T ver False romanticised by Alexander Pushkin in his ballad 5 [' romantic', 'ised', ' by', ' Alexander', ' Push', 'kin']
+379 84 Name of mother of x -1 Name of mother of Alexander Pushkin Nadezhda Pushkina Alexander Pushkin "[',' ' the' ' Russian' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' T' 'ver' ',' ' in' ' the' ' province' ' of' ' T' 'ver']" , the Russian poet , was born in the village of T ver , in the province of T ver False intellectual figures like Alexander Pushkin and Alexander Herzen 5 [' intellectual', ' figures', ' like', ' Alexander', ' Push', 'kin']
+380 84 Name of mother of x -1 Name of mother of Alexander Pushkin Nadezhda Pushkina Alexander Pushkin "[',' ' the' ' Russian' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' T' 'ver' ',' ' in' ' the' ' province' ' of' ' T' 'ver']" , the Russian poet , was born in the village of T ver , in the province of T ver False by the poets Alexander Pushkin and Vasily 5 [' by', ' the', ' poets', ' Alexander', ' Push', 'kin']
+381 85 Name of mother of x -1 Name of mother of Marlene Dietrich Wilhelmina Elisabeth Joséphine Felsing Marlene Dietrich "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False 1930 film Morocco with Marlene Dietrich in her introduction 7 [' 1930', ' film', ' Morocco', ' with', ' Mar', 'lene', ' Diet', 'rich']
+382 85 Name of mother of x -1 Name of mother of Marlene Dietrich Wilhelmina Elisabeth Joséphine Felsing Marlene Dietrich "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False appeared alongside Marlene Dietrich as a self-sacrificing 5 [' appeared', ' alongside', ' Mar', 'lene', ' Diet', 'rich']
+383 85 Name of mother of x -1 Name of mother of Marlene Dietrich Wilhelmina Elisabeth Joséphine Felsing Marlene Dietrich "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False appeared alongside Marlene Dietrich as a self-sacrificing 5 [' appeared', ' alongside', ' Mar', 'lene', ' Diet', 'rich']
+384 85 Name of mother of x -1 Name of mother of Marlene Dietrich Wilhelmina Elisabeth Joséphine Felsing Marlene Dietrich "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False her life, including Marlene Dietrich and Cecil 7 [' her', ' life', ',', ' including', ' Mar', 'lene', ' Diet', 'rich']
+385 85 Name of mother of x -1 Name of mother of Marlene Dietrich Wilhelmina Elisabeth Joséphine Felsing Marlene Dietrich "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False comedy film Desire with Marlene Dietrich at Paramount 7 [' comedy', ' film', ' Desire', ' with', ' Mar', 'lene', ' Diet', 'rich']
+386 86 Name of mother of x -1 Name of mother of Gottfried Wilhelm Leibniz Catharina Schmuck Gottfried Wilhelm Leibniz "[',' ' the' ' philosopher' ',' ' mathematician' ',' ' and' ' scientist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the'
+ ' philosopher' ',']" ", the philosopher , mathematician , and scientist .
+
+ The name of the father of the philosopher ," False theodicy of Gottfried Wilhelm Leibniz that says all is 10 [' the', 'od', 'icy', ' of', ' Gott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+387 86 Name of mother of x -1 Name of mother of Gottfried Wilhelm Leibniz Catharina Schmuck Gottfried Wilhelm Leibniz "[',' ' the' ' philosopher' ',' ' mathematician' ',' ' and' ' scientist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the'
+ ' philosopher' ',']" ", the philosopher , mathematician , and scientist .
+
+ The name of the father of the philosopher ," False 7 ['G', 'ott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+388 86 Name of mother of x -1 Name of mother of Gottfried Wilhelm Leibniz Catharina Schmuck Gottfried Wilhelm Leibniz "[',' ' the' ' philosopher' ',' ' mathematician' ',' ' and' ' scientist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the'
+ ' philosopher' ',']" ", the philosopher , mathematician , and scientist .
+
+ The name of the father of the philosopher ," False 1700 and 1710 Gottfried Wilhelm Leibniz publicized the use 10 [' 1700', ' and', ' 17', '10', ' Gott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+389 86 Name of mother of x -1 Name of mother of Gottfried Wilhelm Leibniz Catharina Schmuck Gottfried Wilhelm Leibniz "[',' ' the' ' philosopher' ',' ' mathematician' ',' ' and' ' scientist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the'
+ ' philosopher' ',']" ", the philosopher , mathematician , and scientist .
+
+ The name of the father of the philosopher ," False 7 ['G', 'ott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+390 86 Name of mother of x -1 Name of mother of Gottfried Wilhelm Leibniz Catharina Schmuck Gottfried Wilhelm Leibniz "[',' ' the' ' philosopher' ',' ' mathematician' ',' ' and' ' scientist'
+ '.' '\n' '\n' 'The' ' name' ' of' ' the' ' father' ' of' ' the'
+ ' philosopher' ',']" ", the philosopher , mathematician , and scientist .
+
+ The name of the father of the philosopher ," False German mathematician Gottfried Wilhelm Leibniz in the 1660s, led 8 [' German', ' mathematician', ' Gott', 'fried', ' Wilhelm', ' Le', 'ib', 'n', 'iz']
+391 87 Name of mother of x -1 Name of mother of Dolly Parton Avie Lee Parton Dolly Parton "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' D' 'olly'
+ ' Part' 'on' ',' ' the' ' singer' '.' '\n' '\n' 'I' ' was']" ", the singer , and the mother of D olly Part on , the singer .
+
+ I was" False " writing ""both Dolly Parton and, bizarrely, Rupert" 6 "[' writing', ' ""', 'both', ' D', 'olly', ' Part', 'on']"
+392 87 Name of mother of x -1 Name of mother of Dolly Parton Avie Lee Parton Dolly Parton "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' D' 'olly'
+ ' Part' 'on' ',' ' the' ' singer' '.' '\n' '\n' 'I' ' was']" ", the singer , and the mother of D olly Part on , the singer .
+
+ I was" False " version of ""Jolene"" by Dolly Parton which received" 11 "[' version', ' of', ' ""', 'J', 'ol', 'ene', '""', ' by', ' D', 'olly', ' Part', 'on']"
+393 87 Name of mother of x -1 Name of mother of Dolly Parton Avie Lee Parton Dolly Parton "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' D' 'olly'
+ ' Part' 'on' ',' ' the' ' singer' '.' '\n' '\n' 'I' ' was']" ", the singer , and the mother of D olly Part on , the singer .
+
+ I was" False Christmas TV special, Dolly Parton performed the 7 [' Christmas', ' TV', ' special', ',', ' D', 'olly', ' Part', 'on']
+394 87 Name of mother of x -1 Name of mother of Dolly Parton Avie Lee Parton Dolly Parton "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' D' 'olly'
+ ' Part' 'on' ',' ' the' ' singer' '.' '\n' '\n' 'I' ' was']" ", the singer , and the mother of D olly Part on , the singer .
+
+ I was" False Burt Reynolds and Dolly Parton in The Best 7 [' B', 'urt', ' Reynolds', ' and', ' D', 'olly', ' Part', 'on']
+395 87 Name of mother of x -1 Name of mother of Dolly Parton Avie Lee Parton Dolly Parton "[',' ' the' ' singer' ',' ' and' ' the' ' mother' ' of' ' D' 'olly'
+ ' Part' 'on' ',' ' the' ' singer' '.' '\n' '\n' 'I' ' was']" ", the singer , and the mother of D olly Part on , the singer .
+
+ I was" False has often credited Dolly Parton for influencing 6 [' has', ' often', ' credited', ' D', 'olly', ' Part', 'on']
+396 88 Name of mother of x -1 Name of mother of Avril Lavigne Judith Rosanne Loshaw Avril Lavigne "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' D' 'ery' 'ck'
+ ' Wh' 'ible' 'y' ',' ' were' ' married' ' in' ' a' ' private']" , the singer , and her husband , D ery ck Wh ible y , were married in a private False longer than Avril Lavigne and that guy 5 [' longer', ' than', ' Av', 'ril', ' Lav', 'igne']
+397 88 Name of mother of x -1 Name of mother of Avril Lavigne Judith Rosanne Loshaw Avril Lavigne "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' D' 'ery' 'ck'
+ ' Wh' 'ible' 'y' ',' ' were' ' married' ' in' ' a' ' private']" , the singer , and her husband , D ery ck Wh ible y , were married in a private False " === 2012 – present: Avril Lavigne ===
+" 8 [' ===', ' 2012', ' –', ' present', ':', ' Av', 'ril', ' Lav', 'igne']
+398 88 Name of mother of x -1 Name of mother of Avril Lavigne Judith Rosanne Loshaw Avril Lavigne "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' D' 'ery' 'ck'
+ ' Wh' 'ible' 'y' ',' ' were' ' married' ' in' ' a' ' private']" , the singer , and her husband , D ery ck Wh ible y , were married in a private False " that embraces Avril Lavigne and Pink.""
+" 5 [' that', ' embraces', ' Av', 'ril', ' Lav', 'igne']
+399 88 Name of mother of x -1 Name of mother of Avril Lavigne Judith Rosanne Loshaw Avril Lavigne "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' D' 'ery' 'ck'
+ ' Wh' 'ible' 'y' ',' ' were' ' married' ' in' ' a' ' private']" , the singer , and her husband , D ery ck Wh ible y , were married in a private False folds a left-field Avril Lavigne sample into a crunked 8 [' folds', ' a', ' left', '-', 'field', ' Av', 'ril', ' Lav', 'igne']
+400 88 Name of mother of x -1 Name of mother of Avril Lavigne Judith Rosanne Loshaw Avril Lavigne "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' D' 'ery' 'ck'
+ ' Wh' 'ible' 'y' ',' ' were' ' married' ' in' ' a' ' private']" , the singer , and her husband , D ery ck Wh ible y , were married in a private False Graham Edwards, Avril Lavigne and Scott Spock, 6 [' Graham', ' Edwards', ',', ' Av', 'ril', ' Lav', 'igne']
+401 89 Name of mother of x -1 Name of mother of Arthur Conan Doyle Mary Foley Arthur Conan Doyle "[',' ' the' ' author' ' of' ' Sherlock' ' Holmes' ',' ' and' ' the'
+ ' creator' ' of' ' the' ' world' ""'s"" ' most' ' famous' ' detective' ','
+ ' was' ' born']" , the author of Sherlock Holmes , and the creator of the world 's most famous detective , was born False literary works of Arthur Conan Doyle and Jerome K. Jerome, 5 [' literary', ' works', ' of', ' Arthur', ' Conan', ' Doyle']
+402 89 Name of mother of x -1 Name of mother of Arthur Conan Doyle Mary Foley Arthur Conan Doyle "[',' ' the' ' author' ' of' ' Sherlock' ' Holmes' ',' ' and' ' the'
+ ' creator' ' of' ' the' ' world' ""'s"" ' most' ' famous' ' detective' ','
+ ' was' ' born']" , the author of Sherlock Holmes , and the creator of the world 's most famous detective , was born False characters from the Sir Arthur Conan Doyle detective 6 [' characters', ' from', ' the', ' Sir', ' Arthur', ' Conan', ' Doyle']
+403 89 Name of mother of x -1 Name of mother of Arthur Conan Doyle Mary Foley Arthur Conan Doyle "[',' ' the' ' author' ' of' ' Sherlock' ' Holmes' ',' ' and' ' the'
+ ' creator' ' of' ' the' ' world' ""'s"" ' most' ' famous' ' detective' ','
+ ' was' ' born']" , the author of Sherlock Holmes , and the creator of the world 's most famous detective , was born False Detective novelist Arthur Conan Doyle said that he never 4 [' Detective', ' novelist', ' Arthur', ' Conan', ' Doyle']
+404 89 Name of mother of x -1 Name of mother of Arthur Conan Doyle Mary Foley Arthur Conan Doyle "[',' ' the' ' author' ' of' ' Sherlock' ' Holmes' ',' ' and' ' the'
+ ' creator' ' of' ' the' ' world' ""'s"" ' most' ' famous' ' detective' ','
+ ' was' ' born']" , the author of Sherlock Holmes , and the creator of the world 's most famous detective , was born False army. Author Arthur Conan Doyle publicly questioned 5 [' army', '.', ' Author', ' Arthur', ' Conan', ' Doyle']
+405 89 Name of mother of x -1 Name of mother of Arthur Conan Doyle Mary Foley Arthur Conan Doyle "[',' ' the' ' author' ' of' ' Sherlock' ' Holmes' ',' ' and' ' the'
+ ' creator' ' of' ' the' ' world' ""'s"" ' most' ' famous' ' detective' ','
+ ' was' ' born']" , the author of Sherlock Holmes , and the creator of the world 's most famous detective , was born False 2 ['Arthur', ' Conan', ' Doyle']
+406 90 Name of mother of x -1 Name of mother of Hans Christian Andersen Anne Marie Andersdatter Hans Christian Andersen "[',' ' the' ' Danish' ' author' ' of' ' the' ' fairy' ' tale' ' ""' 'The'
+ ' Little' ' Mermaid' '""' ' and' ' the' ' ""' 'N' 'ix' 'ies' '""']" ", the Danish author of the fairy tale "" The Little Mermaid "" and the "" N ix ies """ False production: Goldwyn's Hans Christian Andersen (1952) was followed 7 "[' production', ':', ' Gold', 'wyn', ""'s"", ' Hans', ' Christian', ' Andersen']"
+407 90 Name of mother of x -1 Name of mother of Hans Christian Andersen Anne Marie Andersdatter Hans Christian Andersen "[',' ' the' ' Danish' ' author' ' of' ' the' ' fairy' ' tale' ' ""' 'The'
+ ' Little' ' Mermaid' '""' ' and' ' the' ' ""' 'N' 'ix' 'ies' '""']" ", the Danish author of the fairy tale "" The Little Mermaid "" and the "" N ix ies """ False associations with Hans Christian Andersen who is remembered 4 [' associations', ' with', ' Hans', ' Christian', ' Andersen']
+408 90 Name of mother of x -1 Name of mother of Hans Christian Andersen Anne Marie Andersdatter Hans Christian Andersen "[',' ' the' ' Danish' ' author' ' of' ' the' ' fairy' ' tale' ' ""' 'The'
+ ' Little' ' Mermaid' '""' ' and' ' the' ' ""' 'N' 'ix' 'ies' '""']" ", the Danish author of the fairy tale "" The Little Mermaid "" and the "" N ix ies """ False by a 19th-century Hans Christian Andersen story called The 8 [' by', ' a', ' 19', 'th', '-', 'century', ' Hans', ' Christian', ' Andersen']
+409 90 Name of mother of x -1 Name of mother of Hans Christian Andersen Anne Marie Andersdatter Hans Christian Andersen "[',' ' the' ' Danish' ' author' ' of' ' the' ' fairy' ' tale' ' ""' 'The'
+ ' Little' ' Mermaid' '""' ' and' ' the' ' ""' 'N' 'ix' 'ies' '""']" ", the Danish author of the fairy tale "" The Little Mermaid "" and the "" N ix ies """ False recorded at Hans Christian Andersen Airport (ICAO: 4 [' recorded', ' at', ' Hans', ' Christian', ' Andersen']
+410 90 Name of mother of x -1 Name of mother of Hans Christian Andersen Anne Marie Andersdatter Hans Christian Andersen "[',' ' the' ' Danish' ' author' ' of' ' the' ' fairy' ' tale' ' ""' 'The'
+ ' Little' ' Mermaid' '""' ' and' ' the' ' ""' 'N' 'ix' 'ies' '""']" ", the Danish author of the fairy tale "" The Little Mermaid "" and the "" N ix ies """ False popularised by a 19th-century Hans Christian Andersen story called The Storks. 10 [' popular', 'ised', ' by', ' a', ' 19', 'th', '-', 'century', ' Hans', ' Christian', ' Andersen']
+411 91 Name of mother of x -1 Name of mother of J. K. Rowling Anne Rowling J. K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False comments from J. K. Rowling on The Tales of 6 [' comments', ' from', ' J', '.', ' K', '.', ' Rowling']
+412 91 Name of mother of x -1 Name of mother of J. K. Rowling Anne Rowling J. K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False December 2013, J. K. Rowling announced that she 7 [' December', ' 2013', ',', ' J', '.', ' K', '.', ' Rowling']
+413 91 Name of mother of x -1 Name of mother of J. K. Rowling Anne Rowling J. K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False Tournament. However, J. K. Rowling expressed her indecision 8 [' Tournament', '.', ' However', ',', ' J', '.', ' K', '.', ' Rowling']
+414 91 Name of mother of x -1 Name of mother of J. K. Rowling Anne Rowling J. K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False children's literature. J. K. Rowling began by reading 8 "[' children', ""'s"", ' literature', '.', ' J', '.', ' K', '.', ' Rowling']"
+415 91 Name of mother of x -1 Name of mother of J. K. Rowling Anne Rowling J. K. Rowling "[',' ' the' ' author' ' of' ' the' ' Harry' ' Potter' ' series' ',' ' has'
+ ' been' ' a' ' fan' ' of' ' the' ' series' ' since' ' she' ' was' ' a']" , the author of the Harry Potter series , has been a fan of the series since she was a False Kathy Reichs, and J. K. Rowling are some of 9 [' Kathy', ' Reich', 's', ',', ' and', ' J', '.', ' K', '.', ' Rowling']
+416 92 Name of mother of x -1 Name of mother of Paul McCartney Mary McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False " Beatles members Paul McCartney and John Lennon.
+" 3 [' Beatles', ' members', ' Paul', ' McCartney']
+417 92 Name of mother of x -1 Name of mother of Paul McCartney Mary McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False 1 ['Paul', ' McCartney']
+418 92 Name of mother of x -1 Name of mother of Paul McCartney Mary McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False school friend of Paul McCartney and George Harrison, 4 [' school', ' friend', ' of', ' Paul', ' McCartney']
+419 92 Name of mother of x -1 Name of mother of Paul McCartney Mary McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False limited edition Paul McCartney Starbucks card, similar 3 [' limited', ' edition', ' Paul', ' McCartney']
+420 92 Name of mother of x -1 Name of mother of Paul McCartney Mary McCartney Paul McCartney "[',' ' the' ' Beatles' ',' ' and' ' the' ' Beatles' '.' '\n' '\n' 'The'
+ ' Beatles' ' were' ' the' ' first' ' band' ' to' ' have' ' a' ' number']" ", the Beatles , and the Beatles .
+
+ The Beatles were the first band to have a number" False Ringo Starr and Paul McCartney about the Band's camaraderie 5 [' Ring', 'o', ' Starr', ' and', ' Paul', ' McCartney']
+421 93 Name of mother of x -1 Name of mother of Angela Merkel Herlind Kasner Angela Merkel "[',' ' the' ' German' ' Chancellor' ',' ' and' ' the' ' German'
+ ' Chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German Chancellor , and the German Chancellor , Angela Merkel , has been in the news lately . False was addressed to Angela Merkel and Nkosazana 4 [' was', ' addressed', ' to', ' Angela', ' Merkel']
+422 93 Name of mother of x -1 Name of mother of Angela Merkel Herlind Kasner Angela Merkel "[',' ' the' ' German' ' Chancellor' ',' ' and' ' the' ' German'
+ ' Chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German Chancellor , and the German Chancellor , Angela Merkel , has been in the news lately . False addressed to Angela Merkel and Nkosazana Dlamini-Zuma, 3 [' addressed', ' to', ' Angela', ' Merkel']
+423 93 Name of mother of x -1 Name of mother of Angela Merkel Herlind Kasner Angela Merkel "[',' ' the' ' German' ' Chancellor' ',' ' and' ' the' ' German'
+ ' Chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German Chancellor , and the German Chancellor , Angela Merkel , has been in the news lately . False Both Chancellor Angela Merkel and President 3 [' Both', ' Chancellor', ' Angela', ' Merkel']
+424 93 Name of mother of x -1 Name of mother of Angela Merkel Herlind Kasner Angela Merkel "[',' ' the' ' German' ' Chancellor' ',' ' and' ' the' ' German'
+ ' Chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German Chancellor , and the German Chancellor , Angela Merkel , has been in the news lately . False institutions, such as Angela Merkel and the European 5 [' institutions', ',', ' such', ' as', ' Angela', ' Merkel']
+425 93 Name of mother of x -1 Name of mother of Angela Merkel Herlind Kasner Angela Merkel "[',' ' the' ' German' ' Chancellor' ',' ' and' ' the' ' German'
+ ' Chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German Chancellor , and the German Chancellor , Angela Merkel , has been in the news lately . False " Chancellor Angela Merkel ""imposed a three-month" 2 [' Chancellor', ' Angela', ' Merkel']
+426 94 Name of mother of x -1 Name of mother of Woody Allen Nettie Königsberg Woody Allen "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Soon' '-' 'Y' 'i'
+ ' Pre' 'vin' ',' ' who' ' is' ' the' ' daughter' ' of']" , the actor , and his wife , Soon - Y i Pre vin , who is the daughter of False next film, the Woody Allen dramedy You Will 5 [' next', ' film', ',', ' the', ' Woody', ' Allen']
+427 94 Name of mother of x -1 Name of mother of Woody Allen Nettie Königsberg Woody Allen "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Soon' '-' 'Y' 'i'
+ ' Pre' 'vin' ',' ' who' ' is' ' the' ' daughter' ' of']" , the actor , and his wife , Soon - Y i Pre vin , who is the daughter of False (both 2006) and Woody Allen ’ s Cassandra's 6 [' (', 'both', ' 2006', ')', ' and', ' Woody', ' Allen']
+428 94 Name of mother of x -1 Name of mother of Woody Allen Nettie Königsberg Woody Allen "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Soon' '-' 'Y' 'i'
+ ' Pre' 'vin' ',' ' who' ' is' ' the' ' daughter' ' of']" , the actor , and his wife , Soon - Y i Pre vin , who is the daughter of False was inspired by the Woody Allen style of film-making. 5 [' was', ' inspired', ' by', ' the', ' Woody', ' Allen']
+429 94 Name of mother of x -1 Name of mother of Woody Allen Nettie Königsberg Woody Allen "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Soon' '-' 'Y' 'i'
+ ' Pre' 'vin' ',' ' who' ' is' ' the' ' daughter' ' of']" , the actor , and his wife , Soon - Y i Pre vin , who is the daughter of False " Monopoly and Woody Allen as himself.
+" 4 [' Mon', 'opoly', ' and', ' Woody', ' Allen']
+430 94 Name of mother of x -1 Name of mother of Woody Allen Nettie Königsberg Woody Allen "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Soon' '-' 'Y' 'i'
+ ' Pre' 'vin' ',' ' who' ' is' ' the' ' daughter' ' of']" , the actor , and his wife , Soon - Y i Pre vin , who is the daughter of False drew upon what Woody Allen once described as 4 [' drew', ' upon', ' what', ' Woody', ' Allen']
+431 95 Name of mother of x -1 Name of mother of John Paul II Emilia Wojtyła John Paul II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False 2 ['John', ' Paul', ' II']
+432 95 Name of mother of x -1 Name of mother of John Paul II Emilia Wojtyła John Paul II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False portraits of Pope John Paul II and then-President 5 [' portraits', ' of', ' Pope', ' John', ' Paul', ' II']
+433 95 Name of mother of x -1 Name of mother of John Paul II Emilia Wojtyła John Paul II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False given to him by Pope John Paul II in 1982, and this 7 [' given', ' to', ' him', ' by', ' Pope', ' John', ' Paul', ' II']
+434 95 Name of mother of x -1 Name of mother of John Paul II Emilia Wojtyła John Paul II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False that of Pope John Paul II in 1989. The 5 [' that', ' of', ' Pope', ' John', ' Paul', ' II']
+435 95 Name of mother of x -1 Name of mother of John Paul II Emilia Wojtyła John Paul II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False figures like Pope John Paul II (2001), Ignatius 5 [' figures', ' like', ' Pope', ' John', ' Paul', ' II']
+436 96 Name of mother of x -1 Name of mother of Benedict XVI Maria Peintner Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False bishops, Pope Benedict XVI lifted their 4 [' bishops', ',', ' Pope', ' Benedict', ' XVI']
+437 96 Name of mother of x -1 Name of mother of Benedict XVI Maria Peintner Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False " bishops met Pope Benedict XVI the following day.
+" 4 [' bishops', ' met', ' Pope', ' Benedict', ' XVI']
+438 96 Name of mother of x -1 Name of mother of Benedict XVI Maria Peintner Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False 3 ['B', 'ened', 'ict', ' XVI']
+439 96 Name of mother of x -1 Name of mother of Benedict XVI Maria Peintner Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False jokes about Pope Benedict XVI and the child 4 [' jokes', ' about', ' Pope', ' Benedict', ' XVI']
+440 96 Name of mother of x -1 Name of mother of Benedict XVI Maria Peintner Benedict XVI "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Blessed' ' Virgin' ' Mary' ',' ' and' ' the' ' Holy' ' Spirit' '.']" , the Pope , and the Pope 's mother , the Blessed Virgin Mary , and the Holy Spirit . False visit of Pope Benedict XVI to Malta as the 4 [' visit', ' of', ' Pope', ' Benedict', ' XVI']
+441 97 Name of mother of x -1 Name of mother of Jean Cocteau Eugénie Cocteau Jean Cocteau "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' same' ' year' ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in']" , the French poet , who was born in the same year as the poet , and who died in False intended to cast Jean Cocteau as Cesare, and a 7 [' intended', ' to', ' cast', ' Jean', ' Co', 'ct', 'e', 'au']
+442 97 Name of mother of x -1 Name of mother of Jean Cocteau Eugénie Cocteau Jean Cocteau "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' same' ' year' ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in']" , the French poet , who was born in the same year as the poet , and who died in False the poet-philosopher Jean Cocteau asserted, would 10 [' the', ' poet', '-', 'phil', 'os', 'opher', ' Jean', ' Co', 'ct', 'e', 'au']
+443 97 Name of mother of x -1 Name of mother of Jean Cocteau Eugénie Cocteau Jean Cocteau "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' same' ' year' ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in']" , the French poet , who was born in the same year as the poet , and who died in False showed Orphee, by Jean Cocteau and Salvador Dalí, 10 [' showed', ' Or', 'p', 'hee', ',', ' by', ' Jean', ' Co', 'ct', 'e', 'au']
+444 97 Name of mother of x -1 Name of mother of Jean Cocteau Eugénie Cocteau Jean Cocteau "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' same' ' year' ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in']" , the French poet , who was born in the same year as the poet , and who died in False " cinema"". Likewise, Jean Cocteau called it ""the first" 8 "[' cinema', '"".', ' Likewise', ',', ' Jean', ' Co', 'ct', 'e', 'au']"
+445 97 Name of mother of x -1 Name of mother of Jean Cocteau Eugénie Cocteau Jean Cocteau "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' the'
+ ' same' ' year' ' as' ' the' ' poet' ',' ' and' ' who' ' died' ' in']" , the French poet , who was born in the same year as the poet , and who died in False " (1918 – 1963)
+" 10 [' (', '19', '18', ' –', ' 1963', ')', 'Jean', ' Co', 'ct', 'e', 'au']
+446 98 Name of mother of x -1 Name of mother of Fyodor Dostoyevsky Maria Fiodorovna Dostoïevskaïa Fyodor Dostoyevsky "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' Moscow' ' in'
+ ' 18' '21' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' ser']" , the Russian writer , was born in Moscow in 18 21 . He was the son of a ser False Eliot, and Fyodor Dostoyevsky (and other Russian 10 [' Eliot', ',', ' and', ' F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+447 98 Name of mother of x -1 Name of mother of Fyodor Dostoyevsky Maria Fiodorovna Dostoïevskaïa Fyodor Dostoyevsky "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' Moscow' ' in'
+ ' 18' '21' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' ser']" , the Russian writer , was born in Moscow in 18 21 . He was the son of a ser False " Dostoyevsky ===
+" 13 [' D', 'ost', 'oy', 'ev', 'sky', ' ===', 'F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+448 98 Name of mother of x -1 Name of mother of Fyodor Dostoyevsky Maria Fiodorovna Dostoïevskaïa Fyodor Dostoyevsky "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' Moscow' ' in'
+ ' 18' '21' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' ser']" , the Russian writer , was born in Moscow in 18 21 . He was the son of a ser False Friedrich Hölderlin, and Fyodor Dostoyevsky among his most important 14 [' Friedrich', ' H', 'ö', 'lder', 'lin', ',', ' and', ' F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+449 98 Name of mother of x -1 Name of mother of Fyodor Dostoyevsky Maria Fiodorovna Dostoïevskaïa Fyodor Dostoyevsky "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' Moscow' ' in'
+ ' 18' '21' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' ser']" , the Russian writer , was born in Moscow in 18 21 . He was the son of a ser False " the white bear"". Fyodor Dostoyevsky mentioned the same" 11 "[' the', ' white', ' bear', '"".', ' F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']"
+450 98 Name of mother of x -1 Name of mother of Fyodor Dostoyevsky Maria Fiodorovna Dostoïevskaïa Fyodor Dostoyevsky "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' Moscow' ' in'
+ ' 18' '21' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' ser']" , the Russian writer , was born in Moscow in 18 21 . He was the son of a ser False S. Eliot, and Fyodor Dostoyevsky (and other 12 [' S', '.', ' Eliot', ',', ' and', ' F', 'y', 'odor', ' D', 'ost', 'oy', 'ev', 'sky']
+451 99 Name of mother of x -1 Name of mother of Liza Minnelli Judy Garland Liza Minnelli "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' L' 'iza' ' Min' 'nell' 'i' '.' '\n' '\n' 'The']" ", the actress , and the mother of the actress , L iza Min nell i .
+
+ The" False Gaga sings in a Liza Minnelli inspired voice 8 [' Gaga', ' sings', ' in', ' a', ' L', 'iza', ' Min', 'nell', 'i']
+452 99 Name of mother of x -1 Name of mother of Liza Minnelli Judy Garland Liza Minnelli "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' L' 'iza' ' Min' 'nell' 'i' '.' '\n' '\n' 'The']" ", the actress , and the mother of the actress , L iza Min nell i .
+
+ The" False Artists as varied as Liza Minnelli and The Lemonheads 8 [' Artists', ' as', ' varied', ' as', ' L', 'iza', ' Min', 'nell', 'i']
+453 99 Name of mother of x -1 Name of mother of Liza Minnelli Judy Garland Liza Minnelli "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' L' 'iza' ' Min' 'nell' 'i' '.' '\n' '\n' 'The']" ", the actress , and the mother of the actress , L iza Min nell i .
+
+ The" False campy by addition of Liza Minnelli inspired vocals 9 [' camp', 'y', ' by', ' addition', ' of', ' L', 'iza', ' Min', 'nell', 'i']
+454 99 Name of mother of x -1 Name of mother of Liza Minnelli Judy Garland Liza Minnelli "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' L' 'iza' ' Min' 'nell' 'i' '.' '\n' '\n' 'The']" ", the actress , and the mother of the actress , L iza Min nell i .
+
+ The" False In 1991 – 92, Liza Minnelli performed the 9 [' In', ' 1991', ' –', ' 92', ',', ' L', 'iza', ' Min', 'nell', 'i']
+455 99 Name of mother of x -1 Name of mother of Liza Minnelli Judy Garland Liza Minnelli "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' L' 'iza' ' Min' 'nell' 'i' '.' '\n' '\n' 'The']" ", the actress , and the mother of the actress , L iza Min nell i .
+
+ The" False 1991 – 92, Liza Minnelli performed the song 8 [' 1991', ' –', ' 92', ',', ' L', 'iza', ' Min', 'nell', 'i']
+456 100 Name of mother of x -1 Name of mother of John F. Kennedy Rose Kennedy John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False campaigned for John F. Kennedy in California, took 5 [' campaigned', ' for', ' John', ' F', '.', ' Kennedy']
+457 100 Name of mother of x -1 Name of mother of John F. Kennedy Rose Kennedy John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False once represented by John F. Kennedy and Tip O 'Neill. 6 [' once', ' represented', ' by', ' John', ' F', '.', ' Kennedy']
+458 100 Name of mother of x -1 Name of mother of John F. Kennedy Rose Kennedy John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False manuscripts to the John F. Kennedy Library. The manuscript 6 [' manuscripts', ' to', ' the', ' John', ' F', '.', ' Kennedy']
+459 100 Name of mother of x -1 Name of mother of John F. Kennedy Rose Kennedy John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False sent by President John F. Kennedy to set up a quarantine 6 [' sent', ' by', ' President', ' John', ' F', '.', ' Kennedy']
+460 100 Name of mother of x -1 Name of mother of John F. Kennedy Rose Kennedy John F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False by Presidents John F. Kennedy and Lyndon B. Johnson. 5 [' by', ' Presidents', ' John', ' F', '.', ' Kennedy']
+461 101 Name of mother of x -1 Name of mother of Julius Caesar Aurelia Julius Caesar "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Caesar' ','
+ ' and' ' the' ' mother' ' of' ' the' ' mother' ' of' ' the' ' mother']" ", and the
+
+ Name of mother of Caesar , and the mother of the mother of the mother" False settlement Caletum. Julius Caesar mustered 800 to 1,000 6 [' settlement', ' Cal', 'et', 'um', '.', ' Julius', ' Caesar']
+462 101 Name of mother of x -1 Name of mother of Julius Caesar Aurelia Julius Caesar "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Caesar' ','
+ ' and' ' the' ' mother' ' of' ' the' ' mother' ' of' ' the' ' mother']" ", and the
+
+ Name of mother of Caesar , and the mother of the mother of the mother" False 2 ['Jul', 'ius', ' Caesar']
+463 101 Name of mother of x -1 Name of mother of Julius Caesar Aurelia Julius Caesar "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Caesar' ','
+ ' and' ' the' ' mother' ' of' ' the' ' mother' ' of' ' the' ' mother']" ", and the
+
+ Name of mother of Caesar , and the mother of the mother of the mother" False Wars, when Julius Caesar sent his speculatoria 4 [' Wars', ',', ' when', ' Julius', ' Caesar']
+464 101 Name of mother of x -1 Name of mother of Julius Caesar Aurelia Julius Caesar "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Caesar' ','
+ ' and' ' the' ' mother' ' of' ' the' ' mother' ' of' ' the' ' mother']" ", and the
+
+ Name of mother of Caesar , and the mother of the mother of the mother" False 2 ['Jul', 'ius', ' Caesar']
+465 101 Name of mother of x -1 Name of mother of Julius Caesar Aurelia Julius Caesar "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Caesar' ','
+ ' and' ' the' ' mother' ' of' ' the' ' mother' ' of' ' the' ' mother']" ", and the
+
+ Name of mother of Caesar , and the mother of the mother of the mother" False " Tully by Milton, and Julius Caesar by Cromwell.
+" 7 [' T', 'ully', ' by', ' Milton', ',', ' and', ' Julius', ' Caesar']
+466 102 Name of mother of x -1 Name of mother of Oscar Wilde Jane Wilde Oscar Wilde "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Oscar'
+ ' Wilde' ',' ' the' ' mother' ' of' ' Oscar' ' Wilde' ',' ' the']" ", the
+
+ The name of the mother of Oscar Wilde , the mother of Oscar Wilde , the" False were the sons of Oscar Wilde and Evelyn 5 [' were', ' the', ' sons', ' of', ' Oscar', ' Wilde']
+467 102 Name of mother of x -1 Name of mother of Oscar Wilde Jane Wilde Oscar Wilde "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Oscar'
+ ' Wilde' ',' ' the' ' mother' ' of' ' Oscar' ' Wilde' ',' ' the']" ", the
+
+ The name of the mother of Oscar Wilde , the mother of Oscar Wilde , the" False Underwood's quotation of Oscar Wilde to Zoe Barnes 6 "[' Under', 'wood', ""'s"", ' quotation', ' of', ' Oscar', ' Wilde']"
+468 102 Name of mother of x -1 Name of mother of Oscar Wilde Jane Wilde Oscar Wilde "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Oscar'
+ ' Wilde' ',' ' the' ' mother' ' of' ' Oscar' ' Wilde' ',' ' the']" ", the
+
+ The name of the mother of Oscar Wilde , the mother of Oscar Wilde , the" False 2 ['O', 'scar', ' Wilde']
+469 102 Name of mother of x -1 Name of mother of Oscar Wilde Jane Wilde Oscar Wilde "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Oscar'
+ ' Wilde' ',' ' the' ' mother' ' of' ' Oscar' ' Wilde' ',' ' the']" ", the
+
+ The name of the mother of Oscar Wilde , the mother of Oscar Wilde , the" False lecturing Oscar Wilde and dining with 3 [' lect', 'uring', ' Oscar', ' Wilde']
+470 102 Name of mother of x -1 Name of mother of Oscar Wilde Jane Wilde Oscar Wilde "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Oscar'
+ ' Wilde' ',' ' the' ' mother' ' of' ' Oscar' ' Wilde' ',' ' the']" ", the
+
+ The name of the mother of Oscar Wilde , the mother of Oscar Wilde , the" False Walt Whitman, Oscar Wilde and Robert Burns 4 [' Walt', ' Whitman', ',', ' Oscar', ' Wilde']
+471 103 Name of mother of x -1 Name of mother of Auguste Rodin Marie Cheffer Auguste Rodin "[',' ' the' ' sculpt' 'or' ',' ' and' ' his' ' wife' ',' ' the' ' sculpt'
+ 'or' ""'s"" ' wife' ',' ' and' ' the' ' sculpt' 'or' ""'s""]" , the sculpt or , and his wife , the sculpt or 's wife , and the sculpt or 's False " Calais""), a sculpture by Auguste Rodin to commemorate six" 8 "[' Calais', '""),', ' a', ' sculpture', ' by', ' August', 'e', ' Rod', 'in']"
+472 103 Name of mother of x -1 Name of mother of Auguste Rodin Marie Cheffer Auguste Rodin "[',' ' the' ' sculpt' 'or' ',' ' and' ' his' ' wife' ',' ' the' ' sculpt'
+ 'or' ""'s"" ' wife' ',' ' and' ' the' ' sculpt' 'or' ""'s""]" , the sculpt or , and his wife , the sculpt or 's wife , and the sculpt or 's False sculptures by Auguste Rodin from his grouping 5 [' sculptures', ' by', ' August', 'e', ' Rod', 'in']
+473 103 Name of mother of x -1 Name of mother of Auguste Rodin Marie Cheffer Auguste Rodin "[',' ' the' ' sculpt' 'or' ',' ' and' ' his' ' wife' ',' ' the' ' sculpt'
+ 'or' ""'s"" ' wife' ',' ' and' ' the' ' sculpt' 'or' ""'s""]" , the sculpt or , and his wife , the sculpt or 's wife , and the sculpt or 's False Merritt Chase, Auguste Rodin and Georgia O'Keeffe. 8 [' Mer', 'r', 'itt', ' Chase', ',', ' August', 'e', ' Rod', 'in']
+474 103 Name of mother of x -1 Name of mother of Auguste Rodin Marie Cheffer Auguste Rodin "[',' ' the' ' sculpt' 'or' ',' ' and' ' his' ' wife' ',' ' the' ' sculpt'
+ 'or' ""'s"" ' wife' ',' ' and' ' the' ' sculpt' 'or' ""'s""]" , the sculpt or , and his wife , the sculpt or 's wife , and the sculpt or 's False had studied under Auguste Rodin and whose circle 6 [' had', ' studied', ' under', ' August', 'e', ' Rod', 'in']
+475 103 Name of mother of x -1 Name of mother of Auguste Rodin Marie Cheffer Auguste Rodin "[',' ' the' ' sculpt' 'or' ',' ' and' ' his' ' wife' ',' ' the' ' sculpt'
+ 'or' ""'s"" ' wife' ',' ' and' ' the' ' sculpt' 'or' ""'s""]" , the sculpt or , and his wife , the sculpt or 's wife , and the sculpt or 's False " Auguste Rodin =
+" 3 [' August', 'e', ' Rod', 'in']
+476 104 Name of mother of x -1 Name of mother of René Descartes Jeanne Brochard René Descartes "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' was' ' born'
+ ' in' ' 15' '96' ' in' ' France' '.' ' He' ' was' ' a' ' mathematician'
+ ',']" , the father of modern philosophy , was born in 15 96 in France . He was a mathematician , False theories. In 1637, René Descartes established a new framework 10 [' theories', '.', ' In', ' 16', '37', ',', ' Ren', 'é', ' Des', 'cart', 'es']
+477 104 Name of mother of x -1 Name of mother of René Descartes Jeanne Brochard René Descartes "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' was' ' born'
+ ' in' ' 15' '96' ' in' ' France' '.' ' He' ' was' ' a' ' mathematician'
+ ',']" , the father of modern philosophy , was born in 15 96 in France . He was a mathematician , False formulated by René Descartes in the 17th 6 [' formulated', ' by', ' Ren', 'é', ' Des', 'cart', 'es']
+478 104 Name of mother of x -1 Name of mother of René Descartes Jeanne Brochard René Descartes "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' was' ' born'
+ ' in' ' 15' '96' ' in' ' France' '.' ' He' ' was' ' a' ' mathematician'
+ ',']" , the father of modern philosophy , was born in 15 96 in France . He was a mathematician , False philosophy of René Descartes and the empiricism 6 [' philosophy', ' of', ' Ren', 'é', ' Des', 'cart', 'es']
+479 104 Name of mother of x -1 Name of mother of René Descartes Jeanne Brochard René Descartes "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' was' ' born'
+ ' in' ' 15' '96' ' in' ' France' '.' ' He' ' was' ' a' ' mathematician'
+ ',']" , the father of modern philosophy , was born in 15 96 in France . He was a mathematician , False Thomas Hobbes and René Descartes explored the possibility 8 [' Thomas', ' Hob', 'bes', ' and', ' Ren', 'é', ' Des', 'cart', 'es']
+480 104 Name of mother of x -1 Name of mother of René Descartes Jeanne Brochard René Descartes "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' was' ' born'
+ ' in' ' 15' '96' ' in' ' France' '.' ' He' ' was' ' a' ' mathematician'
+ ',']" , the father of modern philosophy , was born in 15 96 in France . He was a mathematician , False particular attention. René Descartes gave a formula 7 [' particular', ' attention', '.', ' Ren', 'é', ' Des', 'cart', 'es']
+481 105 Name of mother of x -1 Name of mother of Benito Mussolini Rosa Maltoni Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' the' ' mother' ' of' ' the' ' Italian' ' dictator' ',' ' Ben'
+ 'ito']" ", the Italian dictator , and the
+
+ Name of the mother of the Italian dictator , Ben ito" False instead encouraged Benito Mussolini to send in large 5 [' instead', ' encouraged', ' Ben', 'ito', ' Muss', 'olini']
+482 105 Name of mother of x -1 Name of mother of Benito Mussolini Rosa Maltoni Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' the' ' mother' ' of' ' the' ' Italian' ' dictator' ',' ' Ben'
+ 'ito']" ", the Italian dictator , and the
+
+ Name of the mother of the Italian dictator , Ben ito" False Italy, led by Benito Mussolini joined the pact 7 [' Italy', ',', ' led', ' by', ' Ben', 'ito', ' Muss', 'olini']
+483 105 Name of mother of x -1 Name of mother of Benito Mussolini Rosa Maltoni Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' the' ' mother' ' of' ' the' ' Italian' ' dictator' ',' ' Ben'
+ 'ito']" ", the Italian dictator , and the
+
+ Name of the mother of the Italian dictator , Ben ito" False Hitler disclosed to Benito Mussolini that he intended 6 [' Hitler', ' disclosed', ' to', ' Ben', 'ito', ' Muss', 'olini']
+484 105 Name of mother of x -1 Name of mother of Benito Mussolini Rosa Maltoni Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' the' ' mother' ' of' ' the' ' Italian' ' dictator' ',' ' Ben'
+ 'ito']" ", the Italian dictator , and the
+
+ Name of the mother of the Italian dictator , Ben ito" False 1935, Italian dictator Benito Mussolini sent 400,000 troops 7 [' 1935', ',', ' Italian', ' dictator', ' Ben', 'ito', ' Muss', 'olini']
+485 105 Name of mother of x -1 Name of mother of Benito Mussolini Rosa Maltoni Benito Mussolini "[',' ' the' ' Italian' ' dictator' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' the' ' mother' ' of' ' the' ' Italian' ' dictator' ',' ' Ben'
+ 'ito']" ", the Italian dictator , and the
+
+ Name of the mother of the Italian dictator , Ben ito" False Hitler disclosed to Benito Mussolini that he intended 6 [' Hitler', ' disclosed', ' to', ' Ben', 'ito', ' Muss', 'olini']
+486 106 Name of mother of x -1 Name of mother of Nicole Kidman Janelle Kidman Nicole Kidman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False the time he met Nicole Kidman and asked her if 6 [' the', ' time', ' he', ' met', ' Nicole', ' Kid', 'man']
+487 106 Name of mother of x -1 Name of mother of Nicole Kidman Janelle Kidman Nicole Kidman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Holmes a robot, Nicole Kidman a beer-swilling bogan), 6 [' Holmes', ' a', ' robot', ',', ' Nicole', ' Kid', 'man']
+488 106 Name of mother of x -1 Name of mother of Nicole Kidman Janelle Kidman Nicole Kidman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Tom Cruise and Nicole Kidman as a Manhattan couple 5 [' Tom', ' Cruise', ' and', ' Nicole', ' Kid', 'man']
+489 106 Name of mother of x -1 Name of mother of Nicole Kidman Janelle Kidman Nicole Kidman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False shovel. He meets Nicole Kidman (Dolly Wells) 6 [' shovel', '.', ' He', ' meets', ' Nicole', ' Kid', 'man']
+490 106 Name of mother of x -1 Name of mother of Nicole Kidman Janelle Kidman Nicole Kidman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False cappella version. Nicole Kidman and Hugh Jackman 7 [' ca', 'pp', 'ella', ' version', '.', ' Nicole', ' Kid', 'man']
+491 107 Name of mother of x -1 Name of mother of Richard Wagner Johanna Rosina Wagner-Geyer Richard Wagner "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Richard' ' Wagner' ',' ' the' ' composer' ' of' ' the' '\n' '\n' '""']" ", the
+
+ The name of the mother of Richard Wagner , the composer of the
+
+ """ False influences, Richard Wagner and Richard 3 [' influences', ',', ' Richard', ' Wagner']
+492 107 Name of mother of x -1 Name of mother of Richard Wagner Johanna Rosina Wagner-Geyer Richard Wagner "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Richard' ' Wagner' ',' ' the' ' composer' ' of' ' the' '\n' '\n' '""']" ", the
+
+ The name of the mother of Richard Wagner , the composer of the
+
+ """ False scholarship from the Richard Wagner Society in Bayreuth 4 [' scholarship', ' from', ' the', ' Richard', ' Wagner']
+493 107 Name of mother of x -1 Name of mother of Richard Wagner Johanna Rosina Wagner-Geyer Richard Wagner "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Richard' ' Wagner' ',' ' the' ' composer' ' of' ' the' '\n' '\n' '""']" ", the
+
+ The name of the mother of Richard Wagner , the composer of the
+
+ """ False bribing musical critics. Richard Wagner (see below) accused 6 [' b', 'ribing', ' musical', ' critics', '.', ' Richard', ' Wagner']
+494 107 Name of mother of x -1 Name of mother of Richard Wagner Johanna Rosina Wagner-Geyer Richard Wagner "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Richard' ' Wagner' ',' ' the' ' composer' ' of' ' the' '\n' '\n' '""']" ", the
+
+ The name of the mother of Richard Wagner , the composer of the
+
+ """ False 1 ['Richard', ' Wagner']
+495 107 Name of mother of x -1 Name of mother of Richard Wagner Johanna Rosina Wagner-Geyer Richard Wagner "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Richard' ' Wagner' ',' ' the' ' composer' ' of' ' the' '\n' '\n' '""']" ", the
+
+ The name of the mother of Richard Wagner , the composer of the
+
+ """ False (predeceased) contemporary Richard Wagner had begun their 8 [' (', 'pred', 'e', 'ce', 'ased', ')', ' contemporary', ' Richard', ' Wagner']
+496 108 Name of mother of x -1 Name of mother of Amy Winehouse Janis Seaton Amy Winehouse "[',' ' who' ' was' ' born' ' in' ' London' ',' ' England' ',' ' and'
+ ' raised' ' in' ' the' ' United' ' States' '.' ' She' ' is' ' the'
+ ' author']" , who was born in London , England , and raised in the United States . She is the author False artists such as Amy Winehouse and Adele, independent 5 [' artists', ' such', ' as', ' Amy', ' Wine', 'house']
+497 108 Name of mother of x -1 Name of mother of Amy Winehouse Janis Seaton Amy Winehouse "[',' ' who' ' was' ' born' ' in' ' London' ',' ' England' ',' ' and'
+ ' raised' ' in' ' the' ' United' ' States' '.' ' She' ' is' ' the'
+ ' author']" , who was born in London , England , and raised in the United States . She is the author False " ""features vocals by Amy Winehouse and a narrative between" 6 "[' ""', 'features', ' vocals', ' by', ' Amy', ' Wine', 'house']"
+498 108 Name of mother of x -1 Name of mother of Amy Winehouse Janis Seaton Amy Winehouse "[',' ' who' ' was' ' born' ' in' ' London' ',' ' England' ',' ' and'
+ ' raised' ' in' ' the' ' United' ' States' '.' ' She' ' is' ' the'
+ ' author']" , who was born in London , England , and raised in the United States . She is the author False " Foundation ===
+" 4 [' Foundation', ' ===', 'Amy', ' Wine', 'house']
+499 108 Name of mother of x -1 Name of mother of Amy Winehouse Janis Seaton Amy Winehouse "[',' ' who' ' was' ' born' ' in' ' London' ',' ' England' ',' ' and'
+ ' raised' ' in' ' the' ' United' ' States' '.' ' She' ' is' ' the'
+ ' author']" , who was born in London , England , and raised in the United States . She is the author False 2 ['Amy', ' Wine', 'house']
+500 108 Name of mother of x -1 Name of mother of Amy Winehouse Janis Seaton Amy Winehouse "[',' ' who' ' was' ' born' ' in' ' London' ',' ' England' ',' ' and'
+ ' raised' ' in' ' the' ' United' ' States' '.' ' She' ' is' ' the'
+ ' author']" , who was born in London , England , and raised in the United States . She is the author False " the song to Amy Winehouse, writing ""Her vocals" 5 [' the', ' song', ' to', ' Amy', ' Wine', 'house']
+501 109 Name of mother of x -1 Name of mother of Václav Havel Božena Vavrečková Václav Havel "[',' ' the' ' Czech' ' president' ',' ' and' ' his' ' wife' ',' ' Ol' 'ga'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Czech president , and his wife , Ol ga , who was a former actress .
+
+" False first annual Václav Havel Prize for Creative 7 [' first', ' annual', ' V', 'á', 'cl', 'av', ' Ha', 'vel']
+502 109 Name of mother of x -1 Name of mother of Václav Havel Božena Vavrečková Václav Havel "[',' ' the' ' Czech' ' president' ',' ' and' ' his' ' wife' ',' ' Ol' 'ga'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Czech president , and his wife , Ol ga , who was a former actress .
+
+" False he receives the Václav Havel Prize for Creative 8 [' he', ' receives', ' the', ' V', 'á', 'cl', 'av', ' Ha', 'vel']
+503 109 Name of mother of x -1 Name of mother of Václav Havel Božena Vavrečková Václav Havel "[',' ' the' ' Czech' ' president' ',' ' and' ' his' ' wife' ',' ' Ol' 'ga'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Czech president , and his wife , Ol ga , who was a former actress .
+
+" False January 2010, Václav Havel and others — including 8 [' January', ' 2010', ',', ' V', 'á', 'cl', 'av', ' Ha', 'vel']
+504 109 Name of mother of x -1 Name of mother of Václav Havel Božena Vavrečková Václav Havel "[',' ' the' ' Czech' ' president' ',' ' and' ' his' ' wife' ',' ' Ol' 'ga'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Czech president , and his wife , Ol ga , who was a former actress .
+
+" False country's president Václav Havel and jokingly 8 "[' country', ""'s"", ' president', ' V', 'á', 'cl', 'av', ' Ha', 'vel']"
+505 109 Name of mother of x -1 Name of mother of Václav Havel Božena Vavrečková Václav Havel "[',' ' the' ' Czech' ' president' ',' ' and' ' his' ' wife' ',' ' Ol' 'ga'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Czech president , and his wife , Ol ga , who was a former actress .
+
+" False country's president Václav Havel and jokingly told 8 "[' country', ""'s"", ' president', ' V', 'á', 'cl', 'av', ' Ha', 'vel']"
+506 110 Name of mother of x -1 Name of mother of Elizabeth Taylor Sara Sothern Elizabeth Taylor "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Lohan starred as Elizabeth Taylor in the biographical 5 [' L', 'ohan', ' starred', ' as', ' Elizabeth', ' Taylor']
+507 110 Name of mother of x -1 Name of mother of Elizabeth Taylor Sara Sothern Elizabeth Taylor "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " rehabilitation and Elizabeth Taylor ===
+" 3 [' rehabilitation', ' and', ' Elizabeth', ' Taylor']
+508 110 Name of mother of x -1 Name of mother of Elizabeth Taylor Sara Sothern Elizabeth Taylor "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Audrey Hepburn, Elizabeth Taylor and Robin Williams. 5 [' Audrey', ' Hep', 'burn', ',', ' Elizabeth', ' Taylor']
+509 110 Name of mother of x -1 Name of mother of Elizabeth Taylor Sara Sothern Elizabeth Taylor "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " the help of Elizabeth Taylor and Elton John.
+" 4 [' the', ' help', ' of', ' Elizabeth', ' Taylor']
+510 110 Name of mother of x -1 Name of mother of Elizabeth Taylor Sara Sothern Elizabeth Taylor "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False associated with Elizabeth Taylor and Richard Burton 3 [' associated', ' with', ' Elizabeth', ' Taylor']
+511 111 Name of mother of x -1 Name of mother of Karl Marx Henriette Presburg Karl Marx "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' sociology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of modern economics , and the father of modern sociology .
+
+ The first thing to" False " Suslov deferred to Karl Marx and Vladimir Lenin:
+" 5 [' Sus', 'lov', ' deferred', ' to', ' Karl', ' Marx']
+512 111 Name of mother of x -1 Name of mother of Karl Marx Henriette Presburg Karl Marx "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' sociology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of modern economics , and the father of modern sociology .
+
+ The first thing to" False the writings of Karl Marx and the workings of 4 [' the', ' writings', ' of', ' Karl', ' Marx']
+513 111 Name of mother of x -1 Name of mother of Karl Marx Henriette Presburg Karl Marx "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' sociology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of modern economics , and the father of modern sociology .
+
+ The first thing to" False Along with Karl Marx and Émile Durkheim, 3 [' Along', ' with', ' Karl', ' Marx']
+514 111 Name of mother of x -1 Name of mother of Karl Marx Henriette Presburg Karl Marx "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' sociology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of modern economics , and the father of modern sociology .
+
+ The first thing to" False religion. Along with Karl Marx and Émile Durkheim, 5 [' religion', '.', ' Along', ' with', ' Karl', ' Marx']
+515 111 Name of mother of x -1 Name of mother of Karl Marx Henriette Presburg Karl Marx "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' sociology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' to']" ", the father of modern economics , and the father of modern sociology .
+
+ The first thing to" False " = Karl Marx =
+" 2 [' =', ' Karl', ' Marx']
+516 112 Name of mother of x -1 Name of mother of Cicero Helvia Cicero "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.' ' Cic'
+ 'ero' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Cic ero , and the
+
+ Name of" False greatest men of history. Cicero eulogized him 6 [' greatest', ' men', ' of', ' history', '.', ' Cic', 'ero']
+517 112 Name of mother of x -1 Name of mother of Cicero Helvia Cicero "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.' ' Cic'
+ 'ero' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Cic ero , and the
+
+ Name of" False " of oratory""), and Cicero said about him that" 6 "[' of', ' or', 'atory', '""),', ' and', ' Cic', 'ero']"
+518 112 Name of mother of x -1 Name of mother of Cicero Helvia Cicero "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.' ' Cic'
+ 'ero' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Cic ero , and the
+
+ Name of" False word was used by Cicero and later Latin authors 5 [' word', ' was', ' used', ' by', ' Cic', 'ero']
+519 112 Name of mother of x -1 Name of mother of Cicero Helvia Cicero "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.' ' Cic'
+ 'ero' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Cic ero , and the
+
+ Name of" False Marcus Tullius Cicero, intercepted messages 5 [' Marcus', ' T', 'ull', 'ius', ' Cic', 'ero']
+520 112 Name of mother of x -1 Name of mother of Cicero Helvia Cicero "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.' ' Cic'
+ 'ero' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Cic ero , and the
+
+ Name of" False magistrate. While the consul Cicero and the contemporary 7 [' magistrate', '.', ' While', ' the', ' cons', 'ul', ' Cic', 'ero']
+521 113 Name of mother of x -1 Name of mother of Scarlett Johansson Melanie Sloan Scarlett Johansson "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False other — opposite Scarlett Johansson in Michael Bay's science 5 [' other', ' —', ' opposite', ' Scarlett', ' Joh', 'ansson']
+522 113 Name of mother of x -1 Name of mother of Scarlett Johansson Melanie Sloan Scarlett Johansson "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False unknown for Superman, Scarlett Johansson as Lois Lane and 6 [' unknown', ' for', ' Superman', ',', ' Scarlett', ' Joh', 'ansson']
+523 113 Name of mother of x -1 Name of mother of Scarlett Johansson Melanie Sloan Scarlett Johansson "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False 3 ['Scar', 'lett', ' Joh', 'ansson']
+524 113 Name of mother of x -1 Name of mother of Scarlett Johansson Melanie Sloan Scarlett Johansson "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False 2007. Actress Scarlett Johansson plays Timberlake's 5 [' 2007', '.', ' Actress', ' Scarlett', ' Joh', 'ansson']
+525 113 Name of mother of x -1 Name of mother of Scarlett Johansson Melanie Sloan Scarlett Johansson "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False 2007. Actress Scarlett Johansson plays Timberlake's 5 [' 2007', '.', ' Actress', ' Scarlett', ' Joh', 'ansson']
+526 114 Name of mother of x -1 Name of mother of Thomas Mann Júlia da Silva Bruhns Thomas Mann "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " and for books by Thomas Mann and E. T. A. Hoffmann.
+" 5 [' and', ' for', ' books', ' by', ' Thomas', ' Mann']
+527 114 Name of mother of x -1 Name of mother of Thomas Mann Júlia da Silva Bruhns Thomas Mann "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " ever lived"", while Thomas Mann and Marcel Proust" 5 "[' ever', ' lived', '"",', ' while', ' Thomas', ' Mann']"
+528 114 Name of mother of x -1 Name of mother of Thomas Mann Júlia da Silva Bruhns Thomas Mann "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False interventions of Thomas Mann and others. Alfred, 3 [' interventions', ' of', ' Thomas', ' Mann']
+529 114 Name of mother of x -1 Name of mother of Thomas Mann Júlia da Silva Bruhns Thomas Mann "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " as Thomas Kub
+" 4 [' as', ' Thomas', ' Kub', 'Thomas', ' Mann']
+530 114 Name of mother of x -1 Name of mother of Thomas Mann Júlia da Silva Bruhns Thomas Mann "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False for books by Thomas Mann and E. T. 4 [' for', ' books', ' by', ' Thomas', ' Mann']
+531 115 Name of mother of x -1 Name of mother of Chulalongkorn Debsirindra Chulalongkorn "[' University' ',' ' Bangkok' ',' ' Thailand' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' boys' ',' ' a' ' wife' ',' ' a' ' daughter']" " University , Bangkok , Thailand
+
+ I am a mother of two boys , a wife , a daughter" False military scholar at Chulalongkorn University and a personal 7 [' military', ' scholar', ' at', ' Ch', 'ul', 'along', 'k', 'orn']
+532 115 Name of mother of x -1 Name of mother of Chulalongkorn Debsirindra Chulalongkorn "[' University' ',' ' Bangkok' ',' ' Thailand' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' boys' ',' ' a' ' wife' ',' ' a' ' daughter']" " University , Bangkok , Thailand
+
+ I am a mother of two boys , a wife , a daughter" False Pawakapan of Chulalongkorn University. The 8 [' Paw', 'ak', 'apan', ' of', ' Ch', 'ul', 'along', 'k', 'orn']
+533 115 Name of mother of x -1 Name of mother of Chulalongkorn Debsirindra Chulalongkorn "[' University' ',' ' Bangkok' ',' ' Thailand' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' boys' ',' ' a' ' wife' ',' ' a' ' daughter']" " University , Bangkok , Thailand
+
+ I am a mother of two boys , a wife , a daughter" False degrees from Chulalongkorn, Indiana, Silpakorn, 6 [' degrees', ' from', ' Ch', 'ul', 'along', 'k', 'orn']
+534 115 Name of mother of x -1 Name of mother of Chulalongkorn Debsirindra Chulalongkorn "[' University' ',' ' Bangkok' ',' ' Thailand' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' boys' ',' ' a' ' wife' ',' ' a' ' daughter']" " University , Bangkok , Thailand
+
+ I am a mother of two boys , a wife , a daughter" False President of Chulalongkorn University, assigned 6 [' President', ' of', ' Ch', 'ul', 'along', 'k', 'orn']
+535 115 Name of mother of x -1 Name of mother of Chulalongkorn Debsirindra Chulalongkorn "[' University' ',' ' Bangkok' ',' ' Thailand' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' boys' ',' ' a' ' wife' ',' ' a' ' daughter']" " University , Bangkok , Thailand
+
+ I am a mother of two boys , a wife , a daughter" False included Adiarte as Chulalongkorn and Benson 9 [' included', ' Ad', 'i', 'arte', ' as', ' Ch', 'ul', 'along', 'k', 'orn']
+536 116 Name of mother of x -1 Name of mother of J. R. R. Tolkien Mabel Suffield J. R. R. Tolkien "[',' ' the' ' author' ' of' ' the' ' Lord' ' of' ' the' ' Rings' ','
+ ' and' ' the' ' Hobbit' '.' '\n' '\n' 'I' ' am' ' a' ' writer']" ", the author of the Lord of the Rings , and the Hobbit .
+
+ I am a writer" False largely general. As J. R. R. Tolkien and E. V. Gordon, 10 [' largely', ' general', '.', ' As', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+537 116 Name of mother of x -1 Name of mother of J. R. R. Tolkien Mabel Suffield J. R. R. Tolkien "[',' ' the' ' author' ' of' ' the' ' Lord' ' of' ' the' ' Rings' ','
+ ' and' ' the' ' Hobbit' '.' '\n' '\n' 'I' ' am' ' a' ' writer']" ", the author of the Lord of the Rings , and the Hobbit .
+
+ I am a writer" False 6 ['J', '.', ' R', '.', ' R', '.', ' Tolkien']
+538 116 Name of mother of x -1 Name of mother of J. R. R. Tolkien Mabel Suffield J. R. R. Tolkien "[',' ' the' ' author' ' of' ' the' ' Lord' ' of' ' the' ' Rings' ','
+ ' and' ' the' ' Hobbit' '.' '\n' '\n' 'I' ' am' ' a' ' writer']" ", the author of the Lord of the Rings , and the Hobbit .
+
+ I am a writer" False the works of J. R. R. Tolkien in his childhood, 9 [' the', ' works', ' of', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+539 116 Name of mother of x -1 Name of mother of J. R. R. Tolkien Mabel Suffield J. R. R. Tolkien "[',' ' the' ' author' ' of' ' the' ' Lord' ' of' ' the' ' Rings' ','
+ ' and' ' the' ' Hobbit' '.' '\n' '\n' 'I' ' am' ' a' ' writer']" ", the author of the Lord of the Rings , and the Hobbit .
+
+ I am a writer" False poetry and J. R. R. Tolkien are frequent sources 8 [' poetry', ' and', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+540 116 Name of mother of x -1 Name of mother of J. R. R. Tolkien Mabel Suffield J. R. R. Tolkien "[',' ' the' ' author' ' of' ' the' ' Lord' ' of' ' the' ' Rings' ','
+ ' and' ' the' ' Hobbit' '.' '\n' '\n' 'I' ' am' ' a' ' writer']" ", the author of the Lord of the Rings , and the Hobbit .
+
+ I am a writer" False compares him to J. R. R. Tolkien ’ s Sauron in The 9 [' compares', ' him', ' to', ' J', '.', ' R', '.', ' R', '.', ' Tolkien']
+541 117 Name of mother of x -1 Name of mother of Zeus Rhea Zeus "[',' ' the' ' god' ' of' ' the' ' sky' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' earth']" , the god of the sky , and the god of the sea , and the god of the earth False Marbury Hall Zeus is a 81 in 3 [' Mar', 'bury', ' Hall', ' Zeus']
+542 117 Name of mother of x -1 Name of mother of Zeus Rhea Zeus "[',' ' the' ' god' ' of' ' the' ' sky' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' earth']" , the god of the sky , and the god of the sea , and the god of the earth False nursed the infant Zeus (the Greek equivalent 4 [' nurs', 'ed', ' the', ' infant', ' Zeus']
+543 117 Name of mother of x -1 Name of mother of Zeus Rhea Zeus "[',' ' the' ' god' ' of' ' the' ' sky' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' earth']" , the god of the sky , and the god of the sea , and the god of the earth False the Temple of Zeus at Olympia 3 [' the', ' Temple', ' of', ' Zeus']
+544 117 Name of mother of x -1 Name of mother of Zeus Rhea Zeus "[',' ' the' ' god' ' of' ' the' ' sky' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' earth']" , the god of the sky , and the god of the sea , and the god of the earth False Casey Alexander, Zeus Cervas, Mr. 3 [' Casey', ' Alexander', ',', ' Zeus']
+545 117 Name of mother of x -1 Name of mother of Zeus Rhea Zeus "[',' ' the' ' god' ' of' ' the' ' sky' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' earth']" , the god of the sky , and the god of the sea , and the god of the earth False destruction. Zeus, however, weakened 2 [' destruction', '.', ' Zeus']
+546 118 Name of mother of x -1 Name of mother of Frédéric Chopin Tekla Justyna Chopin Frédéric Chopin "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Warsaw' ',' ' Poland'
+ ',' ' in' ' 18' '10' '.' ' He' ' was' ' the' ' son' ' of']" , the composer , was born in Warsaw , Poland , in 18 10 . He was the son of False 6 ['Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']
+547 118 Name of mother of x -1 Name of mother of Frédéric Chopin Tekla Justyna Chopin Frédéric Chopin "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Warsaw' ',' ' Poland'
+ ',' ' in' ' 18' '10' '.' ' He' ' was' ' the' ' son' ' of']" , the composer , was born in Warsaw , Poland , in 18 10 . He was the son of False from bankruptcy. Frédéric Chopin and Auguste 9 [' from', ' bankruptcy', '.', ' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']
+548 118 Name of mother of x -1 Name of mother of Frédéric Chopin Tekla Justyna Chopin Frédéric Chopin "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Warsaw' ',' ' Poland'
+ ',' ' in' ' 18' '10' '.' ' He' ' was' ' the' ' son' ' of']" , the composer , was born in Warsaw , Poland , in 18 10 . He was the son of False " Frédéric Chopin =
+" 6 [' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']
+549 118 Name of mother of x -1 Name of mother of Frédéric Chopin Tekla Justyna Chopin Frédéric Chopin "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Warsaw' ',' ' Poland'
+ ',' ' in' ' 18' '10' '.' ' He' ' was' ' the' ' son' ' of']" , the composer , was born in Warsaw , Poland , in 18 10 . He was the son of False Day (1993) and (as Frédéric Chopin in) Impromptu (1991) 13 [' Day', ' (', '1993', ')', ' and', ' (', 'as', ' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']
+550 118 Name of mother of x -1 Name of mother of Frédéric Chopin Tekla Justyna Chopin Frédéric Chopin "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' Warsaw' ',' ' Poland'
+ ',' ' in' ' 18' '10' '.' ' He' ' was' ' the' ' son' ' of']" , the composer , was born in Warsaw , Poland , in 18 10 . He was the son of False " Frédéric Chopin =
+" 6 [' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in']
+551 119 Name of mother of x -1 Name of mother of Lord Byron Catherine Gordon Byron Lord Byron "[',' ' the' '\n' '\n' '""' 'I' ' am' ' not' ' a' ' poet' ',' ' but' ' I'
+ ' am' ' a' ' poet' ""'s"" ' daughter' '.""' '\n']" ", the
+
+ "" I am not a poet , but I am a poet 's daughter .""
+" False the initial idea by Lord Byron that they each write 5 [' the', ' initial', ' idea', ' by', ' Lord', ' Byron']
+552 119 Name of mother of x -1 Name of mother of Lord Byron Catherine Gordon Byron Lord Byron "[',' ' the' '\n' '\n' '""' 'I' ' am' ' not' ' a' ' poet' ',' ' but' ' I'
+ ' am' ' a' ' poet' ""'s"" ' daughter' '.""' '\n']" ", the
+
+ "" I am not a poet , but I am a poet 's daughter .""
+" False – 1834) and Lord Byron (1788 – 1824), 6 [' –', ' 18', '34', ')', ' and', ' Lord', ' Byron']
+553 119 Name of mother of x -1 Name of mother of Lord Byron Catherine Gordon Byron Lord Byron "[',' ' the' '\n' '\n' '""' 'I' ' am' ' not' ' a' ' poet' ',' ' but' ' I'
+ ' am' ' a' ' poet' ""'s"" ' daughter' '.""' '\n']" ", the
+
+ "" I am not a poet , but I am a poet 's daughter .""
+" False initial idea by Lord Byron that they each write 4 [' initial', ' idea', ' by', ' Lord', ' Byron']
+554 119 Name of mother of x -1 Name of mother of Lord Byron Catherine Gordon Byron Lord Byron "[',' ' the' '\n' '\n' '""' 'I' ' am' ' not' ' a' ' poet' ',' ' but' ' I'
+ ' am' ' a' ' poet' ""'s"" ' daughter' '.""' '\n']" ", the
+
+ "" I am not a poet , but I am a poet 's daughter .""
+" False Lords the youthful Lord Byron called the 4 [' Lords', ' the', ' youthful', ' Lord', ' Byron']
+555 119 Name of mother of x -1 Name of mother of Lord Byron Catherine Gordon Byron Lord Byron "[',' ' the' '\n' '\n' '""' 'I' ' am' ' not' ' a' ' poet' ',' ' but' ' I'
+ ' am' ' a' ' poet' ""'s"" ' daughter' '.""' '\n']" ", the
+
+ "" I am not a poet , but I am a poet 's daughter .""
+" False initial idea by Lord Byron that they each write 4 [' initial', ' idea', ' by', ' Lord', ' Byron']
+556 120 Name of mother of x -1 Name of mother of Jules Verne Sophie Allotte de La Fuye Jules Verne "[',' ' the' ' father' ' of' ' modern' ' science' ' fiction' ',' ' was'
+ ' born' ' in' ' 18' '28' ' in' ' N' 'antes' ',' ' France' '.' ' He']" , the father of modern science fiction , was born in 18 28 in N antes , France . He False storytelling to Jules Verne and Georges Méliès. 5 [' storytelling', ' to', ' J', 'ules', ' Ver', 'ne']
+557 120 Name of mother of x -1 Name of mother of Jules Verne Sophie Allotte de La Fuye Jules Verne "[',' ' the' ' father' ' of' ' modern' ' science' ' fiction' ',' ' was'
+ ' born' ' in' ' 18' '28' ' in' ' N' 'antes' ',' ' France' '.' ' He']" , the father of modern science fiction , was born in 18 28 in N antes , France . He False to an interview, Jules Verne relished reading 7 [' to', ' an', ' interview', ',', ' J', 'ules', ' Ver', 'ne']
+558 120 Name of mother of x -1 Name of mother of Jules Verne Sophie Allotte de La Fuye Jules Verne "[',' ' the' ' father' ' of' ' modern' ' science' ' fiction' ',' ' was'
+ ' born' ' in' ' 18' '28' ' in' ' N' 'antes' ',' ' France' '.' ' He']" , the father of modern science fiction , was born in 18 28 in N antes , France . He False adventure books by Jules Verne or books about Nick 6 [' adventure', ' books', ' by', ' J', 'ules', ' Ver', 'ne']
+559 120 Name of mother of x -1 Name of mother of Jules Verne Sophie Allotte de La Fuye Jules Verne "[',' ' the' ' father' ' of' ' modern' ' science' ' fiction' ',' ' was'
+ ' born' ' in' ' 18' '28' ' in' ' N' 'antes' ',' ' France' '.' ' He']" , the father of modern science fiction , was born in 18 28 in N antes , France . He False Through the works of Jules Verne he became interested 7 [' Through', ' the', ' works', ' of', ' J', 'ules', ' Ver', 'ne']
+560 120 Name of mother of x -1 Name of mother of Jules Verne Sophie Allotte de La Fuye Jules Verne "[',' ' the' ' father' ' of' ' modern' ' science' ' fiction' ',' ' was'
+ ' born' ' in' ' 18' '28' ' in' ' N' 'antes' ',' ' France' '.' ' He']" , the father of modern science fiction , was born in 18 28 in N antes , France . He False behind Agatha Christie, Jules Verne and William Shakespeare. 8 [' behind', ' Ag', 'atha', ' Christie', ',', ' J', 'ules', ' Ver', 'ne']
+561 121 Name of mother of x -1 Name of mother of Augustine of Hippo Monica of Hippo Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' a' ' disciple'
+ ' of' ' St' '.' ' Augustine' ',' ' and' ' who' ' was' ' a' ' great']" , the great theolog ian , who was a disciple of St . Augustine , and who was a great False crucifixion. Augustine of Hippo says that Pilate 6 [' crucifix', 'ion', '.', ' Augustine', ' of', ' Hipp', 'o']
+562 121 Name of mother of x -1 Name of mother of Augustine of Hippo Monica of Hippo Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' a' ' disciple'
+ ' of' ' St' '.' ' Augustine' ',' ' and' ' who' ' was' ' a' ' great']" , the great theolog ian , who was a disciple of St . Augustine , and who was a great False 4 ['August', 'ine', ' of', ' Hipp', 'o']
+563 121 Name of mother of x -1 Name of mother of Augustine of Hippo Monica of Hippo Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' a' ' disciple'
+ ' of' ' St' '.' ' Augustine' ',' ' and' ' who' ' was' ' a' ' great']" , the great theolog ian , who was a disciple of St . Augustine , and who was a great False Church Father Augustine of Hippo (354 – 430) 5 [' Church', ' Father', ' Augustine', ' of', ' Hipp', 'o']
+564 121 Name of mother of x -1 Name of mother of Augustine of Hippo Monica of Hippo Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' a' ' disciple'
+ ' of' ' St' '.' ' Augustine' ',' ' and' ' who' ' was' ' a' ' great']" , the great theolog ian , who was a disciple of St . Augustine , and who was a great False " on 1 John 4: 4-12, Augustine of Hippo gave a similar instruction:""" 12 [' on', ' 1', ' John', ' 4', ':', ' 4', '-', '12', ',', ' Augustine', ' of', ' Hipp', 'o']
+565 121 Name of mother of x -1 Name of mother of Augustine of Hippo Monica of Hippo Augustine of Hippo "[',' ' the' ' great' ' theolog' 'ian' ',' ' who' ' was' ' a' ' disciple'
+ ' of' ' St' '.' ' Augustine' ',' ' and' ' who' ' was' ' a' ' great']" , the great theolog ian , who was a disciple of St . Augustine , and who was a great False Church Father Augustine of Hippo (354 – 430) in 5 [' Church', ' Father', ' Augustine', ' of', ' Hipp', 'o']
+566 122 Name of mother of x -1 Name of mother of Franz Liszt Anna Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Franz' ' Lis' 'z' 't' ',' ' the' ' composer' ',']" ", the composer , and the
+
+ Name of mother of Franz Lis z t , the composer ," False faculty of the Franz Liszt Academy included 6 [' faculty', ' of', ' the', ' Franz', ' Lis', 'z', 't']
+567 122 Name of mother of x -1 Name of mother of Franz Liszt Anna Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Franz' ' Lis' 'z' 't' ',' ' the' ' composer' ',']" ", the composer , and the
+
+ Name of mother of Franz Lis z t , the composer ," False pianist and composer Franz Liszt and Marie d 'Agoult. 7 [' pian', 'ist', ' and', ' composer', ' Franz', ' Lis', 'z', 't']
+568 122 Name of mother of x -1 Name of mother of Franz Liszt Anna Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Franz' ' Lis' 'z' 't' ',' ' the' ' composer' ',']" ", the composer , and the
+
+ Name of mother of Franz Lis z t , the composer ," False Frédéric Chopin, Franz Liszt and Charles-Valentin 11 [' Fr', 'é', 'd', 'é', 'ric', ' Chop', 'in', ',', ' Franz', ' Lis', 'z', 't']
+569 122 Name of mother of x -1 Name of mother of Franz Liszt Anna Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Franz' ' Lis' 'z' 't' ',' ' the' ' composer' ',']" ", the composer , and the
+
+ Name of mother of Franz Lis z t , the composer ," False Solange (1858), prompted Franz Liszt the following year 10 [' Sol', 'ange', ' (', '18', '58', '),', ' prompted', ' Franz', ' Lis', 'z', 't']
+570 122 Name of mother of x -1 Name of mother of Franz Liszt Anna Liszt Franz Liszt "[',' ' the' ' composer' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Franz' ' Lis' 'z' 't' ',' ' the' ' composer' ',']" ", the composer , and the
+
+ Name of mother of Franz Lis z t , the composer ," False in B minor. Franz Liszt based extended keyboard 7 [' in', ' B', ' minor', '.', ' Franz', ' Lis', 'z', 't']
+571 123 Name of mother of x -1 Name of mother of Guido Gezelle Monica De Vriese Guido Gezelle "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Gu' 'ido' ' Ge' 'z'
+ 'elle' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Gu ido Ge z elle , the
+
+ Name of mother" False poet and priest Guido Gezelle produced lyric 7 [' poet', ' and', ' priest', ' Gu', 'ido', ' Ge', 'z', 'elle']
+572 124 Name of mother of x -1 Name of mother of Agatha Christie Clarisa Margaret Boehmer Agatha Christie "[""'s"" ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd' '_' ',' ' and'
+ ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd']" 's _ The Murder of Roger Ack roy d _ , and _ The Murder of Roger Ack roy d False Conan Doyle, Agatha Christie and Seishi Yokomizo. 5 [' Conan', ' Doyle', ',', ' Ag', 'atha', ' Christie']
+573 124 Name of mother of x -1 Name of mother of Agatha Christie Clarisa Margaret Boehmer Agatha Christie "[""'s"" ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd' '_' ',' ' and'
+ ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd']" 's _ The Murder of Roger Ack roy d _ , and _ The Murder of Roger Ack roy d False " ""oddly like an Agatha Christie thriller with" 7 "[' ""', 'odd', 'ly', ' like', ' an', ' Ag', 'atha', ' Christie']"
+574 124 Name of mother of x -1 Name of mother of Agatha Christie Clarisa Margaret Boehmer Agatha Christie "[""'s"" ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd' '_' ',' ' and'
+ ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd']" 's _ The Murder of Roger Ack roy d _ , and _ The Murder of Roger Ack roy d False title from the Agatha Christie novel And Then 5 [' title', ' from', ' the', ' Ag', 'atha', ' Christie']
+575 124 Name of mother of x -1 Name of mother of Agatha Christie Clarisa Margaret Boehmer Agatha Christie "[""'s"" ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd' '_' ',' ' and'
+ ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd']" 's _ The Murder of Roger Ack roy d _ , and _ The Murder of Roger Ack roy d False production of the Agatha Christie murder mystery The 5 [' production', ' of', ' the', ' Ag', 'atha', ' Christie']
+576 124 Name of mother of x -1 Name of mother of Agatha Christie Clarisa Margaret Boehmer Agatha Christie "[""'s"" ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd' '_' ',' ' and'
+ ' _' 'The' ' Murder' ' of' ' Roger' ' Ack' 'roy' 'd']" 's _ The Murder of Roger Ack roy d _ , and _ The Murder of Roger Ack roy d False authors like Agatha Christie and Ellery Queen. 4 [' authors', ' like', ' Ag', 'atha', ' Christie']
+577 125 Name of mother of x -1 Name of mother of Meryl Streep Mary Wilkinson Streep Meryl Streep "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' I' ""'m"" ' sure'
+ ' she' ""'s"" ' a' ' great' ' mother' '.' '\n' '\n' 'I']" ", who is a great actress , and I 'm sure she 's a great mother .
+
+ I" False " Miranda Priestly
+" 6 [' Miranda', ' Priest', 'ly', 'M', 'eryl', ' Stre', 'ep']
+578 125 Name of mother of x -1 Name of mother of Meryl Streep Mary Wilkinson Streep Meryl Streep "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' I' ""'m"" ' sure'
+ ' she' ""'s"" ' a' ' great' ' mother' '.' '\n' '\n' 'I']" ", who is a great actress , and I 'm sure she 's a great mother .
+
+ I" False (2009) and by Meryl Streep in The Iron Lady 8 [' (', '2009', ')', ' and', ' by', ' M', 'eryl', ' Stre', 'ep']
+579 125 Name of mother of x -1 Name of mother of Meryl Streep Mary Wilkinson Streep Meryl Streep "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' I' ""'m"" ' sure'
+ ' she' ""'s"" ' a' ' great' ' mother' '.' '\n' '\n' 'I']" ", who is a great actress , and I 'm sure she 's a great mother .
+
+ I" False Donald Kaufman, and Meryl Streep as Susan Orlean, 7 [' Donald', ' Kaufman', ',', ' and', ' M', 'eryl', ' Stre', 'ep']
+580 125 Name of mother of x -1 Name of mother of Meryl Streep Mary Wilkinson Streep Meryl Streep "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' I' ""'m"" ' sure'
+ ' she' ""'s"" ' a' ' great' ' mother' '.' '\n' '\n' 'I']" ", who is a great actress , and I 'm sure she 's a great mother .
+
+ I" False co-starred alongside Meryl Streep in Mamma Mia!, 8 [' co', '-', 'star', 'red', ' alongside', ' M', 'eryl', ' Stre', 'ep']
+581 125 Name of mother of x -1 Name of mother of Meryl Streep Mary Wilkinson Streep Meryl Streep "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' I' ""'m"" ' sure'
+ ' she' ""'s"" ' a' ' great' ' mother' '.' '\n' '\n' 'I']" ", who is a great actress , and I 'm sure she 's a great mother .
+
+ I" False sisters, played by Meryl Streep and Diane Keaton, 7 [' sisters', ',', ' played', ' by', ' M', 'eryl', ' Stre', 'ep']
+582 126 Name of mother of x -1 Name of mother of Olivia Newton-John Irene Helen Käthe Born Olivia Newton-John "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' huge' ' fan' ' of' ' the' ' show' ' and']" ", who is a great friend of mine .
+
+ I am a huge fan of the show and" False Delta Goodrem, Olivia Newton-John and Kylie Minogue for 7 [' Delta', ' Good', 'rem', ',', ' Olivia', ' Newton', '-', 'John']
+583 126 Name of mother of x -1 Name of mother of Olivia Newton-John Irene Helen Käthe Born Olivia Newton-John "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' huge' ' fan' ' of' ' the' ' show' ' and']" ", who is a great friend of mine .
+
+ I am a huge fan of the show and" False " John Travolta and Olivia Newton-John perform ""You're the" 9 [' John', ' T', 'rav', 'olt', 'a', ' and', ' Olivia', ' Newton', '-', 'John']
+584 126 Name of mother of x -1 Name of mother of Olivia Newton-John Irene Helen Käthe Born Olivia Newton-John "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' huge' ' fan' ' of' ' the' ' show' ' and']" ", who is a great friend of mine .
+
+ I am a huge fan of the show and" False " version ===
+" 7 [' version', ' ===', 'O', 'liv', 'ia', ' Newton', '-', 'John']
+585 126 Name of mother of x -1 Name of mother of Olivia Newton-John Irene Helen Käthe Born Olivia Newton-John "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' huge' ' fan' ' of' ' the' ' show' ' and']" ", who is a great friend of mine .
+
+ I am a huge fan of the show and" False " Rosada"". The same year, Olivia Newton-John released the song" 10 "[' Ros', 'ada', '"".', ' The', ' same', ' year', ',', ' Olivia', ' Newton', '-', 'John']"
+586 126 Name of mother of x -1 Name of mother of Olivia Newton-John Irene Helen Käthe Born Olivia Newton-John "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' huge' ' fan' ' of' ' the' ' show' ' and']" ", who is a great friend of mine .
+
+ I am a huge fan of the show and" False guest star Olivia Newton-John appears as 5 [' guest', ' star', ' Olivia', ' Newton', '-', 'John']
+587 127 Name of mother of x -1 Name of mother of Sigmund Freud Amalia Freud Sigmund Freud "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' the' ' psych' 'oan' 'aly' 'tic' ' movement' ',' ' was'
+ ' born']" , the father of psycho analysis , and the father of the psych oan aly tic movement , was born False pervasiveness of Sigmund Freud in the film 7 [' per', 'vas', 'iveness', ' of', ' S', 'igm', 'und', ' Freud']
+588 127 Name of mother of x -1 Name of mother of Sigmund Freud Amalia Freud Sigmund Freud "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' the' ' psych' 'oan' 'aly' 'tic' ' movement' ',' ' was'
+ ' born']" , the father of psycho analysis , and the father of the psych oan aly tic movement , was born False Lane, Wilhelm Reich, Sigmund Freud and others on his 8 [' Lane', ',', ' Wilhelm', ' Reich', ',', ' S', 'igm', 'und', ' Freud']
+589 127 Name of mother of x -1 Name of mother of Sigmund Freud Amalia Freud Sigmund Freud "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' the' ' psych' 'oan' 'aly' 'tic' ' movement' ',' ' was'
+ ' born']" , the father of psycho analysis , and the father of the psych oan aly tic movement , was born False theories of Sigmund Freud and Carl Jung, the 5 [' theories', ' of', ' S', 'igm', 'und', ' Freud']
+590 127 Name of mother of x -1 Name of mother of Sigmund Freud Amalia Freud Sigmund Freud "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' the' ' psych' 'oan' 'aly' 'tic' ' movement' ',' ' was'
+ ' born']" , the father of psycho analysis , and the father of the psych oan aly tic movement , was born False also a student of Sigmund Freud and others, [who] 7 [' also', ' a', ' student', ' of', ' S', 'igm', 'und', ' Freud']
+591 127 Name of mother of x -1 Name of mother of Sigmund Freud Amalia Freud Sigmund Freud "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' the' ' psych' 'oan' 'aly' 'tic' ' movement' ',' ' was'
+ ' born']" , the father of psycho analysis , and the father of the psych oan aly tic movement , was born False Friedrich Nietzsche, Sigmund Freud and Carl Jung influence 6 [' Friedrich', ' Nietzsche', ',', ' S', 'igm', 'und', ' Freud']
+592 128 Name of mother of x -1 Name of mother of Batman Martha Wayne Batman "[',' ' and' ' the' ' other' ' is' ' the' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' child' ' is' ' the']" ", and the other is the mother of the child .
+
+ The mother of the child is the" False 1995, beating Batman Forever, Apollo 13 3 [' 1995', ',', ' beating', ' Batman']
+593 128 Name of mother of x -1 Name of mother of Batman Martha Wayne Batman "[',' ' and' ' the' ' other' ' is' ' the' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' child' ' is' ' the']" ", and the other is the mother of the child .
+
+ The mother of the child is the" False river water. Anarky and Batman ultimately 7 [' river', ' water', '.', ' An', 'ark', 'y', ' and', ' Batman']
+594 128 Name of mother of x -1 Name of mother of Batman Martha Wayne Batman "[',' ' and' ' the' ' other' ' is' ' the' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' child' ' is' ' the']" ", and the other is the mother of the child .
+
+ The mother of the child is the" False adapted to serve as Batman's adversary in 4 [' adapted', ' to', ' serve', ' as', ' Batman']
+595 128 Name of mother of x -1 Name of mother of Batman Martha Wayne Batman "[',' ' and' ' the' ' other' ' is' ' the' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' child' ' is' ' the']" ", and the other is the mother of the child .
+
+ The mother of the child is the" False Wayne's childhood. As Batman, Wayne investigates 5 "[' Wayne', ""'s"", ' childhood', '.', ' As', ' Batman']"
+596 128 Name of mother of x -1 Name of mother of Batman Martha Wayne Batman "[',' ' and' ' the' ' other' ' is' ' the' ' mother' ' of' ' the' ' child'
+ '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' child' ' is' ' the']" ", and the other is the mother of the child .
+
+ The mother of the child is the" False Lego-themed version of Batman appears in The Lego 5 [' Lego', '-', 'themed', ' version', ' of', ' Batman']
+597 129 Name of mother of x -1 Name of mother of Ronald Reagan Nelle Wilson Reagan Ronald Reagan "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False claiming that Ronald Reagan neglected to mention 3 [' claiming', ' that', ' Ronald', ' Reagan']
+598 129 Name of mother of x -1 Name of mother of Ronald Reagan Nelle Wilson Reagan Ronald Reagan "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False States President Ronald Reagan — and that 3 [' States', ' President', ' Ronald', ' Reagan']
+599 129 Name of mother of x -1 Name of mother of Ronald Reagan Nelle Wilson Reagan Ronald Reagan "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False President Ronald Reagan visited the state 2 [' President', ' Ronald', ' Reagan']
+600 129 Name of mother of x -1 Name of mother of Ronald Reagan Nelle Wilson Reagan Ronald Reagan "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False ed. A Companion to Ronald Reagan (Wiley-Blackwell, 6 [' ed', '.', ' A', ' Companion', ' to', ' Ronald', ' Reagan']
+601 129 Name of mother of x -1 Name of mother of Ronald Reagan Nelle Wilson Reagan Ronald Reagan "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False featured on MSNBC at the Ronald Reagan Presidential Library 6 [' featured', ' on', ' MSNBC', ' at', ' the', ' Ronald', ' Reagan']
+602 130 Name of mother of x -1 Name of mother of Mark Twain Jane Lampton Mark Twain "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mark' ' Twain'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Mark Twain , and the
+
+ Name of mother of" False back on whether Mark Twain had said that, obviously 4 [' back', ' on', ' whether', ' Mark', ' Twain']
+603 130 Name of mother of x -1 Name of mother of Mark Twain Jane Lampton Mark Twain "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mark' ' Twain'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Mark Twain , and the
+
+ Name of mother of" False Contributors included Mark Twain and Stephen Crane. 4 [' Contribut', 'ors', ' included', ' Mark', ' Twain']
+604 130 Name of mother of x -1 Name of mother of Mark Twain Jane Lampton Mark Twain "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mark' ' Twain'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Mark Twain , and the
+
+ Name of mother of" False father (My Father, Mark Twain in 1931) and of 6 [' father', ' (', 'My', ' Father', ',', ' Mark', ' Twain']
+605 130 Name of mother of x -1 Name of mother of Mark Twain Jane Lampton Mark Twain "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mark' ' Twain'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Mark Twain , and the
+
+ Name of mother of" False nominee for the 2005 Mark Twain Award and 2007 5 [' nominee', ' for', ' the', ' 2005', ' Mark', ' Twain']
+606 130 Name of mother of x -1 Name of mother of Mark Twain Jane Lampton Mark Twain "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mark' ' Twain'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Mark Twain , and the
+
+ Name of mother of" False " good repair. Even Mark Twain remarked, ""The" 5 [' good', ' repair', '.', ' Even', ' Mark', ' Twain']
+607 131 Name of mother of x -1 Name of mother of Franz Kafka Julie Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' novel']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The novel" False " Austrian literary Franz Kafka Prize
+" 3 [' Austrian', ' literary', ' Franz', ' Kafka']
+608 131 Name of mother of x -1 Name of mother of Franz Kafka Julie Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' novel']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The novel" False 2 ['Fran', 'z', ' Kafka']
+609 131 Name of mother of x -1 Name of mother of Franz Kafka Julie Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' novel']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The novel" False that the writer Franz Kafka had suffered from an 4 [' that', ' the', ' writer', ' Franz', ' Kafka']
+610 131 Name of mother of x -1 Name of mother of Franz Kafka Julie Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' novel']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The novel" False " businessman"" and by Franz Kafka as ""a true Kafka" 5 "[' businessman', '""', ' and', ' by', ' Franz', ' Kafka']"
+611 131 Name of mother of x -1 Name of mother of Franz Kafka Julie Kafka Franz Kafka "[',' ' the' ' author' ' of' ' The' ' Trial' ',' ' The' ' Castle' ','
+ ' and' ' The' ' Met' 'amorph' 'osis' '.' '\n' '\n' 'The' ' novel']" ", the author of The Trial , The Castle , and The Met amorph osis .
+
+ The novel" False Brandenfeld, to Franz Kafka and Felice Bauer. 6 [' Brand', 'en', 'feld', ',', ' to', ' Franz', ' Kafka']
+612 132 Name of mother of x -1 Name of mother of Hilary Duff Susan Colleen Cobb Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False the self-titled Hilary Duff (2004), received 7 [' the', ' self', '-', 't', 'itled', ' Hil', 'ary', ' Duff']
+613 132 Name of mother of x -1 Name of mother of Hilary Duff Susan Colleen Cobb Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False [Metamorphosis (2003) and Hilary Duff (2004)] have distinctive 10 [' [', 'Met', 'amorph', 'osis', ' (', '2003', ')', ' and', ' Hil', 'ary', ' Duff']
+614 132 Name of mother of x -1 Name of mother of Hilary Duff Susan Colleen Cobb Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False sound to those of Hilary Duff and Avril Lavigne. 6 [' sound', ' to', ' those', ' of', ' Hil', 'ary', ' Duff']
+615 132 Name of mother of x -1 Name of mother of Hilary Duff Susan Colleen Cobb Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " quality made past Hilary Duff concerts ""look" 5 [' quality', ' made', ' past', ' Hil', 'ary', ' Duff']
+616 132 Name of mother of x -1 Name of mother of Hilary Duff Susan Colleen Cobb Hilary Duff "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False hits album, Best of Hilary Duff (2008). It is also 7 [' hits', ' album', ',', ' Best', ' of', ' Hil', 'ary', ' Duff']
+617 133 Name of mother of x -1 Name of mother of Alexander von Humboldt Marie-Elisabeth von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' ',' ' who' ' was' ' born'
+ ' in' ' 17' '69' ',' ' and' ' died' ' in' ' 18' '59' '.']" , the great German natural ist , who was born in 17 69 , and died in 18 59 . False and the young Alexander von Humboldt started from Mainz 8 [' and', ' the', ' young', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+618 133 Name of mother of x -1 Name of mother of Alexander von Humboldt Marie-Elisabeth von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' ',' ' who' ' was' ' born'
+ ' in' ' 17' '69' ',' ' and' ' died' ' in' ' 18' '59' '.']" , the great German natural ist , who was born in 17 69 , and died in 18 59 . False Professors and the Alexander von Humboldt Association of America. 9 [' Prof', 'essors', ' and', ' the', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+619 133 Name of mother of x -1 Name of mother of Alexander von Humboldt Marie-Elisabeth von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' ',' ' who' ' was' ' born'
+ ' in' ' 17' '69' ',' ' and' ' died' ' in' ' 18' '59' '.']" , the great German natural ist , who was born in 17 69 , and died in 18 59 . False based on the work of Alexander von Humboldt (1769 – 1859). 10 [' based', ' on', ' the', ' work', ' of', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+620 133 Name of mother of x -1 Name of mother of Alexander von Humboldt Marie-Elisabeth von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' ',' ' who' ' was' ' born'
+ ' in' ' 17' '69' ',' ' and' ' died' ' in' ' 18' '59' '.']" , the great German natural ist , who was born in 17 69 , and died in 18 59 . False " origin"". German Alexander von Humboldt (1769 – 1859)" 8 "[' origin', '"".', ' German', ' Alexander', ' von', ' H', 'umb', 'old', 't']"
+621 133 Name of mother of x -1 Name of mother of Alexander von Humboldt Marie-Elisabeth von Humboldt Alexander von Humboldt "[',' ' the' ' great' ' German' ' natural' 'ist' ',' ' who' ' was' ' born'
+ ' in' ' 17' '69' ',' ' and' ' died' ' in' ' 18' '59' '.']" , the great German natural ist , who was born in 17 69 , and died in 18 59 . False convinced by Goethe and Alexander von Humboldt that making 10 [' convinced', ' by', ' Go', 'ethe', ' and', ' Alexander', ' von', ' H', 'umb', 'old', 't']
+622 134 Name of mother of x -1 Name of mother of Yukio Mishima Shizue Hiraoka Yukio Mishima "[',' ' the' ' Japanese' ' author' ' of' ' the' ' novel' ' ""' 'The'
+ ' Temple' ' of' ' the' ' Golden' ' Pavilion' '""' ' and' ' the' ' film'
+ ' ""' 'The']" ", the Japanese author of the novel "" The Temple of the Golden Pavilion "" and the film "" The" False Jun 'ichirō Tanizaki, Yukio Mishima and, more recently, 12 "[' Jun', "" '"", 'ich', 'ir', 'ō', ' Tan', 'iz', 'aki', ',', ' Yuk', 'io', ' Mish', 'ima']"
+623 134 Name of mother of x -1 Name of mother of Yukio Mishima Shizue Hiraoka Yukio Mishima "[',' ' the' ' Japanese' ' author' ' of' ' the' ' novel' ' ""' 'The'
+ ' Temple' ' of' ' the' ' Golden' ' Pavilion' '""' ' and' ' the' ' film'
+ ' ""' 'The']" ", the Japanese author of the novel "" The Temple of the Golden Pavilion "" and the film "" The" False 'ichirō Tanizaki, Yukio Mishima and, more recently, 11 "["" '"", 'ich', 'ir', 'ō', ' Tan', 'iz', 'aki', ',', ' Yuk', 'io', ' Mish', 'ima']"
+624 134 Name of mother of x -1 Name of mother of Yukio Mishima Shizue Hiraoka Yukio Mishima "[',' ' the' ' Japanese' ' author' ' of' ' the' ' novel' ' ""' 'The'
+ ' Temple' ' of' ' the' ' Golden' ' Pavilion' '""' ' and' ' the' ' film'
+ ' ""' 'The']" ", the Japanese author of the novel "" The Temple of the Golden Pavilion "" and the film "" The" False 'ichirō Tanizaki, Yukio Mishima and, more recently, 11 "["" '"", 'ich', 'ir', 'ō', ' Tan', 'iz', 'aki', ',', ' Yuk', 'io', ' Mish', 'ima']"
+625 135 Name of mother of x -1 Name of mother of Frans Hals Adriaentje van Geertenryck Frans Hals "[',' ' the' ' painter' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the
+ " False based on the Frans Hals painting, the 6 [' based', ' on', ' the', ' Fr', 'ans', ' H', 'als']
+626 135 Name of mother of x -1 Name of mother of Frans Hals Adriaentje van Geertenryck Frans Hals "[',' ' the' ' painter' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the
+ " False Happy Smoker after Frans Hals (1923) hangs in the 7 [' Happy', ' Sm', 'oker', ' after', ' Fr', 'ans', ' H', 'als']
+627 135 Name of mother of x -1 Name of mother of Frans Hals Adriaentje van Geertenryck Frans Hals "[',' ' the' ' painter' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the
+ " False the style of Frans Hals and the school 6 [' the', ' style', ' of', ' Fr', 'ans', ' H', 'als']
+628 135 Name of mother of x -1 Name of mother of Frans Hals Adriaentje van Geertenryck Frans Hals "[',' ' the' ' painter' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the
+ " False Rommelpotspeler after Frans Hals. The Frans Hals catalogue 9 [' R', 'ommel', 'pots', 'pel', 'er', ' after', ' Fr', 'ans', ' H', 'als']
+629 135 Name of mother of x -1 Name of mother of Frans Hals Adriaentje van Geertenryck Frans Hals "[',' ' the' ' painter' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and the
+ " False Happy Smoker after Frans Hals (1923) hangs 7 [' Happy', ' Sm', 'oker', ' after', ' Fr', 'ans', ' H', 'als']
+630 136 Name of mother of x -1 Name of mother of Isaac Newton Hannah Ayscough Isaac Newton "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in'
+ ' the' ' year' ' 16' '42' ',' ' and' ' died' ' in' ' 17' '27']" , the father of modern science , was born in the year 16 42 , and died in 17 27 False operated by the Isaac Newton Group of Telescopes 4 [' operated', ' by', ' the', ' Isaac', ' Newton']
+631 136 Name of mother of x -1 Name of mother of Isaac Newton Hannah Ayscough Isaac Newton "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in'
+ ' the' ' year' ' 16' '42' ',' ' and' ' died' ' in' ' 17' '27']" , the father of modern science , was born in the year 16 42 , and died in 17 27 False Astronomers since Isaac Newton have tried to estimate 4 [' Astron', 'omers', ' since', ' Isaac', ' Newton']
+632 136 Name of mother of x -1 Name of mother of Isaac Newton Hannah Ayscough Isaac Newton "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in'
+ ' the' ' year' ' 16' '42' ',' ' and' ' died' ' in' ' 17' '27']" , the father of modern science , was born in the year 16 42 , and died in 17 27 False telescopes of the Isaac Newton Group began operating 4 [' telescopes', ' of', ' the', ' Isaac', ' Newton']
+633 136 Name of mother of x -1 Name of mother of Isaac Newton Hannah Ayscough Isaac Newton "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in'
+ ' the' ' year' ' 16' '42' ',' ' and' ' died' ' in' ' 17' '27']" , the father of modern science , was born in the year 16 42 , and died in 17 27 False similarities to Isaac Newton but differs 3 [' similarities', ' to', ' Isaac', ' Newton']
+634 136 Name of mother of x -1 Name of mother of Isaac Newton Hannah Ayscough Isaac Newton "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in'
+ ' the' ' year' ' 16' '42' ',' ' and' ' died' ' in' ' 17' '27']" , the father of modern science , was born in the year 16 42 , and died in 17 27 False book Opticks, Isaac Newton reported Rømer's 5 [' book', ' Opt', 'icks', ',', ' Isaac', ' Newton']
+635 137 Name of mother of x -1 Name of mother of Charles III Elizabeth II Charles III "[',' ' the' ' son' ' of' ' Charles' ' I' ',' ' and' ' the' ' son' ' of'
+ ' Charles' ' II' ',' ' and' ' the' ' son' ' of' ' Charles' ' II']" , the son of Charles I , and the son of Charles II , and the son of Charles II False occasionally. Charles III was succeeded in 3 [' occasionally', '.', ' Charles', ' III']
+636 137 Name of mother of x -1 Name of mother of Charles III Elizabeth II Charles III "[',' ' the' ' son' ' of' ' Charles' ' I' ',' ' and' ' the' ' son' ' of'
+ ' Charles' ' II' ',' ' and' ' the' ' son' ' of' ' Charles' ' II']" , the son of Charles I , and the son of Charles II , and the son of Charles II False was ceded to Charles III of Spain in 5 [' was', ' c', 'eded', ' to', ' Charles', ' III']
+637 137 Name of mother of x -1 Name of mother of Charles III Elizabeth II Charles III "[',' ' the' ' son' ' of' ' Charles' ' I' ',' ' and' ' the' ' son' ' of'
+ ' Charles' ' II' ',' ' and' ' the' ' son' ' of' ' Charles' ' II']" , the son of Charles I , and the son of Charles II , and the son of Charles II False 1 ['Charles', ' III']
+638 137 Name of mother of x -1 Name of mother of Charles III Elizabeth II Charles III "[',' ' the' ' son' ' of' ' Charles' ' I' ',' ' and' ' the' ' son' ' of'
+ ' Charles' ' II' ',' ' and' ' the' ' son' ' of' ' Charles' ' II']" , the son of Charles I , and the son of Charles II , and the son of Charles II False They regarded Charles III of Naples as Louis 3 [' They', ' regarded', ' Charles', ' III']
+639 137 Name of mother of x -1 Name of mother of Charles III Elizabeth II Charles III "[',' ' the' ' son' ' of' ' Charles' ' I' ',' ' and' ' the' ' son' ' of'
+ ' Charles' ' II' ',' ' and' ' the' ' son' ' of' ' Charles' ' II']" , the son of Charles I , and the son of Charles II , and the son of Charles II False 1 ['Charles', ' III']
+640 138 Name of mother of x -1 Name of mother of Honoré de Balzac Anne-Charlotte-Laure Sallambier Honoré de Balzac "[',' ' the' ' author' ' of' ' the' ' _' 'Com' 'é' 'die' ' hum' 'aine' '_'
+ ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ']" ", the author of the _ Com é die hum aine _ , and the
+ " False playwright Honoré de Balzac (1799 – 1850). Set 6 [' play', 'wright', ' Honor', 'é', ' de', ' Bal', 'zac']
+641 138 Name of mother of x -1 Name of mother of Honoré de Balzac Anne-Charlotte-Laure Sallambier Honoré de Balzac "[',' ' the' ' author' ' of' ' the' ' _' 'Com' 'é' 'die' ' hum' 'aine' '_'
+ ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ']" ", the author of the _ Com é die hum aine _ , and the
+ " False 5 ['Hon', 'or', 'é', ' de', ' Bal', 'zac']
+642 138 Name of mother of x -1 Name of mother of Honoré de Balzac Anne-Charlotte-Laure Sallambier Honoré de Balzac "[',' ' the' ' author' ' of' ' the' ' _' 'Com' 'é' 'die' ' hum' 'aine' '_'
+ ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ']" ", the author of the _ Com é die hum aine _ , and the
+ " False 5 ['Hon', 'or', 'é', ' de', ' Bal', 'zac']
+643 138 Name of mother of x -1 Name of mother of Honoré de Balzac Anne-Charlotte-Laure Sallambier Honoré de Balzac "[',' ' the' ' author' ' of' ' the' ' _' 'Com' 'é' 'die' ' hum' 'aine' '_'
+ ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ']" ", the author of the _ Com é die hum aine _ , and the
+ " False novelist and playwright Honoré de Balzac (1799 – 1850), 8 [' novelist', ' and', ' play', 'wright', ' Honor', 'é', ' de', ' Bal', 'zac']
+644 138 Name of mother of x -1 Name of mother of Honoré de Balzac Anne-Charlotte-Laure Sallambier Honoré de Balzac "[',' ' the' ' author' ' of' ' the' ' _' 'Com' 'é' 'die' ' hum' 'aine' '_'
+ ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ']" ", the author of the _ Com é die hum aine _ , and the
+ " False " talent to be lucky."" Honoré de Balzac (in his novella" 9 "[' talent', ' to', ' be', ' lucky', '.""', ' Honor', 'é', ' de', ' Bal', 'zac']"
+645 139 Name of mother of x -1 Name of mother of Brigitte Bardot Anne-Marie Mucel Brigitte Bardot "[',' ' the' ' French' ' actress' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress' ' and' ' singer' ' Brig' 'itte' ' Bard' 'ot' '.' '\n' '\n']" ", the French actress , and the mother of the actress and singer Brig itte Bard ot .
+
+" False Marianne Faithfull and Brigitte Bardot set about destroying 8 [' Marian', 'ne', ' Faith', 'full', ' and', ' Brig', 'itte', ' Bard', 'ot']
+646 139 Name of mother of x -1 Name of mother of Brigitte Bardot Anne-Marie Mucel Brigitte Bardot "[',' ' the' ' French' ' actress' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress' ' and' ' singer' ' Brig' 'itte' ' Bard' 'ot' '.' '\n' '\n']" ", the French actress , and the mother of the actress and singer Brig itte Bard ot .
+
+" False Bettie Page with Brigitte Bardot and adding that little 7 [' Bett', 'ie', ' Page', ' with', ' Brig', 'itte', ' Bard', 'ot']
+647 139 Name of mother of x -1 Name of mother of Brigitte Bardot Anne-Marie Mucel Brigitte Bardot "[',' ' the' ' French' ' actress' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress' ' and' ' singer' ' Brig' 'itte' ' Bard' 'ot' '.' '\n' '\n']" ", the French actress , and the mother of the actress and singer Brig itte Bard ot .
+
+" False emulates Marilyn Monroe, Brigitte Bardot and Barbra Streisand 8 [' em', 'ulates', ' Marilyn', ' Monroe', ',', ' Brig', 'itte', ' Bard', 'ot']
+648 139 Name of mother of x -1 Name of mother of Brigitte Bardot Anne-Marie Mucel Brigitte Bardot "[',' ' the' ' French' ' actress' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress' ' and' ' singer' ' Brig' 'itte' ' Bard' 'ot' '.' '\n' '\n']" ", the French actress , and the mother of the actress and singer Brig itte Bard ot .
+
+" False emulates Marilyn Monroe, Brigitte Bardot and Barbra Streisand 8 [' em', 'ulates', ' Marilyn', ' Monroe', ',', ' Brig', 'itte', ' Bard', 'ot']
+649 139 Name of mother of x -1 Name of mother of Brigitte Bardot Anne-Marie Mucel Brigitte Bardot "[',' ' the' ' French' ' actress' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress' ' and' ' singer' ' Brig' 'itte' ' Bard' 'ot' '.' '\n' '\n']" ", the French actress , and the mother of the actress and singer Brig itte Bard ot .
+
+" False formed a duo with Brigitte Bardot in the French 7 [' formed', ' a', ' duo', ' with', ' Brig', 'itte', ' Bard', 'ot']
+650 140 Name of mother of x -1 Name of mother of August Strindberg Eleonora Ulrika Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Stockholm' ',' ' Sweden' ',' ' in' ' 18' '49' '.' '\n' '\n' 'The']" ", the Swedish writer , who was born in Stockholm , Sweden , in 18 49 .
+
+ The" False " Strindberg ===
+" 7 [' Str', 'ind', 'berg', ' ===', 'August', ' Str', 'ind', 'berg']
+651 140 Name of mother of x -1 Name of mother of August Strindberg Eleonora Ulrika Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Stockholm' ',' ' Sweden' ',' ' in' ' 18' '49' '.' '\n' '\n' 'The']" ", the Swedish writer , who was born in Stockholm , Sweden , in 18 49 .
+
+ The" False Swedish playwright August Strindberg and Norwegian playwright 6 [' Swedish', ' play', 'wright', ' August', ' Str', 'ind', 'berg']
+652 140 Name of mother of x -1 Name of mother of August Strindberg Eleonora Ulrika Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Stockholm' ',' ' Sweden' ',' ' in' ' 18' '49' '.' '\n' '\n' 'The']" ", the Swedish writer , who was born in Stockholm , Sweden , in 18 49 .
+
+ The" False " strongly."" Swedish author August Strindberg wrote: ""Linnaeus" 7 "[' strongly', '.""', ' Swedish', ' author', ' August', ' Str', 'ind', 'berg']"
+653 140 Name of mother of x -1 Name of mother of August Strindberg Eleonora Ulrika Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Stockholm' ',' ' Sweden' ',' ' in' ' 18' '49' '.' '\n' '\n' 'The']" ", the Swedish writer , who was born in Stockholm , Sweden , in 18 49 .
+
+ The" False " == Marriage to August Strindberg ==
+" 6 [' ==', ' Marriage', ' to', ' August', ' Str', 'ind', 'berg']
+654 140 Name of mother of x -1 Name of mother of August Strindberg Eleonora Ulrika Strindberg August Strindberg "[',' ' the' ' Swedish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' Stockholm' ',' ' Sweden' ',' ' in' ' 18' '49' '.' '\n' '\n' 'The']" ", the Swedish writer , who was born in Stockholm , Sweden , in 18 49 .
+
+ The" False Swedish dramatist August Strindberg (1849 – 1912). 7 [' Swedish', ' dram', 'at', 'ist', ' August', ' Str', 'ind', 'berg']
+655 141 Name of mother of x -1 Name of mother of Halle Berry Judith Ann Hawkins Halle Berry "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The actress ," False Ororo Munroe 6 [' Or', 'oro', ' Mun', 'ro', 'H', 'alle', ' Berry']
+656 141 Name of mother of x -1 Name of mother of Halle Berry Judith Ann Hawkins Halle Berry "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The actress ," False serial killer and Halle Berry as Jordan Turner, 5 [' serial', ' killer', ' and', ' H', 'alle', ' Berry']
+657 141 Name of mother of x -1 Name of mother of Halle Berry Judith Ann Hawkins Halle Berry "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The actress ," False Hugh Jackman, Halle Berry and Kelsey Grammer 6 [' Hugh', ' Jack', 'man', ',', ' H', 'alle', ' Berry']
+658 141 Name of mother of x -1 Name of mother of Halle Berry Judith Ann Hawkins Halle Berry "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The actress ," False Award winner Halle Berry guest starred 4 [' Award', ' winner', ' H', 'alle', ' Berry']
+659 141 Name of mother of x -1 Name of mother of Halle Berry Judith Ann Hawkins Halle Berry "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe winner , and the mother of two .
+
+ The actress ," False " = Halle Berry =
+" 3 [' =', ' H', 'alle', ' Berry']
+660 142 Name of mother of x -1 Name of mother of Friedrich Schiller Elisabeth Dorothea Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the
+ " False Shakespeare, Friedrich Schiller and Johann Wolfgang 4 [' Shakespeare', ',', ' Friedrich', ' Sch', 'iller']
+661 142 Name of mother of x -1 Name of mother of Friedrich Schiller Elisabeth Dorothea Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the
+ " False Shakespeare, Friedrich Schiller and Johann 4 [' Shakespeare', ',', ' Friedrich', ' Sch', 'iller']
+662 142 Name of mother of x -1 Name of mother of Friedrich Schiller Elisabeth Dorothea Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the
+ " False encouragement. Friedrich Schiller would also deal 4 [' encouragement', '.', ' Friedrich', ' Sch', 'iller']
+663 142 Name of mother of x -1 Name of mother of Friedrich Schiller Elisabeth Dorothea Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the
+ " False Napoleon's encouragement. Friedrich Schiller would also deal with 6 "[' Napoleon', ""'s"", ' encouragement', '.', ' Friedrich', ' Sch', 'iller']"
+664 142 Name of mother of x -1 Name of mother of Friedrich Schiller Elisabeth Dorothea Schiller Friedrich Schiller "[',' ' the' ' poet' ',' ' and' ' the' ' father' ' of' ' the' ' poet' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the father of the poet , the
+ " False William Shakespeare, Friedrich Schiller and Johann Wolfgang 5 [' William', ' Shakespeare', ',', ' Friedrich', ' Sch', 'iller']
+665 143 Name of mother of x -1 Name of mother of George Frideric Handel Dorothea Händel George Frideric Handel "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' H' 'alle' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the' ' son']" , the composer , was born in H alle , Germany , in 16 85 . He was the son False than 25 operas by George Frideric Handel premièred here. In 10 [' than', ' 25', ' oper', 'as', ' by', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+666 143 Name of mother of x -1 Name of mother of George Frideric Handel Dorothea Händel George Frideric Handel "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' H' 'alle' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the' ' son']" , the composer , was born in H alle , Germany , in 16 85 . He was the son False eighteen-year-old composer George Frideric Handel took up residence 11 [' eighteen', '-', 'year', '-', 'old', ' composer', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+667 143 Name of mother of x -1 Name of mother of George Frideric Handel Dorothea Händel George Frideric Handel "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' H' 'alle' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the' ' son']" , the composer , was born in H alle , Germany , in 16 85 . He was the son False 1727. The composer George Frideric Handel was commissioned to 10 [' 17', '27', '.', ' The', ' composer', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+668 143 Name of mother of x -1 Name of mother of George Frideric Handel Dorothea Händel George Frideric Handel "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' H' 'alle' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the' ' son']" , the composer , was born in H alle , Germany , in 16 85 . He was the son False eighteen-year-old composer George Frideric Handel took up residence 11 [' eighteen', '-', 'year', '-', 'old', ' composer', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+669 143 Name of mother of x -1 Name of mother of George Frideric Handel Dorothea Händel George Frideric Handel "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' H' 'alle' ','
+ ' Germany' ',' ' in' ' 16' '85' '.' ' He' ' was' ' the' ' son']" , the composer , was born in H alle , Germany , in 16 85 . He was the son False period. German-born George Frideric Handel became a British 10 [' period', '.', ' German', '-', 'born', ' George', ' Fr', 'ider', 'ic', ' Hand', 'el']
+670 144 Name of mother of x -1 Name of mother of Charlie Chaplin Hannah Chaplin Charlie Chaplin "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False host to the annual Charlie Chaplin Comedy Film 6 [' host', ' to', ' the', ' annual', ' Charlie', ' Cha', 'plin']
+671 144 Name of mother of x -1 Name of mother of Charlie Chaplin Hannah Chaplin Charlie Chaplin "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False format to the films of Charlie Chaplin and Laurel 7 [' format', ' to', ' the', ' films', ' of', ' Charlie', ' Cha', 'plin']
+672 144 Name of mother of x -1 Name of mother of Charlie Chaplin Hannah Chaplin Charlie Chaplin "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False States after comedian Charlie Chaplin parodied it 5 [' States', ' after', ' comedian', ' Charlie', ' Cha', 'plin']
+673 144 Name of mother of x -1 Name of mother of Charlie Chaplin Hannah Chaplin Charlie Chaplin "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False " Hutchinson, he admired Charlie Chaplin ""to the point of" 6 [' Hutchinson', ',', ' he', ' admired', ' Charlie', ' Cha', 'plin']
+674 144 Name of mother of x -1 Name of mother of Charlie Chaplin Hannah Chaplin Charlie Chaplin "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False was hired by Charlie Chaplin to adapt the 5 [' was', ' hired', ' by', ' Charlie', ' Cha', 'plin']
+675 145 Name of mother of x -1 Name of mother of Bertrand Russell Katharine Russell, Viscountess Amberley Bertrand Russell "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Bert' 'rand' ' Russell' ',' ' the' ' great'
+ ' philosopher']" ", the great philosopher , and the
+
+ Name of mother of Bert rand Russell , the great philosopher" False " The philosopher Bertrand Russell thought Eden ""not" 4 [' The', ' philosopher', ' Bert', 'rand', ' Russell']
+676 145 Name of mother of x -1 Name of mother of Bertrand Russell Katharine Russell, Viscountess Amberley Bertrand Russell "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Bert' 'rand' ' Russell' ',' ' the' ' great'
+ ' philosopher']" ", the great philosopher , and the
+
+ Name of mother of Bert rand Russell , the great philosopher" False " '. There is also a Bertrand Russell quotation: ""Mathematics," 7 "["" '."", ' There', ' is', ' also', ' a', ' Bert', 'rand', ' Russell']"
+677 145 Name of mother of x -1 Name of mother of Bertrand Russell Katharine Russell, Viscountess Amberley Bertrand Russell "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Bert' 'rand' ' Russell' ',' ' the' ' great'
+ ' philosopher']" ", the great philosopher , and the
+
+ Name of mother of Bert rand Russell , the great philosopher" False " Bertrand Russell ===
+" 2 [' Bert', 'rand', ' Russell']
+678 145 Name of mother of x -1 Name of mother of Bertrand Russell Katharine Russell, Viscountess Amberley Bertrand Russell "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Bert' 'rand' ' Russell' ',' ' the' ' great'
+ ' philosopher']" ", the great philosopher , and the
+
+ Name of mother of Bert rand Russell , the great philosopher" False " Russell ===
+" 5 [' Russell', ' ===', 'B', 'ert', 'rand', ' Russell']
+679 145 Name of mother of x -1 Name of mother of Bertrand Russell Katharine Russell, Viscountess Amberley Bertrand Russell "[',' ' the' ' great' ' philosopher' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Bert' 'rand' ' Russell' ',' ' the' ' great'
+ ' philosopher']" ", the great philosopher , and the
+
+ Name of mother of Bert rand Russell , the great philosopher" False humorous comment on a Bertrand Russell quote. Both 6 [' humorous', ' comment', ' on', ' a', ' Bert', 'rand', ' Russell']
+680 146 Name of mother of x -1 Name of mother of George Washington Mary Ball Washington George Washington "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' George' ' Washington' ',' '\n']" ", the
+
+ The following is a list of the names of the children of George Washington ,
+" False House, and George Washington of Virginia 4 [' House', ',', ' and', ' George', ' Washington']
+681 146 Name of mother of x -1 Name of mother of George Washington Mary Ball Washington George Washington "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' George' ' Washington' ',' '\n']" ", the
+
+ The following is a list of the names of the children of George Washington ,
+" False 1 ['George', ' Washington']
+682 146 Name of mother of x -1 Name of mother of George Washington Mary Ball Washington George Washington "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' George' ' Washington' ',' '\n']" ", the
+
+ The following is a list of the names of the children of George Washington ,
+" False other colonies. George Washington grew hemp at 4 [' other', ' colonies', '.', ' George', ' Washington']
+683 146 Name of mother of x -1 Name of mother of George Washington Mary Ball Washington George Washington "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' George' ' Washington' ',' '\n']" ", the
+
+ The following is a list of the names of the children of George Washington ,
+" False 1 ['George', ' Washington']
+684 146 Name of mother of x -1 Name of mother of George Washington Mary Ball Washington George Washington "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' George' ' Washington' ',' '\n']" ", the
+
+ The following is a list of the names of the children of George Washington ,
+" False " Papers of George Washington brought a ""fresh" 3 [' Papers', ' of', ' George', ' Washington']
+685 147 Name of mother of x -1 Name of mother of Edvard Munch Laura Cathrine Munch Edvard Munch "[',' ' the' ' son' ' of' ' a' ' Norwegian' ' immigrant' ',' ' was' ' born'
+ ' in' ' the' ' United' ' States' '.' ' He' ' was' ' a' ' painter' ',']" , the son of a Norwegian immigrant , was born in the United States . He was a painter , False August Strindberg, Edvard Munch and Paul Gauguin. 8 [' August', ' Str', 'ind', 'berg', ',', ' Ed', 'vard', ' M', 'unch']
+686 147 Name of mother of x -1 Name of mother of Edvard Munch Laura Cathrine Munch Edvard Munch "[',' ' the' ' son' ' of' ' a' ' Norwegian' ' immigrant' ',' ' was' ' born'
+ ' in' ' the' ' United' ' States' '.' ' He' ' was' ' a' ' painter' ',']" , the son of a Norwegian immigrant , was born in the United States . He was a painter , False " the figure from the Edvard Munch painting The Scream.
+" 7 [' the', ' figure', ' from', ' the', ' Ed', 'vard', ' M', 'unch']
+687 147 Name of mother of x -1 Name of mother of Edvard Munch Laura Cathrine Munch Edvard Munch "[',' ' the' ' son' ' of' ' a' ' Norwegian' ' immigrant' ',' ' was' ' born'
+ ' in' ' the' ' United' ' States' '.' ' He' ' was' ' a' ' painter' ',']" , the son of a Norwegian immigrant , was born in the United States . He was a painter , False Norwegian Symbolist Edvard Munch (1863 – 1944) would 6 [' Norwegian', ' Symbol', 'ist', ' Ed', 'vard', ' M', 'unch']
+688 147 Name of mother of x -1 Name of mother of Edvard Munch Laura Cathrine Munch Edvard Munch "[',' ' the' ' son' ' of' ' a' ' Norwegian' ' immigrant' ',' ' was' ' born'
+ ' in' ' the' ' United' ' States' '.' ' He' ' was' ' a' ' painter' ',']" , the son of a Norwegian immigrant , was born in the United States . He was a painter , False Norwegian Symbolist Edvard Munch (1863 – 1944) 6 [' Norwegian', ' Symbol', 'ist', ' Ed', 'vard', ' M', 'unch']
+689 147 Name of mother of x -1 Name of mother of Edvard Munch Laura Cathrine Munch Edvard Munch "[',' ' the' ' son' ' of' ' a' ' Norwegian' ' immigrant' ',' ' was' ' born'
+ ' in' ' the' ' United' ' States' '.' ' He' ' was' ' a' ' painter' ',']" , the son of a Norwegian immigrant , was born in the United States . He was a painter , False August Strindberg, Edvard Munch and Paul Gauguin. 8 [' August', ' Str', 'ind', 'berg', ',', ' Ed', 'vard', ' M', 'unch']
+690 148 Name of mother of x -1 Name of mother of Mahatma Gandhi Putlibai Karamchand Gandhi Mahatma Gandhi "[',' ' the' ' great' ' leader' ' of' ' the' ' Indian' ' National'
+ ' Congress' ',' ' was' ' born' ' in' ' 18' '69' '.' '\n' '\n' 'The'
+ ' first']" ", the great leader of the Indian National Congress , was born in 18 69 .
+
+ The first" False Chinatown and the Mahatma Gandhi District. Both areas 6 [' Chinatown', ' and', ' the', ' Mah', 'at', 'ma', ' Gandhi']
+691 148 Name of mother of x -1 Name of mother of Mahatma Gandhi Putlibai Karamchand Gandhi Mahatma Gandhi "[',' ' the' ' great' ' leader' ' of' ' the' ' Indian' ' National'
+ ' Congress' ',' ' was' ' born' ' in' ' 18' '69' '.' '\n' '\n' 'The'
+ ' first']" ", the great leader of the Indian National Congress , was born in 18 69 .
+
+ The first" False Temple, the Mahatma Gandhi Kashi Vidyapith, the 6 [' Temple', ',', ' the', ' Mah', 'at', 'ma', ' Gandhi']
+692 148 Name of mother of x -1 Name of mother of Mahatma Gandhi Putlibai Karamchand Gandhi Mahatma Gandhi "[',' ' the' ' great' ' leader' ' of' ' the' ' Indian' ' National'
+ ' Congress' ',' ' was' ' born' ' in' ' 18' '69' '.' '\n' '\n' 'The'
+ ' first']" ", the great leader of the Indian National Congress , was born in 18 69 .
+
+ The first" False Committee confirmed that Mahatma Gandhi was nominated 6 [' Committee', ' confirmed', ' that', ' Mah', 'at', 'ma', ' Gandhi']
+693 148 Name of mother of x -1 Name of mother of Mahatma Gandhi Putlibai Karamchand Gandhi Mahatma Gandhi "[',' ' the' ' great' ' leader' ' of' ' the' ' Indian' ' National'
+ ' Congress' ',' ' was' ' born' ' in' ' 18' '69' '.' '\n' '\n' 'The'
+ ' first']" ", the great leader of the Indian National Congress , was born in 18 69 .
+
+ The first" False articles relating to Mahatma Gandhi and Sardar Patel. 6 [' articles', ' relating', ' to', ' Mah', 'at', 'ma', ' Gandhi']
+694 148 Name of mother of x -1 Name of mother of Mahatma Gandhi Putlibai Karamchand Gandhi Mahatma Gandhi "[',' ' the' ' great' ' leader' ' of' ' the' ' Indian' ' National'
+ ' Congress' ',' ' was' ' born' ' in' ' 18' '69' '.' '\n' '\n' 'The'
+ ' first']" ", the great leader of the Indian National Congress , was born in 18 69 .
+
+ The first" False Most popularly, Mahatma Gandhi strongly believed 7 [' Most', ' popular', 'ly', ',', ' Mah', 'at', 'ma', ' Gandhi']
+695 149 Name of mother of x -1 Name of mother of Carl Linnaeus Christina Brodersonia Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' tax' 'onomy' '.' '\n' '\n' 'The' ' Lin' 'na' 'ean']" ", the bot an ist , and the father of modern tax onomy .
+
+ The Lin na ean" False introduced by Carl Linnaeus in 1735 also 5 [' introduced', ' by', ' Carl', ' Lin', 'na', 'eus']
+696 149 Name of mother of x -1 Name of mother of Carl Linnaeus Christina Brodersonia Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' tax' 'onomy' '.' '\n' '\n' 'The' ' Lin' 'na' 'ean']" ", the bot an ist , and the father of modern tax onomy .
+
+ The Lin na ean" False 3 ['Carl', ' Lin', 'na', 'eus']
+697 149 Name of mother of x -1 Name of mother of Carl Linnaeus Christina Brodersonia Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' tax' 'onomy' '.' '\n' '\n' 'The' ' Lin' 'na' 'ean']" ", the bot an ist , and the father of modern tax onomy .
+
+ The Lin na ean" False Swedish naturalist Carl Linnaeus who gave it the name 6 [' Swedish', ' natural', 'ist', ' Carl', ' Lin', 'na', 'eus']
+698 149 Name of mother of x -1 Name of mother of Carl Linnaeus Christina Brodersonia Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' tax' 'onomy' '.' '\n' '\n' 'The' ' Lin' 'na' 'ean']" ", the bot an ist , and the father of modern tax onomy .
+
+ The Lin na ean" False longifolia (described by Carl Linnaeus in 1753), 9 [' long', 'if', 'olia', ' (', 'described', ' by', ' Carl', ' Lin', 'na', 'eus']
+699 149 Name of mother of x -1 Name of mother of Carl Linnaeus Christina Brodersonia Carl Linnaeus "[',' ' the' ' bot' 'an' 'ist' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' tax' 'onomy' '.' '\n' '\n' 'The' ' Lin' 'na' 'ean']" ", the bot an ist , and the father of modern tax onomy .
+
+ The Lin na ean" False longifolia (described by Carl Linnaeus in 1753), which 9 [' long', 'if', 'olia', ' (', 'described', ' by', ' Carl', ' Lin', 'na', 'eus']
+700 150 Name of mother of x -1 Name of mother of Miley Cyrus Tish Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Liam' ' Hem'
+ 'sworth' ',' ' are' ' expecting' ' their' ' first' ' child' ' together'
+ '.' '\n']" ", the singer , and her husband , Liam Hem sworth , are expecting their first child together .
+" False " version of ""Feedback."" Miley Cyrus performed a dance" 8 "[' version', ' of', ' ""', 'Feed', 'back', '.""', ' M', 'iley', ' Cyrus']"
+701 150 Name of mother of x -1 Name of mother of Miley Cyrus Tish Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Liam' ' Hem'
+ 'sworth' ',' ' are' ' expecting' ' their' ' first' ' child' ' together'
+ '.' '\n']" ", the singer , and her husband , Liam Hem sworth , are expecting their first child together .
+" False " Montana 2: Meet Miley Cyrus =
+" 6 [' Montana', ' 2', ':', ' Meet', ' M', 'iley', ' Cyrus']
+702 150 Name of mother of x -1 Name of mother of Miley Cyrus Tish Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Liam' ' Hem'
+ 'sworth' ',' ' are' ' expecting' ' their' ' first' ' child' ' together'
+ '.' '\n']" ", the singer , and her husband , Liam Hem sworth , are expecting their first child together .
+" False American singer Miley Cyrus for her first extended 4 [' American', ' singer', ' M', 'iley', ' Cyrus']
+703 150 Name of mother of x -1 Name of mother of Miley Cyrus Tish Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Liam' ' Hem'
+ 'sworth' ',' ' are' ' expecting' ' their' ' first' ' child' ' together'
+ '.' '\n']" ", the singer , and her husband , Liam Hem sworth , are expecting their first child together .
+" False albums Meet Miley Cyrus (2007), Breakout (2008), 4 [' albums', ' Meet', ' M', 'iley', ' Cyrus']
+704 150 Name of mother of x -1 Name of mother of Miley Cyrus Tish Cyrus Miley Cyrus "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' Liam' ' Hem'
+ 'sworth' ',' ' are' ' expecting' ' their' ' first' ' child' ' together'
+ '.' '\n']" ", the singer , and her husband , Liam Hem sworth , are expecting their first child together .
+" False recording artist Miley Cyrus for her second 4 [' recording', ' artist', ' M', 'iley', ' Cyrus']
+705 151 Name of mother of x -1 Name of mother of George Bernard Shaw Lucinda Elizabeth Shaw George Bernard Shaw "[',' ' the' ' author' ' of' ' the' ' play' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' George' ' Bernard' ' Shaw' ',' ' the']" ", the author of the play , and the
+
+ Name of mother of George Bernard Shaw , the" False political allegory, George Bernard Shaw describes 6 [' political', ' alleg', 'ory', ',', ' George', ' Bernard', ' Shaw']
+706 151 Name of mother of x -1 Name of mother of George Bernard Shaw Lucinda Elizabeth Shaw George Bernard Shaw "[',' ' the' ' author' ' of' ' the' ' play' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' George' ' Bernard' ' Shaw' ',' ' the']" ", the author of the play , and the
+
+ Name of mother of George Bernard Shaw , the" False Thus, for example, George Bernard Shaw wrote in The 7 [' Thus', ',', ' for', ' example', ',', ' George', ' Bernard', ' Shaw']
+707 151 Name of mother of x -1 Name of mother of George Bernard Shaw Lucinda Elizabeth Shaw George Bernard Shaw "[',' ' the' ' author' ' of' ' the' ' play' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' George' ' Bernard' ' Shaw' ',' ' the']" ", the author of the play , and the
+
+ Name of mother of George Bernard Shaw , the" False G. Wells and George Bernard Shaw and started 6 [' G', '.', ' Wells', ' and', ' George', ' Bernard', ' Shaw']
+708 151 Name of mother of x -1 Name of mother of George Bernard Shaw Lucinda Elizabeth Shaw George Bernard Shaw "[',' ' the' ' author' ' of' ' the' ' play' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' George' ' Bernard' ' Shaw' ',' ' the']" ", the author of the play , and the
+
+ Name of mother of George Bernard Shaw , the" False after opening. George Bernard Shaw wrote of one appearance: 5 [' after', ' opening', '.', ' George', ' Bernard', ' Shaw']
+709 151 Name of mother of x -1 Name of mother of George Bernard Shaw Lucinda Elizabeth Shaw George Bernard Shaw "[',' ' the' ' author' ' of' ' the' ' play' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' George' ' Bernard' ' Shaw' ',' ' the']" ", the author of the play , and the
+
+ Name of mother of George Bernard Shaw , the" False that included George Bernard Shaw and the Irish 4 [' that', ' included', ' George', ' Bernard', ' Shaw']
+710 152 Name of mother of x -1 Name of mother of George Sand Sophie Victoire Delaborde George Sand "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Sand' ',' '\n' '\n' 'I' ' have' ' not' ' forgotten' ' you' ',']" ", and the
+
+ Name of mother of George Sand ,
+
+ I have not forgotten you ," False ambiguous. Author George Sand was portrayed as 4 [' ambiguous', '.', ' Author', ' George', ' Sand']
+711 152 Name of mother of x -1 Name of mother of George Sand Sophie Victoire Delaborde George Sand "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Sand' ',' '\n' '\n' 'I' ' have' ' not' ' forgotten' ' you' ',']" ", and the
+
+ Name of mother of George Sand ,
+
+ I have not forgotten you ," False the feminist writer George Sand and her lover Jules 4 [' the', ' feminist', ' writer', ' George', ' Sand']
+712 152 Name of mother of x -1 Name of mother of George Sand Sophie Victoire Delaborde George Sand "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Sand' ',' '\n' '\n' 'I' ' have' ' not' ' forgotten' ' you' ',']" ", and the
+
+ Name of mother of George Sand ,
+
+ I have not forgotten you ," False feminist writer George Sand and her lover 3 [' feminist', ' writer', ' George', ' Sand']
+713 152 Name of mother of x -1 Name of mother of George Sand Sophie Victoire Delaborde George Sand "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Sand' ',' '\n' '\n' 'I' ' have' ' not' ' forgotten' ' you' ',']" ", and the
+
+ Name of mother of George Sand ,
+
+ I have not forgotten you ," False relations with George Sand have been fictionalized 3 [' relations', ' with', ' George', ' Sand']
+714 152 Name of mother of x -1 Name of mother of George Sand Sophie Victoire Delaborde George Sand "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Sand' ',' '\n' '\n' 'I' ' have' ' not' ' forgotten' ' you' ',']" ", and the
+
+ Name of mother of George Sand ,
+
+ I have not forgotten you ," False about the life of George Sand that starred 5 [' about', ' the', ' life', ' of', ' George', ' Sand']
+715 153 Name of mother of x -1 Name of mother of Henri de Toulouse-Lautrec Adèle Zoë Tapié de Céleyran Henri de Toulouse-Lautrec "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and his wife , the
+ " False Louis Anquetin and Henri de Toulouse-Lautrec – who painted 13 [' Louis', ' An', 'quet', 'in', ' and', ' Henri', ' de', ' T', 'oul', 'ouse', '-', 'L', 'aut', 'rec']
+716 153 Name of mother of x -1 Name of mother of Henri de Toulouse-Lautrec Adèle Zoë Tapié de Céleyran Henri de Toulouse-Lautrec "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the painter , and his wife , the
+ " False Louis Anquetin and Henri de Toulouse-Lautrec – who painted 13 [' Louis', ' An', 'quet', 'in', ' and', ' Henri', ' de', ' T', 'oul', 'ouse', '-', 'L', 'aut', 'rec']
+717 154 Name of mother of x -1 Name of mother of Victoria Princess Victoria of Saxe-Coburg-Saalfeld Victoria "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False players from the PCHA's Victoria Aristocrats, only 6 "[' players', ' from', ' the', ' P', 'CHA', ""'s"", ' Victoria']"
+718 154 Name of mother of x -1 Name of mother of Victoria Princess Victoria of Saxe-Coburg-Saalfeld Victoria "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False In 2004 a national Victoria Cross and George Cross 4 [' In', ' 2004', ' a', ' national', ' Victoria']
+719 154 Name of mother of x -1 Name of mother of Victoria Princess Victoria of Saxe-Coburg-Saalfeld Victoria "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False exhibition at the Victoria and Albert 3 [' exhibition', ' at', ' the', ' Victoria']
+720 154 Name of mother of x -1 Name of mother of Victoria Princess Victoria of Saxe-Coburg-Saalfeld Victoria "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Tasmania. In the final, Victoria batted first and 6 [' Tasmania', '.', ' In', ' the', ' final', ',', ' Victoria']
+721 154 Name of mother of x -1 Name of mother of Victoria Princess Victoria of Saxe-Coburg-Saalfeld Victoria "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False in the order. When Victoria fielded, Miller 5 [' in', ' the', ' order', '.', ' When', ' Victoria']
+722 155 Name of mother of x -1 Name of mother of Superman Martha Kent Superman "[',' ' and' ' the' ' other' ' one' ' is' ' the' ' mother' ' of'
+ ' Superman' '.' '\n' '\n' 'The' ' first' ' one' ' is' ' the' ' mother'
+ ' of']" ", and the other one is the mother of Superman .
+
+ The first one is the mother of" False to change into his Superman cape before 4 [' to', ' change', ' into', ' his', ' Superman']
+723 155 Name of mother of x -1 Name of mother of Superman Martha Kent Superman "[',' ' and' ' the' ' other' ' one' ' is' ' the' ' mother' ' of'
+ ' Superman' '.' '\n' '\n' 'The' ' first' ' one' ' is' ' the' ' mother'
+ ' of']" ", and the other one is the mother of Superman .
+
+ The first one is the mother of" False " Superman V ===
+" 0 [' Superman']
+724 155 Name of mother of x -1 Name of mother of Superman Martha Kent Superman "[',' ' and' ' the' ' other' ' one' ' is' ' the' ' mother' ' of'
+ ' Superman' '.' '\n' '\n' 'The' ' first' ' one' ' is' ' the' ' mother'
+ ' of']" ", and the other one is the mother of Superman .
+
+ The first one is the mother of" False the post-Crisis Superman, whose Clark 5 [' the', ' post', '-', 'C', 'risis', ' Superman']
+725 155 Name of mother of x -1 Name of mother of Superman Martha Kent Superman "[',' ' and' ' the' ' other' ' one' ' is' ' the' ' mother' ' of'
+ ' Superman' '.' '\n' '\n' 'The' ' first' ' one' ' is' ' the' ' mother'
+ ' of']" ", and the other one is the mother of Superman .
+
+ The first one is the mother of" False which utilizes Superman as a Christ figure 2 [' which', ' utilizes', ' Superman']
+726 155 Name of mother of x -1 Name of mother of Superman Martha Kent Superman "[',' ' and' ' the' ' other' ' one' ' is' ' the' ' mother' ' of'
+ ' Superman' '.' '\n' '\n' 'The' ' first' ' one' ' is' ' the' ' mother'
+ ' of']" ", and the other one is the mother of Superman .
+
+ The first one is the mother of" False the debut of Superman Returns, the Academy 3 [' the', ' debut', ' of', ' Superman']
+727 156 Name of mother of x -1 Name of mother of Joshua Reynolds Theophila Potter Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Joshua' ' Reynolds' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Joshua Reynolds , the painter , and the" False years later, Sir Joshua Reynolds painted a portrait 5 [' years', ' later', ',', ' Sir', ' Joshua', ' Reynolds']
+728 156 Name of mother of x -1 Name of mother of Joshua Reynolds Theophila Potter Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Joshua' ' Reynolds' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Joshua Reynolds , the painter , and the" False Burney, and artists Sir Joshua Reynolds and Sir Thomas 7 [' Burn', 'ey', ',', ' and', ' artists', ' Sir', ' Joshua', ' Reynolds']
+729 156 Name of mother of x -1 Name of mother of Joshua Reynolds Theophila Potter Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Joshua' ' Reynolds' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Joshua Reynolds , the painter , and the" False strongly influenced by Joshua Reynolds (1723 – 1792), 4 [' strongly', ' influenced', ' by', ' Joshua', ' Reynolds']
+730 156 Name of mother of x -1 Name of mother of Joshua Reynolds Theophila Potter Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Joshua' ' Reynolds' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Joshua Reynolds , the painter , and the" False " was described by Joshua Reynolds as ""like Caravaggio" 4 [' was', ' described', ' by', ' Joshua', ' Reynolds']
+731 156 Name of mother of x -1 Name of mother of Joshua Reynolds Theophila Potter Joshua Reynolds "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Joshua' ' Reynolds' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Joshua Reynolds , the painter , and the" False including fifteen Joshua Reynolds portraits and a miniature 3 [' including', ' fifteen', ' Joshua', ' Reynolds']
+732 157 Name of mother of x -1 Name of mother of Aldous Huxley Julia Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' ',' ' and' ' the'
+ ' author' ' of' ' the' ' book' ' that' ' inspired' ' the' ' movie' ','
+ ' Brave']" , the author of Brave New World , and the author of the book that inspired the movie , Brave False people like Aldous Huxley and Alfred Adler. 6 [' people', ' like', ' Ald', 'ous', ' H', 'ux', 'ley']
+733 157 Name of mother of x -1 Name of mother of Aldous Huxley Julia Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' ',' ' and' ' the'
+ ' author' ' of' ' the' ' book' ' that' ' inspired' ' the' ' movie' ','
+ ' Brave']" , the author of Brave New World , and the author of the book that inspired the movie , Brave False Leary, Alan Watts, Aldous Huxley and Arthur Koestler, 10 [' Lear', 'y', ',', ' Alan', ' Watts', ',', ' Ald', 'ous', ' H', 'ux', 'ley']
+734 157 Name of mother of x -1 Name of mother of Aldous Huxley Julia Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' ',' ' and' ' the'
+ ' author' ' of' ' the' ' book' ' that' ' inspired' ' the' ' movie' ','
+ ' Brave']" , the author of Brave New World , and the author of the book that inspired the movie , Brave False influenced the work of Aldous Huxley and Robert Graves. 8 [' influenced', ' the', ' work', ' of', ' Ald', 'ous', ' H', 'ux', 'ley']
+735 157 Name of mother of x -1 Name of mother of Aldous Huxley Julia Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' ',' ' and' ' the'
+ ' author' ' of' ' the' ' book' ' that' ' inspired' ' the' ' movie' ','
+ ' Brave']" , the author of Brave New World , and the author of the book that inspired the movie , Brave False English novelist Aldous Huxley and American screenwriter 6 [' English', ' novelist', ' Ald', 'ous', ' H', 'ux', 'ley']
+736 157 Name of mother of x -1 Name of mother of Aldous Huxley Julia Huxley Aldous Huxley "[',' ' the' ' author' ' of' ' Brave' ' New' ' World' ',' ' and' ' the'
+ ' author' ' of' ' the' ' book' ' that' ' inspired' ' the' ' movie' ','
+ ' Brave']" , the author of Brave New World , and the author of the book that inspired the movie , Brave False influenced the work of Aldous Huxley and Robert Graves. 8 [' influenced', ' the', ' work', ' of', ' Ald', 'ous', ' H', 'ux', 'ley']
+737 158 Name of mother of x -1 Name of mother of Jorge Luis Borges Leonor Rita Acevedo Suárez Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you' ' about'
+ ' the']" ", the Argentine writer , who died in 1986 .
+
+ The first thing that strikes you about the" False ontological labyrinth that Jorge Luis Borges might have made. 7 [' ont', 'ological', ' labyrinth', ' that', ' Jorge', ' Luis', ' Borg', 'es']
+738 158 Name of mother of x -1 Name of mother of Jorge Luis Borges Leonor Rita Acevedo Suárez Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you' ' about'
+ ' the']" ", the Argentine writer , who died in 1986 .
+
+ The first thing that strikes you about the" False " Carlos Fuentes, Jorge Luis Borges and Mario Vargas Llosa.
+" 8 [' Carlos', ' Fu', 'ent', 'es', ',', ' Jorge', ' Luis', ' Borg', 'es']
+739 158 Name of mother of x -1 Name of mother of Jorge Luis Borges Leonor Rita Acevedo Suárez Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you' ' about'
+ ' the']" ", the Argentine writer , who died in 1986 .
+
+ The first thing that strikes you about the" False ontological labyrinth that Jorge Luis Borges might have made. 7 [' ont', 'ological', ' labyrinth', ' that', ' Jorge', ' Luis', ' Borg', 'es']
+740 158 Name of mother of x -1 Name of mother of Jorge Luis Borges Leonor Rita Acevedo Suárez Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you' ' about'
+ ' the']" ", the Argentine writer , who died in 1986 .
+
+ The first thing that strikes you about the" False artists such as Jorge Luis Borges and Julio Cortázar. 6 [' artists', ' such', ' as', ' Jorge', ' Luis', ' Borg', 'es']
+741 158 Name of mother of x -1 Name of mother of Jorge Luis Borges Leonor Rita Acevedo Suárez Jorge Luis Borges "[',' ' the' ' Argentine' ' writer' ',' ' who' ' died' ' in' ' 1986' '.'
+ '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you' ' about'
+ ' the']" ", the Argentine writer , who died in 1986 .
+
+ The first thing that strikes you about the" False and fabulist, Jorge Luis Borges (1899 – 1986), 8 [' and', ' fab', 'ul', 'ist', ',', ' Jorge', ' Luis', ' Borg', 'es']
+742 159 Name of mother of x -1 Name of mother of Diego Velázquez Jerónima Velázquez Diego Velázquez "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Diego' ' Vel' 'á' 'z' 'quez' ',' ' the' ' painter']" ", the painter , and the
+
+ Name of mother of Diego Vel á z quez , the painter" False and the Spanish Diego Velázquez as influences 7 [' and', ' the', ' Spanish', ' Diego', ' Vel', 'á', 'z', 'quez']
+743 159 Name of mother of x -1 Name of mother of Diego Velázquez Jerónima Velázquez Diego Velázquez "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Diego' ' Vel' 'á' 'z' 'quez' ',' ' the' ' painter']" ", the painter , and the
+
+ Name of mother of Diego Vel á z quez , the painter" False participated in Diego Velázquez de Cuéllar's and Pánfilo 6 [' participated', ' in', ' Diego', ' Vel', 'á', 'z', 'quez']
+744 159 Name of mother of x -1 Name of mother of Diego Velázquez Jerónima Velázquez Diego Velázquez "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Diego' ' Vel' 'á' 'z' 'quez' ',' ' the' ' painter']" ", the painter , and the
+
+ Name of mother of Diego Vel á z quez , the painter" False participated in Diego Velázquez de Cuéllar's and 6 [' participated', ' in', ' Diego', ' Vel', 'á', 'z', 'quez']
+745 159 Name of mother of x -1 Name of mother of Diego Velázquez Jerónima Velázquez Diego Velázquez "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Diego' ' Vel' 'á' 'z' 'quez' ',' ' the' ' painter']" ", the painter , and the
+
+ Name of mother of Diego Vel á z quez , the painter" False dwarves of Diego Velázquez to Pablo Picasso's 6 [' dwarves', ' of', ' Diego', ' Vel', 'á', 'z', 'quez']
+746 159 Name of mother of x -1 Name of mother of Diego Velázquez Jerónima Velázquez Diego Velázquez "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Diego' ' Vel' 'á' 'z' 'quez' ',' ' the' ' painter']" ", the painter , and the
+
+ Name of mother of Diego Vel á z quez , the painter" False participated in Diego Velázquez de Cuéllar's and 6 [' participated', ' in', ' Diego', ' Vel', 'á', 'z', 'quez']
+747 160 Name of mother of x -1 Name of mother of Jacques-Louis David Marie-Geneviève Buron Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Jacques' '-' 'Louis' ' David' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Jacques - Louis David , the painter ," False influence of Jacques-Louis David can be seen in the 5 [' influence', ' of', ' Jacques', '-', 'Louis', ' David']
+748 160 Name of mother of x -1 Name of mother of Jacques-Louis David Marie-Geneviève Buron Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Jacques' '-' 'Louis' ' David' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Jacques - Louis David , the painter ," False monumental approach of Jacques-Louis David (1748 – 1825) and 6 [' monumental', ' approach', ' of', ' Jacques', '-', 'Louis', ' David']
+749 160 Name of mother of x -1 Name of mother of Jacques-Louis David Marie-Geneviève Buron Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Jacques' '-' 'Louis' ' David' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Jacques - Louis David , the painter ," False approach of Jacques-Louis David (1748 – 1825) 5 [' approach', ' of', ' Jacques', '-', 'Louis', ' David']
+750 160 Name of mother of x -1 Name of mother of Jacques-Louis David Marie-Geneviève Buron Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Jacques' '-' 'Louis' ' David' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Jacques - Louis David , the painter ," False on his fingers. Jacques-Louis David also explored 7 [' on', ' his', ' fingers', '.', ' Jacques', '-', 'Louis', ' David']
+751 160 Name of mother of x -1 Name of mother of Jacques-Louis David Marie-Geneviève Buron Jacques-Louis David "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Jacques' '-' 'Louis' ' David' ',' ' the' ' painter' ',']" ", the painter , and the
+
+ Name of mother of Jacques - Louis David , the painter ," False " French artist Jacques-Louis David (1748 – 1825).
+" 5 [' French', ' artist', ' Jacques', '-', 'Louis', ' David']
+752 161 Name of mother of x -1 Name of mother of Jane Austen Cassandra Austen Jane Austen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Jane' ' Aust'
+ 'en' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Jane Aust en , and the
+
+ Name of mother" False 2 ['Jane', ' Aust', 'en']
+753 161 Name of mother of x -1 Name of mother of Jane Austen Cassandra Austen Jane Austen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Jane' ' Aust'
+ 'en' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Jane Aust en , and the
+
+ Name of mother" False " Stillman's Jane Austen comedy Love & Friendship.
+" 5 "[' Still', 'man', ""'s"", ' Jane', ' Aust', 'en']"
+754 161 Name of mother of x -1 Name of mother of Jane Austen Cassandra Austen Jane Austen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Jane' ' Aust'
+ 'en' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Jane Aust en , and the
+
+ Name of mother" False food was what Jane Austen and her contemporaries 5 [' food', ' was', ' what', ' Jane', ' Aust', 'en']
+755 161 Name of mother of x -1 Name of mother of Jane Austen Cassandra Austen Jane Austen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Jane' ' Aust'
+ 'en' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Jane Aust en , and the
+
+ Name of mother" False in 2010 for the Jane Austen Society of 6 [' in', ' 2010', ' for', ' the', ' Jane', ' Aust', 'en']
+756 161 Name of mother of x -1 Name of mother of Jane Austen Cassandra Austen Jane Austen "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Jane' ' Aust'
+ 'en' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Jane Aust en , and the
+
+ Name of mother" False " A Companion to Jane Austen observed that the ""physicality""" 5 [' A', ' Companion', ' to', ' Jane', ' Aust', 'en']
+757 162 Name of mother of x -1 Name of mother of Galileo Galilei Giulia Ammannati Galileo Galilei "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first' ' of'
+ ' the']" ", the father of modern science , and the father of modern physics .
+
+ The first of the" False discovered in 1610 by Galileo Galilei and was named 7 [' discovered', ' in', ' 16', '10', ' by', ' Galileo', ' Galile', 'i']
+758 162 Name of mother of x -1 Name of mother of Galileo Galilei Giulia Ammannati Galileo Galilei "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first' ' of'
+ ' the']" ", the father of modern science , and the father of modern physics .
+
+ The first of the" False from within. Galileo Galilei first resolved 5 [' from', ' within', '.', ' Galileo', ' Galile', 'i']
+759 162 Name of mother of x -1 Name of mother of Galileo Galilei Giulia Ammannati Galileo Galilei "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first' ' of'
+ ' the']" ", the father of modern science , and the father of modern physics .
+
+ The first of the" False credited to Galileo Galilei in a February 4 [' credited', ' to', ' Galileo', ' Galile', 'i']
+760 162 Name of mother of x -1 Name of mother of Galileo Galilei Giulia Ammannati Galileo Galilei "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first' ' of'
+ ' the']" ", the father of modern science , and the father of modern physics .
+
+ The first of the" False of Io was made by Galileo Galilei on 7 January 7 [' of', ' Io', ' was', ' made', ' by', ' Galileo', ' Galile', 'i']
+761 162 Name of mother of x -1 Name of mother of Galileo Galilei Giulia Ammannati Galileo Galilei "[',' ' the' ' father' ' of' ' modern' ' science' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' physics' '.' '\n' '\n' 'The' ' first' ' of'
+ ' the']" ", the father of modern science , and the father of modern physics .
+
+ The first of the" False Thomas Harriot, Galileo Galilei and other astronomers. 6 [' Thomas', ' Har', 'riot', ',', ' Galileo', ' Galile', 'i']
+762 163 Name of mother of x -1 Name of mother of Joe Biden Jean Biden Joe Biden "[',' ' the' ' former' ' vice' ' president' ',' ' and' ' his' ' wife' ','
+ ' Jill' ' Biden' ',' ' are' ' both' ' Democrats' '.' '\n' '\n' 'The']" ", the former vice president , and his wife , Jill Biden , are both Democrats .
+
+ The" False " Obama eventually chose Joe Biden as his running mate.
+" 4 [' Obama', ' eventually', ' chose', ' Joe', ' Biden']
+763 163 Name of mother of x -1 Name of mother of Joe Biden Jean Biden Joe Biden "[',' ' the' ' former' ' vice' ' president' ',' ' and' ' his' ' wife' ','
+ ' Jill' ' Biden' ',' ' are' ' both' ' Democrats' '.' '\n' '\n' 'The']" ", the former vice president , and his wife , Jill Biden , are both Democrats .
+
+ The" False Obama, Vice President Joe Biden and Dr. Jill 5 [' Obama', ',', ' Vice', ' President', ' Joe', ' Biden']
+764 163 Name of mother of x -1 Name of mother of Joe Biden Jean Biden Joe Biden "[',' ' the' ' former' ' vice' ' president' ',' ' and' ' his' ' wife' ','
+ ' Jill' ' Biden' ',' ' are' ' both' ' Democrats' '.' '\n' '\n' 'The']" ", the former vice president , and his wife , Jill Biden , are both Democrats .
+
+ The" False Vice President Joe Biden and Secretary 3 [' Vice', ' President', ' Joe', ' Biden']
+765 163 Name of mother of x -1 Name of mother of Joe Biden Jean Biden Joe Biden "[',' ' the' ' former' ' vice' ' president' ',' ' and' ' his' ' wife' ','
+ ' Jill' ' Biden' ',' ' are' ' both' ' Democrats' '.' '\n' '\n' 'The']" ", the former vice president , and his wife , Jill Biden , are both Democrats .
+
+ The" False Obama and Joe Biden were formally 3 [' Obama', ' and', ' Joe', ' Biden']
+766 163 Name of mother of x -1 Name of mother of Joe Biden Jean Biden Joe Biden "[',' ' the' ' former' ' vice' ' president' ',' ' and' ' his' ' wife' ','
+ ' Jill' ' Biden' ',' ' are' ' both' ' Democrats' '.' '\n' '\n' 'The']" ", the former vice president , and his wife , Jill Biden , are both Democrats .
+
+ The" False by Vice President Joe Biden on January 3, 2013. 4 [' by', ' Vice', ' President', ' Joe', ' Biden']
+767 164 Name of mother of x -1 Name of mother of H. P. Lovecraft Sarah Susan Phillips Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' ',' ' and'
+ ' the' ' creator' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n']" ", the author of the Cthulhu Myth os , and the creator of the Cthulhu Myth os .
+
+" False elongated caricature of H. P. Lovecraft designed by cartoonist 8 [' elong', 'ated', ' caricature', ' of', ' H', '.', ' P', '.', ' Lovecraft']
+768 164 Name of mother of x -1 Name of mother of H. P. Lovecraft Sarah Susan Phillips Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' ',' ' and'
+ ' the' ' creator' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n']" ", the author of the Cthulhu Myth os , and the creator of the Cthulhu Myth os .
+
+" False E. Howard, H. P. Lovecraft and Arthur C. 8 [' E', '.', ' Howard', ',', ' H', '.', ' P', '.', ' Lovecraft']
+769 164 Name of mother of x -1 Name of mother of H. P. Lovecraft Sarah Susan Phillips Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' ',' ' and'
+ ' the' ' creator' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n']" ", the author of the Cthulhu Myth os , and the creator of the Cthulhu Myth os .
+
+" False fantasy author H. P. Lovecraft in June 1920. 6 [' fantasy', ' author', ' H', '.', ' P', '.', ' Lovecraft']
+770 164 Name of mother of x -1 Name of mother of H. P. Lovecraft Sarah Susan Phillips Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' ',' ' and'
+ ' the' ' creator' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n']" ", the author of the Cthulhu Myth os , and the creator of the Cthulhu Myth os .
+
+" False influenced both H. P. Lovecraft and A. Merritt, both 6 [' influenced', ' both', ' H', '.', ' P', '.', ' Lovecraft']
+771 164 Name of mother of x -1 Name of mother of H. P. Lovecraft Sarah Susan Phillips Lovecraft H. P. Lovecraft "[',' ' the' ' author' ' of' ' the' ' Cthulhu' ' Myth' 'os' ',' ' and'
+ ' the' ' creator' ' of' ' the' ' Cthulhu' ' Myth' 'os' '.' '\n' '\n']" ", the author of the Cthulhu Myth os , and the creator of the Cthulhu Myth os .
+
+" False influenced both H. P. Lovecraft and A. Merritt, 6 [' influenced', ' both', ' H', '.', ' P', '.', ' Lovecraft']
+772 165 Name of mother of x -1 Name of mother of T. S. Eliot Charlotte Champe Stearns T. S. Eliot "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Olivia to support T. S. Eliot and James Joyce. 7 [' Olivia', ' to', ' support', ' T', '.', ' S', '.', ' Eliot']
+773 165 Name of mother of x -1 Name of mother of T. S. Eliot Charlotte Champe Stearns T. S. Eliot "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Arthurian plays, and T. S. Eliot alludes to the Arthur 9 [' Arthur', 'ian', ' plays', ',', ' and', ' T', '.', ' S', '.', ' Eliot']
+774 165 Name of mother of x -1 Name of mother of T. S. Eliot Charlotte Champe Stearns T. S. Eliot "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False humanities, such as T. S. Eliot and George F. 8 [' humanities', ',', ' such', ' as', ' T', '.', ' S', '.', ' Eliot']
+775 165 Name of mother of x -1 Name of mother of T. S. Eliot Charlotte Champe Stearns T. S. Eliot "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False the work of T. S. Eliot and Ezra Pound. Kilmer's 7 [' the', ' work', ' of', ' T', '.', ' S', '.', ' Eliot']
+776 165 Name of mother of x -1 Name of mother of T. S. Eliot Charlotte Champe Stearns T. S. Eliot "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Birkbeck Hill and T. S. Eliot came to regard 9 [' Bir', 'k', 'beck', ' Hill', ' and', ' T', '.', ' S', '.', ' Eliot']
+777 166 Name of mother of x -1 Name of mother of Walter Scott Anne Rutherford Walter Scott "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' events' ' in' ' the' ' life' ' of' ' Walter'
+ ' Scott']" ", the
+
+ The following is a list of the most important events in the life of Walter Scott" False Midleton. Sir Walter Scott includes this 5 [' Mid', 'leton', '.', ' Sir', ' Walter', ' Scott']
+778 166 Name of mother of x -1 Name of mother of Walter Scott Anne Rutherford Walter Scott "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' events' ' in' ' the' ' life' ' of' ' Walter'
+ ' Scott']" ", the
+
+ The following is a list of the most important events in the life of Walter Scott" False October 1897, the son of Walter Scott Marshall and his 7 [' October', ' 1897', ',', ' the', ' son', ' of', ' Walter', ' Scott']
+779 166 Name of mother of x -1 Name of mother of Walter Scott Anne Rutherford Walter Scott "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' events' ' in' ' the' ' life' ' of' ' Walter'
+ ' Scott']" ", the
+
+ The following is a list of the most important events in the life of Walter Scott" False (1823) by Sir Walter Scott is a historical novel 7 [' (', '18', '23', ')', ' by', ' Sir', ' Walter', ' Scott']
+780 166 Name of mother of x -1 Name of mother of Walter Scott Anne Rutherford Walter Scott "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' events' ' in' ' the' ' life' ' of' ' Walter'
+ ' Scott']" ", the
+
+ The following is a list of the most important events in the life of Walter Scott" False In August 1814, Sir Walter Scott visited Dunvegan 7 [' In', ' August', ' 18', '14', ',', ' Sir', ' Walter', ' Scott']
+781 166 Name of mother of x -1 Name of mother of Walter Scott Anne Rutherford Walter Scott "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' most' ' important' ' events' ' in' ' the' ' life' ' of' ' Walter'
+ ' Scott']" ", the
+
+ The following is a list of the most important events in the life of Walter Scott" False Abbotsford House for Walter Scott from 1816, and a 6 [' Ab', 'bots', 'ford', ' House', ' for', ' Walter', ' Scott']
+782 167 Name of mother of x -1 Name of mother of Kate Winslet Sally Ann Bridges Kate Winslet "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False support to Kate Winslet ’ s Golden Hat 4 [' support', ' to', ' Kate', ' Wins', 'let']
+783 167 Name of mother of x -1 Name of mother of Kate Winslet Sally Ann Bridges Kate Winslet "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False lent her support to Kate Winslet ’ s Golden Hat 6 [' lent', ' her', ' support', ' to', ' Kate', ' Wins', 'let']
+784 167 Name of mother of x -1 Name of mother of Kate Winslet Sally Ann Bridges Kate Winslet "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Helen Mirren, Kate Winslet and Daniel Day-Lewis. 6 [' Helen', ' Mir', 'ren', ',', ' Kate', ' Wins', 'let']
+785 167 Name of mother of x -1 Name of mother of Kate Winslet Sally Ann Bridges Kate Winslet "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Leonardo DiCaprio and Kate Winslet as members of 7 [' Leonardo', ' Di', 'Cap', 'rio', ' and', ' Kate', ' Wins', 'let']
+786 167 Name of mother of x -1 Name of mother of Kate Winslet Sally Ann Bridges Kate Winslet "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False her support to Kate Winslet ’ s Golden Hat Foundation 5 [' her', ' support', ' to', ' Kate', ' Wins', 'let']
+787 168 Name of mother of x -1 Name of mother of Francis Bacon Anne Bacon Francis Bacon "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Francis' ' Bacon' ','
+ ' the' ' famous' ' philosopher' ',' ' and' ' the' '\n' '\n' '1' '.']" ", the son of the late Sir Francis Bacon , the famous philosopher , and the
+
+ 1 ." False philosopher and scientist Francis Bacon (1561 – 1626), 4 [' philosopher', ' and', ' scientist', ' Francis', ' Bacon']
+788 168 Name of mother of x -1 Name of mother of Francis Bacon Anne Bacon Francis Bacon "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Francis' ' Bacon' ','
+ ' the' ' famous' ' philosopher' ',' ' and' ' the' '\n' '\n' '1' '.']" ", the son of the late Sir Francis Bacon , the famous philosopher , and the
+
+ 1 ." False As an artist, Francis Bacon was a late starter. 5 [' As', ' an', ' artist', ',', ' Francis', ' Bacon']
+789 168 Name of mother of x -1 Name of mother of Francis Bacon Anne Bacon Francis Bacon "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Francis' ' Bacon' ','
+ ' the' ' famous' ' philosopher' ',' ' and' ' the' '\n' '\n' '1' '.']" ", the son of the late Sir Francis Bacon , the famous philosopher , and the
+
+ 1 ." False might be evidence that Francis Bacon wrote the plays 5 [' might', ' be', ' evidence', ' that', ' Francis', ' Bacon']
+790 168 Name of mother of x -1 Name of mother of Francis Bacon Anne Bacon Francis Bacon "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Francis' ' Bacon' ','
+ ' the' ' famous' ' philosopher' ',' ' and' ' the' '\n' '\n' '1' '.']" ", the son of the late Sir Francis Bacon , the famous philosopher , and the
+
+ 1 ." False 2 ['Franc', 'is', ' Bacon']
+791 168 Name of mother of x -1 Name of mother of Francis Bacon Anne Bacon Francis Bacon "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Francis' ' Bacon' ','
+ ' the' ' famous' ' philosopher' ',' ' and' ' the' '\n' '\n' '1' '.']" ", the son of the late Sir Francis Bacon , the famous philosopher , and the
+
+ 1 ." False William Whewell and Francis Bacon on the theology 5 [' William', ' Whe', 'well', ' and', ' Francis', ' Bacon']
+792 169 Name of mother of x -1 Name of mother of Nelson Mandela Nosekeni Fanny Nelson Mandela "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the father of the nation , and the father of the nation .
+
+ The first time I" False Zuma, although the Nelson Mandela Foundation 6 [' Z', 'uma', ',', ' although', ' the', ' Nelson', ' Mandela']
+793 169 Name of mother of x -1 Name of mother of Nelson Mandela Nosekeni Fanny Nelson Mandela "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the father of the nation , and the father of the nation .
+
+ The first time I" False Treason Trial of Nelson Mandela and others in 1956, 5 [' Tre', 'ason', ' Trial', ' of', ' Nelson', ' Mandela']
+794 169 Name of mother of x -1 Name of mother of Nelson Mandela Nosekeni Fanny Nelson Mandela "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the father of the nation , and the father of the nation .
+
+ The first time I" False South Africa Nelson Mandela in the late 1990s. 3 [' South', ' Africa', ' Nelson', ' Mandela']
+795 169 Name of mother of x -1 Name of mother of Nelson Mandela Nosekeni Fanny Nelson Mandela "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the father of the nation , and the father of the nation .
+
+ The first time I" False period, including Nelson Mandela and the 1976 4 [' period', ',', ' including', ' Nelson', ' Mandela']
+796 169 Name of mother of x -1 Name of mother of Nelson Mandela Nosekeni Fanny Nelson Mandela "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' first' ' time' ' I']" ", the father of the nation , and the father of the nation .
+
+ The first time I" False inquiry. In 1999, Nelson Mandela awarded Ames the 6 [' inquiry', '.', ' In', ' 1999', ',', ' Nelson', ' Mandela']
+797 170 Name of mother of x -1 Name of mother of Miguel de Cervantes Leonor de Cortinas Miguel de Cervantes "[',' ' the' ' author' ' of' ' the' ' _' 'Qu' 'ix' 'ote' '_' ',' ' and'
+ ' the' ' _' 'Ex' 'empl' 'ary' ' Life' '_' ',']" , the author of the _ Qu ix ote _ , and the _ Ex empl ary Life _ , False selected by author Miguel de Cervantes as the first 7 [' selected', ' by', ' author', ' Miguel', ' de', ' C', 'erv', 'antes']
+798 170 Name of mother of x -1 Name of mother of Miguel de Cervantes Leonor de Cortinas Miguel de Cervantes "[',' ' the' ' author' ' of' ' the' ' _' 'Qu' 'ix' 'ote' '_' ',' ' and'
+ ' the' ' _' 'Ex' 'empl' 'ary' ' Life' '_' ',']" , the author of the _ Qu ix ote _ , and the _ Ex empl ary Life _ , False 6 ['M', 'ig', 'uel', ' de', ' C', 'erv', 'antes']
+799 170 Name of mother of x -1 Name of mother of Miguel de Cervantes Leonor de Cortinas Miguel de Cervantes "[',' ' the' ' author' ' of' ' the' ' _' 'Qu' 'ix' 'ote' '_' ',' ' and'
+ ' the' ' _' 'Ex' 'empl' 'ary' ' Life' '_' ',']" , the author of the _ Qu ix ote _ , and the _ Ex empl ary Life _ , False winning the Premio Miguel de Cervantes in 1989, Spanish 8 [' winning', ' the', ' Prem', 'io', ' Miguel', ' de', ' C', 'erv', 'antes']
+800 170 Name of mother of x -1 Name of mother of Miguel de Cervantes Leonor de Cortinas Miguel de Cervantes "[',' ' the' ' author' ' of' ' the' ' _' 'Qu' 'ix' 'ote' '_' ',' ' and'
+ ' the' ' _' 'Ex' 'empl' 'ary' ' Life' '_' ',']" , the author of the _ Qu ix ote _ , and the _ Ex empl ary Life _ , False winning the Premio Miguel de Cervantes in 1989, Spanish 8 [' winning', ' the', ' Prem', 'io', ' Miguel', ' de', ' C', 'erv', 'antes']
+801 170 Name of mother of x -1 Name of mother of Miguel de Cervantes Leonor de Cortinas Miguel de Cervantes "[',' ' the' ' author' ' of' ' the' ' _' 'Qu' 'ix' 'ote' '_' ',' ' and'
+ ' the' ' _' 'Ex' 'empl' 'ary' ' Life' '_' ',']" , the author of the _ Qu ix ote _ , and the _ Ex empl ary Life _ , False received is the 1994 Miguel de Cervantes Prize, considered 8 [' received', ' is', ' the', ' 1994', ' Miguel', ' de', ' C', 'erv', 'antes']
+802 171 Name of mother of x -1 Name of mother of Yoko Ono Isoko Ono Yoko Ono "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' and' ' the' ' artist' ""'s"" ' father' ',' ' and' ' the' ' artist' ""'s""]" , the artist , and the artist 's mother , and the artist 's father , and the artist 's False McCartney when Lennon and Yoko Ono were also invited 7 [' McCartney', ' when', ' Lennon', ' and', ' Y', 'oko', ' On', 'o']
+803 171 Name of mother of x -1 Name of mother of Yoko Ono Isoko Ono Yoko Ono "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' and' ' the' ' artist' ""'s"" ' father' ',' ' and' ' the' ' artist' ""'s""]" , the artist , and the artist 's mother , and the artist 's father , and the artist 's False discuss the games. Yoko Ono and Olivia Harrison 7 [' discuss', ' the', ' games', '.', ' Y', 'oko', ' On', 'o']
+804 171 Name of mother of x -1 Name of mother of Yoko Ono Isoko Ono Yoko Ono "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' and' ' the' ' artist' ""'s"" ' father' ',' ' and' ' the' ' artist' ""'s""]" , the artist , and the artist 's mother , and the artist 's father , and the artist 's False Jane, part Yoko Ono for Generation X, and 6 [' Jane', ',', ' part', ' Y', 'oko', ' On', 'o']
+805 171 Name of mother of x -1 Name of mother of Yoko Ono Isoko Ono Yoko Ono "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' and' ' the' ' artist' ""'s"" ' father' ',' ' and' ' the' ' artist' ""'s""]" , the artist , and the artist 's mother , and the artist 's father , and the artist 's False " = Yoko Ono =
+" 4 [' =', ' Y', 'oko', ' On', 'o']
+806 171 Name of mother of x -1 Name of mother of Yoko Ono Isoko Ono Yoko Ono "[',' ' the' ' artist' ',' ' and' ' the' ' artist' ""'s"" ' mother' ','
+ ' and' ' the' ' artist' ""'s"" ' father' ',' ' and' ' the' ' artist' ""'s""]" , the artist , and the artist 's mother , and the artist 's father , and the artist 's False territory, with Yoko Ono reciting some 6 [' territory', ',', ' with', ' Y', 'oko', ' On', 'o']
+807 172 Name of mother of x -1 Name of mother of Pyotr Ilyich Tchaikovsky Aleksandra Tchaikovskaya Pyotr Ilyich Tchaikovsky "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Isaac Albéniz, Pyotr Ilyich Tchaikovsky and Sergei 13 [' Isaac', ' Alb', 'én', 'iz', ',', ' Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+808 172 Name of mother of x -1 Name of mother of Pyotr Ilyich Tchaikovsky Aleksandra Tchaikovskaya Pyotr Ilyich Tchaikovsky "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Dvořák, Isaac Albéniz, Pyotr Ilyich Tchaikovsky and Sergei 20 [' D', 'vo', '�', '�', 'á', 'k', ',', ' Isaac', ' Alb', 'én', 'iz', ',', ' Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+809 172 Name of mother of x -1 Name of mother of Pyotr Ilyich Tchaikovsky Aleksandra Tchaikovskaya Pyotr Ilyich Tchaikovsky "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Isaac Albéniz, Pyotr Ilyich Tchaikovsky and Sergei Rachmaninoff, 13 [' Isaac', ' Alb', 'én', 'iz', ',', ' Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+810 172 Name of mother of x -1 Name of mother of Pyotr Ilyich Tchaikovsky Aleksandra Tchaikovskaya Pyotr Ilyich Tchaikovsky "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " Minor, Op. 30"" by Pyotr Ilyich Tchaikovsky and ""String Quartet" 15 "[' Minor', ',', ' Op', '.', ' 30', '""', ' by', ' Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']"
+811 172 Name of mother of x -1 Name of mother of Pyotr Ilyich Tchaikovsky Aleksandra Tchaikovskaya Pyotr Ilyich Tchaikovsky "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 8 ['Py', 'ot', 'r', ' Ily', 'ich', ' T', 'cha', 'ik', 'ovsky']
+812 173 Name of mother of x -1 Name of mother of John Cage Crete Cage John Cage "[',' ' the' ' father' ' of' ' modern' ' music' ',' ' and' ' the' ' father'
+ ' of' ' the' ' av' 'ant' '-' 'gard' 'e' ',' ' was' ' born']" , the father of modern music , and the father of the av ant - gard e , was born False avant-garde composers John Cage and Karlheinz Stockhausen. 8 [' av', 'ant', '-', 'gard', 'e', ' compos', 'ers', ' John', ' Cage']
+813 173 Name of mother of x -1 Name of mother of John Cage Crete Cage John Cage "[',' ' the' ' father' ' of' ' modern' ' music' ',' ' and' ' the' ' father'
+ ' of' ' the' ' av' 'ant' '-' 'gard' 'e' ',' ' was' ' born']" , the father of modern music , and the father of the av ant - gard e , was born False avant-garde composers such as John Cage and Karlheinz 10 [' av', 'ant', '-', 'gard', 'e', ' compos', 'ers', ' such', ' as', ' John', ' Cage']
+814 173 Name of mother of x -1 Name of mother of John Cage Crete Cage John Cage "[',' ' the' ' father' ' of' ' modern' ' music' ',' ' and' ' the' ' father'
+ ' of' ' the' ' av' 'ant' '-' 'gard' 'e' ',' ' was' ' born']" , the father of modern music , and the father of the av ant - gard e , was born False " avant-garde composer John Cage often attended.
+" 7 [' av', 'ant', '-', 'gard', 'e', ' composer', ' John', ' Cage']
+815 173 Name of mother of x -1 Name of mother of John Cage Crete Cage John Cage "[',' ' the' ' father' ' of' ' modern' ' music' ',' ' and' ' the' ' father'
+ ' of' ' the' ' av' 'ant' '-' 'gard' 'e' ',' ' was' ' born']" , the father of modern music , and the father of the av ant - gard e , was born False scores for a book John Cage was working 5 [' scores', ' for', ' a', ' book', ' John', ' Cage']
+816 173 Name of mother of x -1 Name of mother of John Cage Crete Cage John Cage "[',' ' the' ' father' ' of' ' modern' ' music' ',' ' and' ' the' ' father'
+ ' of' ' the' ' av' 'ant' '-' 'gard' 'e' ',' ' was' ' born']" , the father of modern music , and the father of the av ant - gard e , was born False 1 ['John', ' Cage']
+817 174 Name of mother of x -1 Name of mother of Catherine Zeta-Jones Pat Fair Catherine Zeta-Jones "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' '.' '\n' '\n' 'I' ' have'
+ ' a' ' friend' ' who' ' is' ' a' ' friend' ' of' ' mine']" ", who is a friend of mine .
+
+ I have a friend who is a friend of mine" False 5 ['C', 'atherine', ' Z', 'eta', '-', 'Jones']
+818 174 Name of mother of x -1 Name of mother of Catherine Zeta-Jones Pat Fair Catherine Zeta-Jones "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' '.' '\n' '\n' 'I' ' have'
+ ' a' ' friend' ' who' ' is' ' a' ' friend' ' of' ' mine']" ", who is a friend of mine .
+
+ I have a friend who is a friend of mine" False Motion Picture, Catherine Zeta-Jones for Best Supporting 7 [' Motion', ' Picture', ',', ' Catherine', ' Z', 'eta', '-', 'Jones']
+819 174 Name of mother of x -1 Name of mother of Catherine Zeta-Jones Pat Fair Catherine Zeta-Jones "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' '.' '\n' '\n' 'I' ' have'
+ ' a' ' friend' ' who' ' is' ' a' ' friend' ' of' ' mine']" ", who is a friend of mine .
+
+ I have a friend who is a friend of mine" False 2011. She replaced Catherine Zeta-Jones in the role. The New 8 [' 2011', '.', ' She', ' replaced', ' Catherine', ' Z', 'eta', '-', 'Jones']
+820 174 Name of mother of x -1 Name of mother of Catherine Zeta-Jones Pat Fair Catherine Zeta-Jones "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' '.' '\n' '\n' 'I' ' have'
+ ' a' ' friend' ' who' ' is' ' a' ' friend' ' of' ' mine']" ", who is a friend of mine .
+
+ I have a friend who is a friend of mine" False Russell Crowe, Catherine Zeta-Jones and Mark Wahlberg, 8 [' Russell', ' Cro', 'we', ',', ' Catherine', ' Z', 'eta', '-', 'Jones']
+821 174 Name of mother of x -1 Name of mother of Catherine Zeta-Jones Pat Fair Catherine Zeta-Jones "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' '.' '\n' '\n' 'I' ' have'
+ ' a' ' friend' ' who' ' is' ' a' ' friend' ' of' ' mine']" ", who is a friend of mine .
+
+ I have a friend who is a friend of mine" False " Catherine Zeta-Jones =
+" 4 [' Catherine', ' Z', 'eta', '-', 'Jones']
+822 176 Name of mother of x -1 Name of mother of George Harrison Louise Anne Harrison George Harrison "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was held" False " McCartney, and George Harrison (then known as ""the" 4 [' McCartney', ',', ' and', ' George', ' Harrison']
+823 176 Name of mother of x -1 Name of mother of George Harrison Louise Anne Harrison George Harrison "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was held" False the music of George Harrison is most compelling 4 [' the', ' music', ' of', ' George', ' Harrison']
+824 176 Name of mother of x -1 Name of mother of George Harrison Louise Anne Harrison George Harrison "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was held" False " exclusivity of Songs by George Harrison by saying, ""in a world" 6 [' exclus', 'ivity', ' of', ' Songs', ' by', ' George', ' Harrison']
+825 176 Name of mother of x -1 Name of mother of George Harrison Louise Anne Harrison George Harrison "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was held" False " You"", the Songs by George Harrison EP remains the sole" 6 "[' You', '"",', ' the', ' Songs', ' by', ' George', ' Harrison']"
+826 176 Name of mother of x -1 Name of mother of George Harrison Louise Anne Harrison George Harrison "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was held" False " performance, George Harrison said, ""And what brings" 3 [' performance', ',', ' George', ' Harrison']
+827 177 Name of mother of x -1 Name of mother of Alexandre Dumas Marie-Louise-Élisabeth Labouret Dumas Alexandre Dumas "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 18'
+ '02' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 18 02 .
+
+ The first of the great French" False bridge. The novelist Alexandre Dumas was strongly 7 [' bridge', '.', ' The', ' novelist', ' Alexand', 're', ' Dum', 'as']
+828 177 Name of mother of x -1 Name of mother of Alexandre Dumas Marie-Louise-Élisabeth Labouret Dumas Alexandre Dumas "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 18'
+ '02' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 18 02 .
+
+ The first of the great French" False world, including Alexandre Dumas and Honoré de Balzac 6 [' world', ',', ' including', ' Alexand', 're', ' Dum', 'as']
+829 177 Name of mother of x -1 Name of mother of Alexandre Dumas Marie-Louise-Élisabeth Labouret Dumas Alexandre Dumas "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 18'
+ '02' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 18 02 .
+
+ The first of the great French" False Cressida, and also by Alexandre Dumas in The Man in 10 [' C', 'ress', 'ida', ',', ' and', ' also', ' by', ' Alexand', 're', ' Dum', 'as']
+830 177 Name of mother of x -1 Name of mother of Alexandre Dumas Marie-Louise-Élisabeth Labouret Dumas Alexandre Dumas "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 18'
+ '02' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 18 02 .
+
+ The first of the great French" False classics, ranging from Alexandre Dumas over Buffalo 7 [' classics', ',', ' ranging', ' from', ' Alexand', 're', ' Dum', 'as']
+831 177 Name of mother of x -1 Name of mother of Alexandre Dumas Marie-Louise-Élisabeth Labouret Dumas Alexandre Dumas "[',' ' the' ' great' ' French' ' novelist' ',' ' was' ' born' ' in' ' 18'
+ '02' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great' ' French']" ", the great French novelist , was born in 18 02 .
+
+ The first of the great French" False with the opera. Alexandre Dumas set a chapter 7 [' with', ' the', ' opera', '.', ' Alexand', 're', ' Dum', 'as']
+832 178 Name of mother of x -1 Name of mother of Abraham Lincoln Nancy Hanks Lincoln Abraham Lincoln "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Abraham' ' Lincoln' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Abraham Lincoln ,
+" False re-election bid against Abraham Lincoln convinced Douglas 6 [' re', '-', 'election', ' bid', ' against', ' Abraham', ' Lincoln']
+833 178 Name of mother of x -1 Name of mother of Abraham Lincoln Nancy Hanks Lincoln Abraham Lincoln "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Abraham' ' Lincoln' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Abraham Lincoln ,
+" False themselves in favour of Abraham Lincoln and his Union. Ultimately, 5 [' themselves', ' in', ' favour', ' of', ' Abraham', ' Lincoln']
+834 178 Name of mother of x -1 Name of mother of Abraham Lincoln Nancy Hanks Lincoln Abraham Lincoln "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Abraham' ' Lincoln' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Abraham Lincoln ,
+" False Political Genius of Abraham Lincoln is a 2005 book 4 [' Political', ' Genius', ' of', ' Abraham', ' Lincoln']
+835 178 Name of mother of x -1 Name of mother of Abraham Lincoln Nancy Hanks Lincoln Abraham Lincoln "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Abraham' ' Lincoln' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Abraham Lincoln ,
+" False slaves. President Abraham Lincoln signed the 4 [' slaves', '.', ' President', ' Abraham', ' Lincoln']
+836 178 Name of mother of x -1 Name of mother of Abraham Lincoln Nancy Hanks Lincoln Abraham Lincoln "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Abraham' ' Lincoln' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Abraham Lincoln ,
+" False slavery critic Abraham Lincoln to the presidency. 3 [' slavery', ' critic', ' Abraham', ' Lincoln']
+837 179 Name of mother of x -1 Name of mother of Robert Louis Stevenson Matilde Margaret Isabella Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' '\n'
+ '\n' 'The' ' Adventures' ' of' ' Tom' ' Sawyer' ',' ' and' ' other'
+ ' Stories']" ", the author of Treasure Island , and the
+
+ The Adventures of Tom Sawyer , and other Stories" False of it, with the Robert Louis Stevenson tale serving as a 7 [' of', ' it', ',', ' with', ' the', ' Robert', ' Louis', ' Stevenson']
+838 179 Name of mother of x -1 Name of mother of Robert Louis Stevenson Matilde Margaret Isabella Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' '\n'
+ '\n' 'The' ' Adventures' ' of' ' Tom' ' Sawyer' ',' ' and' ' other'
+ ' Stories']" ", the author of Treasure Island , and the
+
+ The Adventures of Tom Sawyer , and other Stories" False " with a quotation by Robert Louis Stevenson — ""The cruelest" 6 [' with', ' a', ' quotation', ' by', ' Robert', ' Louis', ' Stevenson']
+839 179 Name of mother of x -1 Name of mother of Robert Louis Stevenson Matilde Margaret Isabella Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' '\n'
+ '\n' 'The' ' Adventures' ' of' ' Tom' ' Sawyer' ',' ' and' ' other'
+ ' Stories']" ", the author of Treasure Island , and the
+
+ The Adventures of Tom Sawyer , and other Stories" False book report was on the Robert Louis Stevenson novel Treasure 7 [' book', ' report', ' was', ' on', ' the', ' Robert', ' Louis', ' Stevenson']
+840 179 Name of mother of x -1 Name of mother of Robert Louis Stevenson Matilde Margaret Isabella Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' '\n'
+ '\n' 'The' ' Adventures' ' of' ' Tom' ' Sawyer' ',' ' and' ' other'
+ ' Stories']" ", the author of Treasure Island , and the
+
+ The Adventures of Tom Sawyer , and other Stories" False Henley, Rodin met Robert Louis Stevenson and Robert Browning, 8 [' Hen', 'ley', ',', ' Rod', 'in', ' met', ' Robert', ' Louis', ' Stevenson']
+841 179 Name of mother of x -1 Name of mother of Robert Louis Stevenson Matilde Margaret Isabella Stevenson Robert Louis Stevenson "[',' ' the' ' author' ' of' ' Treasure' ' Island' ',' ' and' ' the' '\n'
+ '\n' 'The' ' Adventures' ' of' ' Tom' ' Sawyer' ',' ' and' ' other'
+ ' Stories']" ", the author of Treasure Island , and the
+
+ The Adventures of Tom Sawyer , and other Stories" False forgotten; although Robert Louis Stevenson and his wife discovered 5 [' forgotten', ';', ' although', ' Robert', ' Louis', ' Stevenson']
+842 181 Name of mother of x -1 Name of mother of Arnold Schwarzenegger Aurelia Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False California Governor Arnold Schwarzenegger and First Lady 3 [' California', ' Governor', ' Arnold', ' Schwarzenegger']
+843 181 Name of mother of x -1 Name of mother of Arnold Schwarzenegger Aurelia Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False the soundtrack of the Arnold Schwarzenegger movie Last Action 5 [' the', ' soundtrack', ' of', ' the', ' Arnold', ' Schwarzenegger']
+844 181 Name of mother of x -1 Name of mother of Arnold Schwarzenegger Aurelia Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False 3 ['Ar', 'n', 'old', ' Schwarzenegger']
+845 181 Name of mother of x -1 Name of mother of Arnold Schwarzenegger Aurelia Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False Bruce Willis, and Arnold Schwarzenegger in the action 5 [' Bruce', ' Willis', ',', ' and', ' Arnold', ' Schwarzenegger']
+846 181 Name of mother of x -1 Name of mother of Arnold Schwarzenegger Aurelia Schwarzenegger Arnold Schwarzenegger "[',' ' the' ' Terminator' ',' ' and' ' the' ' Terminator' ' 2' ':'
+ ' Judgment' ' Day' '.' '\n' '\n' 'The' ' Terminator' ' is' ' a' ' 1984'
+ ' American']" ", the Terminator , and the Terminator 2 : Judgment Day .
+
+ The Terminator is a 1984 American" False conflict with actor Arnold Schwarzenegger caused filming 4 [' conflict', ' with', ' actor', ' Arnold', ' Schwarzenegger']
+847 182 Name of mother of x -1 Name of mother of Penélope Cruz Encarna Sánchez Penélope Cruz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Spanish actress Penélope Cruz [...] is much 6 [' Spanish', ' actress', ' Pen', 'é', 'l', 'ope', ' Cruz']
+848 182 Name of mother of x -1 Name of mother of Penélope Cruz Encarna Sánchez Penélope Cruz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False hired Spanish actress Penélope Cruz to pose as Belle 7 [' hired', ' Spanish', ' actress', ' Pen', 'é', 'l', 'ope', ' Cruz']
+849 182 Name of mother of x -1 Name of mother of Penélope Cruz Encarna Sánchez Penélope Cruz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Satan's helper Carmen, Penélope Cruz doesn 't hold 9 "[' Satan', ""'s"", ' helper', ' Carmen', ',', ' Pen', 'é', 'l', 'ope', ' Cruz']"
+850 182 Name of mother of x -1 Name of mother of Penélope Cruz Encarna Sánchez Penélope Cruz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Jeff Bridges, Penélope Cruz and John Goodman. 7 [' Jeff', ' Bridges', ',', ' Pen', 'é', 'l', 'ope', ' Cruz']
+851 182 Name of mother of x -1 Name of mother of Penélope Cruz Encarna Sánchez Penélope Cruz "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Spanish actress Penélope Cruz [...] is much 6 [' Spanish', ' actress', ' Pen', 'é', 'l', 'ope', ' Cruz']
+852 183 Name of mother of x -1 Name of mother of Elton John Sheila Eileen Farebrother Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the singer , and his wife , the actress , and the two of them were in the audience False performances by Elton John and James Taylor 4 [' performances', ' by', ' El', 'ton', ' John']
+853 183 Name of mother of x -1 Name of mother of Elton John Sheila Eileen Farebrother Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the singer , and his wife , the actress , and the two of them were in the audience False Cougar Mellencamp, Elton John and Michael Jackson, 8 [' Cou', 'gar', ' Mell', 'enc', 'amp', ',', ' El', 'ton', ' John']
+854 183 Name of mother of x -1 Name of mother of Elton John Sheila Eileen Farebrother Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the singer , and his wife , the actress , and the two of them were in the audience False overseas teams. Elton John has also used 5 [' overseas', ' teams', '.', ' El', 'ton', ' John']
+855 183 Name of mother of x -1 Name of mother of Elton John Sheila Eileen Farebrother Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the singer , and his wife , the actress , and the two of them were in the audience False performers ranging from Elton John and Keith Moon 5 [' performers', ' ranging', ' from', ' El', 'ton', ' John']
+856 183 Name of mother of x -1 Name of mother of Elton John Sheila Eileen Farebrother Elton John "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the singer , and his wife , the actress , and the two of them were in the audience False " Life""), and Elton John (""Funeral for" 5 "[' Life', '""),', ' and', ' El', 'ton', ' John']"
+857 184 Name of mother of x -1 Name of mother of Elvis Presley Gladys Presley Elvis Presley "[',' ' the' ' King' ' of' ' Rock' ' and' ' Roll' ',' ' and' ' the' ' King'
+ ' of' ' Rock' ' and' ' Roll' ',' ' Elvis' ' Pres' 'ley' ',']" , the King of Rock and Roll , and the King of Rock and Roll , Elvis Pres ley , False this Spanish-infused Elvis Presley cover, even though 7 [' this', ' Spanish', '-', 'inf', 'used', ' Elvis', ' Pres', 'ley']
+858 184 Name of mother of x -1 Name of mother of Elvis Presley Gladys Presley Elvis Presley "[',' ' the' ' King' ' of' ' Rock' ' and' ' Roll' ',' ' and' ' the' ' King'
+ ' of' ' Rock' ' and' ' Roll' ',' ' Elvis' ' Pres' 'ley' ',']" , the King of Rock and Roll , and the King of Rock and Roll , Elvis Pres ley , False going to see Elvis Presley is that you 5 [' going', ' to', ' see', ' Elvis', ' Pres', 'ley']
+859 184 Name of mother of x -1 Name of mother of Elvis Presley Gladys Presley Elvis Presley "[',' ' the' ' King' ' of' ' Rock' ' and' ' Roll' ',' ' and' ' the' ' King'
+ ' of' ' Rock' ' and' ' Roll' ',' ' Elvis' ' Pres' 'ley' ',']" , the King of Rock and Roll , and the King of Rock and Roll , Elvis Pres ley , False and its connection to Elvis Presley is featured in the 6 [' and', ' its', ' connection', ' to', ' Elvis', ' Pres', 'ley']
+860 184 Name of mother of x -1 Name of mother of Elvis Presley Gladys Presley Elvis Presley "[',' ' the' ' King' ' of' ' Rock' ' and' ' Roll' ',' ' and' ' the' ' King'
+ ' of' ' Rock' ' and' ' Roll' ',' ' Elvis' ' Pres' 'ley' ',']" , the King of Rock and Roll , and the King of Rock and Roll , Elvis Pres ley , False CBS special on Elvis Presley receiving ratings 5 [' CBS', ' special', ' on', ' Elvis', ' Pres', 'ley']
+861 184 Name of mother of x -1 Name of mother of Elvis Presley Gladys Presley Elvis Presley "[',' ' the' ' King' ' of' ' Rock' ' and' ' Roll' ',' ' and' ' the' ' King'
+ ' of' ' Rock' ' and' ' Roll' ',' ' Elvis' ' Pres' 'ley' ',']" , the King of Rock and Roll , and the King of Rock and Roll , Elvis Pres ley , False guitar after seeing Elvis Presley on television. 5 [' guitar', ' after', ' seeing', ' Elvis', ' Pres', 'ley']
+862 185 Name of mother of x -1 Name of mother of Rudyard Kipling Alice MacDonald Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' poem' ',' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' poem' ' was' ' written']" ", the author of the poem , was a great admire r of the
+
+ The poem was written" False Arthur Conan Doyle, Rudyard Kipling as well as newspaper 7 [' Arthur', ' Conan', ' Doyle', ',', ' Rud', 'yard', ' Ki', 'pling']
+863 185 Name of mother of x -1 Name of mother of Rudyard Kipling Alice MacDonald Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' poem' ',' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' poem' ' was' ' written']" ", the author of the poem , was a great admire r of the
+
+ The poem was written" False Ed Whitfield, Rudyard Kipling and President of 7 [' Ed', ' Whit', 'field', ',', ' Rud', 'yard', ' Ki', 'pling']
+864 185 Name of mother of x -1 Name of mother of Rudyard Kipling Alice MacDonald Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' poem' ',' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' poem' ' was' ' written']" ", the author of the poem , was a great admire r of the
+
+ The poem was written" False September 1939, the Rudyard Kipling left Fleetwood 7 [' September', ' 1939', ',', ' the', ' Rud', 'yard', ' Ki', 'pling']
+865 185 Name of mother of x -1 Name of mother of Rudyard Kipling Alice MacDonald Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' poem' ',' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' poem' ' was' ' written']" ", the author of the poem , was a great admire r of the
+
+ The poem was written" False the 20th century, Rudyard Kipling referred to lines 8 [' the', ' 20', 'th', ' century', ',', ' Rud', 'yard', ' Ki', 'pling']
+866 185 Name of mother of x -1 Name of mother of Rudyard Kipling Alice MacDonald Kipling Rudyard Kipling "[',' ' the' ' author' ' of' ' the' ' poem' ',' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' poem' ' was' ' written']" ", the author of the poem , was a great admire r of the
+
+ The poem was written" False Fells in mid-1895 was Rudyard Kipling — and later in 11 [' F', 'ells', ' in', ' mid', '-', '18', '95', ' was', ' Rud', 'yard', ' Ki', 'pling']
+867 186 Name of mother of x -1 Name of mother of Stephen King Nellie Pillsbury Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' movie' ',' ' and' ' the' ' author' ' of' ' the' ' movie']" , the author of the book , and the author of the movie , and the author of the movie False based on three Stephen King stories: Stand by 4 [' based', ' on', ' three', ' Stephen', ' King']
+868 186 Name of mother of x -1 Name of mother of Stephen King Nellie Pillsbury Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' movie' ',' ' and' ' the' ' author' ' of' ' the' ' movie']" , the author of the book , and the author of the movie , and the author of the movie False mentioned in the Stephen King novels The 4 [' mentioned', ' in', ' the', ' Stephen', ' King']
+869 186 Name of mother of x -1 Name of mother of Stephen King Nellie Pillsbury Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' movie' ',' ' and' ' the' ' author' ' of' ' the' ' movie']" , the author of the book , and the author of the movie , and the author of the movie False the black cloak', Stephen King confirmed in an interview 5 "[' the', ' black', ' cloak', ""',"", ' Stephen', ' King']"
+870 186 Name of mother of x -1 Name of mother of Stephen King Nellie Pillsbury Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' movie' ',' ' and' ' the' ' author' ' of' ' the' ' movie']" , the author of the book , and the author of the movie , and the author of the movie False 1 ['Stephen', ' King']
+871 186 Name of mother of x -1 Name of mother of Stephen King Nellie Pillsbury Stephen King "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' movie' ',' ' and' ' the' ' author' ' of' ' the' ' movie']" , the author of the book , and the author of the movie , and the author of the movie False infamous 1986 Stephen King movie Maximum 3 [' infamous', ' 1986', ' Stephen', ' King']
+872 187 Name of mother of x -1 Name of mother of Leonard Cohen Marsha Klonitsky Leonard Cohen "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' is' ' that']" ", the father of the modern world .
+
+ The first thing I noticed about the book is that" False authors like Leonard Cohen and Walt Whitman 3 [' authors', ' like', ' Leonard', ' Cohen']
+873 187 Name of mother of x -1 Name of mother of Leonard Cohen Marsha Klonitsky Leonard Cohen "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' is' ' that']" ", the father of the modern world .
+
+ The first thing I noticed about the book is that" False " 't Forget"" to the Leonard Cohen tribute album" 7 "["" '"", 't', ' Forget', '""', ' to', ' the', ' Leonard', ' Cohen']"
+874 187 Name of mother of x -1 Name of mother of Leonard Cohen Marsha Klonitsky Leonard Cohen "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' is' ' that']" ", the father of the modern world .
+
+ The first thing I noticed about the book is that" False Suzanne, a song by Leonard Cohen which told a tale of 6 [' Suzanne', ',', ' a', ' song', ' by', ' Leonard', ' Cohen']
+875 187 Name of mother of x -1 Name of mother of Leonard Cohen Marsha Klonitsky Leonard Cohen "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' is' ' that']" ", the father of the modern world .
+
+ The first thing I noticed about the book is that" False " Kurt Cobain, Leonard Cohen and Bob Dylan.
+" 5 [' Kurt', ' Cob', 'ain', ',', ' Leonard', ' Cohen']
+876 187 Name of mother of x -1 Name of mother of Leonard Cohen Marsha Klonitsky Leonard Cohen "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' I' ' noticed' ' about' ' the' ' book' ' is' ' that']" ", the father of the modern world .
+
+ The first thing I noticed about the book is that" False " Can 't Forget"" to the Leonard Cohen tribute album I'm Your" 8 "[' Can', "" '"", 't', ' Forget', '""', ' to', ' the', ' Leonard', ' Cohen']"
+877 188 Name of mother of x -1 Name of mother of Henry VIII of England Elizabeth of York Henry VIII of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Henry' ' VIII' ' of' ' England' ',' ' and' ' the' '\n']" ", and the
+
+ The
+
+ Name of mother of Henry VIII of England , and the
+" False from Pope Leo X to Henry VIII of England flatters the 8 [' from', ' Pope', ' Leo', ' X', ' to', ' Henry', ' VIII', ' of', ' England']
+878 188 Name of mother of x -1 Name of mother of Henry VIII of England Elizabeth of York Henry VIII of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Henry' ' VIII' ' of' ' England' ',' ' and' ' the' '\n']" ", and the
+
+ The
+
+ Name of mother of Henry VIII of England , and the
+" False requested annulment of Henry VIII of England from Catherine 8 [' requested', ' ann', 'ul', 'ment', ' of', ' Henry', ' VIII', ' of', ' England']
+879 188 Name of mother of x -1 Name of mother of Henry VIII of England Elizabeth of York Henry VIII of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Henry' ' VIII' ' of' ' England' ',' ' and' ' the' '\n']" ", and the
+
+ The
+
+ Name of mother of Henry VIII of England , and the
+" False since the meeting of Henry VIII of England and Francis 7 [' since', ' the', ' meeting', ' of', ' Henry', ' VIII', ' of', ' England']
+880 188 Name of mother of x -1 Name of mother of Henry VIII of England Elizabeth of York Henry VIII of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Henry' ' VIII' ' of' ' England' ',' ' and' ' the' '\n']" ", and the
+
+ The
+
+ Name of mother of Henry VIII of England , and the
+" False the film he played Henry VIII of England opposite Scarlett 7 [' the', ' film', ' he', ' played', ' Henry', ' VIII', ' of', ' England']
+881 188 Name of mother of x -1 Name of mother of Henry VIII of England Elizabeth of York Henry VIII of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Henry' ' VIII' ' of' ' England' ',' ' and' ' the' '\n']" ", and the
+
+ The
+
+ Name of mother of Henry VIII of England , and the
+" False portraits of Henry VIII of England and Alessandro del 5 [' portraits', ' of', ' Henry', ' VIII', ' of', ' England']
+882 189 Name of mother of x -1 Name of mother of Milla Jovovich Galina Jovovich Milla Jovovich "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' for' ' over' ' a' ' decade']" ", who is a former model and actress .
+
+ The couple have been married for over a decade" False Leeloo, Besson chose Milla Jovovich from the 200 to 11 [' Le', 'el', 'oo', ',', ' B', 'esson', ' chose', ' M', 'illa', ' J', 'ov', 'ovich']
+883 189 Name of mother of x -1 Name of mother of Milla Jovovich Galina Jovovich Milla Jovovich "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' for' ' over' ' a' ' decade']" ", who is a former model and actress .
+
+ The couple have been married for over a decade" False 4 ['M', 'illa', ' J', 'ov', 'ovich']
+884 189 Name of mother of x -1 Name of mother of Milla Jovovich Galina Jovovich Milla Jovovich "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' for' ' over' ' a' ' decade']" ", who is a former model and actress .
+
+ The couple have been married for over a decade" False discussed by Milla Jovovich and Paul W. S. Anderson 6 [' discussed', ' by', ' M', 'illa', ' J', 'ov', 'ovich']
+885 189 Name of mother of x -1 Name of mother of Milla Jovovich Galina Jovovich Milla Jovovich "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' for' ' over' ' a' ' decade']" ", who is a former model and actress .
+
+ The couple have been married for over a decade" False available in 2008. Milla Jovovich and Oded Fehr filmed 8 [' available', ' in', ' 2008', '.', ' M', 'illa', ' J', 'ov', 'ovich']
+886 189 Name of mother of x -1 Name of mother of Milla Jovovich Galina Jovovich Milla Jovovich "[',' ' who' ' is' ' a' ' former' ' model' ' and' ' actress' '.' '\n' '\n'
+ 'The' ' couple' ' have' ' been' ' married' ' for' ' over' ' a' ' decade']" ", who is a former model and actress .
+
+ The couple have been married for over a decade" False direct the sequel. Milla Jovovich confirmed her 8 [' direct', ' the', ' sequel', '.', ' M', 'illa', ' J', 'ov', 'ovich']
+887 190 Name of mother of x -1 Name of mother of Ivan Turgenev Varvara Petrovna Turgeneva Ivan Turgenev "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '18' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' brothers' ',']" ", the Russian writer , was born in 18 18 .
+
+ The first of the three brothers ," False outlet for both Ivan Turgenev and Leo Tolstoy. However, 7 [' outlet', ' for', ' both', ' Ivan', ' T', 'urg', 'ene', 'v']
+888 191 Name of mother of x -1 Name of mother of Frank Sinatra Dolly Sinatra Frank Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' Frank' ' Sin' 'atra' ' Jr' '.' '\n' '\n' 'Frank' ' Sin']" ", the singer , and the father of the singer , Frank Sin atra Jr .
+
+ Frank Sin" False " starred Kelly, Frank Sinatra and Jules Munshin.
+" 5 [' starred', ' Kelly', ',', ' Frank', ' Sin', 'atra']
+889 191 Name of mother of x -1 Name of mother of Frank Sinatra Dolly Sinatra Frank Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' Frank' ' Sin' 'atra' ' Jr' '.' '\n' '\n' 'Frank' ' Sin']" ", the singer , and the father of the singer , Frank Sin atra Jr .
+
+ Frank Sin" False psychoanalyst in which Frank Sinatra is smothering 7 [' psych', 'oan', 'alyst', ' in', ' which', ' Frank', ' Sin', 'atra']
+890 191 Name of mother of x -1 Name of mother of Frank Sinatra Dolly Sinatra Frank Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' Frank' ' Sin' 'atra' ' Jr' '.' '\n' '\n' 'Frank' ' Sin']" ", the singer , and the father of the singer , Frank Sin atra Jr .
+
+ Frank Sin" False were recorded by Frank Sinatra but both critics 5 [' were', ' recorded', ' by', ' Frank', ' Sin', 'atra']
+891 191 Name of mother of x -1 Name of mother of Frank Sinatra Dolly Sinatra Frank Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' Frank' ' Sin' 'atra' ' Jr' '.' '\n' '\n' 'Frank' ' Sin']" ", the singer , and the father of the singer , Frank Sin atra Jr .
+
+ Frank Sin" False a guest on The Frank Sinatra Timex Special 6 [' a', ' guest', ' on', ' The', ' Frank', ' Sin', 'atra']
+892 191 Name of mother of x -1 Name of mother of Frank Sinatra Dolly Sinatra Frank Sinatra "[',' ' the' ' singer' ',' ' and' ' the' ' father' ' of' ' the' ' singer'
+ ',' ' Frank' ' Sin' 'atra' ' Jr' '.' '\n' '\n' 'Frank' ' Sin']" ", the singer , and the father of the singer , Frank Sin atra Jr .
+
+ Frank Sin" False Sinatra released Frank Sinatra Sings for Only 5 [' Sin', 'atra', ' released', ' Frank', ' Sin', 'atra']
+893 192 Name of mother of x -1 Name of mother of Otto von Bismarck Wilhelmine Luise Mencken Otto von Bismarck "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' Empire' ',' ' and' ' who']" ", the German chancellor , who was a great admire r of the
+
+ German Empire , and who" False Chancellor Otto von Bismarck introduced 6 [' Chancellor', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+894 192 Name of mother of x -1 Name of mother of Otto von Bismarck Wilhelmine Luise Mencken Otto von Bismarck "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' Empire' ',' ' and' ' who']" ", the German chancellor , who was a great admire r of the
+
+ German Empire , and who" False debate whether Otto von Bismarck — Minister President 7 [' debate', ' whether', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+895 192 Name of mother of x -1 Name of mother of Otto von Bismarck Wilhelmine Luise Mencken Otto von Bismarck "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' Empire' ',' ' and' ' who']" ", the German chancellor , who was a great admire r of the
+
+ German Empire , and who" False German Chancellor Otto von Bismarck introduced protective 7 [' German', ' Chancellor', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+896 192 Name of mother of x -1 Name of mother of Otto von Bismarck Wilhelmine Luise Mencken Otto von Bismarck "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' Empire' ',' ' and' ' who']" ", the German chancellor , who was a great admire r of the
+
+ German Empire , and who" False Dash. In 1896, Otto von Bismarck purchased a King 10 [' Dash', '.', ' In', ' 1896', ',', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+897 192 Name of mother of x -1 Name of mother of Otto von Bismarck Wilhelmine Luise Mencken Otto von Bismarck "[',' ' the' ' German' ' chancellor' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' the' '\n' '\n' 'German' ' Empire' ',' ' and' ' who']" ", the German chancellor , who was a great admire r of the
+
+ German Empire , and who" False foreigners: Otto von Bismarck (1911), Ivar Aasen 7 [' foreigners', ':', ' Otto', ' von', ' B', 'ism', 'ar', 'ck']
+898 193 Name of mother of x -1 Name of mother of Sean Connery Euphemia McLean Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the'
+ ' original' ' Bond' ' films' ',' ' has' ' died' ' at' ' the' ' age' ' of'
+ ' 89']" , the actor who played James Bond in the original Bond films , has died at the age of 89 False Bond actor, Sean Connery and the latter partly 5 [' Bond', ' actor', ',', ' Sean', ' Con', 'nery']
+899 193 Name of mother of x -1 Name of mother of Sean Connery Euphemia McLean Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the'
+ ' original' ' Bond' ' films' ',' ' has' ' died' ' at' ' the' ' age' ' of'
+ ' 89']" , the actor who played James Bond in the original Bond films , has died at the age of 89 False Dr. No, starring Sean Connery as Bond. As of 2016, 7 [' Dr', '.', ' No', ',', ' starring', ' Sean', ' Con', 'nery']
+900 193 Name of mother of x -1 Name of mother of Sean Connery Euphemia McLean Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the'
+ ' original' ' Bond' ' films' ',' ' has' ' died' ' at' ' the' ' age' ' of'
+ ' 89']" , the actor who played James Bond in the original Bond films , has died at the age of 89 False starred opposite Sean Connery as a seductive insurance 4 [' starred', ' opposite', ' Sean', ' Con', 'nery']
+901 193 Name of mother of x -1 Name of mother of Sean Connery Euphemia McLean Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the'
+ ' original' ' Bond' ' films' ',' ' has' ' died' ' at' ' the' ' age' ' of'
+ ' 89']" , the actor who played James Bond in the original Bond films , has died at the age of 89 False " ""It's not just that Sean Connery looks a lot more" 8 "[' ""', 'It', ""'s"", ' not', ' just', ' that', ' Sean', ' Con', 'nery']"
+902 193 Name of mother of x -1 Name of mother of Sean Connery Euphemia McLean Sean Connery "[',' ' the' ' actor' ' who' ' played' ' James' ' Bond' ' in' ' the'
+ ' original' ' Bond' ' films' ',' ' has' ' died' ' at' ' the' ' age' ' of'
+ ' 89']" , the actor who played James Bond in the original Bond films , has died at the age of 89 False Bond series, with Sean Connery in the lead role; 6 [' Bond', ' series', ',', ' with', ' Sean', ' Con', 'nery']
+903 194 Name of mother of x -1 Name of mother of Ernest Hemingway Grace Hall Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False authors such as Ernest Hemingway and Lillian Hellman, 6 [' authors', ' such', ' as', ' Ernest', ' Hem', 'ing', 'way']
+904 194 Name of mother of x -1 Name of mother of Ernest Hemingway Grace Hall Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False Coco Chanel and Ernest Hemingway who lived at the 7 [' Coco', ' Chan', 'el', ' and', ' Ernest', ' Hem', 'ing', 'way']
+905 194 Name of mother of x -1 Name of mother of Ernest Hemingway Grace Hall Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False pisco sours. Ernest Hemingway and Orson Welles 8 [' p', 'isco', ' s', 'ours', '.', ' Ernest', ' Hem', 'ing', 'way']
+906 194 Name of mother of x -1 Name of mother of Ernest Hemingway Grace Hall Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " Hemingway =
+" 9 [' Hem', 'ing', 'way', ' =', 'Er', 'n', 'est', ' Hem', 'ing', 'way']
+907 194 Name of mother of x -1 Name of mother of Ernest Hemingway Grace Hall Hemingway Ernest Hemingway "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False spent time in Paris, Ernest Hemingway never made an appearance 8 [' spent', ' time', ' in', ' Paris', ',', ' Ernest', ' Hem', 'ing', 'way']
+908 195 Name of mother of x -1 Name of mother of John Singer Sargent Mary Newbold Sargent John Singer Sargent "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False university commissioned John Singer Sargent to paint, within 6 [' university', ' commissioned', ' John', ' Singer', ' S', 'arg', 'ent']
+909 195 Name of mother of x -1 Name of mother of John Singer Sargent Mary Newbold Sargent John Singer Sargent "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False painted by John Singer Sargent in 1907, and 6 [' painted', ' by', ' John', ' Singer', ' S', 'arg', 'ent']
+910 195 Name of mother of x -1 Name of mother of John Singer Sargent Mary Newbold Sargent John Singer Sargent "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False portrait, painted by John Singer Sargent in 1907, and his 8 [' portrait', ',', ' painted', ' by', ' John', ' Singer', ' S', 'arg', 'ent']
+911 195 Name of mother of x -1 Name of mother of John Singer Sargent Mary Newbold Sargent John Singer Sargent "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False 1870s, while Renoir and John Singer Sargent visited in the 1880s. 11 [' 1870', 's', ',', ' while', ' Ren', 'oir', ' and', ' John', ' Singer', ' S', 'arg', 'ent']
+912 195 Name of mother of x -1 Name of mother of John Singer Sargent Mary Newbold Sargent John Singer Sargent "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False James, artist John Singer Sargent, the Rothschild 7 [' James', ',', ' artist', ' John', ' Singer', ' S', 'arg', 'ent']
+913 196 Name of mother of x -1 Name of mother of Benjamin Franklin Abiah Folger Benjamin Franklin "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Benjamin' ' Franklin' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Benjamin Franklin ,
+" False 2 ['Ben', 'jamin', ' Franklin']
+914 196 Name of mother of x -1 Name of mother of Benjamin Franklin Abiah Folger Benjamin Franklin "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Benjamin' ' Franklin' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Benjamin Franklin ,
+" False recruits the help of Benjamin Franklin and Dr. Fred's ancestor, 5 [' recruits', ' the', ' help', ' of', ' Benjamin', ' Franklin']
+915 196 Name of mother of x -1 Name of mother of Benjamin Franklin Abiah Folger Benjamin Franklin "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Benjamin' ' Franklin' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Benjamin Franklin ,
+" False " hand-operated Benjamin Franklin printing press.
+" 4 [' hand', '-', 'operated', ' Benjamin', ' Franklin']
+916 196 Name of mother of x -1 Name of mother of Benjamin Franklin Abiah Folger Benjamin Franklin "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Benjamin' ' Franklin' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Benjamin Franklin ,
+" False for the Federalists; Benjamin Franklin Bache and Philip 6 [' for', ' the', ' Federal', 'ists', ';', ' Benjamin', ' Franklin']
+917 196 Name of mother of x -1 Name of mother of Benjamin Franklin Abiah Folger Benjamin Franklin "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Benjamin' ' Franklin' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Benjamin Franklin ,
+" False were made to Benjamin Franklin for possible 4 [' were', ' made', ' to', ' Benjamin', ' Franklin']
+918 197 Name of mother of x -1 Name of mother of Émile Zola Émilie Aubert Émile Zola "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the French writer , who was a friend of the family , and who had been a friend of False " perfection"". Émile Zola commented on Offenbach" 5 "[' perfection', '"".', ' É', 'mile', ' Z', 'ola']"
+919 197 Name of mother of x -1 Name of mother of Émile Zola Émilie Aubert Émile Zola "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the French writer , who was a friend of the family , and who had been a friend of False literature. Novelist Émile Zola called it an important 7 [' literature', '.', ' Novel', 'ist', ' É', 'mile', ' Z', 'ola']
+920 197 Name of mother of x -1 Name of mother of Émile Zola Émilie Aubert Émile Zola "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the French writer , who was a friend of the family , and who had been a friend of False working-class girl. Émile Zola also approved, 8 [' working', '-', 'class', ' girl', '.', ' É', 'mile', ' Z', 'ola']
+921 197 Name of mother of x -1 Name of mother of Émile Zola Émilie Aubert Émile Zola "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the French writer , who was a friend of the family , and who had been a friend of False Leo Tolstoy, Émile Zola and Victor Hugo, 8 [' Leo', ' Tol', 'st', 'oy', ',', ' É', 'mile', ' Z', 'ola']
+922 197 Name of mother of x -1 Name of mother of Émile Zola Émilie Aubert Émile Zola "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the French writer , who was a friend of the family , and who had been a friend of False The French novelist Émile Zola lived in what 6 [' The', ' French', ' novelist', ' É', 'mile', ' Z', 'ola']
+923 198 Name of mother of x -1 Name of mother of Jonathan Swift Abigail Erick Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' ' G' 'ull' 'iver' ""'s"" ' Travels']" , the author of G ull iver 's Travels , and the author of the G ull iver 's Travels False reportedly provoked Jonathan Swift to laughter, a rare 3 [' reportedly', ' provoked', ' Jonathan', ' Swift']
+924 198 Name of mother of x -1 Name of mother of Jonathan Swift Abigail Erick Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' ' G' 'ull' 'iver' ""'s"" ' Travels']" , the author of G ull iver 's Travels , and the author of the G ull iver 's Travels False asked a friend of Jonathan Swift to plead with 5 [' asked', ' a', ' friend', ' of', ' Jonathan', ' Swift']
+925 198 Name of mother of x -1 Name of mother of Jonathan Swift Abigail Erick Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' ' G' 'ull' 'iver' ""'s"" ' Travels']" , the author of G ull iver 's Travels , and the author of the G ull iver 's Travels False anti-war picture that Jonathan Swift was already presenting 6 [' anti', '-', 'war', ' picture', ' that', ' Jonathan', ' Swift']
+926 198 Name of mother of x -1 Name of mother of Jonathan Swift Abigail Erick Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' ' G' 'ull' 'iver' ""'s"" ' Travels']" , the author of G ull iver 's Travels , and the author of the G ull iver 's Travels False " Reverend Dr. Jonathan Swift ""carefully corrected""" 4 [' Reverend', ' Dr', '.', ' Jonathan', ' Swift']
+927 198 Name of mother of x -1 Name of mother of Jonathan Swift Abigail Erick Jonathan Swift "[',' ' the' ' author' ' of' ' G' 'ull' 'iver' ""'s"" ' Travels' ',' ' and'
+ ' the' ' author' ' of' ' the' ' G' 'ull' 'iver' ""'s"" ' Travels']" , the author of G ull iver 's Travels , and the author of the G ull iver 's Travels False of licensing. Jonathan Swift was a strong advocate 4 [' of', ' licensing', '.', ' Jonathan', ' Swift']
+928 199 Name of mother of x -1 Name of mother of Petrarch Eietta Canigiani Petrarch "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ ' Petr' 'arch' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Petr arch , and the
+
+ Name of" False biographies: Petrarch and Machiavelli. 4 [' bi', 'ographies', ':', ' Petr', 'arch']
+929 199 Name of mother of x -1 Name of mother of Petrarch Eietta Canigiani Petrarch "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ ' Petr' 'arch' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Petr arch , and the
+
+ Name of" False of Dante and Petrarch in establishing 4 [' of', ' Dante', ' and', ' Petr', 'arch']
+930 199 Name of mother of x -1 Name of mother of Petrarch Eietta Canigiani Petrarch "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ ' Petr' 'arch' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Petr arch , and the
+
+ Name of" False true in her essays on Petrarch and Vincenzo Monti. 6 [' true', ' in', ' her', ' essays', ' on', ' Petr', 'arch']
+931 199 Name of mother of x -1 Name of mother of Petrarch Eietta Canigiani Petrarch "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ ' Petr' 'arch' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Petr arch , and the
+
+ Name of" False poetry from Petrarch to Poliziano. 3 [' poetry', ' from', ' Petr', 'arch']
+932 199 Name of mother of x -1 Name of mother of Petrarch Eietta Canigiani Petrarch "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' and' ' Mrs' '.'
+ ' Petr' 'arch' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", the son of the late Mr . and Mrs . Petr arch , and the
+
+ Name of" False form, as developed by Petrarch and Dante, was 6 [' form', ',', ' as', ' developed', ' by', ' Petr', 'arch']
+933 200 Name of mother of x -1 Name of mother of Virginia Woolf Julia Stephen Virginia Woolf "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' '\n' '\n' 'I' ' am' ' not' ' sure' ' if' ' I']" ", and the other is the name of the mother of the
+
+ I am not sure if I" False book was enjoyed by Virginia Woolf and acted as a source 6 [' book', ' was', ' enjoyed', ' by', ' Virginia', ' Wool', 'f']
+934 200 Name of mother of x -1 Name of mother of Virginia Woolf Julia Stephen Virginia Woolf "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' '\n' '\n' 'I' ' am' ' not' ' sure' ' if' ' I']" ", and the other is the name of the mother of the
+
+ I am not sure if I" False of One's Own, Virginia Woolf created a 7 "[' of', ' One', ""'s"", ' Own', ',', ' Virginia', ' Wool', 'f']"
+935 200 Name of mother of x -1 Name of mother of Virginia Woolf Julia Stephen Virginia Woolf "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' '\n' '\n' 'I' ' am' ' not' ' sure' ' if' ' I']" ", and the other is the name of the mother of the
+
+ I am not sure if I" False in Who's Afraid of Virginia Woolf ?; as the episode 8 "[' in', ' Who', ""'s"", ' Af', 'raid', ' of', ' Virginia', ' Wool', 'f']"
+936 200 Name of mother of x -1 Name of mother of Virginia Woolf Julia Stephen Virginia Woolf "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' '\n' '\n' 'I' ' am' ' not' ' sure' ' if' ' I']" ", and the other is the name of the mother of the
+
+ I am not sure if I" False each other as Virginia Woolf and Emma Goldman 5 [' each', ' other', ' as', ' Virginia', ' Wool', 'f']
+937 200 Name of mother of x -1 Name of mother of Virginia Woolf Julia Stephen Virginia Woolf "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' '\n' '\n' 'I' ' am' ' not' ' sure' ' if' ' I']" ", and the other is the name of the mother of the
+
+ I am not sure if I" False Brooke's own life such as Virginia Woolf (then Virginia 8 "[' Brooke', ""'s"", ' own', ' life', ' such', ' as', ' Virginia', ' Wool', 'f']"
+938 201 Name of mother of x -1 Name of mother of C. S. Lewis Florence Augusta Hamilton C. S. Lewis "[',' ' the' ' author' ' of' ' the' ' N' 'arn' 'ia' ' Chronicles' ','
+ ' and' ' the' ' author' ' of' ' the' ' Chronicles' ' of' ' N' 'arn' 'ia']" , the author of the N arn ia Chronicles , and the author of the Chronicles of N arn ia False " lifelong love of C. S. Lewis. ""Even now," 7 [' lifelong', ' love', ' of', ' C', '.', ' S', '.', ' Lewis']
+939 201 Name of mother of x -1 Name of mother of C. S. Lewis Florence Augusta Hamilton C. S. Lewis "[',' ' the' ' author' ' of' ' the' ' N' 'arn' 'ia' ' Chronicles' ','
+ ' and' ' the' ' author' ' of' ' the' ' Chronicles' ' of' ' N' 'arn' 'ia']" , the author of the N arn ia Chronicles , and the author of the Chronicles of N arn ia False " lifelong love of C. S. Lewis. ""Even now, if" 7 [' lifelong', ' love', ' of', ' C', '.', ' S', '.', ' Lewis']
+940 201 Name of mother of x -1 Name of mother of C. S. Lewis Florence Augusta Hamilton C. S. Lewis "[',' ' the' ' author' ' of' ' the' ' N' 'arn' 'ia' ' Chronicles' ','
+ ' and' ' the' ' author' ' of' ' the' ' Chronicles' ' of' ' N' 'arn' 'ia']" , the author of the N arn ia Chronicles , and the author of the Chronicles of N arn ia False " a lifelong love of C. S. Lewis. ""Even now," 8 [' a', ' lifelong', ' love', ' of', ' C', '.', ' S', '.', ' Lewis']
+941 201 Name of mother of x -1 Name of mother of C. S. Lewis Florence Augusta Hamilton C. S. Lewis "[',' ' the' ' author' ' of' ' the' ' N' 'arn' 'ia' ' Chronicles' ','
+ ' and' ' the' ' author' ' of' ' the' ' Chronicles' ' of' ' N' 'arn' 'ia']" , the author of the N arn ia Chronicles , and the author of the Chronicles of N arn ia False chronological novel in C. S. Lewis's epic fantasy series, 7 [' chronological', ' novel', ' in', ' C', '.', ' S', '.', ' Lewis']
+942 201 Name of mother of x -1 Name of mother of C. S. Lewis Florence Augusta Hamilton C. S. Lewis "[',' ' the' ' author' ' of' ' the' ' N' 'arn' 'ia' ' Chronicles' ','
+ ' and' ' the' ' author' ' of' ' the' ' Chronicles' ' of' ' N' 'arn' 'ia']" , the author of the N arn ia Chronicles , and the author of the Chronicles of N arn ia False " Christianity, C. S. Lewis argued that ""conscience" 6 [' Christianity', ',', ' C', '.', ' S', '.', ' Lewis']
+943 202 Name of mother of x -1 Name of mother of Anton Chekhov Evgenia Chekhova Anton Chekhov "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '91' ' in'
+ ' T' 'agan' 'rog' ',' ' Russia' '.' ' He' ' was' ' the']" , the Russian writer , was born in 18 91 in T agan rog , Russia . He was the False 4 ['Ant', 'on', ' Che', 'kh', 'ov']
+944 202 Name of mother of x -1 Name of mother of Anton Chekhov Evgenia Chekhova Anton Chekhov "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '91' ' in'
+ ' T' 'agan' 'rog' ',' ' Russia' '.' ' He' ' was' ' the']" , the Russian writer , was born in 18 91 in T agan rog , Russia . He was the False 4 ['Ant', 'on', ' Che', 'kh', 'ov']
+945 202 Name of mother of x -1 Name of mother of Anton Chekhov Evgenia Chekhova Anton Chekhov "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '91' ' in'
+ ' T' 'agan' 'rog' ',' ' Russia' '.' ' He' ' was' ' the']" , the Russian writer , was born in 18 91 in T agan rog , Russia . He was the False " Chekhov =
+" 8 [' Che', 'kh', 'ov', ' =', 'Ant', 'on', ' Che', 'kh', 'ov']
+946 202 Name of mother of x -1 Name of mother of Anton Chekhov Evgenia Chekhova Anton Chekhov "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '91' ' in'
+ ' T' 'agan' 'rog' ',' ' Russia' '.' ' He' ' was' ' the']" , the Russian writer , was born in 18 91 in T agan rog , Russia . He was the False Baudelaire, Leo Tolstoy, Anton Chekhov and Sigmund Freud, 13 [' B', 'aud', 'el', 'aire', ',', ' Leo', ' Tol', 'st', 'oy', ',', ' Anton', ' Che', 'kh', 'ov']
+947 202 Name of mother of x -1 Name of mother of Anton Chekhov Evgenia Chekhova Anton Chekhov "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '91' ' in'
+ ' T' 'agan' 'rog' ',' ' Russia' '.' ' He' ' was' ' the']" , the Russian writer , was born in 18 91 in T agan rog , Russia . He was the False 4 ['Ant', 'on', ' Che', 'kh', 'ov']
+948 203 Name of mother of x -1 Name of mother of William Hogarth Anne Gibbons William Hogarth "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' William' ' Hog' 'arth' ',']" ", the father of the
+
+ The following is a list of the children of William Hog arth ," False work of friend William Hogarth and poetry by 5 [' work', ' of', ' friend', ' William', ' Hog', 'arth']
+949 203 Name of mother of x -1 Name of mother of William Hogarth Anne Gibbons William Hogarth "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' William' ' Hog' 'arth' ',']" ", the father of the
+
+ The following is a list of the children of William Hog arth ," False of satirists William Hogarth and Honoré Daumier. 5 [' of', ' satir', 'ists', ' William', ' Hog', 'arth']
+950 203 Name of mother of x -1 Name of mother of William Hogarth Anne Gibbons William Hogarth "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' William' ' Hog' 'arth' ',']" ", the father of the
+
+ The following is a list of the children of William Hog arth ," False co-founded in 1736 by William Hogarth at the Theatre 9 [' co', '-', 'founded', ' in', ' 17', '36', ' by', ' William', ' Hog', 'arth']
+951 203 Name of mother of x -1 Name of mother of William Hogarth Anne Gibbons William Hogarth "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' William' ' Hog' 'arth' ',']" ", the father of the
+
+ The following is a list of the children of William Hog arth ," False Marshalsea on 25 March. William Hogarth accompanied the 9 [' Marsh', 'alse', 'a', ' on', ' 25', ' March', '.', ' William', ' Hog', 'arth']
+952 203 Name of mother of x -1 Name of mother of William Hogarth Anne Gibbons William Hogarth "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' William' ' Hog' 'arth' ',']" ", the father of the
+
+ The following is a list of the children of William Hog arth ," False 1700 and artist William Hogarth resided at No 5 [' 1700', ' and', ' artist', ' William', ' Hog', 'arth']
+953 204 Name of mother of x -1 Name of mother of Alanis Morissette Georgia Feuerstein Alanis Morissette "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' J' '.' 'P' '.'
+ ' (' 'J' '.' 'P' '.' ' is' ' the' ' name']" , the singer , and her husband , J . P . ( J . P . is the name False during a tour with Alanis Morissette in September 8 [' during', ' a', ' tour', ' with', ' Alan', 'is', ' Mor', 'iss', 'ette']
+954 204 Name of mother of x -1 Name of mother of Alanis Morissette Georgia Feuerstein Alanis Morissette "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' J' '.' 'P' '.'
+ ' (' 'J' '.' 'P' '.' ' is' ' the' ' name']" , the singer , and her husband , J . P . ( J . P . is the name False team included Alanis Morissette and David Blaine. 6 [' team', ' included', ' Alan', 'is', ' Mor', 'iss', 'ette']
+955 204 Name of mother of x -1 Name of mother of Alanis Morissette Georgia Feuerstein Alanis Morissette "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' J' '.' 'P' '.'
+ ' (' 'J' '.' 'P' '.' ' is' ' the' ' name']" , the singer , and her husband , J . P . ( J . P . is the name False during a tour with Alanis Morissette in September 1996. 8 [' during', ' a', ' tour', ' with', ' Alan', 'is', ' Mor', 'iss', 'ette']
+956 204 Name of mother of x -1 Name of mother of Alanis Morissette Georgia Feuerstein Alanis Morissette "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' J' '.' 'P' '.'
+ ' (' 'J' '.' 'P' '.' ' is' ' the' ' name']" , the singer , and her husband , J . P . ( J . P . is the name False " singer songwriter Alanis Morissette and ""cracked" 7 [' singer', ' song', 'writer', ' Alan', 'is', ' Mor', 'iss', 'ette']
+957 204 Name of mother of x -1 Name of mother of Alanis Morissette Georgia Feuerstein Alanis Morissette "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' J' '.' 'P' '.'
+ ' (' 'J' '.' 'P' '.' ' is' ' the' ' name']" , the singer , and her husband , J . P . ( J . P . is the name False 4 ['Alan', 'is', ' Mor', 'iss', 'ette']
+958 205 Name of mother of x -1 Name of mother of Antonín Dvořák Anna Dvořáková Antonín Dvořák "[',' ' the' ' Czech' ' composer' ',' ' was' ' born' ' in' ' 18' '41' ' in'
+ ' the' ' Czech' ' Republic' '.' ' He' ' was' ' a' ' composer' ' of']" , the Czech composer , was born in 18 41 in the Czech Republic . He was a composer of False " E-Minor (B.19)"" by Antonín Dvořák are also present" 16 "[' E', '-', 'Minor', ' (', 'B', '.', '19', ')""', ' by', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']"
+959 205 Name of mother of x -1 Name of mother of Antonín Dvořák Anna Dvořáková Antonín Dvořák "[',' ' the' ' Czech' ' composer' ',' ' was' ' born' ' in' ' 18' '41' ' in'
+ ' the' ' Czech' ' Republic' '.' ' He' ' was' ' a' ' composer' ' of']" , the Czech composer , was born in 18 41 in the Czech Republic . He was a composer of False Johannes Brahms, Antonín Dvořák and Jules Massenet. 11 [' Johannes', ' Brah', 'ms', ',', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']
+960 205 Name of mother of x -1 Name of mother of Antonín Dvořák Anna Dvořáková Antonín Dvořák "[',' ' the' ' Czech' ' composer' ',' ' was' ' born' ' in' ' 18' '41' ' in'
+ ' the' ' Czech' ' Republic' '.' ' He' ' was' ' a' ' composer' ' of']" , the Czech composer , was born in 18 41 in the Czech Republic . He was a composer of False " E-Minor (B.19)"" by Antonín Dvořák are also present" 16 "[' E', '-', 'Minor', ' (', 'B', '.', '19', ')""', ' by', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']"
+961 205 Name of mother of x -1 Name of mother of Antonín Dvořák Anna Dvořáková Antonín Dvořák "[',' ' the' ' Czech' ' composer' ',' ' was' ' born' ' in' ' 18' '41' ' in'
+ ' the' ' Czech' ' Republic' '.' ' He' ' was' ' a' ' composer' ' of']" , the Czech composer , was born in 18 41 in the Czech Republic . He was a composer of False " (B.19)"" by Antonín Dvořák are also present" 13 "[' (', 'B', '.', '19', ')""', ' by', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']"
+962 205 Name of mother of x -1 Name of mother of Antonín Dvořák Anna Dvořáková Antonín Dvořák "[',' ' the' ' Czech' ' composer' ',' ' was' ' born' ' in' ' 18' '41' ' in'
+ ' the' ' Czech' ' Republic' '.' ' He' ' was' ' a' ' composer' ' of']" , the Czech composer , was born in 18 41 in the Czech Republic . He was a composer of False " renowned composer Antonín Dvořák wrote his Ninth (""New" 9 [' renowned', ' composer', ' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k']
+963 206 Name of mother of x -1 Name of mother of Alec Baldwin Carol Baldwin Alec Baldwin "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Alec' ' Baldwin']" ", who is a very good friend of mine .
+
+ I am a big fan of Alec Baldwin" False on October 4, 2012. Alec Baldwin reportedly approached 7 [' on', ' October', ' 4', ',', ' 2012', '.', ' Alec', ' Baldwin']
+964 206 Name of mother of x -1 Name of mother of Alec Baldwin Carol Baldwin Alec Baldwin "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Alec' ' Baldwin']" ", who is a very good friend of mine .
+
+ I am a big fan of Alec Baldwin" False all, however, was the Alec Baldwin – Julianne Moore 7 [' all', ',', ' however', ',', ' was', ' the', ' Alec', ' Baldwin']
+965 206 Name of mother of x -1 Name of mother of Alec Baldwin Carol Baldwin Alec Baldwin "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Alec' ' Baldwin']" ", who is a very good friend of mine .
+
+ I am a big fan of Alec Baldwin" False funniest things Alec Baldwin has done on the show 4 [' funn', 'iest', ' things', ' Alec', ' Baldwin']
+966 206 Name of mother of x -1 Name of mother of Alec Baldwin Carol Baldwin Alec Baldwin "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Alec' ' Baldwin']" ", who is a very good friend of mine .
+
+ I am a big fan of Alec Baldwin" False Night Live. Alec Baldwin has also hosted Saturday 4 [' Night', ' Live', '.', ' Alec', ' Baldwin']
+967 206 Name of mother of x -1 Name of mother of Alec Baldwin Carol Baldwin Alec Baldwin "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Alec' ' Baldwin']" ", who is a very good friend of mine .
+
+ I am a big fan of Alec Baldwin" False " Baldwin as Gray Edwards
+" 6 [' Baldwin', ' as', ' Gray', ' Edwards', 'Ale', 'c', ' Baldwin']
+968 207 Name of mother of x -1 Name of mother of Felix Mendelssohn Lea Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' son' ' of' ' the' ' famous'
+ ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ' in']" - Bar th old y , the son of the famous composer , was born in Le ip zig in False von Weber and Felix Mendelssohn were important 7 [' von', ' Weber', ' and', ' Felix', ' Mend', 'els', 'so', 'hn']
+969 207 Name of mother of x -1 Name of mother of Felix Mendelssohn Lea Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' son' ' of' ' the' ' famous'
+ ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ' in']" - Bar th old y , the son of the famous composer , was born in Le ip zig in False " E.S. Engelsberg, Felix Mendelssohn and Richard Wagner.
+" 11 [' E', '.', 'S', '.', ' Engels', 'berg', ',', ' Felix', ' Mend', 'els', 'so', 'hn']
+970 207 Name of mother of x -1 Name of mother of Felix Mendelssohn Lea Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' son' ' of' ' the' ' famous'
+ ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ' in']" - Bar th old y , the son of the famous composer , was born in Le ip zig in False Palace. The composer Felix Mendelssohn is known to have 8 [' Palace', '.', ' The', ' composer', ' Felix', ' Mend', 'els', 'so', 'hn']
+971 207 Name of mother of x -1 Name of mother of Felix Mendelssohn Lea Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' son' ' of' ' the' ' famous'
+ ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ' in']" - Bar th old y , the son of the famous composer , was born in Le ip zig in False " Mendelssohn =
+" 10 [' Mend', 'els', 'so', 'hn', ' =', 'Fel', 'ix', ' Mend', 'els', 'so', 'hn']
+972 207 Name of mother of x -1 Name of mother of Felix Mendelssohn Lea Mendelssohn Bartholdy Felix Mendelssohn "['-' 'Bar' 'th' 'old' 'y' ',' ' the' ' son' ' of' ' the' ' famous'
+ ' composer' ',' ' was' ' born' ' in' ' Le' 'ip' 'zig' ' in']" - Bar th old y , the son of the famous composer , was born in Le ip zig in False 5 ['Fel', 'ix', ' Mend', 'els', 'so', 'hn']
+973 208 Name of mother of x -1 Name of mother of Hillary Clinton Dorothy Howell Rodham Hillary Clinton "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' her'
+ ' husband' ',' ' former' ' president' ' Bill' ' Clinton' ',' ' are'
+ ' both' ' in' ' the']" , the former secretary of state , and her husband , former president Bill Clinton , are both in the False by the time of Hillary Clinton's 2008 presidential 5 [' by', ' the', ' time', ' of', ' Hillary', ' Clinton']
+974 208 Name of mother of x -1 Name of mother of Hillary Clinton Dorothy Howell Rodham Hillary Clinton "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' her'
+ ' husband' ',' ' former' ' president' ' Bill' ' Clinton' ',' ' are'
+ ' both' ' in' ' the']" , the former secretary of state , and her husband , former president Bill Clinton , are both in the False with U.S. Senator Hillary Clinton and sorority member 7 [' with', ' U', '.', 'S', '.', ' Senator', ' Hillary', ' Clinton']
+975 208 Name of mother of x -1 Name of mother of Hillary Clinton Dorothy Howell Rodham Hillary Clinton "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' her'
+ ' husband' ',' ' former' ' president' ' Bill' ' Clinton' ',' ' are'
+ ' both' ' in' ' the']" , the former secretary of state , and her husband , former president Bill Clinton , are both in the False Former state senator Hillary Clinton has used the recording 4 [' Former', ' state', ' senator', ' Hillary', ' Clinton']
+976 208 Name of mother of x -1 Name of mother of Hillary Clinton Dorothy Howell Rodham Hillary Clinton "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' her'
+ ' husband' ',' ' former' ' president' ' Bill' ' Clinton' ',' ' are'
+ ' both' ' in' ' the']" , the former secretary of state , and her husband , former president Bill Clinton , are both in the False 2 was condemned by Hillary Clinton over fears that children 5 [' 2', ' was', ' condemned', ' by', ' Hillary', ' Clinton']
+977 208 Name of mother of x -1 Name of mother of Hillary Clinton Dorothy Howell Rodham Hillary Clinton "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' and' ' her'
+ ' husband' ',' ' former' ' president' ' Bill' ' Clinton' ',' ' are'
+ ' both' ' in' ' the']" , the former secretary of state , and her husband , former president Bill Clinton , are both in the False Secretary of State Hillary Clinton called Bahrain ’ s 4 [' Secretary', ' of', ' State', ' Hillary', ' Clinton']
+978 209 Name of mother of x -1 Name of mother of Rita Hayworth Volga Hayworth Rita Hayworth "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False as those played by Rita Hayworth in Gilda (1946), Lana 6 [' as', ' those', ' played', ' by', ' Rita', ' Hay', 'worth']
+979 209 Name of mother of x -1 Name of mother of Rita Hayworth Volga Hayworth Rita Hayworth "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Puig's Betrayed by Rita Hayworth (1971). A crucial 9 "[' Pu', 'ig', ""'s"", ' Bet', 'ray', 'ed', ' by', ' Rita', ' Hay', 'worth']"
+980 209 Name of mother of x -1 Name of mother of Rita Hayworth Volga Hayworth Rita Hayworth "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False gloves like Rita Hayworth in Gilda and points 4 [' gloves', ' like', ' Rita', ' Hay', 'worth']
+981 209 Name of mother of x -1 Name of mother of Rita Hayworth Volga Hayworth Rita Hayworth "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False those played by Rita Hayworth in Gilda (1946), 5 [' those', ' played', ' by', ' Rita', ' Hay', 'worth']
+982 209 Name of mother of x -1 Name of mother of Rita Hayworth Volga Hayworth Rita Hayworth "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False gloves like Rita Hayworth in Gilda and points 4 [' gloves', ' like', ' Rita', ' Hay', 'worth']
+983 210 Name of mother of x -1 Name of mother of Charlotte Brontë Maria Branwell Charlotte Brontë "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Charlotte'
+ ' Br' 'ont' 'ë' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", and the
+
+ Name of mother of Charlotte Br ont ë , and the
+
+ Name of" False Ballycloran and Charlotte Brontë who married a 8 [' B', 'ally', 'cl', 'oran', ' and', ' Charlotte', ' Br', 'ont', 'ë']
+984 210 Name of mother of x -1 Name of mother of Charlotte Brontë Maria Branwell Charlotte Brontë "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Charlotte'
+ ' Br' 'ont' 'ë' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", and the
+
+ Name of mother of Charlotte Br ont ë , and the
+
+ Name of" False Ballycloran and Charlotte Brontë who married a curate 8 [' B', 'ally', 'cl', 'oran', ' and', ' Charlotte', ' Br', 'ont', 'ë']
+985 210 Name of mother of x -1 Name of mother of Charlotte Brontë Maria Branwell Charlotte Brontë "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Charlotte'
+ ' Br' 'ont' 'ë' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", and the
+
+ Name of mother of Charlotte Br ont ë , and the
+
+ Name of" False presented papers on Charlotte Brontë and Samuel 6 [' presented', ' papers', ' on', ' Charlotte', ' Br', 'ont', 'ë']
+986 210 Name of mother of x -1 Name of mother of Charlotte Brontë Maria Branwell Charlotte Brontë "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Charlotte'
+ ' Br' 'ont' 'ë' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", and the
+
+ Name of mother of Charlotte Br ont ë , and the
+
+ Name of" False presented papers on Charlotte Brontë and Samuel Taylor 6 [' presented', ' papers', ' on', ' Charlotte', ' Br', 'ont', 'ë']
+987 210 Name of mother of x -1 Name of mother of Charlotte Brontë Maria Branwell Charlotte Brontë "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Charlotte'
+ ' Br' 'ont' 'ë' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" ", and the
+
+ Name of mother of Charlotte Br ont ë , and the
+
+ Name of" False Shorter prepared to write Charlotte Brontë and Her Circle, 8 [' Sh', 'orter', ' prepared', ' to', ' write', ' Charlotte', ' Br', 'ont', 'ë']
+988 211 Name of mother of x -1 Name of mother of Catherine Deneuve Renée Simonot Catherine Deneuve "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Hunger (1983), with Catherine Deneuve and Susan Sarandon. 8 [' Hunger', ' (', '1983', '),', ' with', ' Catherine', ' D', 'ene', 'uve']
+989 211 Name of mother of x -1 Name of mother of Catherine Deneuve Renée Simonot Catherine Deneuve "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Mankiewicz claims that Catherine Deneuve wanted to 8 [' M', 'ank', 'iewicz', ' claims', ' that', ' Catherine', ' D', 'ene', 'uve']
+990 211 Name of mother of x -1 Name of mother of Catherine Deneuve Renée Simonot Catherine Deneuve "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False (1983), with Catherine Deneuve and Susan Sarandon. 7 [' (', '1983', '),', ' with', ' Catherine', ' D', 'ene', 'uve']
+991 211 Name of mother of x -1 Name of mother of Catherine Deneuve Renée Simonot Catherine Deneuve "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False claims that Catherine Deneuve wanted to play the 5 [' claims', ' that', ' Catherine', ' D', 'ene', 'uve']
+992 211 Name of mother of x -1 Name of mother of Catherine Deneuve Renée Simonot Catherine Deneuve "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Hunger (1983), with Catherine Deneuve and Susan 8 [' Hunger', ' (', '1983', '),', ' with', ' Catherine', ' D', 'ene', 'uve']
+993 212 Name of mother of x -1 Name of mother of Dwight D. Eisenhower Ida Stover Eisenhower Dwight D. Eisenhower "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' was' ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' was']" ", the first president of the United States , was born in this house .
+
+ The house was" False London. General Dwight D. Eisenhower took a suite on 6 [' London', '.', ' General', ' Dwight', ' D', '.', ' Eisenhower']
+994 212 Name of mother of x -1 Name of mother of Dwight D. Eisenhower Ida Stover Eisenhower Dwight D. Eisenhower "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' was' ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' was']" ", the first president of the United States , was born in this house .
+
+ The house was" False by president Dwight D. Eisenhower to the United Nations, 5 [' by', ' president', ' Dwight', ' D', '.', ' Eisenhower']
+995 212 Name of mother of x -1 Name of mother of Dwight D. Eisenhower Ida Stover Eisenhower Dwight D. Eisenhower "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' was' ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' was']" ", the first president of the United States , was born in this house .
+
+ The house was" False was renamed to honor Dwight D. Eisenhower and Edwin C. Johnson. 7 [' was', ' renamed', ' to', ' honor', ' Dwight', ' D', '.', ' Eisenhower']
+996 212 Name of mother of x -1 Name of mother of Dwight D. Eisenhower Ida Stover Eisenhower Dwight D. Eisenhower "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' was' ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' was']" ", the first president of the United States , was born in this house .
+
+ The house was" False direction of General Dwight D. Eisenhower (Supreme Commander 6 [' direction', ' of', ' General', ' Dwight', ' D', '.', ' Eisenhower']
+997 212 Name of mother of x -1 Name of mother of Dwight D. Eisenhower Ida Stover Eisenhower Dwight D. Eisenhower "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' was' ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' was']" ", the first president of the United States , was born in this house .
+
+ The house was" False " commander"". General Dwight D. Eisenhower called him ""the" 6 "[' commander', '"".', ' General', ' Dwight', ' D', '.', ' Eisenhower']"
+998 213 Name of mother of x -1 Name of mother of Mario Vargas Llosa Dora Llosa Ureta Mario Vargas Llosa "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' The' ' Per' 'uvian' ' writer' ',' ' who' ' has' ' been' ' living']" , the author of the book , and the film , The Per uvian writer , who has been living False " Vargas Llosa =
+" 9 [' Varg', 'as', ' Ll', 'osa', ' =', 'Mario', ' Varg', 'as', ' Ll', 'osa']
+999 213 Name of mother of x -1 Name of mother of Mario Vargas Llosa Dora Llosa Ureta Mario Vargas Llosa "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' The' ' Per' 'uvian' ' writer' ',' ' who' ' has' ' been' ' living']" , the author of the book , and the film , The Per uvian writer , who has been living False " Llosa =
+" 7 [' Ll', 'osa', ' =', 'Mario', ' Varg', 'as', ' Ll', 'osa']
+1000 213 Name of mother of x -1 Name of mother of Mario Vargas Llosa Dora Llosa Ureta Mario Vargas Llosa "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' The' ' Per' 'uvian' ' writer' ',' ' who' ' has' ' been' ' living']" , the author of the book , and the film , The Per uvian writer , who has been living False " Mario Vargas Llosa =
+" 4 [' Mario', ' Varg', 'as', ' Ll', 'osa']
+1001 213 Name of mother of x -1 Name of mother of Mario Vargas Llosa Dora Llosa Ureta Mario Vargas Llosa "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' The' ' Per' 'uvian' ' writer' ',' ' who' ' has' ' been' ' living']" , the author of the book , and the film , The Per uvian writer , who has been living False commonly known as Mario Vargas Llosa (Spanish: [ˈmaɾjo 7 [' commonly', ' known', ' as', ' Mario', ' Varg', 'as', ' Ll', 'osa']
+1002 213 Name of mother of x -1 Name of mother of Mario Vargas Llosa Dora Llosa Ureta Mario Vargas Llosa "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' The' ' Per' 'uvian' ' writer' ',' ' who' ' has' ' been' ' living']" , the author of the book , and the film , The Per uvian writer , who has been living False laureate writer Mario Vargas Llosa has the protagonist 6 [' laureate', ' writer', ' Mario', ' Varg', 'as', ' Ll', 'osa']
+1003 214 Name of mother of x -1 Name of mother of Helen Mirren Kitty Rogers Helen Mirren "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' actress' '.' '\n' '\n' 'I' ' have' ' seen' ' her' ' in' ' many']" ", who is a great actress , and a great actress .
+
+ I have seen her in many" False Hopkins as Hitchcock, Helen Mirren as Reville, 6 [' Hopkins', ' as', ' Hitchcock', ',', ' Helen', ' Mir', 'ren']
+1004 214 Name of mother of x -1 Name of mother of Helen Mirren Kitty Rogers Helen Mirren "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' actress' '.' '\n' '\n' 'I' ' have' ' seen' ' her' ' in' ' many']" ", who is a great actress , and a great actress .
+
+ I have seen her in many" False Taylor-Compton alongside Helen Mirren and Joe Pesci in Love 7 [' Taylor', '-', 'Com', 'pton', ' alongside', ' Helen', ' Mir', 'ren']
+1005 214 Name of mother of x -1 Name of mother of Helen Mirren Kitty Rogers Helen Mirren "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' actress' '.' '\n' '\n' 'I' ' have' ' seen' ' her' ' in' ' many']" ", who is a great actress , and a great actress .
+
+ I have seen her in many" False " Jane Lapotaire, Helen Mirren and Janet Suzman.
+" 7 [' Jane', ' Lap', 'ota', 'ire', ',', ' Helen', ' Mir', 'ren']
+1006 214 Name of mother of x -1 Name of mother of Helen Mirren Kitty Rogers Helen Mirren "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' actress' '.' '\n' '\n' 'I' ' have' ' seen' ' her' ' in' ' many']" ", who is a great actress , and a great actress .
+
+ I have seen her in many" False Hopkins as Hitchcock, Helen Mirren as Reville, Scarlett 6 [' Hopkins', ' as', ' Hitchcock', ',', ' Helen', ' Mir', 'ren']
+1007 214 Name of mother of x -1 Name of mother of Helen Mirren Kitty Rogers Helen Mirren "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' actress' '.' '\n' '\n' 'I' ' have' ' seen' ' her' ' in' ' many']" ", who is a great actress , and a great actress .
+
+ I have seen her in many" False produced since the star Helen Mirren quit in 1995. 6 [' produced', ' since', ' the', ' star', ' Helen', ' Mir', 'ren']
+1008 215 Name of mother of x -1 Name of mother of Charlotte Gainsbourg Jane Birkin Charlotte Gainsbourg "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' her' ' work']" ", who is a very good friend of mine .
+
+ I have been a fan of her work" False John's wife. Charlotte Gainsbourg was originally 7 "[' John', ""'s"", ' wife', '.', ' Charlotte', ' G', 'ains', 'bourg']"
+1009 215 Name of mother of x -1 Name of mother of Charlotte Gainsbourg Jane Birkin Charlotte Gainsbourg "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' her' ' work']" ", who is a very good friend of mine .
+
+ I have been a fan of her work" False collaborated with Charlotte Gainsbourg on her album IRM, 5 [' collaborated', ' with', ' Charlotte', ' G', 'ains', 'bourg']
+1010 215 Name of mother of x -1 Name of mother of Charlotte Gainsbourg Jane Birkin Charlotte Gainsbourg "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' her' ' work']" ", who is a very good friend of mine .
+
+ I have been a fan of her work" False collaborated with Charlotte Gainsbourg on her album 5 [' collaborated', ' with', ' Charlotte', ' G', 'ains', 'bourg']
+1011 215 Name of mother of x -1 Name of mother of Charlotte Gainsbourg Jane Birkin Charlotte Gainsbourg "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' her' ' work']" ", who is a very good friend of mine .
+
+ I have been a fan of her work" False Connor, John's wife. Charlotte Gainsbourg was originally set 9 "[' Connor', ',', ' John', ""'s"", ' wife', '.', ' Charlotte', ' G', 'ains', 'bourg']"
+1012 215 Name of mother of x -1 Name of mother of Charlotte Gainsbourg Jane Birkin Charlotte Gainsbourg "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' her' ' work']" ", who is a very good friend of mine .
+
+ I have been a fan of her work" False Connor, John's wife. Charlotte Gainsbourg was originally 9 "[' Connor', ',', ' John', ""'s"", ' wife', '.', ' Charlotte', ' G', 'ains', 'bourg']"
+1013 216 Name of mother of x -1 Name of mother of Angela Lansbury Moyna Macgill Angela Lansbury "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' British' ' Parliament'
+ ',' ' and' ' a' ' member' ' of' ' the' ' British' ' Parliament' ','
+ ' and' ' a']" , who was a member of the British Parliament , and a member of the British Parliament , and a False Menken asked Angela Lansbury to perform 5 [' Men', 'ken', ' asked', ' Angela', ' Lans', 'bury']
+1014 216 Name of mother of x -1 Name of mother of Angela Lansbury Moyna Macgill Angela Lansbury "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' British' ' Parliament'
+ ',' ' and' ' a' ' member' ' of' ' the' ' British' ' Parliament' ','
+ ' and' ' a']" , who was a member of the British Parliament , and a member of the British Parliament , and a False Stars, featuring Angela Lansbury walking on hot 5 [' Stars', ',', ' featuring', ' Angela', ' Lans', 'bury']
+1015 216 Name of mother of x -1 Name of mother of Angela Lansbury Moyna Macgill Angela Lansbury "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' British' ' Parliament'
+ ',' ' and' ' a' ' member' ' of' ' the' ' British' ' Parliament' ','
+ ' and' ' a']" , who was a member of the British Parliament , and a member of the British Parliament , and a False people (Zeta-Jones, Angela Lansbury and Alexander Hanson) 9 [' people', ' (', 'Z', 'eta', '-', 'Jones', ',', ' Angela', ' Lans', 'bury']
+1016 216 Name of mother of x -1 Name of mother of Angela Lansbury Moyna Macgill Angela Lansbury "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' British' ' Parliament'
+ ',' ' and' ' a' ' member' ' of' ' the' ' British' ' Parliament' ','
+ ' and' ' a']" , who was a member of the British Parliament , and a member of the British Parliament , and a False " Our Guest"", sung by Angela Lansbury in the 1991" 7 "[' Our', ' Guest', '"",', ' sung', ' by', ' Angela', ' Lans', 'bury']"
+1017 216 Name of mother of x -1 Name of mother of Angela Lansbury Moyna Macgill Angela Lansbury "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' British' ' Parliament'
+ ',' ' and' ' a' ' member' ' of' ' the' ' British' ' Parliament' ','
+ ' and' ' a']" , who was a member of the British Parliament , and a member of the British Parliament , and a False Menken asked Angela Lansbury to perform the 5 [' Men', 'ken', ' asked', ' Angela', ' Lans', 'bury']
+1018 217 Name of mother of x -1 Name of mother of Thomas Jefferson Jane Randolph Jefferson Thomas Jefferson "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Thomas']" ", the first president of the United States .
+
+ The first president of the United States , Thomas" False unpopular astronomer Thomas Jefferson Jackson See), 3 [' unpopular', ' astronomer', ' Thomas', ' Jefferson']
+1019 217 Name of mother of x -1 Name of mother of Thomas Jefferson Jane Randolph Jefferson Thomas Jefferson "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Thomas']" ", the first president of the United States .
+
+ The first president of the United States , Thomas" False under President Thomas Jefferson was present at 3 [' under', ' President', ' Thomas', ' Jefferson']
+1020 217 Name of mother of x -1 Name of mother of Thomas Jefferson Jane Randolph Jefferson Thomas Jefferson "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Thomas']" ", the first president of the United States .
+
+ The first president of the United States , Thomas" False extent that Thomas Jefferson personally deposited 3 [' extent', ' that', ' Thomas', ' Jefferson']
+1021 217 Name of mother of x -1 Name of mother of Thomas Jefferson Jane Randolph Jefferson Thomas Jefferson "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Thomas']" ", the first president of the United States .
+
+ The first president of the United States , Thomas" False president, Thomas Jefferson directed that plans 3 [' president', ',', ' Thomas', ' Jefferson']
+1022 217 Name of mother of x -1 Name of mother of Thomas Jefferson Jane Randolph Jefferson Thomas Jefferson "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Thomas']" ", the first president of the United States .
+
+ The first president of the United States , Thomas" False oldest child of Thomas Jefferson Butcher and 4 [' oldest', ' child', ' of', ' Thomas', ' Jefferson']
+1023 218 Name of mother of x -1 Name of mother of Nikolai Gogol Mariia Hohol Nikolai Gogol "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '09' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' brothers' ',']" ", the Russian writer , was born in 18 09 .
+
+ The first of the three brothers ," False comparison with Nikolai Gogol (one of his favourite 6 [' comparison', ' with', ' Nikol', 'ai', ' G', 'og', 'ol']
+1024 218 Name of mother of x -1 Name of mother of Nikolai Gogol Mariia Hohol Nikolai Gogol "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '09' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' brothers' ',']" ", the Russian writer , was born in 18 09 .
+
+ The first of the three brothers ," False Russian novelist Nikolai Gogol in Venice. He then 6 [' Russian', ' novelist', ' Nikol', 'ai', ' G', 'og', 'ol']
+1025 218 Name of mother of x -1 Name of mother of Nikolai Gogol Mariia Hohol Nikolai Gogol "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '09' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' brothers' ',']" ", the Russian writer , was born in 18 09 .
+
+ The first of the three brothers ," False comparison with Nikolai Gogol (one of his favourite 6 [' comparison', ' with', ' Nikol', 'ai', ' G', 'og', 'ol']
+1026 218 Name of mother of x -1 Name of mother of Nikolai Gogol Mariia Hohol Nikolai Gogol "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '09' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' brothers' ',']" ", the Russian writer , was born in 18 09 .
+
+ The first of the three brothers ," False comparison with Nikolai Gogol (one of his favourite 6 [' comparison', ' with', ' Nikol', 'ai', ' G', 'og', 'ol']
+1027 218 Name of mother of x -1 Name of mother of Nikolai Gogol Mariia Hohol Nikolai Gogol "[',' ' the' ' Russian' ' writer' ',' ' was' ' born' ' in' ' 18' '09' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' brothers' ',']" ", the Russian writer , was born in 18 09 .
+
+ The first of the three brothers ," False met Russian novelist Nikolai Gogol in Venice. He then 7 [' met', ' Russian', ' novelist', ' Nikol', 'ai', ' G', 'og', 'ol']
+1028 219 Name of mother of x -1 Name of mother of Kobe Bryant Pam Bryant Kobe Bryant "[',' ' the' ' Lakers' ""'"" ' star' ' player' ',' ' was' ' arrested' ' for'
+ ' DUI' ' in' ' 2003' '.' '\n' '\n' 'The' ' Lakers' ""'"" ' star']" ", the Lakers ' star player , was arrested for DUI in 2003 .
+
+ The Lakers ' star" False " === 1996 – 2016: The Kobe Bryant era ===
+" 7 [' ===', ' 1996', ' –', ' 2016', ':', ' The', ' Kobe', ' Bryant']
+1029 219 Name of mother of x -1 Name of mother of Kobe Bryant Pam Bryant Kobe Bryant "[',' ' the' ' Lakers' ""'"" ' star' ' player' ',' ' was' ' arrested' ' for'
+ ' DUI' ' in' ' 2003' '.' '\n' '\n' 'The' ' Lakers' ""'"" ' star']" ", the Lakers ' star player , was arrested for DUI in 2003 .
+
+ The Lakers ' star" False the season when Kobe Bryant posted 60 4 [' the', ' season', ' when', ' Kobe', ' Bryant']
+1030 219 Name of mother of x -1 Name of mother of Kobe Bryant Pam Bryant Kobe Bryant "[',' ' the' ' Lakers' ""'"" ' star' ' player' ',' ' was' ' arrested' ' for'
+ ' DUI' ' in' ' 2003' '.' '\n' '\n' 'The' ' Lakers' ""'"" ' star']" ", the Lakers ' star player , was arrested for DUI in 2003 .
+
+ The Lakers ' star" False first team since Kobe Bryant and Shaquille 4 [' first', ' team', ' since', ' Kobe', ' Bryant']
+1031 219 Name of mother of x -1 Name of mother of Kobe Bryant Pam Bryant Kobe Bryant "[',' ' the' ' Lakers' ""'"" ' star' ' player' ',' ' was' ' arrested' ' for'
+ ' DUI' ' in' ' 2003' '.' '\n' '\n' 'The' ' Lakers' ""'"" ' star']" ", the Lakers ' star player , was arrested for DUI in 2003 .
+
+ The Lakers ' star" False the night that Kobe Bryant made his farewell 4 [' the', ' night', ' that', ' Kobe', ' Bryant']
+1032 219 Name of mother of x -1 Name of mother of Kobe Bryant Pam Bryant Kobe Bryant "[',' ' the' ' Lakers' ""'"" ' star' ' player' ',' ' was' ' arrested' ' for'
+ ' DUI' ' in' ' 2003' '.' '\n' '\n' 'The' ' Lakers' ""'"" ' star']" ", the Lakers ' star player , was arrested for DUI in 2003 .
+
+ The Lakers ' star" False questioning if Kobe Bryant could shoot a jump 3 [' questioning', ' if', ' Kobe', ' Bryant']
+1033 220 Name of mother of x -1 Name of mother of Antoine de Saint-Exupéry Marie Louise Andrée Boyer de Fonscolombe Antoine de Saint-Exupéry "[',' ' the' ' French' ' av' 'iator' ' and' ' author' ' of' ' the'
+ ' classic' ' _' 'The' ' Little' ' Prince' '_' ',' ' who' ' died' ' in'
+ ' 1944']" , the French av iator and author of the classic _ The Little Prince _ , who died in 1944 False Deserta (1888). Antoine de Saint-Exupéry described both 14 [' Desert', 'a', ' (', '18', '88', ').', ' Ant', 'oine', ' de', ' Saint', '-', 'Ex', 'up', 'é', 'ry']
+1034 220 Name of mother of x -1 Name of mother of Antoine de Saint-Exupéry Marie Louise Andrée Boyer de Fonscolombe Antoine de Saint-Exupéry "[',' ' the' ' French' ' av' 'iator' ' and' ' author' ' of' ' the'
+ ' classic' ' _' 'The' ' Little' ' Prince' '_' ',' ' who' ' died' ' in'
+ ' 1944']" , the French av iator and author of the classic _ The Little Prince _ , who died in 1944 False for the future, Antoine de Saint-Exupéry writes of how deeply 12 [' for', ' the', ' future', ',', ' Ant', 'oine', ' de', ' Saint', '-', 'Ex', 'up', 'é', 'ry']
+1035 220 Name of mother of x -1 Name of mother of Antoine de Saint-Exupéry Marie Louise Andrée Boyer de Fonscolombe Antoine de Saint-Exupéry "[',' ' the' ' French' ' av' 'iator' ' and' ' author' ' of' ' the'
+ ' classic' ' _' 'The' ' Little' ' Prince' '_' ',' ' who' ' died' ' in'
+ ' 1944']" , the French av iator and author of the classic _ The Little Prince _ , who died in 1944 False Arabia Deserta (1888). Antoine de Saint-Exupéry described 15 [' Arabia', ' Desert', 'a', ' (', '18', '88', ').', ' Ant', 'oine', ' de', ' Saint', '-', 'Ex', 'up', 'é', 'ry']
+1036 220 Name of mother of x -1 Name of mother of Antoine de Saint-Exupéry Marie Louise Andrée Boyer de Fonscolombe Antoine de Saint-Exupéry "[',' ' the' ' French' ' av' 'iator' ' and' ' author' ' of' ' the'
+ ' classic' ' _' 'The' ' Little' ' Prince' '_' ',' ' who' ' died' ' in'
+ ' 1944']" , the French av iator and author of the classic _ The Little Prince _ , who died in 1944 False hopes for the future, Antoine de Saint-Exupéry writes of how 13 [' hopes', ' for', ' the', ' future', ',', ' Ant', 'oine', ' de', ' Saint', '-', 'Ex', 'up', 'é', 'ry']
+1037 220 Name of mother of x -1 Name of mother of Antoine de Saint-Exupéry Marie Louise Andrée Boyer de Fonscolombe Antoine de Saint-Exupéry "[',' ' the' ' French' ' av' 'iator' ' and' ' author' ' of' ' the'
+ ' classic' ' _' 'The' ' Little' ' Prince' '_' ',' ' who' ' died' ' in'
+ ' 1944']" , the French av iator and author of the classic _ The Little Prince _ , who died in 1944 False Deserta (1888). Antoine de Saint-Exupéry described both 14 [' Desert', 'a', ' (', '18', '88', ').', ' Ant', 'oine', ' de', ' Saint', '-', 'Ex', 'up', 'é', 'ry']
+1038 221 Name of mother of x -1 Name of mother of Mary Shelley Mary Wollstonecraft Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' Frankenstein' ',' ' and' ' the' ' author' ' of' ' Frankenstein'
+ ',' ' and' ' the']" , the author of Frankenstein , and the author of Frankenstein , and the author of Frankenstein , and the False money. Although Mary Shelley wrote twenty-one 4 [' money', '.', ' Although', ' Mary', ' Shelley']
+1039 221 Name of mother of x -1 Name of mother of Mary Shelley Mary Wollstonecraft Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' Frankenstein' ',' ' and' ' the' ' author' ' of' ' Frankenstein'
+ ',' ' and' ' the']" , the author of Frankenstein , and the author of Frankenstein , and the author of Frankenstein , and the False " straight out of a Mary Shelley novel."" The cutscene" 5 [' straight', ' out', ' of', ' a', ' Mary', ' Shelley']
+1040 221 Name of mother of x -1 Name of mother of Mary Shelley Mary Wollstonecraft Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' Frankenstein' ',' ' and' ' the' ' author' ' of' ' Frankenstein'
+ ',' ' and' ' the']" , the author of Frankenstein , and the author of Frankenstein , and the author of Frankenstein , and the False second trip, Mary Shelley spent time in 4 [' second', ' trip', ',', ' Mary', ' Shelley']
+1041 221 Name of mother of x -1 Name of mother of Mary Shelley Mary Wollstonecraft Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' Frankenstein' ',' ' and' ' the' ' author' ' of' ' Frankenstein'
+ ',' ' and' ' the']" , the author of Frankenstein , and the author of Frankenstein , and the author of Frankenstein , and the False at Bagnacavallo. Mary Shelley was distracted 8 [' at', ' B', 'agn', 'ac', 'av', 'allo', '.', ' Mary', ' Shelley']
+1042 221 Name of mother of x -1 Name of mother of Mary Shelley Mary Wollstonecraft Mary Shelley "[',' ' the' ' author' ' of' ' Frankenstein' ',' ' and' ' the' ' author'
+ ' of' ' Frankenstein' ',' ' and' ' the' ' author' ' of' ' Frankenstein'
+ ',' ' and' ' the']" , the author of Frankenstein , and the author of Frankenstein , and the author of Frankenstein , and the False " incarnate romance,"" Mary Shelley recalled in 1826." 5 "[' incarn', 'ate', ' romance', ',""', ' Mary', ' Shelley']"
+1043 223 Name of mother of x -1 Name of mother of Sophia Loren Romilda Villani Sophia Loren "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Gwyneth Paltrow, Sophia Loren and Celine Dion) 8 [' G', 'wyn', 'eth', ' P', 'alt', 'row', ',', ' Sophia', ' Loren']
+1044 223 Name of mother of x -1 Name of mother of Sophia Loren Romilda Villani Sophia Loren "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Jerry Lewis and Sophia Loren before having 4 [' Jerry', ' Lewis', ' and', ' Sophia', ' Loren']
+1045 223 Name of mother of x -1 Name of mother of Sophia Loren Romilda Villani Sophia Loren "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False famous picture of Sophia Loren staring at Mansfield's 4 [' famous', ' picture', ' of', ' Sophia', ' Loren']
+1046 223 Name of mother of x -1 Name of mother of Sophia Loren Romilda Villani Sophia Loren "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False compared her to Sophia Loren and described her 4 [' compared', ' her', ' to', ' Sophia', ' Loren']
+1047 223 Name of mother of x -1 Name of mother of Sophia Loren Romilda Villani Sophia Loren "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False (1994), compared her to Sophia Loren and described 7 [' (', '1994', '),', ' compared', ' her', ' to', ' Sophia', ' Loren']
+1048 224 Name of mother of x -1 Name of mother of Bhumibol Adulyadej Srinagarindra, Princess Mother Bhumibol Adulyadej "[',' ' the' ' King' ' of' ' Thailand' '.' '\n' '\n' 'The' ' King' ' of'
+ ' Thailand' ' is' ' the' ' longest' ' reigning' ' monarch' ' in' ' the'
+ ' world']" ", the King of Thailand .
+
+ The King of Thailand is the longest reigning monarch in the world" False " state after King Bhumibol Adulyadej of Thailand.
+" 11 [' state', ' after', ' King', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1049 224 Name of mother of x -1 Name of mother of Bhumibol Adulyadej Srinagarindra, Princess Mother Bhumibol Adulyadej "[',' ' the' ' King' ' of' ' Thailand' '.' '\n' '\n' 'The' ' King' ' of'
+ ' Thailand' ' is' ' the' ' longest' ' reigning' ' monarch' ' in' ' the'
+ ' world']" ", the King of Thailand .
+
+ The King of Thailand is the longest reigning monarch in the world" False present monarch, King Bhumibol Adulyadej (Rama IX), currently 12 [' present', ' monarch', ',', ' King', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1050 224 Name of mother of x -1 Name of mother of Bhumibol Adulyadej Srinagarindra, Princess Mother Bhumibol Adulyadej "[',' ' the' ' King' ' of' ' Thailand' '.' '\n' '\n' 'The' ' King' ' of'
+ ' Thailand' ' is' ' the' ' longest' ' reigning' ' monarch' ' in' ' the'
+ ' world']" ", the King of Thailand .
+
+ The King of Thailand is the longest reigning monarch in the world" False current king, Bhumibol Adulyadej is a first-cousin 11 [' current', ' king', ',', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1051 224 Name of mother of x -1 Name of mother of Bhumibol Adulyadej Srinagarindra, Princess Mother Bhumibol Adulyadej "[',' ' the' ' King' ' of' ' Thailand' '.' '\n' '\n' 'The' ' King' ' of'
+ ' Thailand' ' is' ' the' ' longest' ' reigning' ' monarch' ' in' ' the'
+ ' world']" ", the King of Thailand .
+
+ The King of Thailand is the longest reigning monarch in the world" False opium zone. King Bhumibol Adulyadej and other members 12 [' opium', ' zone', '.', ' King', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1052 224 Name of mother of x -1 Name of mother of Bhumibol Adulyadej Srinagarindra, Princess Mother Bhumibol Adulyadej "[',' ' the' ' King' ' of' ' Thailand' '.' '\n' '\n' 'The' ' King' ' of'
+ ' Thailand' ' is' ' the' ' longest' ' reigning' ' monarch' ' in' ' the'
+ ' world']" ", the King of Thailand .
+
+ The King of Thailand is the longest reigning monarch in the world" False " state after King Bhumibol Adulyadej of Thailand.
+" 11 [' state', ' after', ' King', ' Bh', 'um', 'ib', 'ol', ' A', 'du', 'ly', 'ade', 'j']
+1053 225 Name of mother of x -1 Name of mother of Hermann Hesse Marie Hesse Hermann Hesse "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " Eichendorff and Hermann Hesse (1943)
+" 9 [' E', 'iche', 'nd', 'or', 'ff', ' and', ' Herman', 'n', ' H', 'esse']
+1054 225 Name of mother of x -1 Name of mother of Hermann Hesse Marie Hesse Hermann Hesse "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False this species. Hermann Hesse mentioned this bird 6 [' this', ' species', '.', ' Herman', 'n', ' H', 'esse']
+1055 225 Name of mother of x -1 Name of mother of Hermann Hesse Marie Hesse Hermann Hesse "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False Gottfried Keller, Hermann Hesse and other writers have 7 [' Gott', 'fried', ' Keller', ',', ' Herman', 'n', ' H', 'esse']
+1056 225 Name of mother of x -1 Name of mother of Hermann Hesse Marie Hesse Hermann Hesse "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " Eichendorff and Hermann Hesse (1943)
+" 9 [' E', 'iche', 'nd', 'or', 'ff', ' and', ' Herman', 'n', ' H', 'esse']
+1057 225 Name of mother of x -1 Name of mother of Hermann Hesse Marie Hesse Hermann Hesse "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False this species. Hermann Hesse mentioned this bird 6 [' this', ' species', '.', ' Herman', 'n', ' H', 'esse']
+1058 227 Name of mother of x -1 Name of mother of Thomas More Agnes Graunger Thomas More "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False Defense Fund at the Thomas More Law Center On April 5 [' Defense', ' Fund', ' at', ' the', ' Thomas', ' More']
+1059 227 Name of mother of x -1 Name of mother of Thomas More Agnes Graunger Thomas More "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False Tyndale and Thomas More over the translation 5 [' Ty', 'nd', 'ale', ' and', ' Thomas', ' More']
+1060 227 Name of mother of x -1 Name of mother of Thomas More Agnes Graunger Thomas More "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False Catholic lawyer Thomas More in his struggle with 3 [' Catholic', ' lawyer', ' Thomas', ' More']
+1061 227 Name of mother of x -1 Name of mother of Thomas More Agnes Graunger Thomas More "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False Machiavelli and Thomas More revived the 6 [' Mach', 'ia', 've', 'lli', ' and', ' Thomas', ' More']
+1062 227 Name of mother of x -1 Name of mother of Thomas More Agnes Graunger Thomas More "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False philosophers Thomas More and Philip 2 [' philosophers', ' Thomas', ' More']
+1063 228 Name of mother of x -1 Name of mother of George Orwell Ida Mabel Limouzin George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' author' ' of'
+ ' Animal' ' Farm' ',' ' and' ' the' ' author' ' of' ' the' ' classic'
+ ' novel']" , the author of 1984 , and the author of Animal Farm , and the author of the classic novel False others. In 1942, George Orwell noted that 6 [' others', '.', ' In', ' 1942', ',', ' George', ' Orwell']
+1064 228 Name of mother of x -1 Name of mother of George Orwell Ida Mabel Limouzin George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' author' ' of'
+ ' Animal' ' Farm' ',' ' and' ' the' ' author' ' of' ' the' ' classic'
+ ' novel']" , the author of 1984 , and the author of Animal Farm , and the author of the classic novel False Lawrence and George Orwell both wrote 3 [' Lawrence', ' and', ' George', ' Orwell']
+1065 228 Name of mother of x -1 Name of mother of George Orwell Ida Mabel Limouzin George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' author' ' of'
+ ' Animal' ' Farm' ',' ' and' ' the' ' author' ' of' ' the' ' classic'
+ ' novel']" , the author of 1984 , and the author of Animal Farm , and the author of the classic novel False 1 ['George', ' Orwell']
+1066 228 Name of mother of x -1 Name of mother of George Orwell Ida Mabel Limouzin George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' author' ' of'
+ ' Animal' ' Farm' ',' ' and' ' the' ' author' ' of' ' the' ' classic'
+ ' novel']" , the author of 1984 , and the author of Animal Farm , and the author of the classic novel False " other thing""). George Orwell wrote his novel Nineteen" 4 "[' other', ' thing', '"").', ' George', ' Orwell']"
+1067 228 Name of mother of x -1 Name of mother of George Orwell Ida Mabel Limouzin George Orwell "[',' ' the' ' author' ' of' ' 1984' ',' ' and' ' the' ' author' ' of'
+ ' Animal' ' Farm' ',' ' and' ' the' ' author' ' of' ' the' ' classic'
+ ' novel']" , the author of 1984 , and the author of Animal Farm , and the author of the classic novel False Two months later George Orwell wrote the 4 [' Two', ' months', ' later', ' George', ' Orwell']
+1068 229 Name of mother of x -1 Name of mother of Jessica Simpson Tina Ann Drew Jessica Simpson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Wedding']" , the actress who played the role of the mother of the bride in the movie � � The Wedding False American singer Jessica Simpson was dating 3 [' American', ' singer', ' Jessica', ' Simpson']
+1069 229 Name of mother of x -1 Name of mother of Jessica Simpson Tina Ann Drew Jessica Simpson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Wedding']" , the actress who played the role of the mother of the bride in the movie � � The Wedding False Christina Aguilera, Jessica Simpson and Mandy Moore. 6 [' Christina', ' Agu', 'iler', 'a', ',', ' Jessica', ' Simpson']
+1070 229 Name of mother of x -1 Name of mother of Jessica Simpson Tina Ann Drew Jessica Simpson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Wedding']" , the actress who played the role of the mother of the bride in the movie � � The Wedding False recording artist Jessica Simpson that Sony Music 3 [' recording', ' artist', ' Jessica', ' Simpson']
+1071 229 Name of mother of x -1 Name of mother of Jessica Simpson Tina Ann Drew Jessica Simpson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Wedding']" , the actress who played the role of the mother of the bride in the movie � � The Wedding False People in the News: Jessica Simpson and Nick Lachey, 6 [' People', ' in', ' the', ' News', ':', ' Jessica', ' Simpson']
+1072 229 Name of mother of x -1 Name of mother of Jessica Simpson Tina Ann Drew Jessica Simpson "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Wedding']" , the actress who played the role of the mother of the bride in the movie � � The Wedding False when American singer Jessica Simpson was dating Dallas 4 [' when', ' American', ' singer', ' Jessica', ' Simpson']
+1073 230 Name of mother of x -1 Name of mother of Louis XIV of France Anne of Austria Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False Utrecht, King Louis XIV of France recognised the Hanoverian 8 [' Ut', 're', 'cht', ',', ' King', ' Louis', ' XIV', ' of', ' France']
+1074 230 Name of mother of x -1 Name of mother of Louis XIV of France Anne of Austria Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " Normandy"" by King Louis XIV of France on 31 December" 7 "[' Normandy', '""', ' by', ' King', ' Louis', ' XIV', ' of', ' France']"
+1075 230 Name of mother of x -1 Name of mother of Louis XIV of France Anne of Austria Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " provided by King Louis XIV of France (""The Sun King"")." 6 [' provided', ' by', ' King', ' Louis', ' XIV', ' of', ' France']
+1076 230 Name of mother of x -1 Name of mother of Louis XIV of France Anne of Austria Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False 74). This time Louis XIV of France was a key English 7 [' 74', ').', ' This', ' time', ' Louis', ' XIV', ' of', ' France']
+1077 230 Name of mother of x -1 Name of mother of Louis XIV of France Anne of Austria Louis XIV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False father's death by Louis XIV of France and James's remaining 7 "[' father', ""'s"", ' death', ' by', ' Louis', ' XIV', ' of', ' France']"
+1078 231 Name of mother of x -1 Name of mother of Confucius Yan Zhengzai Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' great' ' sage' ' of'
+ ' the' ' East' ',' '\n' '\n' 'The' ' great' ' sage' ' of' ' the']" ", the great sage , and the great sage of the East ,
+
+ The great sage of the" False attributed to Confucius, possibly based 4 [' attributed', ' to', ' Conf', 'u', 'cius']
+1079 231 Name of mother of x -1 Name of mother of Confucius Yan Zhengzai Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' great' ' sage' ' of'
+ ' the' ' East' ',' '\n' '\n' 'The' ' great' ' sage' ' of' ' the']" ", the great sage , and the great sage of the East ,
+
+ The great sage of the" False by Kongzi, or Confucius (551 – 479 BCE), 7 [' by', ' Kong', 'zi', ',', ' or', ' Conf', 'u', 'cius']
+1080 231 Name of mother of x -1 Name of mother of Confucius Yan Zhengzai Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' great' ' sage' ' of'
+ ' the' ' East' ',' '\n' '\n' 'The' ' great' ' sage' ' of' ' the']" ", the great sage , and the great sage of the East ,
+
+ The great sage of the" False " ""moral-ethical precepts of Confucius seem to have" 9 "[' ""', 'moral', '-', 'ethical', ' precept', 's', ' of', ' Conf', 'u', 'cius']"
+1081 231 Name of mother of x -1 Name of mother of Confucius Yan Zhengzai Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' great' ' sage' ' of'
+ ' the' ' East' ',' '\n' '\n' 'The' ' great' ' sage' ' of' ' the']" ", the great sage , and the great sage of the East ,
+
+ The great sage of the" False an image of Confucius or tablet with his 5 [' an', ' image', ' of', ' Conf', 'u', 'cius']
+1082 231 Name of mother of x -1 Name of mother of Confucius Yan Zhengzai Confucius "[',' ' the' ' great' ' sage' ',' ' and' ' the' ' great' ' sage' ' of'
+ ' the' ' East' ',' '\n' '\n' 'The' ' great' ' sage' ' of' ' the']" ", the great sage , and the great sage of the East ,
+
+ The great sage of the" False the tunnel under the Confucius Apartments 6 [' the', ' tunnel', ' under', ' the', ' Conf', 'u', 'cius']
+1083 232 Name of mother of x -1 Name of mother of Charles de Gaulle Jeanne Maillot Charles de Gaulle "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ',' ' who' ' are' ' the' ' most']" ", the French president , and the French people .
+
+ The French people , who are the most" False French General Charles de Gaulle or Vichy General 5 [' French', ' General', ' Charles', ' de', ' Gaul', 'le']
+1084 232 Name of mother of x -1 Name of mother of Charles de Gaulle Jeanne Maillot Charles de Gaulle "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ',' ' who' ' are' ' the' ' most']" ", the French president , and the French people .
+
+ The French people , who are the most" False interview with General Charles de Gaulle after the latter's 6 [' interview', ' with', ' General', ' Charles', ' de', ' Gaul', 'le']
+1085 232 Name of mother of x -1 Name of mother of Charles de Gaulle Jeanne Maillot Charles de Gaulle "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ',' ' who' ' are' ' the' ' most']" ", the French president , and the French people .
+
+ The French people , who are the most" False French president Charles de Gaulle criticized the UN, 5 [' French', ' president', ' Charles', ' de', ' Gaul', 'le']
+1086 232 Name of mother of x -1 Name of mother of Charles de Gaulle Jeanne Maillot Charles de Gaulle "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ',' ' who' ' are' ' the' ' most']" ", the French president , and the French people .
+
+ The French people , who are the most" False occasions, General Charles de Gaulle expressed his 6 [' occasions', ',', ' General', ' Charles', ' de', ' Gaul', 'le']
+1087 232 Name of mother of x -1 Name of mother of Charles de Gaulle Jeanne Maillot Charles de Gaulle "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ',' ' who' ' are' ' the' ' most']" ", the French president , and the French people .
+
+ The French people , who are the most" False between French President Charles de Gaulle and the Commission's 6 [' between', ' French', ' President', ' Charles', ' de', ' Gaul', 'le']
+1088 233 Name of mother of x -1 Name of mother of Leonard Bernstein Jennie Resnick Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ' Bernstein' ',' ' who' ' was' ' a' ' great' ' friend' ' of'
+ ' the']" , the famous composer , and his wife , Fel icia Bernstein , who was a great friend of the False " has been cited by Leonard Bernstein as ""the most" 5 [' has', ' been', ' cited', ' by', ' Leonard', ' Bernstein']
+1089 233 Name of mother of x -1 Name of mother of Leonard Bernstein Jennie Resnick Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ' Bernstein' ',' ' who' ' was' ' a' ' great' ' friend' ' of'
+ ' the']" , the famous composer , and his wife , Fel icia Bernstein , who was a great friend of the False 1960s through Leonard Bernstein and others. In 4 [' 1960', 's', ' through', ' Leonard', ' Bernstein']
+1090 233 Name of mother of x -1 Name of mother of Leonard Bernstein Jennie Resnick Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ' Bernstein' ',' ' who' ' was' ' a' ' great' ' friend' ' of'
+ ' the']" , the famous composer , and his wife , Fel icia Bernstein , who was a great friend of the False Symphony, congratulating Leonard Bernstein and the New York 5 [' Symphony', ',', ' congrat', 'ulating', ' Leonard', ' Bernstein']
+1091 233 Name of mother of x -1 Name of mother of Leonard Bernstein Jennie Resnick Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ' Bernstein' ',' ' who' ' was' ' a' ' great' ' friend' ' of'
+ ' the']" , the famous composer , and his wife , Fel icia Bernstein , who was a great friend of the False Commissioned by Leonard Bernstein in 1962 to commemorate 4 [' Commission', 'ed', ' by', ' Leonard', ' Bernstein']
+1092 233 Name of mother of x -1 Name of mother of Leonard Bernstein Jennie Resnick Leonard Bernstein "[',' ' the' ' famous' ' composer' ',' ' and' ' his' ' wife' ',' ' Fel'
+ 'icia' ' Bernstein' ',' ' who' ' was' ' a' ' great' ' friend' ' of'
+ ' the']" , the famous composer , and his wife , Fel icia Bernstein , who was a great friend of the False Symphony, congratulating Leonard Bernstein and the New York 5 [' Symphony', ',', ' congrat', 'ulating', ' Leonard', ' Bernstein']
+1093 234 Name of mother of x -1 Name of mother of Spock Amanda Grayson Spock "[',' ' the' ' son' ' of' ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a'
+ ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a' ' Vulcan' ',']" , the son of Vulcan , and the son of a Vulcan , and the son of a Vulcan , False Kirk, Scotty, and Spock when they transport 6 [' Kirk', ',', ' Scot', 'ty', ',', ' and', ' Spock']
+1094 234 Name of mother of x -1 Name of mother of Spock Amanda Grayson Spock "[',' ' the' ' son' ' of' ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a'
+ ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a' ' Vulcan' ',']" , the son of Vulcan , and the son of a Vulcan , and the son of a Vulcan , False Meanwhile, Captain Spock begins to act protectively 3 [' Meanwhile', ',', ' Captain', ' Spock']
+1095 234 Name of mother of x -1 Name of mother of Spock Amanda Grayson Spock "[',' ' the' ' son' ' of' ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a'
+ ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a' ' Vulcan' ',']" , the son of Vulcan , and the son of a Vulcan , and the son of a Vulcan , False version of Spock distinguished 2 [' version', ' of', ' Spock']
+1096 234 Name of mother of x -1 Name of mother of Spock Amanda Grayson Spock "[',' ' the' ' son' ' of' ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a'
+ ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a' ' Vulcan' ',']" , the son of Vulcan , and the son of a Vulcan , and the son of a Vulcan , False The Search for Spock was released, 3 [' The', ' Search', ' for', ' Spock']
+1097 234 Name of mother of x -1 Name of mother of Spock Amanda Grayson Spock "[',' ' the' ' son' ' of' ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a'
+ ' Vulcan' ',' ' and' ' the' ' son' ' of' ' a' ' Vulcan' ',']" , the son of Vulcan , and the son of a Vulcan , and the son of a Vulcan , False all disagreed that Spock and McCoy would 3 [' all', ' disagreed', ' that', ' Spock']
+1098 235 Name of mother of x -1 Name of mother of Leonhard Euler Marguerite Brucker Leonhard Euler "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '07'
+ ',' ' and' ' died' ' in' ' 17' '83' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 07 , and died in 17 83 .
+
+ The" False Observatory's 1.2 m Leonhard Euler Telescope on March 9 "[' Observatory', ""'s"", ' 1', '.', '2', ' m', ' Leon', 'hard', ' E', 'uler']"
+1099 235 Name of mother of x -1 Name of mother of Leonhard Euler Marguerite Brucker Leonhard Euler "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '07'
+ ',' ' and' ' died' ' in' ' 17' '83' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 07 , and died in 17 83 .
+
+ The" False dimensions, and Leonhard Euler was the first 6 [' dimensions', ',', ' and', ' Leon', 'hard', ' E', 'uler']
+1100 235 Name of mother of x -1 Name of mother of Leonhard Euler Marguerite Brucker Leonhard Euler "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '07'
+ ',' ' and' ' died' ' in' ' 17' '83' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 07 , and died in 17 83 .
+
+ The" False Observatory's 1.2 m Leonhard Euler Telescope on 9 "[' Observatory', ""'s"", ' 1', '.', '2', ' m', ' Leon', 'hard', ' E', 'uler']"
+1101 235 Name of mother of x -1 Name of mother of Leonhard Euler Marguerite Brucker Leonhard Euler "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '07'
+ ',' ' and' ' died' ' in' ' 17' '83' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 07 , and died in 17 83 .
+
+ The" False mid-18th century, Leonhard Euler wrote what he admitted 9 [' mid', '-', '18', 'th', ' century', ',', ' Leon', 'hard', ' E', 'uler']
+1102 235 Name of mother of x -1 Name of mother of Leonhard Euler Marguerite Brucker Leonhard Euler "[',' ' the' ' mathematician' ',' ' who' ' was' ' born' ' in' ' 17' '07'
+ ',' ' and' ' died' ' in' ' 17' '83' '.' '\n' '\n' 'The']" ", the mathematician , who was born in 17 07 , and died in 17 83 .
+
+ The" False spectrograph on the 1.2m Leonhard Euler Telescope at Chile's 12 [' spect', 'rog', 'raph', ' on', ' the', ' 1', '.', '2', 'm', ' Leon', 'hard', ' E', 'uler']
+1103 236 Name of mother of x -1 Name of mother of Heinrich Heine Betty Heine Heinrich Heine "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ' in' ' D' 'ü' 's'
+ 'se' 'ld' 'orf' ',' ' Germany' '.' ' He']" , the poet , was born in 18 44 in D ü s se ld orf , Germany . He False Jew-hatred '). Writing to Heinrich Heine in 1839, he offered 11 "[' Jew', '-', 'hat', 'red', "" '"", ').', ' Writing', ' to', ' Hein', 'rich', ' He', 'ine']"
+1104 236 Name of mother of x -1 Name of mother of Heinrich Heine Betty Heine Heinrich Heine "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ' in' ' D' 'ü' 's'
+ 'se' 'ld' 'orf' ',' ' Germany' '.' ' He']" , the poet , was born in 18 44 in D ü s se ld orf , Germany . He False Jew-hatred '). Writing to Heinrich Heine in 1839, he offered 11 "[' Jew', '-', 'hat', 'red', "" '"", ').', ' Writing', ' to', ' Hein', 'rich', ' He', 'ine']"
+1105 236 Name of mother of x -1 Name of mother of Heinrich Heine Betty Heine Heinrich Heine "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ' in' ' D' 'ü' 's'
+ 'se' 'ld' 'orf' ',' ' Germany' '.' ' He']" , the poet , was born in 18 44 in D ü s se ld orf , Germany . He False Jew-hatred '). Writing to Heinrich Heine in 1839, he offered 11 "[' Jew', '-', 'hat', 'red', "" '"", ').', ' Writing', ' to', ' Hein', 'rich', ' He', 'ine']"
+1106 236 Name of mother of x -1 Name of mother of Heinrich Heine Betty Heine Heinrich Heine "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ' in' ' D' 'ü' 's'
+ 'se' 'ld' 'orf' ',' ' Germany' '.' ' He']" , the poet , was born in 18 44 in D ü s se ld orf , Germany . He False Gambara) and Heinrich Heine (in his poem Angélique) 7 [' Gamb', 'ara', ')', ' and', ' Hein', 'rich', ' He', 'ine']
+1107 236 Name of mother of x -1 Name of mother of Heinrich Heine Betty Heine Heinrich Heine "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '44' ' in' ' D' 'ü' 's'
+ 'se' 'ld' 'orf' ',' ' Germany' '.' ' He']" , the poet , was born in 18 44 in D ü s se ld orf , Germany . He False Europe pianists who, as Heinrich Heine wrote, invaded 9 [' Europe', ' pian', 'ists', ' who', ',', ' as', ' Hein', 'rich', ' He', 'ine']
+1108 237 Name of mother of x -1 Name of mother of Vladimir Nabokov Yelena Rukavishnikova Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' the' ' author' ' of'
+ ' the' ' novel' ' that' ' made' ' him' ' famous' ',' ' was' ' born']" , the author of Lol ita , and the author of the novel that made him famous , was born False connection also made by Vladimir Nabokov in his Lectures 7 [' connection', ' also', ' made', ' by', ' Vladimir', ' Nab', 'ok', 'ov']
+1109 237 Name of mother of x -1 Name of mother of Vladimir Nabokov Yelena Rukavishnikova Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' the' ' author' ' of'
+ ' the' ' novel' ' that' ' made' ' him' ' famous' ',' ' was' ' born']" , the author of Lol ita , and the author of the novel that made him famous , was born False also made by Vladimir Nabokov in his Lectures 6 [' also', ' made', ' by', ' Vladimir', ' Nab', 'ok', 'ov']
+1110 237 Name of mother of x -1 Name of mother of Vladimir Nabokov Yelena Rukavishnikova Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' the' ' author' ' of'
+ ' the' ' novel' ' that' ' made' ' him' ' famous' ',' ' was' ' born']" , the author of Lol ita , and the author of the novel that made him famous , was born False Russian-American novelist Vladimir Nabokov being accused 7 [' Russian', '-', 'American', ' novelist', ' Vladimir', ' Nab', 'ok', 'ov']
+1111 237 Name of mother of x -1 Name of mother of Vladimir Nabokov Yelena Rukavishnikova Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' the' ' author' ' of'
+ ' the' ' novel' ' that' ' made' ' him' ' famous' ',' ' was' ' born']" , the author of Lol ita , and the author of the novel that made him famous , was born False Russian-American novelist Vladimir Nabokov being accused 7 [' Russian', '-', 'American', ' novelist', ' Vladimir', ' Nab', 'ok', 'ov']
+1112 237 Name of mother of x -1 Name of mother of Vladimir Nabokov Yelena Rukavishnikova Vladimir Nabokov "[',' ' the' ' author' ' of' ' Lol' 'ita' ',' ' and' ' the' ' author' ' of'
+ ' the' ' novel' ' that' ' made' ' him' ' famous' ',' ' was' ' born']" , the author of Lol ita , and the author of the novel that made him famous , was born False Jonathan Swift. Vladimir Nabokov was less enthusiastic 6 [' Jonathan', ' Swift', '.', ' Vladimir', ' Nab', 'ok', 'ov']
+1113 238 Name of mother of x -1 Name of mother of Georg Wilhelm Friedrich Hegel Maria Magdalena Louisa Fromm Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '70' ','
+ ' and' ' died' ' in' ' 18' '31' '.' '\n' '\n' 'The']" ", the philosopher , who was born in 17 70 , and died in 18 31 .
+
+ The" False Gottlieb Fichte, Georg Wilhelm Friedrich Hegel and Friedrich 10 [' Gott', 'lie', 'b', ' F', 'ich', 'te', ',', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1114 238 Name of mother of x -1 Name of mother of Georg Wilhelm Friedrich Hegel Maria Magdalena Louisa Fromm Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '70' ','
+ ' and' ' died' ' in' ' 18' '31' '.' '\n' '\n' 'The']" ", the philosopher , who was born in 17 70 , and died in 18 31 .
+
+ The" False influence exercised by Georg Wilhelm Friedrich Hegel on Romanian thought, 6 [' influence', ' exercised', ' by', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1115 238 Name of mother of x -1 Name of mother of Georg Wilhelm Friedrich Hegel Maria Magdalena Louisa Fromm Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '70' ','
+ ' and' ' died' ' in' ' 18' '31' '.' '\n' '\n' 'The']" ", the philosopher , who was born in 17 70 , and died in 18 31 .
+
+ The" False Gottlieb Fichte, Georg Wilhelm Friedrich Hegel and Friedrich 10 [' Gott', 'lie', 'b', ' F', 'ich', 'te', ',', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1116 238 Name of mother of x -1 Name of mother of Georg Wilhelm Friedrich Hegel Maria Magdalena Louisa Fromm Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '70' ','
+ ' and' ' died' ' in' ' 18' '31' '.' '\n' '\n' 'The']" ", the philosopher , who was born in 17 70 , and died in 18 31 .
+
+ The" False Gottlieb Fichte, Georg Wilhelm Friedrich Hegel and Friedrich Wilhelm 10 [' Gott', 'lie', 'b', ' F', 'ich', 'te', ',', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1117 238 Name of mother of x -1 Name of mother of Georg Wilhelm Friedrich Hegel Maria Magdalena Louisa Fromm Georg Wilhelm Friedrich Hegel "[',' ' the' ' philosopher' ',' ' who' ' was' ' born' ' in' ' 17' '70' ','
+ ' and' ' died' ' in' ' 18' '31' '.' '\n' '\n' 'The']" ", the philosopher , who was born in 17 70 , and died in 18 31 .
+
+ The" False Gottlieb Fichte, Georg Wilhelm Friedrich Hegel and Friedrich Wilhelm 10 [' Gott', 'lie', 'b', ' F', 'ich', 'te', ',', ' Georg', ' Wilhelm', ' Friedrich', ' Hegel']
+1118 239 Name of mother of x -1 Name of mother of James McNeill Whistler Anna McNeill Whistler James McNeill Whistler "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' mother' ',' ' and' ' his' ' sister' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's mother , and his sister , the painter 's False Oscar Wilde, James McNeill Whistler and others in the 8 [' Oscar', ' Wilde', ',', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1119 239 Name of mother of x -1 Name of mother of James McNeill Whistler Anna McNeill Whistler James McNeill Whistler "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' mother' ',' ' and' ' his' ' sister' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's mother , and his sister , the painter 's False Grossmith, Matthew Arnold, James McNeill Whistler and Oscar Wilde. 11 [' Gross', 'mith', ',', ' Matthew', ' Arnold', ',', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1120 239 Name of mother of x -1 Name of mother of James McNeill Whistler Anna McNeill Whistler James McNeill Whistler "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' mother' ',' ' and' ' his' ' sister' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's mother , and his sister , the painter 's False bridge. A statue of James McNeill Whistler by Nicholas 10 [' bridge', '.', ' A', ' statue', ' of', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1121 239 Name of mother of x -1 Name of mother of James McNeill Whistler Anna McNeill Whistler James McNeill Whistler "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' mother' ',' ' and' ' his' ' sister' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's mother , and his sister , the painter 's False Rossetti, Oscar Wilde, James McNeill Whistler and others 11 [' Ross', 'etti', ',', ' Oscar', ' Wilde', ',', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1122 239 Name of mother of x -1 Name of mother of James McNeill Whistler Anna McNeill Whistler James McNeill Whistler "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' mother' ',' ' and' ' his' ' sister' ',' ' the' ' painter' ""'s""]" , the painter , and his wife , the painter 's mother , and his sister , the painter 's False Grossmith, Matthew Arnold, James McNeill Whistler and Oscar Wilde. Hesketh 11 [' Gross', 'mith', ',', ' Matthew', ' Arnold', ',', ' James', ' Mc', 'Neill', ' Wh', 'ist', 'ler']
+1123 240 Name of mother of x -1 Name of mother of Prince Mattie Shaw Prince "[' William' ' and' ' Prince' ' William' ',' ' and' ' the' ' other' ' two'
+ ' are' ' the' ' sons' ' of' ' the' ' King' ' of' ' England' '.' '\n' '\n']" " William and Prince William , and the other two are the sons of the King of England .
+
+" False gave birth to a son, Prince Edward, the 6 [' gave', ' birth', ' to', ' a', ' son', ',', ' Prince']
+1124 240 Name of mother of x -1 Name of mother of Prince Mattie Shaw Prince "[' William' ' and' ' Prince' ' William' ',' ' and' ' the' ' other' ' two'
+ ' are' ' the' ' sons' ' of' ' the' ' King' ' of' ' England' '.' '\n' '\n']" " William and Prince William , and the other two are the sons of the King of England .
+
+" False 0 ['Prince']
+1125 240 Name of mother of x -1 Name of mother of Prince Mattie Shaw Prince "[' William' ' and' ' Prince' ' William' ',' ' and' ' the' ' other' ' two'
+ ' are' ' the' ' sons' ' of' ' the' ' King' ' of' ' England' '.' '\n' '\n']" " William and Prince William , and the other two are the sons of the King of England .
+
+" False 0 ['Prince']
+1126 240 Name of mother of x -1 Name of mother of Prince Mattie Shaw Prince "[' William' ' and' ' Prince' ' William' ',' ' and' ' the' ' other' ' two'
+ ' are' ' the' ' sons' ' of' ' the' ' King' ' of' ' England' '.' '\n' '\n']" " William and Prince William , and the other two are the sons of the King of England .
+
+" False eventually renamed Prince of Wales. Unfortunately, 2 [' eventually', ' renamed', ' Prince']
+1127 240 Name of mother of x -1 Name of mother of Prince Mattie Shaw Prince "[' William' ' and' ' Prince' ' William' ',' ' and' ' the' ' other' ' two'
+ ' are' ' the' ' sons' ' of' ' the' ' King' ' of' ' England' '.' '\n' '\n']" " William and Prince William , and the other two are the sons of the King of England .
+
+" False " (anglicized as Prince Escalus).
+" 5 [' (', 'ang', 'lic', 'ized', ' as', ' Prince']
+1128 241 Name of mother of x -1 Name of mother of Theodore Roosevelt Martha Bulloch Roosevelt Theodore Roosevelt "[',' ' the' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' '\n' '\n'
+ 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' (' '18']" ", the
+
+ The odore Roosevelt , Jr .
+
+ The odore Roosevelt , Jr . ( 18" False for President Theodore Roosevelt to celebrate the 3 [' for', ' President', ' Theodore', ' Roosevelt']
+1129 241 Name of mother of x -1 Name of mother of Theodore Roosevelt Martha Bulloch Roosevelt Theodore Roosevelt "[',' ' the' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' '\n' '\n'
+ 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' (' '18']" ", the
+
+ The odore Roosevelt , Jr .
+
+ The odore Roosevelt , Jr . ( 18" False review for President Theodore Roosevelt in Oyster Bay in 4 [' review', ' for', ' President', ' Theodore', ' Roosevelt']
+1130 241 Name of mother of x -1 Name of mother of Theodore Roosevelt Martha Bulloch Roosevelt Theodore Roosevelt "[',' ' the' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' '\n' '\n'
+ 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' (' '18']" ", the
+
+ The odore Roosevelt , Jr .
+
+ The odore Roosevelt , Jr . ( 18" False students included Theodore Roosevelt and other 3 [' students', ' included', ' Theodore', ' Roosevelt']
+1131 241 Name of mother of x -1 Name of mother of Theodore Roosevelt Martha Bulloch Roosevelt Theodore Roosevelt "[',' ' the' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' '\n' '\n'
+ 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' (' '18']" ", the
+
+ The odore Roosevelt , Jr .
+
+ The odore Roosevelt , Jr . ( 18" False 2 ['The', 'odore', ' Roosevelt']
+1132 241 Name of mother of x -1 Name of mother of Theodore Roosevelt Martha Bulloch Roosevelt Theodore Roosevelt "[',' ' the' '\n' '\n' 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' '\n' '\n'
+ 'The' 'odore' ' Roosevelt' ',' ' Jr' '.' ' (' '18']" ", the
+
+ The odore Roosevelt , Jr .
+
+ The odore Roosevelt , Jr . ( 18" False 2 ['The', 'odore', ' Roosevelt']
+1133 242 Name of mother of x -1 Name of mother of Charles Baudelaire Caroline Aupick Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False d'Aurevilly, Charles Baudelaire and José María de 11 "[' d', ""'"", 'A', 'ure', 'v', 'illy', ',', ' Charles', ' B', 'aud', 'el', 'aire']"
+1134 242 Name of mother of x -1 Name of mother of Charles Baudelaire Caroline Aupick Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False " inaugurated by Charles Baudelaire with the ""destruction" 7 [' inaug', 'urated', ' by', ' Charles', ' B', 'aud', 'el', 'aire']
+1135 242 Name of mother of x -1 Name of mother of Charles Baudelaire Caroline Aupick Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False " experiments inaugurated by Charles Baudelaire with the ""destruction" 8 [' experiments', ' inaug', 'urated', ' by', ' Charles', ' B', 'aud', 'el', 'aire']
+1136 242 Name of mother of x -1 Name of mother of Charles Baudelaire Caroline Aupick Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False Symboliste poet Charles Baudelaire marks the 7 [' Symbol', 'iste', ' poet', ' Charles', ' B', 'aud', 'el', 'aire']
+1137 242 Name of mother of x -1 Name of mother of Charles Baudelaire Caroline Aupick Charles Baudelaire "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ' of' ' the' '\n' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the poet , and the poet of the
+ " False composers, such as Charles Baudelaire and Maurice Ravel, 9 [' compos', 'ers', ',', ' such', ' as', ' Charles', ' B', 'aud', 'el', 'aire']
+1138 244 Name of mother of x -1 Name of mother of Shirley MacLaine Kathlyn Corinne Maclean Shirley MacLaine "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False considerable. Shirley MacLaine said that McCracken 5 [' considerable', '.', ' Shirley', ' Mac', 'L', 'aine']
+1139 244 Name of mother of x -1 Name of mother of Shirley MacLaine Kathlyn Corinne Maclean Shirley MacLaine "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False starred with Shirley MacLaine in the western 5 [' starred', ' with', ' Shirley', ' Mac', 'L', 'aine']
+1140 244 Name of mother of x -1 Name of mother of Shirley MacLaine Kathlyn Corinne Maclean Shirley MacLaine "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False considerable. Shirley MacLaine said that McCracken 5 [' considerable', '.', ' Shirley', ' Mac', 'L', 'aine']
+1141 244 Name of mother of x -1 Name of mother of Shirley MacLaine Kathlyn Corinne Maclean Shirley MacLaine "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Despite speculation, Shirley MacLaine did not attend and 6 [' Despite', ' speculation', ',', ' Shirley', ' Mac', 'L', 'aine']
+1142 244 Name of mother of x -1 Name of mother of Shirley MacLaine Kathlyn Corinne Maclean Shirley MacLaine "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Broadway parts. Shirley MacLaine described her 6 [' Broadway', ' parts', '.', ' Shirley', ' Mac', 'L', 'aine']
+1143 245 Name of mother of x -1 Name of mother of Leonid Brezhnev Natalia Denisovna Brezhneva Leonid Brezhnev "[',' ' the' ' Soviet' ' leader' ',' ' and' ' his' ' wife' ',' ' Gal' 'ina'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Soviet leader , and his wife , Gal ina , who was a former actress .
+
+" False residence for Leonid Brezhnev and Volodymyr Shcherbytsky, 7 [' residence', ' for', ' Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1144 245 Name of mother of x -1 Name of mother of Leonid Brezhnev Natalia Denisovna Brezhneva Leonid Brezhnev "[',' ' the' ' Soviet' ' leader' ',' ' and' ' his' ' wife' ',' ' Gal' 'ina'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Soviet leader , and his wife , Gal ina , who was a former actress .
+
+" False 5 ['Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1145 245 Name of mother of x -1 Name of mother of Leonid Brezhnev Natalia Denisovna Brezhneva Leonid Brezhnev "[',' ' the' ' Soviet' ' leader' ',' ' and' ' his' ' wife' ',' ' Gal' 'ina'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Soviet leader , and his wife , Gal ina , who was a former actress .
+
+" False worldwide, Soviet Premier Leonid Brezhnev ordered the destruction 9 [' worldwide', ',', ' Soviet', ' Premier', ' Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1146 245 Name of mother of x -1 Name of mother of Leonid Brezhnev Natalia Denisovna Brezhneva Leonid Brezhnev "[',' ' the' ' Soviet' ' leader' ',' ' and' ' his' ' wife' ',' ' Gal' 'ina'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Soviet leader , and his wife , Gal ina , who was a former actress .
+
+" False incident with Leonid Brezhnev in which he said, 7 [' incident', ' with', ' Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1147 245 Name of mother of x -1 Name of mother of Leonid Brezhnev Natalia Denisovna Brezhneva Leonid Brezhnev "[',' ' the' ' Soviet' ' leader' ',' ' and' ' his' ' wife' ',' ' Gal' 'ina'
+ ',' ' who' ' was' ' a' ' former' ' actress' '.' '\n' '\n']" ", the Soviet leader , and his wife , Gal ina , who was a former actress .
+
+" False past relations with Leonid Brezhnev were now seriously 8 [' past', ' relations', ' with', ' Leon', 'id', ' Bre', 'zh', 'ne', 'v']
+1148 246 Name of mother of x -1 Name of mother of Bill Clinton Virginia Clinton Kelley Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False with President Bill Clinton at the White 3 [' with', ' President', ' Bill', ' Clinton']
+1149 246 Name of mother of x -1 Name of mother of Bill Clinton Virginia Clinton Kelley Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False speech by President Bill Clinton and a fly-over by 4 [' speech', ' by', ' President', ' Bill', ' Clinton']
+1150 246 Name of mother of x -1 Name of mother of Bill Clinton Virginia Clinton Kelley Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False which US President Bill Clinton was discovered to 4 [' which', ' US', ' President', ' Bill', ' Clinton']
+1151 246 Name of mother of x -1 Name of mother of Bill Clinton Virginia Clinton Kelley Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False allegations against Bill Clinton in the Whitewater 3 [' allegations', ' against', ' Bill', ' Clinton']
+1152 246 Name of mother of x -1 Name of mother of Bill Clinton Virginia Clinton Kelley Bill Clinton "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False United States Bill Clinton is a 1968 3 [' United', ' States', ' Bill', ' Clinton']
+1153 247 Name of mother of x -1 Name of mother of André Gide Juliette Gide André Gide "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' his' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of' ' his']" , the French writer , who was a friend of his , and who had been a friend of his False the likes of André Gide and Albert Londres 6 [' the', ' likes', ' of', ' And', 'ré', ' G', 'ide']
+1154 247 Name of mother of x -1 Name of mother of André Gide Juliette Gide André Gide "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' his' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of' ' his']" , the French writer , who was a friend of his , and who had been a friend of his False literature by the likes of André Gide and Albert Londres 8 [' literature', ' by', ' the', ' likes', ' of', ' And', 'ré', ' G', 'ide']
+1155 248 Name of mother of x -1 Name of mother of H. G. Wells Sarah Neal H. G. Wells "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' H' '.' ' G' '.' ' Wells']" ", the author of the book , and the
+
+ Name of mother of H . G . Wells" False in fictional works. H. G. Wells references Disney in 8 [' in', ' fictional', ' works', '.', ' H', '.', ' G', '.', ' Wells']
+1156 248 Name of mother of x -1 Name of mother of H. G. Wells Sarah Neal H. G. Wells "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' H' '.' ' G' '.' ' Wells']" ", the author of the book , and the
+
+ Name of mother of H . G . Wells" False included the author H. G. Wells and the philosopher 7 [' included', ' the', ' author', ' H', '.', ' G', '.', ' Wells']
+1157 248 Name of mother of x -1 Name of mother of H. G. Wells Sarah Neal H. G. Wells "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' H' '.' ' G' '.' ' Wells']" ", the author of the book , and the
+
+ Name of mother of H . G . Wells" False " ""Poe's greatest work"". H. G. Wells noted that ""Pym" 11 "[' ""', 'P', 'oe', ""'s"", ' greatest', ' work', '"".', ' H', '.', ' G', '.', ' Wells']"
+1158 248 Name of mother of x -1 Name of mother of H. G. Wells Sarah Neal H. G. Wells "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' H' '.' ' G' '.' ' Wells']" ", the author of the book , and the
+
+ Name of mother of H . G . Wells" False " the movement. H. G. Wells remarked upon ""the" 7 [' the', ' movement', '.', ' H', '.', ' G', '.', ' Wells']
+1159 248 Name of mother of x -1 Name of mother of H. G. Wells Sarah Neal H. G. Wells "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' H' '.' ' G' '.' ' Wells']" ", the author of the book , and the
+
+ Name of mother of H . G . Wells" False Home Office by H. G. Wells and George Bernard 7 [' Home', ' Office', ' by', ' H', '.', ' G', '.', ' Wells']
+1160 249 Name of mother of x -1 Name of mother of Giorgio Vasari Maddalena Tacci Giorgio Vasari "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '11' ',' ' and'
+ ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The' ' first']" ", the painter , was born in 15 11 , and died in 15 74 .
+
+ The first" False language used by Giorgio Vasari in his work 7 [' language', ' used', ' by', ' G', 'ior', 'gio', ' Vas', 'ari']
+1161 249 Name of mother of x -1 Name of mother of Giorgio Vasari Maddalena Tacci Giorgio Vasari "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '11' ',' ' and'
+ ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The' ' first']" ", the painter , was born in 15 11 , and died in 15 74 .
+
+ The first" False 16th-century art historian Giorgio Vasari claimed van 10 [' 16', 'th', '-', 'century', ' art', ' historian', ' G', 'ior', 'gio', ' Vas', 'ari']
+1162 249 Name of mother of x -1 Name of mother of Giorgio Vasari Maddalena Tacci Giorgio Vasari "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '11' ',' ' and'
+ ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The' ' first']" ", the painter , was born in 15 11 , and died in 15 74 .
+
+ The first" False 16th-century art historian Giorgio Vasari proposed that 10 [' 16', 'th', '-', 'century', ' art', ' historian', ' G', 'ior', 'gio', ' Vas', 'ari']
+1163 249 Name of mother of x -1 Name of mother of Giorgio Vasari Maddalena Tacci Giorgio Vasari "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '11' ',' ' and'
+ ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The' ' first']" ", the painter , was born in 15 11 , and died in 15 74 .
+
+ The first" False artist and critic Giorgio Vasari (1511 – 1574) first 7 [' artist', ' and', ' critic', ' G', 'ior', 'gio', ' Vas', 'ari']
+1164 249 Name of mother of x -1 Name of mother of Giorgio Vasari Maddalena Tacci Giorgio Vasari "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '11' ',' ' and'
+ ' died' ' in' ' 15' '74' '.' '\n' '\n' 'The' ' first']" ", the painter , was born in 15 11 , and died in 15 74 .
+
+ The first" False the language used by Giorgio Vasari in his work Lives 8 [' the', ' language', ' used', ' by', ' G', 'ior', 'gio', ' Vas', 'ari']
+1165 250 Name of mother of x -1 Name of mother of Charlemagne Bertrada of Laon Charlemagne "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False Dardanelles, although Charlemagne did not participate 9 [' D', 'ard', 'an', 'ell', 'es', ',', ' although', ' Char', 'lem', 'agne']
+1166 250 Name of mother of x -1 Name of mother of Charlemagne Bertrada of Laon Charlemagne "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False in Greece. Charlemagne was relieved for 5 [' in', ' Greece', '.', ' Char', 'lem', 'agne']
+1167 250 Name of mother of x -1 Name of mother of Charlemagne Bertrada of Laon Charlemagne "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False 2 ['Char', 'lem', 'agne']
+1168 250 Name of mother of x -1 Name of mother of Charlemagne Bertrada of Laon Charlemagne "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False coronation of Charlemagne (reigned as 5 [' coron', 'ation', ' of', ' Char', 'lem', 'agne']
+1169 250 Name of mother of x -1 Name of mother of Charlemagne Bertrada of Laon Charlemagne "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False 2 ['Char', 'lem', 'agne']
+1170 251 Name of mother of x -1 Name of mother of William Butler Yeats Susan Pollexfen William Butler Yeats "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ',' ' the' ' poet' ',']" , the poet , and his wife , the poet ess , and the poet ess , the poet , False William Blake, William Butler Yeats and Henry Vaughan, 6 [' William', ' Blake', ',', ' William', ' Butler', ' Ye', 'ats']
+1171 251 Name of mother of x -1 Name of mother of William Butler Yeats Susan Pollexfen William Butler Yeats "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ',' ' the' ' poet' ',']" , the poet , and his wife , the poet ess , and the poet ess , the poet , False service. He paraphrased William Butler Yeats by saying of his 9 [' service', '.', ' He', ' paraph', 'r', 'ased', ' William', ' Butler', ' Ye', 'ats']
+1172 251 Name of mother of x -1 Name of mother of William Butler Yeats Susan Pollexfen William Butler Yeats "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ',' ' the' ' poet' ',']" , the poet , and his wife , the poet ess , and the poet ess , the poet , False " recitation of the William Butler Yeats poem ""Mother" 7 [' rec', 'itation', ' of', ' the', ' William', ' Butler', ' Ye', 'ats']
+1173 251 Name of mother of x -1 Name of mother of William Butler Yeats Susan Pollexfen William Butler Yeats "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ',' ' the' ' poet' ',']" , the poet , and his wife , the poet ess , and the poet ess , the poet , False began to die: in 1939 William Butler Yeats and Ford Madox 9 [' began', ' to', ' die', ':', ' in', ' 1939', ' William', ' Butler', ' Ye', 'ats']
+1174 251 Name of mother of x -1 Name of mother of William Butler Yeats Susan Pollexfen William Butler Yeats "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' the' ' poet' 'ess' ','
+ ' and' ' the' ' poet' 'ess' ',' ' the' ' poet' ',']" , the poet , and his wife , the poet ess , and the poet ess , the poet , False Rudyard Kipling, William Butler Yeats and others further 8 [' Rud', 'yard', ' Ki', 'pling', ',', ' William', ' Butler', ' Ye', 'ats']
+1175 253 Name of mother of x -1 Name of mother of Franklin Delano Roosevelt Sara Roosevelt Franklin Delano Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False his preference for Franklin Delano Roosevelt as his running mate. 6 [' his', ' preference', ' for', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1176 253 Name of mother of x -1 Name of mother of Franklin Delano Roosevelt Sara Roosevelt Franklin Delano Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False the United States Franklin Delano Roosevelt and British Prime 6 [' the', ' United', ' States', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1177 253 Name of mother of x -1 Name of mother of Franklin Delano Roosevelt Sara Roosevelt Franklin Delano Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False 1934, President Franklin Delano Roosevelt began endorsing 6 [' 1934', ',', ' President', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1178 253 Name of mother of x -1 Name of mother of Franklin Delano Roosevelt Sara Roosevelt Franklin Delano Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False father's fifth cousin, Franklin Delano Roosevelt (1882 – 1945), 8 "[' father', ""'s"", ' fifth', ' cousin', ',', ' Franklin', ' Del', 'ano', ' Roosevelt']"
+1179 253 Name of mother of x -1 Name of mother of Franklin Delano Roosevelt Sara Roosevelt Franklin Delano Roosevelt "[',' ' the' ' first' ' President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' to' ' be' ' elected' ' to' ' the' ' presidency'
+ ' without' ' a']" , the first President of the United States , and the first to be elected to the presidency without a False Newport, President Franklin Delano Roosevelt reviewed the 6 [' Newport', ',', ' President', ' Franklin', ' Del', 'ano', ' Roosevelt']
+1180 254 Name of mother of x -1 Name of mother of Poseidon Rhea Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False the Boeing P-8A Poseidon maritime patrol 7 [' the', ' Boeing', ' P', '-', '8', 'A', ' Pose', 'idon']
+1181 254 Name of mother of x -1 Name of mother of Poseidon Rhea Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False the sea-god Poseidon and Euryale, 5 [' the', ' sea', '-', 'god', ' Pose', 'idon']
+1182 254 Name of mother of x -1 Name of mother of Poseidon Rhea Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False stepfather Asterion. Poseidon sent up from the 6 [' step', 'father', ' Aster', 'ion', '.', ' Pose', 'idon']
+1183 254 Name of mother of x -1 Name of mother of Poseidon Rhea Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False planet is called Poseidon (Ποσειδώνας, Poseidonas), 4 [' planet', ' is', ' called', ' Pose', 'idon']
+1184 254 Name of mother of x -1 Name of mother of Poseidon Rhea Poseidon "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False submarine Poseidon had been sunk there 2 [' submarine', ' Pose', 'idon']
+1185 255 Name of mother of x -1 Name of mother of Emily Brontë Maria Branwell Emily Brontë "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' Emily' ' Br' 'ont' 'ë' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' Emily' ' Br']" ", and the
+
+ Name of Emily Br ont ë , and the
+
+ Name of Emily Br" False juvenilia of Emily Brontë at eight o 'clock 6 [' juven', 'ilia', ' of', ' Emily', ' Br', 'ont', 'ë']
+1186 255 Name of mother of x -1 Name of mother of Emily Brontë Maria Branwell Emily Brontë "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' Emily' ' Br' 'ont' 'ë' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' Emily' ' Br']" ", and the
+
+ Name of Emily Br ont ë , and the
+
+ Name of Emily Br" False " Mine"", a poem by Emily Brontë that had been a" 8 "[' Mine', '"",', ' a', ' poem', ' by', ' Emily', ' Br', 'ont', 'ë']"
+1187 256 Name of mother of x -1 Name of mother of Audrey Hepburn Ella van Heemstra Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False romantic role opposite Audrey Hepburn in Charade. Grant 5 [' romantic', ' role', ' opposite', ' Audrey', ' Hep', 'burn']
+1188 256 Name of mother of x -1 Name of mother of Audrey Hepburn Ella van Heemstra Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False non-singing star, Audrey Hepburn, to play the leading 8 [' non', '-', 'sing', 'ing', ' star', ',', ' Audrey', ' Hep', 'burn']
+1189 256 Name of mother of x -1 Name of mother of Audrey Hepburn Ella van Heemstra Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False supermodel Twiggy, actress Audrey Hepburn and others, demarcating 9 [' super', 'model', ' Tw', 'ig', 'gy', ',', ' actress', ' Audrey', ' Hep', 'burn']
+1190 256 Name of mother of x -1 Name of mother of Audrey Hepburn Ella van Heemstra Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False the Afternoon with Audrey Hepburn and Maurice 6 [' the', ' After', 'noon', ' with', ' Audrey', ' Hep', 'burn']
+1191 256 Name of mother of x -1 Name of mother of Audrey Hepburn Ella van Heemstra Audrey Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Children's Hour with Audrey Hepburn and Shirley 6 "[' Children', ""'s"", ' Hour', ' with', ' Audrey', ' Hep', 'burn']"
+1192 257 Name of mother of x -1 Name of mother of John Ruskin Margaret Cock Ruskin John Ruskin "[',' ' the' ' great' ' English' ' critic' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' John' ' Rus' 'kin' ',' ' the' ' great']" ", the great English critic , and the
+
+ Name of mother of John Rus kin , the great" False The critic John Ruskin compared the 4 [' The', ' critic', ' John', ' Rus', 'kin']
+1193 257 Name of mother of x -1 Name of mother of John Ruskin Margaret Cock Ruskin John Ruskin "[',' ' the' ' great' ' English' ' critic' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' John' ' Rus' 'kin' ',' ' the' ' great']" ", the great English critic , and the
+
+ Name of mother of John Rus kin , the great" False Temeraire's wood. John Ruskin foreshadowed 8 "[' Tem', 'er', 'aire', ""'s"", ' wood', '.', ' John', ' Rus', 'kin']"
+1194 257 Name of mother of x -1 Name of mother of John Ruskin Margaret Cock Ruskin John Ruskin "[',' ' the' ' great' ' English' ' critic' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' John' ' Rus' 'kin' ',' ' the' ' great']" ", the great English critic , and the
+
+ Name of mother of John Rus kin , the great" False from Temeraire's wood. John Ruskin foreshadowed 9 "[' from', ' Tem', 'er', 'aire', ""'s"", ' wood', '.', ' John', ' Rus', 'kin']"
+1195 257 Name of mother of x -1 Name of mother of John Ruskin Margaret Cock Ruskin John Ruskin "[',' ' the' ' great' ' English' ' critic' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' John' ' Rus' 'kin' ',' ' the' ' great']" ", the great English critic , and the
+
+ Name of mother of John Rus kin , the great" False by the art critic John Ruskin in The Stones of Venice, 6 [' by', ' the', ' art', ' critic', ' John', ' Rus', 'kin']
+1196 257 Name of mother of x -1 Name of mother of John Ruskin Margaret Cock Ruskin John Ruskin "[',' ' the' ' great' ' English' ' critic' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' John' ' Rus' 'kin' ',' ' the' ' great']" ", the great English critic , and the
+
+ Name of mother of John Rus kin , the great" False 2 ['John', ' Rus', 'kin']
+1197 259 Name of mother of x -1 Name of mother of Niccolò Machiavelli Bartolomea di Stefano Nelli Niccolò Machiavelli "[',' ' the' ' Flore' 'nt' 'ine' ' states' 'man' ' and' ' author' ' of'
+ ' the' ' _' 'Prin' 'ci' 'pe' '_' ',' ' and' ' the' ' _']" , the Flore nt ine states man and author of the _ Prin ci pe _ , and the _ False project on which Niccolò Machiavelli also worked. Leonardo's 10 [' project', ' on', ' which', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1198 259 Name of mother of x -1 Name of mother of Niccolò Machiavelli Bartolomea di Stefano Nelli Niccolò Machiavelli "[',' ' the' ' Flore' 'nt' 'ine' ' states' 'man' ' and' ' author' ' of'
+ ' the' ' _' 'Prin' 'ci' 'pe' '_' ',' ' and' ' the' ' _']" , the Flore nt ine states man and author of the _ Prin ci pe _ , and the _ False In The Art of War, Niccolò Machiavelli observed that 13 [' In', ' The', ' Art', ' of', ' War', ',', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1199 259 Name of mother of x -1 Name of mother of Niccolò Machiavelli Bartolomea di Stefano Nelli Niccolò Machiavelli "[',' ' the' ' Flore' 'nt' 'ine' ' states' 'man' ' and' ' author' ' of'
+ ' the' ' _' 'Prin' 'ci' 'pe' '_' ',' ' and' ' the' ' _']" , the Flore nt ine states man and author of the _ Prin ci pe _ , and the _ False " the Republic, Niccolò Machiavelli remarked
+" 10 [' the', ' Republic', ',', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1200 259 Name of mother of x -1 Name of mother of Niccolò Machiavelli Bartolomea di Stefano Nelli Niccolò Machiavelli "[',' ' the' ' Flore' 'nt' 'ine' ' states' 'man' ' and' ' author' ' of'
+ ' the' ' _' 'Prin' 'ci' 'pe' '_' ',' ' and' ' the' ' _']" , the Flore nt ine states man and author of the _ Prin ci pe _ , and the _ False philosophers such as Niccolò Machiavelli and Thomas More revived 10 [' philosophers', ' such', ' as', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1201 259 Name of mother of x -1 Name of mother of Niccolò Machiavelli Bartolomea di Stefano Nelli Niccolò Machiavelli "[',' ' the' ' Flore' 'nt' 'ine' ' states' 'man' ' and' ' author' ' of'
+ ' the' ' _' 'Prin' 'ci' 'pe' '_' ',' ' and' ' the' ' _']" , the Flore nt ine states man and author of the _ Prin ci pe _ , and the _ False Italian philosopher Niccolò Machiavelli remarked upon the 9 [' Italian', ' philosopher', ' Nic', 'col', '�', '�', ' Mach', 'ia', 've', 'lli']
+1202 260 Name of mother of x -1 Name of mother of Charles V Juana I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' the' ' French' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Emperor of the French , and the
+ " False the Emperor Charles V himself was approaching 3 [' the', ' Emperor', ' Charles', ' V']
+1203 260 Name of mother of x -1 Name of mother of Charles V Juana I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' the' ' French' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Emperor of the French , and the
+ " False Francis and Charles V in August 1516, 3 [' Francis', ' and', ' Charles', ' V']
+1204 260 Name of mother of x -1 Name of mother of Charles V Juana I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' the' ' French' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Emperor of the French , and the
+ " False the direct cause of Charles V granting permission 5 [' the', ' direct', ' cause', ' of', ' Charles', ' V']
+1205 260 Name of mother of x -1 Name of mother of Charles V Juana I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' the' ' French' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Emperor of the French , and the
+ " False and Emperor Charles V and date of construction 3 [' and', ' Emperor', ' Charles', ' V']
+1206 260 Name of mother of x -1 Name of mother of Charles V Juana I of Castile Charles V "[',' ' the' ' Emperor' ' of' ' the' ' French' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the Emperor of the French , and the
+ " False 1 ['Charles', ' V']
+1207 261 Name of mother of x -1 Name of mother of Johannes Brahms Johanna Henrica Christian jester Johannes Brahms "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Johannes' ' Brah' 'ms' ',' ' the']" ", the composer , and his wife , the
+
+ Name of mother of Johannes Brah ms , the" False Robert Schumann, Johannes Brahms and Ferruccio 6 [' Robert', ' Sch', 'umann', ',', ' Johannes', ' Brah', 'ms']
+1208 261 Name of mother of x -1 Name of mother of Johannes Brahms Johanna Henrica Christian jester Johannes Brahms "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Johannes' ' Brah' 'ms' ',' ' the']" ", the composer , and his wife , the
+
+ Name of mother of Johannes Brah ms , the" False " anything like that."" Johannes Brahms also admired" 6 "[' anything', ' like', ' that', '.""', ' Johannes', ' Brah', 'ms']"
+1209 261 Name of mother of x -1 Name of mother of Johannes Brahms Johanna Henrica Christian jester Johannes Brahms "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Johannes' ' Brah' 'ms' ',' ' the']" ", the composer , and his wife , the
+
+ Name of mother of Johannes Brah ms , the" False Robert Schumann, Johannes Brahms and Ferruccio 6 [' Robert', ' Sch', 'umann', ',', ' Johannes', ' Brah', 'ms']
+1210 261 Name of mother of x -1 Name of mother of Johannes Brahms Johanna Henrica Christian jester Johannes Brahms "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Johannes' ' Brah' 'ms' ',' ' the']" ", the composer , and his wife , the
+
+ Name of mother of Johannes Brah ms , the" False Heinrich Schütz and Johannes Brahms set to music, Brahms 8 [' Hein', 'rich', ' Sch', 'ü', 'tz', ' and', ' Johannes', ' Brah', 'ms']
+1211 261 Name of mother of x -1 Name of mother of Johannes Brahms Johanna Henrica Christian jester Johannes Brahms "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Johannes' ' Brah' 'ms' ',' ' the']" ", the composer , and his wife , the
+
+ Name of mother of Johannes Brah ms , the" False Robert Schumann, Johannes Brahms and Ferruccio 6 [' Robert', ' Sch', 'umann', ',', ' Johannes', ' Brah', 'ms']
+1212 262 Name of mother of x -1 Name of mother of Cristiano Ronaldo Maria Dolores dos Santos Viveiros da Aveiro Cristiano Ronaldo "[',' ' the' ' Portuguese' ' forward' ',' ' who' ' has' ' been' ' linked'
+ ' with' ' a' ' move' ' to' ' the' ' Premier' ' League' ',' ' has' ' been'
+ ' linked']" , the Portuguese forward , who has been linked with a move to the Premier League , has been linked False and the signings of Cristiano Ronaldo and Kaká. Real 6 [' and', ' the', ' signings', ' of', ' Crist', 'iano', ' Ronaldo']
+1213 262 Name of mother of x -1 Name of mother of Cristiano Ronaldo Maria Dolores dos Santos Viveiros da Aveiro Cristiano Ronaldo "[',' ' the' ' Portuguese' ' forward' ',' ' who' ' has' ' been' ' linked'
+ ' with' ' a' ' move' ' to' ' the' ' Premier' ' League' ',' ' has' ' been'
+ ' linked']" , the Portuguese forward , who has been linked with a move to the Premier League , has been linked False of April 2016, Cristiano Ronaldo is the all-time 6 [' of', ' April', ' 2016', ',', ' Crist', 'iano', ' Ronaldo']
+1214 262 Name of mother of x -1 Name of mother of Cristiano Ronaldo Maria Dolores dos Santos Viveiros da Aveiro Cristiano Ronaldo "[',' ' the' ' Portuguese' ' forward' ',' ' who' ' has' ' been' ' linked'
+ ' with' ' a' ' move' ' to' ' the' ' Premier' ' League' ',' ' has' ' been'
+ ' linked']" , the Portuguese forward , who has been linked with a move to the Premier League , has been linked False " Held jointly with Cristiano Ronaldo and Luis Suárez.
+" 5 [' Held', ' jointly', ' with', ' Crist', 'iano', ' Ronaldo']
+1215 262 Name of mother of x -1 Name of mother of Cristiano Ronaldo Maria Dolores dos Santos Viveiros da Aveiro Cristiano Ronaldo "[',' ' the' ' Portuguese' ' forward' ',' ' who' ' has' ' been' ' linked'
+ ' with' ' a' ' move' ' to' ' the' ' Premier' ' League' ',' ' has' ' been'
+ ' linked']" , the Portuguese forward , who has been linked with a move to the Premier League , has been linked False Tournament, marking Cristiano Ronaldo when the team 5 [' Tournament', ',', ' marking', ' Crist', 'iano', ' Ronaldo']
+1216 262 Name of mother of x -1 Name of mother of Cristiano Ronaldo Maria Dolores dos Santos Viveiros da Aveiro Cristiano Ronaldo "[',' ' the' ' Portuguese' ' forward' ',' ' who' ' has' ' been' ' linked'
+ ' with' ' a' ' move' ' to' ' the' ' Premier' ' League' ',' ' has' ' been'
+ ' linked']" , the Portuguese forward , who has been linked with a move to the Premier League , has been linked False Portuguese compatriot Cristiano Ronaldo — with whom he 5 [' Portuguese', ' compat', 'riot', ' Crist', 'iano', ' Ronaldo']
+1217 263 Name of mother of x -1 Name of mother of Martin Luther King Jr. Alberta Williams King Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False / Magneto and Martin Luther King Jr. / Malcolm X stating 8 [' /', ' Magnet', 'o', ' and', ' Martin', ' Luther', ' King', ' Jr', '.']
+1218 263 Name of mother of x -1 Name of mother of Martin Luther King Jr. Alberta Williams King Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False committee to promote Martin Luther King Jr. Day as a national 7 [' committee', ' to', ' promote', ' Martin', ' Luther', ' King', ' Jr', '.']
+1219 263 Name of mother of x -1 Name of mother of Martin Luther King Jr. Alberta Williams King Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False the funeral of Dr. Martin Luther King Jr. He considered himself 9 [' the', ' funeral', ' of', ' Dr', '.', ' Martin', ' Luther', ' King', ' Jr', '.']
+1220 263 Name of mother of x -1 Name of mother of Martin Luther King Jr. Alberta Williams King Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False May 29, 1964 Martin Luther King Jr. spoke at the 8 [' May', ' 29', ',', ' 1964', ' Martin', ' Luther', ' King', ' Jr', '.']
+1221 263 Name of mother of x -1 Name of mother of Martin Luther King Jr. Alberta Williams King Martin Luther King Jr. "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False neighborhoods into Lansing as Martin Luther King Jr. Boulevard. The 8 [' neighborhoods', ' into', ' Lansing', ' as', ' Martin', ' Luther', ' King', ' Jr', '.']
+1222 264 Name of mother of x -1 Name of mother of Henrik Ibsen Marichen Altenburg Henrik Ibsen "[',' ' the' ' Norwegian' ' play' 'wright' ',' ' was' ' born' ' in' ' 18'
+ '28' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' children']" ", the Norwegian play wright , was born in 18 28 .
+
+ The first of the three children" False Marcus Thrane in 1917, Henrik Ibsen in two volumes in 10 [' Marcus', ' Th', 'rane', ' in', ' 1917', ',', ' Hen', 'rik', ' I', 'bs', 'en']
+1223 264 Name of mother of x -1 Name of mother of Henrik Ibsen Marichen Altenburg Henrik Ibsen "[',' ' the' ' Norwegian' ' play' 'wright' ',' ' was' ' born' ' in' ' 18'
+ '28' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' children']" ", the Norwegian play wright , was born in 18 28 .
+
+ The first of the three children" False Norwegian dramatists Henrik Ibsen and Gunnar Heiberg. 8 [' Norwegian', ' dram', 'at', 'ists', ' Hen', 'rik', ' I', 'bs', 'en']
+1224 264 Name of mother of x -1 Name of mother of Henrik Ibsen Marichen Altenburg Henrik Ibsen "[',' ' the' ' Norwegian' ' play' 'wright' ',' ' was' ' born' ' in' ' 18'
+ '28' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' children']" ", the Norwegian play wright , was born in 18 28 .
+
+ The first of the three children" False critics. Along with Henrik Ibsen and August Strindberg, 8 [' critics', '.', ' Along', ' with', ' Hen', 'rik', ' I', 'bs', 'en']
+1225 264 Name of mother of x -1 Name of mother of Henrik Ibsen Marichen Altenburg Henrik Ibsen "[',' ' the' ' Norwegian' ' play' 'wright' ',' ' was' ' born' ' in' ' 18'
+ '28' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' children']" ", the Norwegian play wright , was born in 18 28 .
+
+ The first of the three children" False Sverdrup and Henrik Ibsen spanned several 9 [' S', 'ver', 'd', 'rup', ' and', ' Hen', 'rik', ' I', 'bs', 'en']
+1226 264 Name of mother of x -1 Name of mother of Henrik Ibsen Marichen Altenburg Henrik Ibsen "[',' ' the' ' Norwegian' ' play' 'wright' ',' ' was' ' born' ' in' ' 18'
+ '28' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ' children']" ", the Norwegian play wright , was born in 18 28 .
+
+ The first of the three children" False Thrane in 1917, Henrik Ibsen in two volumes 9 [' Th', 'rane', ' in', ' 1917', ',', ' Hen', 'rik', ' I', 'bs', 'en']
+1227 265 Name of mother of x -1 Name of mother of Niels Bohr Ellen Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' Nobel' ' Prize'
+ ' in']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The Nobel Prize in" False spectrometer. Niels Bohr won the Physics 7 [' spect', 'rom', 'eter', '.', ' Ni', 'els', ' Boh', 'r']
+1228 265 Name of mother of x -1 Name of mother of Niels Bohr Ellen Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' Nobel' ' Prize'
+ ' in']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The Nobel Prize in" False After a year at the Niels Bohr Institute in Denmark, 8 [' After', ' a', ' year', ' at', ' the', ' Ni', 'els', ' Boh', 'r']
+1229 265 Name of mother of x -1 Name of mother of Niels Bohr Ellen Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' Nobel' ' Prize'
+ ' in']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The Nobel Prize in" False the uranium program. Niels Bohr and John Wheeler 7 [' the', ' uranium', ' program', '.', ' Ni', 'els', ' Boh', 'r']
+1230 265 Name of mother of x -1 Name of mother of Niels Bohr Ellen Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' Nobel' ' Prize'
+ ' in']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The Nobel Prize in" False also worked with Niels Bohr in explaining 6 [' also', ' worked', ' with', ' Ni', 'els', ' Boh', 'r']
+1231 265 Name of mother of x -1 Name of mother of Niels Bohr Ellen Bohr Niels Bohr "[',' ' the' ' Danish' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' Physics' ' in' ' 1922' '.' '\n' '\n' 'The' ' Nobel' ' Prize'
+ ' in']" ", the Danish physicist who won the Nobel Prize for Physics in 1922 .
+
+ The Nobel Prize in" False eager to invite Niels Bohr to the Tube Alloys 6 [' eager', ' to', ' invite', ' Ni', 'els', ' Boh', 'r']
+1232 266 Name of mother of x -1 Name of mother of George W. Bush Barbara Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False In December 2006 George W. Bush showed dissatisfaction 6 [' In', ' December', ' 2006', ' George', ' W', '.', ' Bush']
+1233 266 Name of mother of x -1 Name of mother of George W. Bush Barbara Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False statement, President George W. Bush stated in 6 [' statement', ',', ' President', ' George', ' W', '.', ' Bush']
+1234 266 Name of mother of x -1 Name of mother of George W. Bush Barbara Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False United States George W. Bush and United States 5 [' United', ' States', ' George', ' W', '.', ' Bush']
+1235 266 Name of mother of x -1 Name of mother of George W. Bush Barbara Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False President George W. Bush who were perceived 4 [' President', ' George', ' W', '.', ' Bush']
+1236 266 Name of mother of x -1 Name of mother of George W. Bush Barbara Bush George W. Bush "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False that President George W. Bush did not have 5 [' that', ' President', ' George', ' W', '.', ' Bush']
+1237 267 Name of mother of x -1 Name of mother of Peter Gabriel Edith Irene Allen Peter Gabriel "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the mother of the groom .
+
+ The bride and groom" False original frontman Peter Gabriel and guitarist Steve 3 [' original', ' frontman', ' Peter', ' Gabriel']
+1238 267 Name of mother of x -1 Name of mother of Peter Gabriel Edith Irene Allen Peter Gabriel "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the mother of the groom .
+
+ The bride and groom" False " to Earth"" by Peter Gabriel and ""O ... Saya""," 5 "[' to', ' Earth', '""', ' by', ' Peter', ' Gabriel']"
+1239 267 Name of mother of x -1 Name of mother of Peter Gabriel Edith Irene Allen Peter Gabriel "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the mother of the groom .
+
+ The bride and groom" False Genesis front man Peter Gabriel asked how they could 4 [' Genesis', ' front', ' man', ' Peter', ' Gabriel']
+1240 267 Name of mother of x -1 Name of mother of Peter Gabriel Edith Irene Allen Peter Gabriel "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the mother of the groom .
+
+ The bride and groom" False original frontman Peter Gabriel and guitarist Steve 3 [' original', ' frontman', ' Peter', ' Gabriel']
+1241 267 Name of mother of x -1 Name of mother of Peter Gabriel Edith Irene Allen Peter Gabriel "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the father of the bride , and the mother of the groom .
+
+ The bride and groom" False " Streets, and even Peter Gabriel is quite fun."" The" 5 [' Streets', ',', ' and', ' even', ' Peter', ' Gabriel']
+1242 268 Name of mother of x -1 Name of mother of Douglas Adams Janet Adams Douglas Adams "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False collaborated with author Douglas Adams in a week-long series 4 [' collaborated', ' with', ' author', ' Douglas', ' Adams']
+1243 268 Name of mother of x -1 Name of mother of Douglas Adams Janet Adams Douglas Adams "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " from viewers, Douglas Adams wrote, ""If the" 4 [' from', ' viewers', ',', ' Douglas', ' Adams']
+1244 268 Name of mother of x -1 Name of mother of Douglas Adams Janet Adams Douglas Adams "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False like most Douglas Adams material, is overrated 3 [' like', ' most', ' Douglas', ' Adams']
+1245 268 Name of mother of x -1 Name of mother of Douglas Adams Janet Adams Douglas Adams "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " criticisms from viewers, Douglas Adams wrote, ""If" 5 [' criticisms', ' from', ' viewers', ',', ' Douglas', ' Adams']
+1246 268 Name of mother of x -1 Name of mother of Douglas Adams Janet Adams Douglas Adams "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Chapman and Douglas Adams wrote a pilot 3 [' Chapman', ' and', ' Douglas', ' Adams']
+1247 269 Name of mother of x -1 Name of mother of Nicolaus Copernicus Barbara Koppernigk Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' Tor' 'u' '�'
+ '�' ',' ' Poland' ',' ' in' ' 14' '73' '.' '\n']" ", the astronomer , who was born in Tor u � � , Poland , in 14 73 .
+" False 16th century, Nicolaus Copernicus proposed a heliocentric 8 [' 16', 'th', ' century', ',', ' Nicola', 'us', ' Cop', 'ern', 'icus']
+1248 269 Name of mother of x -1 Name of mother of Nicolaus Copernicus Barbara Koppernigk Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' Tor' 'u' '�'
+ '�' ',' ' Poland' ',' ' in' ' 14' '73' '.' '\n']" ", the astronomer , who was born in Tor u � � , Poland , in 14 73 .
+" False Reformation. Nicolaus Copernicus dedicated De revolutionibus 7 [' Re', 'formation', '.', ' Nicola', 'us', ' Cop', 'ern', 'icus']
+1249 269 Name of mother of x -1 Name of mother of Nicolaus Copernicus Barbara Koppernigk Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' Tor' 'u' '�'
+ '�' ',' ' Poland' ',' ' in' ' 14' '73' '.' '\n']" ", the astronomer , who was born in Tor u � � , Poland , in 14 73 .
+" False together with Nicolaus Copernicus with a feast day 6 [' together', ' with', ' Nicola', 'us', ' Cop', 'ern', 'icus']
+1250 269 Name of mother of x -1 Name of mother of Nicolaus Copernicus Barbara Koppernigk Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' Tor' 'u' '�'
+ '�' ',' ' Poland' ',' ' in' ' 14' '73' '.' '\n']" ", the astronomer , who was born in Tor u � � , Poland , in 14 73 .
+" False 5 ['Nic', 'ola', 'us', ' Cop', 'ern', 'icus']
+1251 269 Name of mother of x -1 Name of mother of Nicolaus Copernicus Barbara Koppernigk Nicolaus Copernicus "[',' ' the' ' astronomer' ',' ' who' ' was' ' born' ' in' ' Tor' 'u' '�'
+ '�' ',' ' Poland' ',' ' in' ' 14' '73' '.' '\n']" ", the astronomer , who was born in Tor u � � , Poland , in 14 73 .
+" False Chemistry after Nicolaus Copernicus on February 19, 6 [' Chemistry', ' after', ' Nicola', 'us', ' Cop', 'ern', 'icus']
+1252 270 Name of mother of x -1 Name of mother of Karl Popper Jenny Popper Karl Popper "[',' ' the' ' great' ' philosopher' ' of' ' the' ' 20' 'th' ' century' ','
+ ' who' ' was' ' a' ' great' ' admire' 'r' ' of' ' the' ' work' ' of']" , the great philosopher of the 20 th century , who was a great admire r of the work of False " 1976, the philosopher Karl Popper said that ""Darwinism" 6 [' 1976', ',', ' the', ' philosopher', ' Karl', ' Po', 'pper']
+1253 270 Name of mother of x -1 Name of mother of Karl Popper Jenny Popper Karl Popper "[',' ' the' ' great' ' philosopher' ' of' ' the' ' 20' 'th' ' century' ','
+ ' who' ' was' ' a' ' great' ' admire' 'r' ' of' ' the' ' work' ' of']" , the great philosopher of the 20 th century , who was a great admire r of the work of False macroeconomics? Karl Popper called this the 6 [' macro', 'econom', 'ics', '?', ' Karl', ' Po', 'pper']
+1254 270 Name of mother of x -1 Name of mother of Karl Popper Jenny Popper Karl Popper "[',' ' the' ' great' ' philosopher' ' of' ' the' ' 20' 'th' ' century' ','
+ ' who' ' was' ' a' ' great' ' admire' 'r' ' of' ' the' ' work' ' of']" , the great philosopher of the 20 th century , who was a great admire r of the work of False Stephen Toulmin, and Karl Popper — have repeatedly 8 [' Stephen', ' T', 'oul', 'min', ',', ' and', ' Karl', ' Po', 'pper']
+1255 270 Name of mother of x -1 Name of mother of Karl Popper Jenny Popper Karl Popper "[',' ' the' ' great' ' philosopher' ' of' ' the' ' 20' 'th' ' century' ','
+ ' who' ' was' ' a' ' great' ' admire' 'r' ' of' ' the' ' work' ' of']" , the great philosopher of the 20 th century , who was a great admire r of the work of False " the philosopher Karl Popper said that ""Darwinism" 4 [' the', ' philosopher', ' Karl', ' Po', 'pper']
+1256 270 Name of mother of x -1 Name of mother of Karl Popper Jenny Popper Karl Popper "[',' ' the' ' great' ' philosopher' ' of' ' the' ' 20' 'th' ' century' ','
+ ' who' ' was' ' a' ' great' ' admire' 'r' ' of' ' the' ' work' ' of']" , the great philosopher of the 20 th century , who was a great admire r of the work of False tutors including Karl Popper and Harold Laski; 5 [' tut', 'ors', ' including', ' Karl', ' Po', 'pper']
+1257 271 Name of mother of x -1 Name of mother of George H. W. Bush Dorothy Walker Bush George H. W. Bush "[',' ' the' ' former' ' president' ',' ' and' ' his' ' wife' ','
+ ' Barbara' ',' ' were' ' in' ' the' ' audience' '.' '\n' '\n' 'The'
+ ' event']" ", the former president , and his wife , Barbara , were in the audience .
+
+ The event" False President George H. W. Bush designated him the 6 [' President', ' George', ' H', '.', ' W', '.', ' Bush']
+1258 271 Name of mother of x -1 Name of mother of George H. W. Bush Dorothy Walker Bush George H. W. Bush "[',' ' the' ' former' ' president' ',' ' and' ' his' ' wife' ','
+ ' Barbara' ',' ' were' ' in' ' the' ' audience' '.' '\n' '\n' 'The'
+ ' event']" ", the former president , and his wife , Barbara , were in the audience .
+
+ The event" False Republican candidate George H. W. Bush accused Democratic 7 [' Republican', ' candidate', ' George', ' H', '.', ' W', '.', ' Bush']
+1259 271 Name of mother of x -1 Name of mother of George H. W. Bush Dorothy Walker Bush George H. W. Bush "[',' ' the' ' former' ' president' ',' ' and' ' his' ' wife' ','
+ ' Barbara' ',' ' were' ' in' ' the' ' audience' '.' '\n' '\n' 'The'
+ ' event']" ", the former president , and his wife , Barbara , were in the audience .
+
+ The event" False with President George H. W. Bush again, the fourth 7 [' with', ' President', ' George', ' H', '.', ' W', '.', ' Bush']
+1260 271 Name of mother of x -1 Name of mother of George H. W. Bush Dorothy Walker Bush George H. W. Bush "[',' ' the' ' former' ' president' ',' ' and' ' his' ' wife' ','
+ ' Barbara' ',' ' were' ' in' ' the' ' audience' '.' '\n' '\n' 'The'
+ ' event']" ", the former president , and his wife , Barbara , were in the audience .
+
+ The event" False Republican incumbent George H. W. Bush (37.4 percent of the 7 [' Republican', ' incumbent', ' George', ' H', '.', ' W', '.', ' Bush']
+1261 271 Name of mother of x -1 Name of mother of George H. W. Bush Dorothy Walker Bush George H. W. Bush "[',' ' the' ' former' ' president' ',' ' and' ' his' ' wife' ','
+ ' Barbara' ',' ' were' ' in' ' the' ' audience' '.' '\n' '\n' 'The'
+ ' event']" ", the former president , and his wife , Barbara , were in the audience .
+
+ The event" False " by President George H. W. Bush before sentencing.
+" 7 [' by', ' President', ' George', ' H', '.', ' W', '.', ' Bush']
+1262 272 Name of mother of x -1 Name of mother of Seneca Helvia Seneca "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' a' ' huge' ',' ' sprawling' ',' ' and' ' very']" ", the Roman emperor , and the
+
+ The Roman Empire was a huge , sprawling , and very" False continue due north to Seneca Falls. NY 96, 5 [' continue', ' due', ' north', ' to', ' Sen', 'eca']
+1263 272 Name of mother of x -1 Name of mother of Seneca Helvia Seneca "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' a' ' huge' ',' ' sprawling' ',' ' and' ' very']" ", the Roman emperor , and the
+
+ The Roman Empire was a huge , sprawling , and very" False commentary of Seneca the Younger's De 3 [' commentary', ' of', ' Sen', 'eca']
+1264 272 Name of mother of x -1 Name of mother of Seneca Helvia Seneca "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' a' ' huge' ',' ' sprawling' ',' ' and' ' very']" ", the Roman emperor , and the
+
+ The Roman Empire was a huge , sprawling , and very" False annihilated by the Seneca people from New 5 [' annihil', 'ated', ' by', ' the', ' Sen', 'eca']
+1265 272 Name of mother of x -1 Name of mother of Seneca Helvia Seneca "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' a' ' huge' ',' ' sprawling' ',' ' and' ' very']" ", the Roman emperor , and the
+
+ The Roman Empire was a huge , sprawling , and very" False valuable player, but Seneca settled for 5 [' valuable', ' player', ',', ' but', ' Sen', 'eca']
+1266 272 Name of mother of x -1 Name of mother of Seneca Helvia Seneca "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' a' ' huge' ',' ' sprawling' ',' ' and' ' very']" ", the Roman emperor , and the
+
+ The Roman Empire was a huge , sprawling , and very" False 1 ['Sen', 'eca']
+1267 273 Name of mother of x -1 Name of mother of Julia Roberts Betty Lou Bredemus Julia Roberts "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'Pretty'
+ ' Woman']" , the actress who played the role of the mother of the bride in the movie � � Pretty Woman False and Andie MacDowell, Julia Roberts and Madonna for 8 [' and', ' And', 'ie', ' Mac', 'D', 'owell', ',', ' Julia', ' Roberts']
+1268 273 Name of mother of x -1 Name of mother of Julia Roberts Betty Lou Bredemus Julia Roberts "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'Pretty'
+ ' Woman']" , the actress who played the role of the mother of the bride in the movie � � Pretty Woman False Drew Barrymore and Julia Roberts to lower their asking 5 [' Drew', ' Barry', 'more', ' and', ' Julia', ' Roberts']
+1269 273 Name of mother of x -1 Name of mother of Julia Roberts Betty Lou Bredemus Julia Roberts "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'Pretty'
+ ' Woman']" , the actress who played the role of the mother of the bride in the movie � � Pretty Woman False spokesmodel, along with Julia Roberts and Winslet. The campaign 6 [' spokes', 'model', ',', ' along', ' with', ' Julia', ' Roberts']
+1270 273 Name of mother of x -1 Name of mother of Julia Roberts Betty Lou Bredemus Julia Roberts "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'Pretty'
+ ' Woman']" , the actress who played the role of the mother of the bride in the movie � � Pretty Woman False getting a date with Julia Roberts as doing My Fair 5 [' getting', ' a', ' date', ' with', ' Julia', ' Roberts']
+1271 273 Name of mother of x -1 Name of mother of Julia Roberts Betty Lou Bredemus Julia Roberts "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'Pretty'
+ ' Woman']" , the actress who played the role of the mother of the bride in the movie � � Pretty Woman False Clooney cast Julia Roberts as the mysterious 4 [' Clo', 'oney', ' cast', ' Julia', ' Roberts']
+1272 274 Name of mother of x -1 Name of mother of Thomas Gainsborough NN Burroughs Thomas Gainsborough "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False Renaissance prevailed — Thomas Gainsborough and Joshua Reynolds 6 [' Renaissance', ' prevailed', ' —', ' Thomas', ' G', 'ains', 'borough']
+1273 274 Name of mother of x -1 Name of mother of Thomas Gainsborough NN Burroughs Thomas Gainsborough "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False Richard Cosway and Thomas Gainsborough lived at Schomberg 7 [' Richard', ' Cos', 'way', ' and', ' Thomas', ' G', 'ains', 'borough']
+1274 274 Name of mother of x -1 Name of mother of Thomas Gainsborough NN Burroughs Thomas Gainsborough "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False portrait by Thomas Gainsborough shows him 5 [' portrait', ' by', ' Thomas', ' G', 'ains', 'borough']
+1275 274 Name of mother of x -1 Name of mother of Thomas Gainsborough NN Burroughs Thomas Gainsborough "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False room where artist Thomas Gainsborough was completing 6 [' room', ' where', ' artist', ' Thomas', ' G', 'ains', 'borough']
+1276 274 Name of mother of x -1 Name of mother of Thomas Gainsborough NN Burroughs Thomas Gainsborough "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' and']" , the painter , and his wife , the painter 's wife , and the painter 's mother , and False prevailed — Thomas Gainsborough and Joshua Reynolds 5 [' prevailed', ' —', ' Thomas', ' G', 'ains', 'borough']
+1277 275 Name of mother of x -1 Name of mother of Ingmar Bergman Karin Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',']" , the Swedish director of the film , and the film 's producer , and the film 's writer , False the director Ingmar Bergman referenced Magnolia 5 [' the', ' director', ' Ing', 'mar', ' Berg', 'man']
+1278 275 Name of mother of x -1 Name of mother of Ingmar Bergman Karin Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',']" , the Swedish director of the film , and the film 's producer , and the film 's writer , False Steven Spielberg. Ingmar Bergman disliked the film 6 [' Steven', ' Spielberg', '.', ' Ing', 'mar', ' Berg', 'man']
+1279 275 Name of mother of x -1 Name of mother of Ingmar Bergman Karin Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',']" , the Swedish director of the film , and the film 's producer , and the film 's writer , False " Krzysztof Kieślowski, Ingmar Bergman and J. Mahendran.
+" 16 [' Kr', 'z', 'ys', 'z', 'to', 'f', ' K', 'ie', '�', '�', 'l', 'owski', ',', ' Ing', 'mar', ' Berg', 'man']
+1280 275 Name of mother of x -1 Name of mother of Ingmar Bergman Karin Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',']" , the Swedish director of the film , and the film 's producer , and the film 's writer , False and white films of Ingmar Bergman bear a resemblance 7 [' and', ' white', ' films', ' of', ' Ing', 'mar', ' Berg', 'man']
+1281 275 Name of mother of x -1 Name of mother of Ingmar Bergman Karin Bergman Ingmar Bergman "[',' ' the' ' Swedish' ' director' ' of' ' the' ' film' ',' ' and' ' the'
+ ' film' ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',']" , the Swedish director of the film , and the film 's producer , and the film 's writer , False Aghed, the director Ingmar Bergman referenced Magnolia 8 [' Ag', 'hed', ',', ' the', ' director', ' Ing', 'mar', ' Berg', 'man']
+1282 276 Name of mother of x -1 Name of mother of Leon Trotsky Anna Bronstein Leon Trotsky "[',' ' the' ' son' ' of' ' a' ' Russian' ' Jew' ',' ' and' ' a' ' Jew'
+ 'ess' ',' ' and' ' a' ' Jew' 'ess' ',' ' and' ' a']" , the son of a Russian Jew , and a Jew ess , and a Jew ess , and a False Lenin protégé Leon Trotsky cited Kropotkin's 6 [' Lenin', ' prot', 'é', 'g', 'é', ' Leon', ' Trotsky']
+1283 276 Name of mother of x -1 Name of mother of Leon Trotsky Anna Bronstein Leon Trotsky "[',' ' the' ' son' ' of' ' a' ' Russian' ' Jew' ',' ' and' ' a' ' Jew'
+ 'ess' ',' ' and' ' a' ' Jew' 'ess' ',' ' and' ' a']" , the son of a Russian Jew , and a Jew ess , and a Jew ess , and a False supporters of Leon Trotsky against those 3 [' supporters', ' of', ' Leon', ' Trotsky']
+1284 276 Name of mother of x -1 Name of mother of Leon Trotsky Anna Bronstein Leon Trotsky "[',' ' the' ' son' ' of' ' a' ' Russian' ' Jew' ',' ' and' ' a' ' Jew'
+ 'ess' ',' ' and' ' a' ' Jew' 'ess' ',' ' and' ' a']" , the son of a Russian Jew , and a Jew ess , and a Jew ess , and a False Vladimir Lenin and Leon Trotsky and many others, and 4 [' Vladimir', ' Lenin', ' and', ' Leon', ' Trotsky']
+1285 276 Name of mother of x -1 Name of mother of Leon Trotsky Anna Bronstein Leon Trotsky "[',' ' the' ' son' ' of' ' a' ' Russian' ' Jew' ',' ' and' ' a' ' Jew'
+ 'ess' ',' ' and' ' a' ' Jew' 'ess' ',' ' and' ' a']" , the son of a Russian Jew , and a Jew ess , and a Jew ess , and a False Stepney Green. Leon Trotsky and Vladimir Lenin 5 [' Step', 'ney', ' Green', '.', ' Leon', ' Trotsky']
+1286 276 Name of mother of x -1 Name of mother of Leon Trotsky Anna Bronstein Leon Trotsky "[',' ' the' ' son' ' of' ' a' ' Russian' ' Jew' ',' ' and' ' a' ' Jew'
+ 'ess' ',' ' and' ' a' ' Jew' 'ess' ',' ' and' ' a']" , the son of a Russian Jew , and a Jew ess , and a Jew ess , and a False " revolutionary terror"" Leon Trotsky for an edition" 4 "[' revolutionary', ' terror', '""', ' Leon', ' Trotsky']"
+1287 277 Name of mother of x -1 Name of mother of Wilhelm II Victoria, Princess Royal Wilhelm II "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' '\n' '\n' 'The'
+ ' Emperor' ' of' ' Germany' ',' ' the' ' Emperor' ' of' ' Germany' ',']" ", the Emperor of Germany , and the
+
+ The Emperor of Germany , the Emperor of Germany ," False Germany's Kaiser Wilhelm II had advocated a fast 4 "[' Germany', ""'s"", ' Kaiser', ' Wilhelm', ' II']"
+1288 277 Name of mother of x -1 Name of mother of Wilhelm II Victoria, Princess Royal Wilhelm II "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' '\n' '\n' 'The'
+ ' Emperor' ' of' ' Germany' ',' ' the' ' Emperor' ' of' ' Germany' ',']" ", the Emperor of Germany , and the
+
+ The Emperor of Germany , the Emperor of Germany ," False voyage. Kaiser Wilhelm II celebrated the results 4 [' voyage', '.', ' Kaiser', ' Wilhelm', ' II']
+1289 277 Name of mother of x -1 Name of mother of Wilhelm II Victoria, Princess Royal Wilhelm II "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' '\n' '\n' 'The'
+ ' Emperor' ' of' ' Germany' ',' ' the' ' Emperor' ' of' ' Germany' ',']" ", the Emperor of Germany , and the
+
+ The Emperor of Germany , the Emperor of Germany ," False the chase. Kaiser Wilhelm II was enraged 5 [' the', ' chase', '.', ' Kaiser', ' Wilhelm', ' II']
+1290 277 Name of mother of x -1 Name of mother of Wilhelm II Victoria, Princess Royal Wilhelm II "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' '\n' '\n' 'The'
+ ' Emperor' ' of' ' Germany' ',' ' the' ' Emperor' ' of' ' Germany' ',']" ", the Emperor of Germany , and the
+
+ The Emperor of Germany , the Emperor of Germany ," False August for Wilhelm II to have a meeting 3 [' August', ' for', ' Wilhelm', ' II']
+1291 277 Name of mother of x -1 Name of mother of Wilhelm II Victoria, Princess Royal Wilhelm II "[',' ' the' ' Emperor' ' of' ' Germany' ',' ' and' ' the' '\n' '\n' 'The'
+ ' Emperor' ' of' ' Germany' ',' ' the' ' Emperor' ' of' ' Germany' ',']" ", the Emperor of Germany , and the
+
+ The Emperor of Germany , the Emperor of Germany ," False orders from Kaiser Wilhelm II to avoid risking the 4 [' orders', ' from', ' Kaiser', ' Wilhelm', ' II']
+1292 278 Name of mother of x -1 Name of mother of Michael Jordan Deloris Jordan Michael Jordan "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the father of the nation , and the father of the nation .
+
+ The father of the" False " forward and center), and Michael Jordan praised his ""game," 6 [' forward', ' and', ' center', '),', ' and', ' Michael', ' Jordan']
+1293 278 Name of mother of x -1 Name of mother of Michael Jordan Deloris Jordan Michael Jordan "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the father of the nation , and the father of the nation .
+
+ The father of the" False " (three games)
+" 5 [' (', 'three', ' games', ')', 'Michael', ' Jordan']
+1294 278 Name of mother of x -1 Name of mother of Michael Jordan Deloris Jordan Michael Jordan "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the father of the nation , and the father of the nation .
+
+ The father of the" False same year Michael Jordan entered the 3 [' same', ' year', ' Michael', ' Jordan']
+1295 278 Name of mother of x -1 Name of mother of Michael Jordan Deloris Jordan Michael Jordan "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the father of the nation , and the father of the nation .
+
+ The father of the" False Players of the Year in Michael Jordan and Ewing. 6 [' Players', ' of', ' the', ' Year', ' in', ' Michael', ' Jordan']
+1296 278 Name of mother of x -1 Name of mother of Michael Jordan Deloris Jordan Michael Jordan "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' nation' '.' '\n' '\n' 'The' ' father' ' of' ' the']" ", the father of the nation , and the father of the nation .
+
+ The father of the" False representing sports icon Michael Jordan for the entirety of 4 [' representing', ' sports', ' icon', ' Michael', ' Jordan']
+1297 279 Name of mother of x -1 Name of mother of Orson Welles Beatrice Ives Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ ',' ' in' ' the' ' film' ' no' 'ir' ' classic' ' _' 'The']" , the actor , and his wife , Rita Hay worth , in the film no ir classic _ The False Stanley Kubrick, Orson Welles and Max Ophüls, 6 [' Stanley', ' Kubrick', ',', ' Or', 'son', ' Well', 'es']
+1298 279 Name of mother of x -1 Name of mother of Orson Welles Beatrice Ives Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ ',' ' in' ' the' ' film' ' no' 'ir' ' classic' ' _' 'The']" , the actor , and his wife , Rita Hay worth , in the film no ir classic _ The False directed by Orson Welles and featuring 5 [' directed', ' by', ' Or', 'son', ' Well', 'es']
+1299 279 Name of mother of x -1 Name of mother of Orson Welles Beatrice Ives Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ ',' ' in' ' the' ' film' ' no' 'ir' ' classic' ' _' 'The']" , the actor , and his wife , Rita Hay worth , in the film no ir classic _ The False 3 ['Or', 'son', ' Well', 'es']
+1300 279 Name of mother of x -1 Name of mother of Orson Welles Beatrice Ives Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ ',' ' in' ' the' ' film' ' no' 'ir' ' classic' ' _' 'The']" , the actor , and his wife , Rita Hay worth , in the film no ir classic _ The False RKO lot that the Orson Welles deal will end 8 [' R', 'KO', ' lot', ' that', ' the', ' Or', 'son', ' Well', 'es']
+1301 279 Name of mother of x -1 Name of mother of Orson Welles Beatrice Ives Welles Orson Welles "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Rita' ' Hay' 'worth'
+ ',' ' in' ' the' ' film' ' no' 'ir' ' classic' ' _' 'The']" , the actor , and his wife , Rita Hay worth , in the film no ir classic _ The False services of Orson Welles to narrate a documentary. 5 [' services', ' of', ' Or', 'son', ' Well', 'es']
+1302 280 Name of mother of x -1 Name of mother of Ella Fitzgerald Temperance Ella Fitzgerald "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False American songbook Ella Fitzgerald owned by then, 5 [' American', ' song', 'book', ' Ell', 'a', ' Fitzgerald']
+1303 280 Name of mother of x -1 Name of mother of Ella Fitzgerald Temperance Ella Fitzgerald "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False accompanied Ella Fitzgerald in 1956, for around 3 [' accompanied', ' Ell', 'a', ' Fitzgerald']
+1304 280 Name of mother of x -1 Name of mother of Ella Fitzgerald Temperance Ella Fitzgerald "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False concert performed by Ella Fitzgerald at the Royal 5 [' concert', ' performed', ' by', ' Ell', 'a', ' Fitzgerald']
+1305 280 Name of mother of x -1 Name of mother of Ella Fitzgerald Temperance Ella Fitzgerald "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False the American songbook Ella Fitzgerald owned by then, 6 [' the', ' American', ' song', 'book', ' Ell', 'a', ' Fitzgerald']
+1306 280 Name of mother of x -1 Name of mother of Ella Fitzgerald Temperance Ella Fitzgerald "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False American songbook Ella Fitzgerald owned by then, but 5 [' American', ' song', 'book', ' Ell', 'a', ' Fitzgerald']
+1307 281 Name of mother of x -1 Name of mother of Marcus Aurelius Calvisia Domitia Lucilla Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Marcus' ' Aure' 'l' 'ius' ',' ' the' ' Emperor' ',']" ", the Emperor , and the
+
+ Name of mother of Marcus Aure l ius , the Emperor ," False represented Roman emperor Marcus Aurelius Antoninus (Andun 6 [' represented', ' Roman', ' emperor', ' Marcus', ' Aure', 'l', 'ius']
+1308 281 Name of mother of x -1 Name of mother of Marcus Aurelius Calvisia Domitia Lucilla Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Marcus' ' Aure' 'l' 'ius' ',' ' the' ' Emperor' ',']" ", the Emperor , and the
+
+ Name of mother of Marcus Aure l ius , the Emperor ," False 3 ['Marcus', ' Aure', 'l', 'ius']
+1309 281 Name of mother of x -1 Name of mother of Marcus Aurelius Calvisia Domitia Lucilla Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Marcus' ' Aure' 'l' 'ius' ',' ' the' ' Emperor' ',']" ", the Emperor , and the
+
+ Name of mother of Marcus Aure l ius , the Emperor ," False " ""collegial"" system that Marcus Aurelius had first used," 10 "[' ""', 'col', 'leg', 'ial', '""', ' system', ' that', ' Marcus', ' Aure', 'l', 'ius']"
+1310 281 Name of mother of x -1 Name of mother of Marcus Aurelius Calvisia Domitia Lucilla Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Marcus' ' Aure' 'l' 'ius' ',' ' the' ' Emperor' ',']" ", the Emperor , and the
+
+ Name of mother of Marcus Aure l ius , the Emperor ," False " Caracalla's names, Marcus Aurelius Antoninus.
+" 9 "[' Car', 'ac', 'alla', ""'s"", ' names', ',', ' Marcus', ' Aure', 'l', 'ius']"
+1311 281 Name of mother of x -1 Name of mother of Marcus Aurelius Calvisia Domitia Lucilla Marcus Aurelius "[',' ' the' ' Emperor' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Marcus' ' Aure' 'l' 'ius' ',' ' the' ' Emperor' ',']" ", the Emperor , and the
+
+ Name of mother of Marcus Aure l ius , the Emperor ," False in Gladiator, is Marcus Aurelius's intended successor. 7 [' in', ' Gladiator', ',', ' is', ' Marcus', ' Aure', 'l', 'ius']
+1312 282 Name of mother of x -1 Name of mother of Céline Dion Thérèse Dion Céline Dion "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Ren' 'é' ' Ang' 'é'
+ 'l' 'il' ',' ' who' ' is' ' the' ' father' ' of' ' her']" , the singer , and her husband Ren é Ang é l il , who is the father of her False " single is a bad Céline Dion song with ""barrel-turning" 7 [' single', ' is', ' a', ' bad', ' C', 'é', 'line', ' Dion']
+1313 282 Name of mother of x -1 Name of mother of Céline Dion Thérèse Dion Céline Dion "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Ren' 'é' ' Ang' 'é'
+ 'l' 'il' ',' ' who' ' is' ' the' ' father' ' of' ' her']" , the singer , and her husband Ren é Ang é l il , who is the father of her False taped La spéciale Céline Dion in Paris, France 9 [' taped', ' La', ' sp', 'é', 'cial', 'e', ' C', 'é', 'line', ' Dion']
+1314 282 Name of mother of x -1 Name of mother of Céline Dion Thérèse Dion Céline Dion "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Ren' 'é' ' Ang' 'é'
+ 'l' 'il' ',' ' who' ' is' ' the' ' father' ' of' ' her']" , the singer , and her husband Ren é Ang é l il , who is the father of her False her early career: Céline Dion chante Noël (1981) 7 [' her', ' early', ' career', ':', ' C', 'é', 'line', ' Dion']
+1315 282 Name of mother of x -1 Name of mother of Céline Dion Thérèse Dion Céline Dion "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Ren' 'é' ' Ang' 'é'
+ 'l' 'il' ',' ' who' ' is' ' the' ' father' ' of' ' her']" , the singer , and her husband Ren é Ang é l il , who is the father of her False viewers, entitled Céline Dion à tout prix. She 6 [' viewers', ',', ' entitled', ' C', 'é', 'line', ' Dion']
+1316 282 Name of mother of x -1 Name of mother of Céline Dion Thérèse Dion Céline Dion "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ' Ren' 'é' ' Ang' 'é'
+ 'l' 'il' ',' ' who' ' is' ' the' ' father' ' of' ' her']" , the singer , and her husband Ren é Ang é l il , who is the father of her False viewers, entitled Céline Dion à tout prix. 6 [' viewers', ',', ' entitled', ' C', 'é', 'line', ' Dion']
+1317 283 Name of mother of x -1 Name of mother of Glenn Close Bettine Moore Close Glenn Close "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' hers' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of hers for" False Vanessa Redgrave, Glenn Close and her eldest daughter 5 [' Vanessa', ' Red', 'grave', ',', ' Glenn', ' Close']
+1318 283 Name of mother of x -1 Name of mother of Glenn Close Bettine Moore Close Glenn Close "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' hers' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of hers for" False (née Olsen), voiced by Glenn Close, is Homer's long-lost 8 [' (', 'n', 'ée', ' Olsen', '),', ' voiced', ' by', ' Glenn', ' Close']
+1319 283 Name of mother of x -1 Name of mother of Glenn Close Bettine Moore Close Glenn Close "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' hers' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of hers for" False voiced by Glenn Close, is Homer's 3 [' voiced', ' by', ' Glenn', ' Close']
+1320 283 Name of mother of x -1 Name of mother of Glenn Close Bettine Moore Close Glenn Close "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' hers' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of hers for" False of Det. Bill Stork; Glenn Close replaced Sally Struthers 8 [' of', ' Det', '.', ' Bill', ' St', 'ork', ';', ' Glenn', ' Close']
+1321 283 Name of mother of x -1 Name of mother of Glenn Close Bettine Moore Close Glenn Close "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' hers' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of hers for" False their graves. Glenn Close is up to the material, 4 [' their', ' graves', '.', ' Glenn', ' Close']
+1322 284 Name of mother of x -1 Name of mother of William Ewart Gladstone Anne MacKenzie Robertson William Ewart Gladstone "[',' ' the' ' great' ' Liberal' ' states' 'man' ',' ' who' ' was' ' a'
+ ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first' ' of']" ", the great Liberal states man , who was a great
+
+ 1 .
+
+ The first of" False support of William Ewart Gladstone over the issue of Irish 6 [' support', ' of', ' William', ' E', 'wart', ' Glad', 'stone']
+1323 284 Name of mother of x -1 Name of mother of William Ewart Gladstone Anne MacKenzie Robertson William Ewart Gladstone "[',' ' the' ' great' ' Liberal' ' states' 'man' ',' ' who' ' was' ' a'
+ ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first' ' of']" ", the great Liberal states man , who was a great
+
+ 1 .
+
+ The first of" False Prime Minister William Ewart Gladstone and poet Robert 6 [' Prime', ' Minister', ' William', ' E', 'wart', ' Glad', 'stone']
+1324 284 Name of mother of x -1 Name of mother of William Ewart Gladstone Anne MacKenzie Robertson William Ewart Gladstone "[',' ' the' ' great' ' Liberal' ' states' 'man' ',' ' who' ' was' ' a'
+ ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first' ' of']" ", the great Liberal states man , who was a great
+
+ 1 .
+
+ The first of" False Prime Minister William Ewart Gladstone remembered his arrival 6 [' Prime', ' Minister', ' William', ' E', 'wart', ' Glad', 'stone']
+1325 284 Name of mother of x -1 Name of mother of William Ewart Gladstone Anne MacKenzie Robertson William Ewart Gladstone "[',' ' the' ' great' ' Liberal' ' states' 'man' ',' ' who' ' was' ' a'
+ ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first' ' of']" ", the great Liberal states man , who was a great
+
+ 1 .
+
+ The first of" False " described by William Ewart Gladstone as ""the greatest" 6 [' described', ' by', ' William', ' E', 'wart', ' Glad', 'stone']
+1326 284 Name of mother of x -1 Name of mother of William Ewart Gladstone Anne MacKenzie Robertson William Ewart Gladstone "[',' ' the' ' great' ' Liberal' ' states' 'man' ',' ' who' ' was' ' a'
+ ' great' '\n' '\n' '1' '.' '\n' '\n' 'The' ' first' ' of']" ", the great Liberal states man , who was a great
+
+ 1 .
+
+ The first of" False British Prime Minister William Ewart Gladstone said that Pedro 7 [' British', ' Prime', ' Minister', ' William', ' E', 'wart', ' Glad', 'stone']
+1327 285 Name of mother of x -1 Name of mother of Lily Allen Alison Owen Lily Allen "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' am' ' so'
+ ' sorry' ' to' ' hear' ' about' ' your' ' loss' '.' ' I']" , who is a friend of mine , and I am so sorry to hear about your loss . I False British pop star Lily Allen on her second 4 [' British', ' pop', ' star', ' Lily', ' Allen']
+1328 285 Name of mother of x -1 Name of mother of Lily Allen Alison Owen Lily Allen "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' am' ' so'
+ ' sorry' ' to' ' hear' ' about' ' your' ' loss' '.' ' I']" , who is a friend of mine , and I am so sorry to hear about your loss . I False on anything Lily Allen has ever penned and 3 [' on', ' anything', ' Lily', ' Allen']
+1329 285 Name of mother of x -1 Name of mother of Lily Allen Alison Owen Lily Allen "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' am' ' so'
+ ' sorry' ' to' ' hear' ' about' ' your' ' loss' '.' ' I']" , who is a friend of mine , and I am so sorry to hear about your loss . I False British recording artist Lily Allen from her debut studio 4 [' British', ' recording', ' artist', ' Lily', ' Allen']
+1330 285 Name of mother of x -1 Name of mother of Lily Allen Alison Owen Lily Allen "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' am' ' so'
+ ' sorry' ' to' ' hear' ' about' ' your' ' loss' '.' ' I']" , who is a friend of mine , and I am so sorry to hear about your loss . I False height on anything Lily Allen has ever penned and 4 [' height', ' on', ' anything', ' Lily', ' Allen']
+1331 285 Name of mother of x -1 Name of mother of Lily Allen Alison Owen Lily Allen "[',' ' who' ' is' ' a' ' friend' ' of' ' mine' ',' ' and' ' I' ' am' ' so'
+ ' sorry' ' to' ' hear' ' about' ' your' ' loss' '.' ' I']" , who is a friend of mine , and I am so sorry to hear about your loss . I False British singer Lily Allen also covered the 3 [' British', ' singer', ' Lily', ' Allen']
+1332 286 Name of mother of x -1 Name of mother of Uma Thurman Nena von Schlebrügge Uma Thurman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False and also starred Uma Thurman and Jeremy Northam. 6 [' and', ' also', ' starred', ' U', 'ma', ' Thur', 'man']
+1333 286 Name of mother of x -1 Name of mother of Uma Thurman Nena von Schlebrügge Uma Thurman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Patrick Stewart and Uma Thurman had been cast. Natalie 6 [' Patrick', ' Stewart', ' and', ' U', 'ma', ' Thur', 'man']
+1334 286 Name of mother of x -1 Name of mother of Uma Thurman Nena von Schlebrügge Uma Thurman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Mr. Freeze, while Uma Thurman starred as 8 [' Mr', '.', ' Freeze', ',', ' while', ' U', 'ma', ' Thur', 'man']
+1335 286 Name of mother of x -1 Name of mother of Uma Thurman Nena von Schlebrügge Uma Thurman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False 3 ['U', 'ma', ' Thur', 'man']
+1336 286 Name of mother of x -1 Name of mother of Uma Thurman Nena von Schlebrügge Uma Thurman "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False scenes pay homage to Uma Thurman as The Bride 7 [' scenes', ' pay', ' homage', ' to', ' U', 'ma', ' Thur', 'man']
+1337 287 Name of mother of x -1 Name of mother of Alexander the Great Olympias Alexander the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Alexander' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Alexander the Great , and the name of the" False ancient Greek ruler Alexander the Great in 4th century 5 [' ancient', ' Greek', ' ruler', ' Alexander', ' the', ' Great']
+1338 287 Name of mother of x -1 Name of mother of Alexander the Great Olympias Alexander the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Alexander' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Alexander the Great , and the name of the" False conquests of Alexander the Great (r. 336 – 323 5 [' conqu', 'ests', ' of', ' Alexander', ' the', ' Great']
+1339 287 Name of mother of x -1 Name of mother of Alexander the Great Olympias Alexander the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Alexander' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Alexander the Great , and the name of the" False was in power. After Alexander the Great conquered Egypt 7 [' was', ' in', ' power', '.', ' After', ' Alexander', ' the', ' Great']
+1340 287 Name of mother of x -1 Name of mother of Alexander the Great Olympias Alexander the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Alexander' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Alexander the Great , and the name of the" False luminaries as Alexander the Great and Hannibal, 5 [' lumin', 'aries', ' as', ' Alexander', ' the', ' Great']
+1341 287 Name of mother of x -1 Name of mother of Alexander the Great Olympias Alexander the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Alexander' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Alexander the Great , and the name of the" False " Macedonian ruler Alexander the Great without a fight.
+" 5 [' Maced', 'onian', ' ruler', ' Alexander', ' the', ' Great']
+1342 289 Name of mother of x -1 Name of mother of Dante Gabriel Rossetti Frances Polidori Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ""'s"" ' brother' ',']" , the poet , and his wife , Christina , who was the daughter of the poet 's brother , False pre-Raphaelite artists Dante Gabriel Rossetti and Thomas Woolner, 9 [' pre', '-', 'R', 'aphael', 'ite', ' artists', ' Dante', ' Gabriel', ' Ross', 'etti']
+1343 289 Name of mother of x -1 Name of mother of Dante Gabriel Rossetti Frances Polidori Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ""'s"" ' brother' ',']" , the poet , and his wife , Christina , who was the daughter of the poet 's brother , False Burne-Jones and Dante Gabriel Rossetti and with the Neo-Gothic 8 [' Burn', 'e', '-', 'Jones', ' and', ' Dante', ' Gabriel', ' Ross', 'etti']
+1344 289 Name of mother of x -1 Name of mother of Dante Gabriel Rossetti Frances Polidori Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ""'s"" ' brother' ',']" , the poet , and his wife , Christina , who was the daughter of the poet 's brother , False life. There he met Dante Gabriel Rossetti and other members 8 [' life', '.', ' There', ' he', ' met', ' Dante', ' Gabriel', ' Ross', 'etti']
+1345 289 Name of mother of x -1 Name of mother of Dante Gabriel Rossetti Frances Polidori Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ""'s"" ' brother' ',']" , the poet , and his wife , Christina , who was the daughter of the poet 's brother , False occasioned comment; Dante Gabriel Rossetti composing a limerick 7 [' occasion', 'ed', ' comment', ';', ' Dante', ' Gabriel', ' Ross', 'etti']
+1346 289 Name of mother of x -1 Name of mother of Dante Gabriel Rossetti Frances Polidori Dante Gabriel Rossetti "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Christina' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' poet' ""'s"" ' brother' ',']" , the poet , and his wife , Christina , who was the daughter of the poet 's brother , False Burne-Jones and Dante Gabriel Rossetti and with the Neo-Gothic 8 [' Burn', 'e', '-', 'Jones', ' and', ' Dante', ' Gabriel', ' Ross', 'etti']
+1347 290 Name of mother of x -1 Name of mother of Stendhal Henriette Gagnon Stendhal "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St' 'end' 'hal' ','
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St']" ", the
+
+ Name of mother of St end hal , the
+
+ Name of mother of St" False " Voltaire, Goethe, Stendhal and Victor Hugo.
+" 8 [' Volt', 'aire', ',', ' Go', 'ethe', ',', ' St', 'end', 'hal']
+1348 290 Name of mother of x -1 Name of mother of Stendhal Henriette Gagnon Stendhal "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St' 'end' 'hal' ','
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St']" ", the
+
+ Name of mother of St end hal , the
+
+ Name of mother of St" False " Voltaire, Goethe, Stendhal and Victor Hugo.
+" 8 [' Volt', 'aire', ',', ' Go', 'ethe', ',', ' St', 'end', 'hal']
+1349 290 Name of mother of x -1 Name of mother of Stendhal Henriette Gagnon Stendhal "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St' 'end' 'hal' ','
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St']" ", the
+
+ Name of mother of St end hal , the
+
+ Name of mother of St" False recreation when Stendhal visited in 1827. A 4 [' recreation', ' when', ' St', 'end', 'hal']
+1350 290 Name of mother of x -1 Name of mother of Stendhal Henriette Gagnon Stendhal "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St' 'end' 'hal' ','
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St']" ", the
+
+ Name of mother of St end hal , the
+
+ Name of mother of St" False " Voltaire, Goethe, Stendhal and Victor Hugo.
+" 8 [' Volt', 'aire', ',', ' Go', 'ethe', ',', ' St', 'end', 'hal']
+1351 290 Name of mother of x -1 Name of mother of Stendhal Henriette Gagnon Stendhal "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St' 'end' 'hal' ','
+ ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' St']" ", the
+
+ Name of mother of St end hal , the
+
+ Name of mother of St" False of recreation when Stendhal visited in 1827. 5 [' of', ' recreation', ' when', ' St', 'end', 'hal']
+1352 291 Name of mother of x -1 Name of mother of Jack London Flora London Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False others, writer Jack London and, although 4 [' others', ',', ' writer', ' Jack', ' London']
+1353 291 Name of mother of x -1 Name of mother of Jack London Flora London Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False direct rebuttals. As Jack London would later write, 6 [' direct', ' rebutt', 'als', '.', ' As', ' Jack', ' London']
+1354 291 Name of mother of x -1 Name of mother of Jack London Flora London Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False define. When Jack London came to London in 1902 4 [' define', '.', ' When', ' Jack', ' London']
+1355 291 Name of mother of x -1 Name of mother of Jack London Flora London Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False readers including Jack London and Isadora 3 [' readers', ' including', ' Jack', ' London']
+1356 291 Name of mother of x -1 Name of mother of Jack London Flora London Jack London "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False other authors like Jack London and Roberts, 4 [' other', ' authors', ' like', ' Jack', ' London']
+1357 292 Name of mother of x -1 Name of mother of Werner Heisenberg Annie Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False Pascual Jordan, Werner Heisenberg and an elegant 8 [' P', 'asc', 'ual', ' Jordan', ',', ' Werner', ' He', 'isen', 'berg']
+1358 292 Name of mother of x -1 Name of mother of Werner Heisenberg Annie Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False Walther Gerlach, Werner Heisenberg and Carl Friedrich 9 [' Wal', 'ther', ' Ger', 'l', 'ach', ',', ' Werner', ' He', 'isen', 'berg']
+1359 292 Name of mother of x -1 Name of mother of Werner Heisenberg Annie Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False " a standard form."" Werner Heisenberg subsequently" 7 "[' a', ' standard', ' form', '.""', ' Werner', ' He', 'isen', 'berg']"
+1360 292 Name of mother of x -1 Name of mother of Werner Heisenberg Annie Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False BKS model inspired Werner Heisenberg in his development 7 [' B', 'KS', ' model', ' inspired', ' Werner', ' He', 'isen', 'berg']
+1361 292 Name of mother of x -1 Name of mother of Werner Heisenberg Annie Heisenberg Werner Heisenberg "[',' ' the' ' German' ' physicist' ' who' ' won' ' the' ' Nobel' ' Prize'
+ ' for' ' his' ' work' ' on' ' the' ' quantum' ' theory' ' of' ' the'
+ ' atom' '.']" , the German physicist who won the Nobel Prize for his work on the quantum theory of the atom . False Hahn, Max von Laue, Werner Heisenberg and Carl Friedrich 11 [' H', 'ahn', ',', ' Max', ' von', ' La', 'ue', ',', ' Werner', ' He', 'isen', 'berg']
+1362 294 Name of mother of x -1 Name of mother of Ernest Rutherford Martha Thompson Ernest Rutherford "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' nuclear' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' nuclear' ' reactor']" ", the father of modern physics , and the father of nuclear physics .
+
+ The first nuclear reactor" False studied under Sir Ernest Rutherford at the University 4 [' studied', ' under', ' Sir', ' Ernest', ' Rutherford']
+1363 294 Name of mother of x -1 Name of mother of Ernest Rutherford Martha Thompson Ernest Rutherford "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' nuclear' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' nuclear' ' reactor']" ", the father of modern physics , and the father of nuclear physics .
+
+ The first nuclear reactor" False documented solution offered. Ernest Rutherford suggested that 5 [' documented', ' solution', ' offered', '.', ' Ernest', ' Rutherford']
+1364 294 Name of mother of x -1 Name of mother of Ernest Rutherford Martha Thompson Ernest Rutherford "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' nuclear' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' nuclear' ' reactor']" ", the father of modern physics , and the father of nuclear physics .
+
+ The first nuclear reactor" False pioneering work of Ernest Rutherford at McGill University 4 [' pioneering', ' work', ' of', ' Ernest', ' Rutherford']
+1365 294 Name of mother of x -1 Name of mother of Ernest Rutherford Martha Thompson Ernest Rutherford "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' nuclear' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' nuclear' ' reactor']" ", the father of modern physics , and the father of nuclear physics .
+
+ The first nuclear reactor" False physicist Ernest Rutherford who discovered 2 [' physicist', ' Ernest', ' Rutherford']
+1366 294 Name of mother of x -1 Name of mother of Ernest Rutherford Martha Thompson Ernest Rutherford "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' nuclear' ' physics' '.' '\n' '\n' 'The' ' first'
+ ' nuclear' ' reactor']" ", the father of modern physics , and the father of nuclear physics .
+
+ The first nuclear reactor" False provided men like Ernest Rutherford with sources 4 [' provided', ' men', ' like', ' Ernest', ' Rutherford']
+1367 295 Name of mother of x -1 Name of mother of Gwyneth Paltrow Blythe Danner Gwyneth Paltrow "[',' ' who' ' is' ' a' ' big' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ 'The' ' show' ' is' ' a' ' great' ' show' '.' ' I']" ", who is a big fan of the show .
+
+ The show is a great show . I" False Matt Damon, Gwyneth Paltrow and Jude Law. The 8 [' Matt', ' Damon', ',', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1368 295 Name of mother of x -1 Name of mother of Gwyneth Paltrow Blythe Danner Gwyneth Paltrow "[',' ' who' ' is' ' a' ' big' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ 'The' ' show' ' is' ' a' ' great' ' show' '.' ' I']" ", who is a big fan of the show .
+
+ The show is a great show . I" False Baker Hall, Gwyneth Paltrow and John C. 8 [' Baker', ' Hall', ',', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1369 295 Name of mother of x -1 Name of mother of Gwyneth Paltrow Blythe Danner Gwyneth Paltrow "[',' ' who' ' is' ' a' ' big' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ 'The' ' show' ' is' ' a' ' great' ' show' '.' ' I']" ", who is a big fan of the show .
+
+ The show is a great show . I" False Holmes, Gwen Stefani, Gwyneth Paltrow and Angelina Jolie. 12 [' Holmes', ',', ' G', 'wen', ' Stef', 'ani', ',', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1370 295 Name of mother of x -1 Name of mother of Gwyneth Paltrow Blythe Danner Gwyneth Paltrow "[',' ' who' ' is' ' a' ' big' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ 'The' ' show' ' is' ' a' ' great' ' show' '.' ' I']" ", who is a big fan of the show .
+
+ The show is a great show . I" False Renée Zellweger and Gwyneth Paltrow were considered 12 [' Ren', 'ée', ' Z', 'ell', 'we', 'ger', ' and', ' G', 'wyn', 'eth', ' P', 'alt', 'row']
+1371 295 Name of mother of x -1 Name of mother of Gwyneth Paltrow Blythe Danner Gwyneth Paltrow "[',' ' who' ' is' ' a' ' big' ' fan' ' of' ' the' ' show' '.' '\n' '\n'
+ 'The' ' show' ' is' ' a' ' great' ' show' '.' ' I']" ", who is a big fan of the show .
+
+ The show is a great show . I" False " ""doing"" a naked Gwyneth Paltrow originally featured" 10 "[' ""', 'doing', '""', ' a', ' naked', ' G', 'wyn', 'eth', ' P', 'alt', 'row']"
+1372 296 Name of mother of x -1 Name of mother of Frederick II of Prussia Sophia Dorothea of Hanover Frederick II of Prussia "[',' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z' 'oll' 'ern' 's' ','
+ ' and' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z']" , the last of the Ho hen z oll ern s , and the last of the Ho hen z False hegemony. King Frederick II of Prussia had no intention 7 [' hegemony', '.', ' King', ' Frederick', ' II', ' of', ' Pr', 'ussia']
+1373 296 Name of mother of x -1 Name of mother of Frederick II of Prussia Sophia Dorothea of Hanover Frederick II of Prussia "[',' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z' 'oll' 'ern' 's' ','
+ ' and' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z']" , the last of the Ho hen z oll ern s , and the last of the Ho hen z False hegemony. King Frederick II of Prussia had no intention 7 [' hegemony', '.', ' King', ' Frederick', ' II', ' of', ' Pr', 'ussia']
+1374 296 Name of mother of x -1 Name of mother of Frederick II of Prussia Sophia Dorothea of Hanover Frederick II of Prussia "[',' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z' 'oll' 'ern' 's' ','
+ ' and' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z']" , the last of the Ho hen z oll ern s , and the last of the Ho hen z False Years'War — used by Frederick II of Prussia — and in Napoleon's 10 "[' Years', ""'"", 'War', ' —', ' used', ' by', ' Frederick', ' II', ' of', ' Pr', 'ussia']"
+1375 296 Name of mother of x -1 Name of mother of Frederick II of Prussia Sophia Dorothea of Hanover Frederick II of Prussia "[',' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z' 'oll' 'ern' 's' ','
+ ' and' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z']" , the last of the Ho hen z oll ern s , and the last of the Ho hen z False In December, King Frederick II of Prussia invaded the Duchy 8 [' In', ' December', ',', ' King', ' Frederick', ' II', ' of', ' Pr', 'ussia']
+1376 296 Name of mother of x -1 Name of mother of Frederick II of Prussia Sophia Dorothea of Hanover Frederick II of Prussia "[',' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z' 'oll' 'ern' 's' ','
+ ' and' ' the' ' last' ' of' ' the' ' Ho' 'hen' 'z']" , the last of the Ho hen z oll ern s , and the last of the Ho hen z False Years'War — used by Frederick II of Prussia — and in Napoleon's 10 "[' Years', ""'"", 'War', ' —', ' used', ' by', ' Frederick', ' II', ' of', ' Pr', 'ussia']"
+1377 297 Name of mother of x -1 Name of mother of Joseph Conrad Ewa Korzeniewska Joseph Conrad "[',' ' the' ' son' ' of' ' a' ' wealthy' ' merchant' ',' ' and' ' a'
+ ' woman' ' of' ' the' ' people' '.' '\n' '\n' 'The' ' first' ' thing']" ", the son of a wealthy merchant , and a woman of the people .
+
+ The first thing" False by writers such as Joseph Conrad and H. G. Wells. 5 [' by', ' writers', ' such', ' as', ' Joseph', ' Conrad']
+1378 297 Name of mother of x -1 Name of mother of Joseph Conrad Ewa Korzeniewska Joseph Conrad "[',' ' the' ' son' ' of' ' a' ' wealthy' ' merchant' ',' ' and' ' a'
+ ' woman' ' of' ' the' ' people' '.' '\n' '\n' 'The' ' first' ' thing']" ", the son of a wealthy merchant , and a woman of the people .
+
+ The first thing" False debut, author Joseph Conrad agreed that the novel's 4 [' debut', ',', ' author', ' Joseph', ' Conrad']
+1379 297 Name of mother of x -1 Name of mother of Joseph Conrad Ewa Korzeniewska Joseph Conrad "[',' ' the' ' son' ' of' ' a' ' wealthy' ' merchant' ',' ' and' ' a'
+ ' woman' ' of' ' the' ' people' '.' '\n' '\n' 'The' ' first' ' thing']" ", the son of a wealthy merchant , and a woman of the people .
+
+ The first thing" False redemption — what Joseph Conrad in Lord Jim called 4 [' redemption', ' —', ' what', ' Joseph', ' Conrad']
+1380 297 Name of mother of x -1 Name of mother of Joseph Conrad Ewa Korzeniewska Joseph Conrad "[',' ' the' ' son' ' of' ' a' ' wealthy' ' merchant' ',' ' and' ' a'
+ ' woman' ' of' ' the' ' people' '.' '\n' '\n' 'The' ' first' ' thing']" ", the son of a wealthy merchant , and a woman of the people .
+
+ The first thing" False " Horseman"") and Joseph Conrad (Under Western" 5 "[' Horse', 'man', '"")', ' and', ' Joseph', ' Conrad']"
+1381 297 Name of mother of x -1 Name of mother of Joseph Conrad Ewa Korzeniewska Joseph Conrad "[',' ' the' ' son' ' of' ' a' ' wealthy' ' merchant' ',' ' and' ' a'
+ ' woman' ' of' ' the' ' people' '.' '\n' '\n' 'The' ' first' ' thing']" ", the son of a wealthy merchant , and a woman of the people .
+
+ The first thing" False " Bronze Horseman"") and Joseph Conrad (Under Western Eyes)." 6 "[' Bronze', ' Horse', 'man', '"")', ' and', ' Joseph', ' Conrad']"
+1382 298 Name of mother of x -1 Name of mother of Caravaggio Lucia Aratori Caravaggio "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Car' 'av' 'agg' 'io'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Car av agg io , the
+
+ Name of mother of" False Derek Jarman's Caravaggio (1986), playing 7 "[' Derek', ' Jar', 'man', ""'s"", ' Car', 'av', 'agg', 'io']"
+1383 298 Name of mother of x -1 Name of mother of Caravaggio Lucia Aratori Caravaggio "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Car' 'av' 'agg' 'io'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Car av agg io , the
+
+ Name of mother of" False Tarakanova (1937) and Caravaggio (1941). Post-war examples 11 [' Tar', 'ak', 'anova', ' (', '19', '37', ')', ' and', ' Car', 'av', 'agg', 'io']
+1384 298 Name of mother of x -1 Name of mother of Caravaggio Lucia Aratori Caravaggio "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Car' 'av' 'agg' 'io'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Car av agg io , the
+
+ Name of mother of" False Tarakanova (1937) and Caravaggio (1941). Post-war 11 [' Tar', 'ak', 'anova', ' (', '19', '37', ')', ' and', ' Car', 'av', 'agg', 'io']
+1385 298 Name of mother of x -1 Name of mother of Caravaggio Lucia Aratori Caravaggio "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Car' 'av' 'agg' 'io'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Car av agg io , the
+
+ Name of mother of" False Italian painter Caravaggio and the Russian 5 [' Italian', ' painter', ' Car', 'av', 'agg', 'io']
+1386 298 Name of mother of x -1 Name of mother of Caravaggio Lucia Aratori Caravaggio "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Car' 'av' 'agg' 'io'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Car av agg io , the
+
+ Name of mother of" False Battle of Caravaggio in 1448. It was 5 [' Battle', ' of', ' Car', 'av', 'agg', 'io']
+1387 299 Name of mother of x -1 Name of mother of George Eliot Christiana Pearson George Eliot "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Eliot' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of George Eliot , and the
+
+ Name of mother of" False Review from 1851. George Eliot lived at No. 6 [' Review', ' from', ' 18', '51', '.', ' George', ' Eliot']
+1388 299 Name of mother of x -1 Name of mother of George Eliot Christiana Pearson George Eliot "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Eliot' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of George Eliot , and the
+
+ Name of mother of" False writers such as George Eliot professed profound 4 [' writers', ' such', ' as', ' George', ' Eliot']
+1389 299 Name of mother of x -1 Name of mother of George Eliot Christiana Pearson George Eliot "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Eliot' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of George Eliot , and the
+
+ Name of mother of" False " ill-repute. George Eliot wrote ""there is" 6 [' ill', '-', 'rep', 'ute', '.', ' George', ' Eliot']
+1390 299 Name of mother of x -1 Name of mother of George Eliot Christiana Pearson George Eliot "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Eliot' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of George Eliot , and the
+
+ Name of mother of" False Review from 1851. George Eliot lived at No. 6 [' Review', ' from', ' 18', '51', '.', ' George', ' Eliot']
+1391 299 Name of mother of x -1 Name of mother of George Eliot Christiana Pearson George Eliot "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' George'
+ ' Eliot' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of George Eliot , and the
+
+ Name of mother of" False Stowe wrote to author George Eliot to update her on 6 [' St', 'owe', ' wrote', ' to', ' author', ' George', ' Eliot']
+1392 300 Name of mother of x -1 Name of mother of John Milton Sara Jeffrey John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False fellow judge John Milton Elliott on March 3 [' fellow', ' judge', ' John', ' Milton']
+1393 300 Name of mother of x -1 Name of mother of John Milton Sara Jeffrey John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Dorman Bridgman Eaton, John Milton Gregory, and 8 [' D', 'orman', ' Brid', 'g', 'man', ' Eaton', ',', ' John', ' Milton']
+1394 300 Name of mother of x -1 Name of mother of John Milton Sara Jeffrey John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False highly unpopular; John Milton wrote Areopagitica 4 [' highly', ' unpopular', ';', ' John', ' Milton']
+1395 300 Name of mother of x -1 Name of mother of John Milton Sara Jeffrey John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False associate of John Milton — with a preface by 3 [' associate', ' of', ' John', ' Milton']
+1396 300 Name of mother of x -1 Name of mother of John Milton Sara Jeffrey John Milton "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Dorman Bridgman Eaton, John Milton Gregory, and Leroy 8 [' D', 'orman', ' Brid', 'g', 'man', ' Eaton', ',', ' John', ' Milton']
+1397 301 Name of mother of x -1 Name of mother of George VI Mary of Teck George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False coronation of King George VI and Queen Elizabeth 5 [' coron', 'ation', ' of', ' King', ' George', ' VI']
+1398 301 Name of mother of x -1 Name of mother of George VI Mary of Teck George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False Filler cautions that George VI would never have 6 [' F', 'iller', ' ca', 'utions', ' that', ' George', ' VI']
+1399 301 Name of mother of x -1 Name of mother of George VI Mary of Teck George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False the use of the King George VI Graving Dock 6 [' the', ' use', ' of', ' the', ' King', ' George', ' VI']
+1400 301 Name of mother of x -1 Name of mother of George VI Mary of Teck George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False United Kingdom, George VI announced he 4 [' United', ' Kingdom', ',', ' George', ' VI']
+1401 301 Name of mother of x -1 Name of mother of George VI Mary of Teck George VI "['.' '\n' '\n' 'The' ' Queen' ' of' ' England' ',' ' the' ' Queen' ' of'
+ ' England' ',' ' the' ' Queen' ' of' ' England' ',' ' the' ' Queen']" ".
+
+ The Queen of England , the Queen of England , the Queen of England , the Queen" False areas excluded the King George VI Graving Dock 5 [' areas', ' excluded', ' the', ' King', ' George', ' VI']
+1402 302 Name of mother of x -1 Name of mother of Christopher Columbus Susanna Fontanarossa Christopher Columbus "[',' ' the' ' first' ' European' ' to' ' discover' ' America' '.' '\n'
+ '\n' 'The' ' first' ' European' ' to' ' discover' ' America' ' was'
+ ' Christopher' ' Columbus' '.']" ", the first European to discover America .
+
+ The first European to discover America was Christopher Columbus ." False with the voyages of Christopher Columbus in 1492 and continued 6 [' with', ' the', ' voy', 'ages', ' of', ' Christopher', ' Columbus']
+1403 302 Name of mother of x -1 Name of mother of Christopher Columbus Susanna Fontanarossa Christopher Columbus "[',' ' the' ' first' ' European' ' to' ' discover' ' America' '.' '\n'
+ '\n' 'The' ' first' ' European' ' to' ' discover' ' America' ' was'
+ ' Christopher' ' Columbus' '.']" ", the first European to discover America .
+
+ The first European to discover America was Christopher Columbus ." False 1 ['Christopher', ' Columbus']
+1404 302 Name of mother of x -1 Name of mother of Christopher Columbus Susanna Fontanarossa Christopher Columbus "[',' ' the' ' first' ' European' ' to' ' discover' ' America' '.' '\n'
+ '\n' 'The' ' first' ' European' ' to' ' discover' ' America' ' was'
+ ' Christopher' ' Columbus' '.']" ", the first European to discover America .
+
+ The first European to discover America was Christopher Columbus ." False contact made before Christopher Columbus between the first 4 [' contact', ' made', ' before', ' Christopher', ' Columbus']
+1405 302 Name of mother of x -1 Name of mother of Christopher Columbus Susanna Fontanarossa Christopher Columbus "[',' ' the' ' first' ' European' ' to' ' discover' ' America' '.' '\n'
+ '\n' 'The' ' first' ' European' ' to' ' discover' ' America' ' was'
+ ' Christopher' ' Columbus' '.']" ", the first European to discover America .
+
+ The first European to discover America was Christopher Columbus ." False surpasses all other Christopher Columbus sculptures in 5 [' surpass', 'es', ' all', ' other', ' Christopher', ' Columbus']
+1406 302 Name of mother of x -1 Name of mother of Christopher Columbus Susanna Fontanarossa Christopher Columbus "[',' ' the' ' first' ' European' ' to' ' discover' ' America' '.' '\n'
+ '\n' 'The' ' first' ' European' ' to' ' discover' ' America' ' was'
+ ' Christopher' ' Columbus' '.']" ", the first European to discover America .
+
+ The first European to discover America was Christopher Columbus ." False 1 ['Christopher', ' Columbus']
+1407 303 Name of mother of x -1 Name of mother of Sarah Bernhardt Judith-Julie Bernardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Gaiety Theatre, Sarah Bernhardt performed the 6 [' Ga', 'iety', ' Theatre', ',', ' Sarah', ' Bern', 'hardt']
+1408 303 Name of mother of x -1 Name of mother of Sarah Bernhardt Judith-Julie Bernardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False included Lillie Langtry, Sarah Bernhardt and Lady Randolph 9 [' included', ' L', 'ill', 'ie', ' Lang', 'try', ',', ' Sarah', ' Bern', 'hardt']
+1409 303 Name of mother of x -1 Name of mother of Sarah Bernhardt Judith-Julie Bernardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False included Lillie Langtry, Sarah Bernhardt and Lady Randolph 9 [' included', ' L', 'ill', 'ie', ' Lang', 'try', ',', ' Sarah', ' Bern', 'hardt']
+1410 303 Name of mother of x -1 Name of mother of Sarah Bernhardt Judith-Julie Bernardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False included Lillie Langtry, Sarah Bernhardt and Lady Randolph 9 [' included', ' L', 'ill', 'ie', ' Lang', 'try', ',', ' Sarah', ' Bern', 'hardt']
+1411 303 Name of mother of x -1 Name of mother of Sarah Bernhardt Judith-Julie Bernardt Sarah Bernhardt "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False included Lillie Langtry, Sarah Bernhardt and Lady Randolph 9 [' included', ' L', 'ill', 'ie', ' Lang', 'try', ',', ' Sarah', ' Bern', 'hardt']
+1412 304 Name of mother of x -1 Name of mother of Henry Wadsworth Longfellow Zilpah Wadsworth Henry Wadsworth Longfellow "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' the' ' house' '.' '\n' '\n'
+ 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the poet , was born in the house .
+
+ The house is now a museum , and" False character created by Henry Wadsworth Longfellow in his narrative poem 8 [' character', ' created', ' by', ' Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1413 304 Name of mother of x -1 Name of mother of Henry Wadsworth Longfellow Zilpah Wadsworth Henry Wadsworth Longfellow "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' the' ' house' '.' '\n' '\n'
+ 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the poet , was born in the house .
+
+ The house is now a museum , and" False " scorn"" for the Quaker. Henry Wadsworth Longfellow recreated the" 12 "[' scorn', '""', ' for', ' the', ' Qu', 'aker', '.', ' Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']"
+1414 304 Name of mother of x -1 Name of mother of Henry Wadsworth Longfellow Zilpah Wadsworth Henry Wadsworth Longfellow "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' the' ' house' '.' '\n' '\n'
+ 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the poet , was born in the house .
+
+ The house is now a museum , and" False 5 ['Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1415 304 Name of mother of x -1 Name of mother of Henry Wadsworth Longfellow Zilpah Wadsworth Henry Wadsworth Longfellow "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' the' ' house' '.' '\n' '\n'
+ 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the poet , was born in the house .
+
+ The house is now a museum , and" False fellow poet Henry Wadsworth Longfellow also wrote a poem 7 [' fellow', ' poet', ' Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1416 304 Name of mother of x -1 Name of mother of Henry Wadsworth Longfellow Zilpah Wadsworth Henry Wadsworth Longfellow "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' the' ' house' '.' '\n' '\n'
+ 'The' ' house' ' is' ' now' ' a' ' museum' ',' ' and']" ", the poet , was born in the house .
+
+ The house is now a museum , and" False " Longfellow =
+" 9 [' Long', 'f', 'ellow', ' =', 'Henry', ' Wad', 'sworth', ' Long', 'f', 'ellow']
+1417 305 Name of mother of x -1 Name of mother of Tom Hanks Janet Marylyn Frager Tom Hanks "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ',']" ", who is a great friend of mine .
+
+ I am a big fan of the show ," False being selected by Tom Hanks to direct the epic 5 [' being', ' selected', ' by', ' Tom', ' H', 'anks']
+1418 305 Name of mother of x -1 Name of mother of Tom Hanks Janet Marylyn Frager Tom Hanks "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ',']" ", who is a great friend of mine .
+
+ I am a big fan of the show ," False such as actor Tom Hanks and author and 5 [' such', ' as', ' actor', ' Tom', ' H', 'anks']
+1419 305 Name of mother of x -1 Name of mother of Tom Hanks Janet Marylyn Frager Tom Hanks "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ',']" ", who is a great friend of mine .
+
+ I am a big fan of the show ," False whether stars Tom Hanks and Tim Allen 4 [' whether', ' stars', ' Tom', ' H', 'anks']
+1420 305 Name of mother of x -1 Name of mother of Tom Hanks Janet Marylyn Frager Tom Hanks "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ',']" ", who is a great friend of mine .
+
+ I am a big fan of the show ," False " Hanks as Himself
+" 6 [' H', 'anks', ' as', ' Himself', 'Tom', ' H', 'anks']
+1421 305 Name of mother of x -1 Name of mother of Tom Hanks Janet Marylyn Frager Tom Hanks "[',' ' who' ' is' ' a' ' great' ' friend' ' of' ' mine' '.' '\n' '\n' 'I'
+ ' am' ' a' ' big' ' fan' ' of' ' the' ' show' ',']" ", who is a great friend of mine .
+
+ I am a big fan of the show ," False although actor Tom Hanks has expressed 4 [' although', ' actor', ' Tom', ' H', 'anks']
+1422 306 Name of mother of x -1 Name of mother of Alphonse de Lamartine Alix de Lamartine Alphonse de Lamartine "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '90'
+ ',' ' and' ' died' ' in' ' 18' '69' '.' '\n' '\n']" ", the French poet , who was born in 17 90 , and died in 18 69 .
+
+" False Romantic poet Alphonse de Lamartine and English 7 [' Romantic', ' poet', ' Alph', 'onse', ' de', ' Lam', 'art', 'ine']
+1423 306 Name of mother of x -1 Name of mother of Alphonse de Lamartine Alix de Lamartine Alphonse de Lamartine "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '90'
+ ',' ' and' ' died' ' in' ' 18' '69' '.' '\n' '\n']" ", the French poet , who was born in 17 90 , and died in 18 69 .
+
+" False Romantic poet Alphonse de Lamartine and English travel 7 [' Romantic', ' poet', ' Alph', 'onse', ' de', ' Lam', 'art', 'ine']
+1424 306 Name of mother of x -1 Name of mother of Alphonse de Lamartine Alix de Lamartine Alphonse de Lamartine "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '90'
+ ',' ' and' ' died' ' in' ' 18' '69' '.' '\n' '\n']" ", the French poet , who was born in 17 90 , and died in 18 69 .
+
+" False surrender. In 1847 Alphonse de Lamartine even invented 10 [' surrender', '.', ' In', ' 18', '47', ' Alph', 'onse', ' de', ' Lam', 'art', 'ine']
+1425 306 Name of mother of x -1 Name of mother of Alphonse de Lamartine Alix de Lamartine Alphonse de Lamartine "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '90'
+ ',' ' and' ' died' ' in' ' 18' '69' '.' '\n' '\n']" ", the French poet , who was born in 17 90 , and died in 18 69 .
+
+" False inspired by Alphonse de Lamartine and Victor 7 [' inspired', ' by', ' Alph', 'onse', ' de', ' Lam', 'art', 'ine']
+1426 306 Name of mother of x -1 Name of mother of Alphonse de Lamartine Alix de Lamartine Alphonse de Lamartine "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '90'
+ ',' ' and' ' died' ' in' ' 18' '69' '.' '\n' '\n']" ", the French poet , who was born in 17 90 , and died in 18 69 .
+
+" False especially that of Alphonse de Lamartine and Alfred de Musset. 8 [' especially', ' that', ' of', ' Alph', 'onse', ' de', ' Lam', 'art', 'ine']
+1427 307 Name of mother of x -1 Name of mother of Édith Piaf Line Marsa Édith Piaf "[',' ' the' ' French' ' singer' ',' ' who' ' died' ' in' ' 1963' '.' '\n'
+ '\n' 'The' ' song' ' was' ' written' ' by' ' Jacques' ' B' 'rel']" ", the French singer , who died in 1963 .
+
+ The song was written by Jacques B rel" False chanteuse Édith Piaf in Pam Gems' musical 8 [' ch', 'ante', 'use', ' É', 'd', 'ith', ' P', 'ia', 'f']
+1428 307 Name of mother of x -1 Name of mother of Édith Piaf Line Marsa Édith Piaf "[',' ' the' ' French' ' singer' ',' ' who' ' died' ' in' ' 1963' '.' '\n'
+ '\n' 'The' ' song' ' was' ' written' ' by' ' Jacques' ' B' 'rel']" ", the French singer , who died in 1963 .
+
+ The song was written by Jacques B rel" False old singer named Édith Piaf was discovered in 8 [' old', ' singer', ' named', ' É', 'd', 'ith', ' P', 'ia', 'f']
+1429 307 Name of mother of x -1 Name of mother of Édith Piaf Line Marsa Édith Piaf "[',' ' the' ' French' ' singer' ',' ' who' ' died' ' in' ' 1963' '.' '\n'
+ '\n' 'The' ' song' ' was' ' written' ' by' ' Jacques' ' B' 'rel']" ", the French singer , who died in 1963 .
+
+ The song was written by Jacques B rel" False old singer named Édith Piaf was discovered in 8 [' old', ' singer', ' named', ' É', 'd', 'ith', ' P', 'ia', 'f']
+1430 307 Name of mother of x -1 Name of mother of Édith Piaf Line Marsa Édith Piaf "[',' ' the' ' French' ' singer' ',' ' who' ' died' ' in' ' 1963' '.' '\n'
+ '\n' 'The' ' song' ' was' ' written' ' by' ' Jacques' ' B' 'rel']" ", the French singer , who died in 1963 .
+
+ The song was written by Jacques B rel" False singer named Édith Piaf was discovered in the 7 [' singer', ' named', ' É', 'd', 'ith', ' P', 'ia', 'f']
+1431 307 Name of mother of x -1 Name of mother of Édith Piaf Line Marsa Édith Piaf "[',' ' the' ' French' ' singer' ',' ' who' ' died' ' in' ' 1963' '.' '\n'
+ '\n' 'The' ' song' ' was' ' written' ' by' ' Jacques' ' B' 'rel']" ", the French singer , who died in 1963 .
+
+ The song was written by Jacques B rel" False French chanteuse Édith Piaf in Pam Gems' musical 9 [' French', ' ch', 'ante', 'use', ' É', 'd', 'ith', ' P', 'ia', 'f']
+1432 308 Name of mother of x -1 Name of mother of Emma Thompson Phyllida Law Emma Thompson "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False and screenwriter Emma Thompson aided in script 4 [' and', ' screen', 'writer', ' Emma', ' Thompson']
+1433 308 Name of mother of x -1 Name of mother of Emma Thompson Phyllida Law Emma Thompson "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False English, and it's Emma Thompson, Kate Winslet, 6 "[' English', ',', ' and', ' it', ""'s"", ' Emma', ' Thompson']"
+1434 308 Name of mother of x -1 Name of mother of Emma Thompson Phyllida Law Emma Thompson "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False the idea that Emma Thompson would be providing 4 [' the', ' idea', ' that', ' Emma', ' Thompson']
+1435 308 Name of mother of x -1 Name of mother of Emma Thompson Phyllida Law Emma Thompson "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Neville's daughter. Emma Thompson has an uncredited 5 "[' Neville', ""'s"", ' daughter', '.', ' Emma', ' Thompson']"
+1436 308 Name of mother of x -1 Name of mother of Emma Thompson Phyllida Law Emma Thompson "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Poppins, the film stars Emma Thompson as author P. L. 8 [' Po', 'pp', 'ins', ',', ' the', ' film', ' stars', ' Emma', ' Thompson']
+1437 309 Name of mother of x -1 Name of mother of Johnny Cash Carrie Cloveree Rivers Johnny Cash "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the father of the groom .
+
+ The wedding was a" False for a series of Johnny Cash tour dates coming 5 [' for', ' a', ' series', ' of', ' Johnny', ' Cash']
+1438 309 Name of mother of x -1 Name of mother of Johnny Cash Carrie Cloveree Rivers Johnny Cash "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the father of the groom .
+
+ The wedding was a" False Kristofferson, and Johnny Cash formed The Highwaymen, 6 [' Krist', 'off', 'erson', ',', ' and', ' Johnny', ' Cash']
+1439 309 Name of mother of x -1 Name of mother of Johnny Cash Carrie Cloveree Rivers Johnny Cash "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the father of the groom .
+
+ The wedding was a" False a cover of Johnny Cash and June Carter Cash's 4 [' a', ' cover', ' of', ' Johnny', ' Cash']
+1440 309 Name of mother of x -1 Name of mother of Johnny Cash Carrie Cloveree Rivers Johnny Cash "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the father of the groom .
+
+ The wedding was a" False " redemptions of Johnny Cash and Neil Diamond.""
+" 5 [' red', 'empt', 'ions', ' of', ' Johnny', ' Cash']
+1441 309 Name of mother of x -1 Name of mother of Johnny Cash Carrie Cloveree Rivers Johnny Cash "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the father of the groom .
+
+ The wedding was a" False originally wanted Johnny Cash for, went 3 [' originally', ' wanted', ' Johnny', ' Cash']
+1442 310 Name of mother of x -1 Name of mother of Jackie Chan Lee-Lee Chan Jackie Chan "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' young'
+ ' Bruce' ' Lee' ' in' ' the' ' film' ',' ' was' ' also' ' a' ' martial']" , the actor who played the role of the young Bruce Lee in the film , was also a martial False 2 ['Jack', 'ie', ' Chan']
+1443 310 Name of mother of x -1 Name of mother of Jackie Chan Lee-Lee Chan Jackie Chan "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' young'
+ ' Bruce' ' Lee' ' in' ' the' ' film' ',' ' was' ' also' ' a' ' martial']" , the actor who played the role of the young Bruce Lee in the film , was also a martial False actor in Asia after Jackie Chan at the time. While 5 [' actor', ' in', ' Asia', ' after', ' Jackie', ' Chan']
+1444 310 Name of mother of x -1 Name of mother of Jackie Chan Lee-Lee Chan Jackie Chan "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' young'
+ ' Bruce' ' Lee' ' in' ' the' ' film' ',' ' was' ' also' ' a' ' martial']" , the actor who played the role of the young Bruce Lee in the film , was also a martial False by entertainers Jackie Chan and Cristian Castro, 4 [' by', ' entertain', 'ers', ' Jackie', ' Chan']
+1445 310 Name of mother of x -1 Name of mother of Jackie Chan Lee-Lee Chan Jackie Chan "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' young'
+ ' Bruce' ' Lee' ' in' ' the' ' film' ',' ' was' ' also' ' a' ' martial']" , the actor who played the role of the young Bruce Lee in the film , was also a martial False Asia after Jackie Chan at the time. 3 [' Asia', ' after', ' Jackie', ' Chan']
+1446 310 Name of mother of x -1 Name of mother of Jackie Chan Lee-Lee Chan Jackie Chan "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' young'
+ ' Bruce' ' Lee' ' in' ' the' ' film' ',' ' was' ' also' ' a' ' martial']" , the actor who played the role of the young Bruce Lee in the film , was also a martial False " Amitabh Bachchan, and Jackie Chan titled Gold Struck.
+" 8 [' Amit', 'ab', 'h', ' Bach', 'chan', ',', ' and', ' Jackie', ' Chan']
+1447 311 Name of mother of x -1 Name of mother of Jean-Luc Godard Odile Monod Jean-Luc Godard "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',' ' and']" , the director of the film , and the film 's producer , and the film 's writer , and False 4 ['Jean', '-', 'Luc', ' God', 'ard']
+1448 311 Name of mother of x -1 Name of mother of Jean-Luc Godard Odile Monod Jean-Luc Godard "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',' ' and']" , the director of the film , and the film 's producer , and the film 's writer , and False initially wanted to get Jean-Luc Godard or Federico 8 [' initially', ' wanted', ' to', ' get', ' Jean', '-', 'Luc', ' God', 'ard']
+1449 311 Name of mother of x -1 Name of mother of Jean-Luc Godard Odile Monod Jean-Luc Godard "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',' ' and']" , the director of the film , and the film 's producer , and the film 's writer , and False earlier — directed by Jean-Luc Godard and Nicholas 8 [' earlier', ' —', ' directed', ' by', ' Jean', '-', 'Luc', ' God', 'ard']
+1450 311 Name of mother of x -1 Name of mother of Jean-Luc Godard Odile Monod Jean-Luc Godard "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',' ' and']" , the director of the film , and the film 's producer , and the film 's writer , and False directed by Jean-Luc Godard and Jean-Pierre Gorin. 6 [' directed', ' by', ' Jean', '-', 'Luc', ' God', 'ard']
+1451 311 Name of mother of x -1 Name of mother of Jean-Luc Godard Odile Monod Jean-Luc Godard "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' writer' ',' ' and']" , the director of the film , and the film 's producer , and the film 's writer , and False and director Jean-Luc Godard regarded the 6 [' and', ' director', ' Jean', '-', 'Luc', ' God', 'ard']
+1452 312 Name of mother of x -1 Name of mother of Henry James Mary Walsh James Henry James "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ',' ' and' ' a'
+ ' member' ' of' ' the' ' British' ' aristocracy' ',' ' was' ' a' ' man'
+ ' of']" , the son of a wealthy family , and a member of the British aristocracy , was a man of False met Edith Wharton, Henry James and Gertrude 7 [' met', ' Ed', 'ith', ' Wh', 'arton', ',', ' Henry', ' James']
+1453 312 Name of mother of x -1 Name of mother of Henry James Mary Walsh James Henry James "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ',' ' and' ' a'
+ ' member' ' of' ' the' ' British' ' aristocracy' ',' ' was' ' a' ' man'
+ ' of']" , the son of a wealthy family , and a member of the British aristocracy , was a man of False visiting writer Henry James noted in 1877 3 [' visiting', ' writer', ' Henry', ' James']
+1454 312 Name of mother of x -1 Name of mother of Henry James Mary Walsh James Henry James "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ',' ' and' ' a'
+ ' member' ' of' ' the' ' British' ' aristocracy' ',' ' was' ' a' ' man'
+ ' of']" , the son of a wealthy family , and a member of the British aristocracy , was a man of False The Reverend Henry James Prince (1811 – 99) 3 [' The', ' Reverend', ' Henry', ' James']
+1455 312 Name of mother of x -1 Name of mother of Henry James Mary Walsh James Henry James "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ',' ' and' ' a'
+ ' member' ' of' ' the' ' British' ' aristocracy' ',' ' was' ' a' ' man'
+ ' of']" , the son of a wealthy family , and a member of the British aristocracy , was a man of False including several by Henry James and Charles 4 [' including', ' several', ' by', ' Henry', ' James']
+1456 312 Name of mother of x -1 Name of mother of Henry James Mary Walsh James Henry James "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ',' ' and' ' a'
+ ' member' ' of' ' the' ' British' ' aristocracy' ',' ' was' ' a' ' man'
+ ' of']" , the son of a wealthy family , and a member of the British aristocracy , was a man of False the century, novelist Henry James ranked Austen with 5 [' the', ' century', ',', ' novelist', ' Henry', ' James']
+1457 313 Name of mother of x -1 Name of mother of Aleksandr Solzhenitsyn Taisiya Zakharovna Shcerbak Aleksandr Solzhenitsyn "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' born' ' in' ' 1918'
+ ',' ' and' ' his' ' wife' ',' ' Nat' 'alia' ',' ' who' ' was']" , the Russian writer , who was born in 1918 , and his wife , Nat alia , who was False personalities such as Aleksandr Solzhenitsyn and Andrei Sakharov, 9 [' personalities', ' such', ' as', ' Ale', 'ks', 'andr', ' Sol', 'zhen', 'its', 'yn']
+1458 314 Name of mother of x -1 Name of mother of Alberto Giacometti Annetta Giacometti Alberto Giacometti "[',' ' the' ' sculpt' 'or' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' wife' ',' ' the' ' sculpt']" , the sculpt or , and his wife , the painter , and the painter 's wife , the sculpt False Swiss artist Alberto Giacometti made Tzara the subject 6 [' Swiss', ' artist', ' Alberto', ' Gi', 'ac', 'omet', 'ti']
+1459 314 Name of mother of x -1 Name of mother of Alberto Giacometti Annetta Giacometti Alberto Giacometti "[',' ' the' ' sculpt' 'or' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' wife' ',' ' the' ' sculpt']" , the sculpt or , and his wife , the painter , and the painter 's wife , the sculpt False 1949, Swiss artist Alberto Giacometti made Tzara the subject 8 [' 1949', ',', ' Swiss', ' artist', ' Alberto', ' Gi', 'ac', 'omet', 'ti']
+1460 315 Name of mother of x -1 Name of mother of Francis of Assisi Pica de Bourlemont Francis of Assisi "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' of' ' Ass' 'isi' ',' ' the' '\n' '\n' 'The' ' name']" ", the
+
+ The name of the mother of Francis of Ass isi , the
+
+ The name" False dedicated to Saint Francis of Assisi which was built in 6 [' dedicated', ' to', ' Saint', ' Francis', ' of', ' Ass', 'isi']
+1461 315 Name of mother of x -1 Name of mother of Francis of Assisi Pica de Bourlemont Francis of Assisi "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' of' ' Ass' 'isi' ',' ' the' '\n' '\n' 'The' ' name']" ", the
+
+ The name of the mother of Francis of Ass isi , the
+
+ The name" False 4 ['Franc', 'is', ' of', ' Ass', 'isi']
+1462 315 Name of mother of x -1 Name of mother of Francis of Assisi Pica de Bourlemont Francis of Assisi "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' of' ' Ass' 'isi' ',' ' the' '\n' '\n' 'The' ' name']" ", the
+
+ The name of the mother of Francis of Ass isi , the
+
+ The name" False 4 ['Franc', 'is', ' of', ' Ass', 'isi']
+1463 315 Name of mother of x -1 Name of mother of Francis of Assisi Pica de Bourlemont Francis of Assisi "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' of' ' Ass' 'isi' ',' ' the' '\n' '\n' 'The' ' name']" ", the
+
+ The name of the mother of Francis of Ass isi , the
+
+ The name" False 4 ['Franc', 'is', ' of', ' Ass', 'isi']
+1464 315 Name of mother of x -1 Name of mother of Francis of Assisi Pica de Bourlemont Francis of Assisi "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' of' ' Ass' 'isi' ',' ' the' '\n' '\n' 'The' ' name']" ", the
+
+ The name of the mother of Francis of Ass isi , the
+
+ The name" False named for St. Francis of Assisi a few miles away. 7 [' named', ' for', ' St', '.', ' Francis', ' of', ' Ass', 'isi']
+1465 316 Name of mother of x -1 Name of mother of Rosa Bonheur Sophie Marquis Rosa Bonheur "[',' ' the' ' famous' ' French' ' artist' ',' ' who' ' was' ' born' ' in'
+ ' 18' '33' ',' ' and' ' died' ' in' ' 18' '94' '.' '\n']" ", the famous French artist , who was born in 18 33 , and died in 18 94 .
+" False realists such as Rosa Bonheur and Gustave Courbet. 7 [' real', 'ists', ' such', ' as', ' Rosa', ' Bon', 'he', 'ur']
+1466 316 Name of mother of x -1 Name of mother of Rosa Bonheur Sophie Marquis Rosa Bonheur "[',' ' the' ' famous' ' French' ' artist' ',' ' who' ' was' ' born' ' in'
+ ' 18' '33' ',' ' and' ' died' ' in' ' 18' '94' '.' '\n']" ", the famous French artist , who was born in 18 33 , and died in 18 94 .
+" False by the works of Rosa Bonheur and Gustave 7 [' by', ' the', ' works', ' of', ' Rosa', ' Bon', 'he', 'ur']
+1467 316 Name of mother of x -1 Name of mother of Rosa Bonheur Sophie Marquis Rosa Bonheur "[',' ' the' ' famous' ' French' ' artist' ',' ' who' ' was' ' born' ' in'
+ ' 18' '33' ',' ' and' ' died' ' in' ' 18' '94' '.' '\n']" ", the famous French artist , who was born in 18 33 , and died in 18 94 .
+" False the works of Rosa Bonheur and Gustave Courbet, 6 [' the', ' works', ' of', ' Rosa', ' Bon', 'he', 'ur']
+1468 316 Name of mother of x -1 Name of mother of Rosa Bonheur Sophie Marquis Rosa Bonheur "[',' ' the' ' famous' ' French' ' artist' ',' ' who' ' was' ' born' ' in'
+ ' 18' '33' ',' ' and' ' died' ' in' ' 18' '94' '.' '\n']" ", the famous French artist , who was born in 18 33 , and died in 18 94 .
+" False she-painters like Rosa Bonheur the mares and cows? 8 [' she', '-', 'pain', 'ters', ' like', ' Rosa', ' Bon', 'he', 'ur']
+1469 316 Name of mother of x -1 Name of mother of Rosa Bonheur Sophie Marquis Rosa Bonheur "[',' ' the' ' famous' ' French' ' artist' ',' ' who' ' was' ' born' ' in'
+ ' 18' '33' ',' ' and' ' died' ' in' ' 18' '94' '.' '\n']" ", the famous French artist , who was born in 18 33 , and died in 18 94 .
+" False influenced by the works of Rosa Bonheur and Gustave Courbet, 8 [' influenced', ' by', ' the', ' works', ' of', ' Rosa', ' Bon', 'he', 'ur']
+1470 317 Name of mother of x -1 Name of mother of Bernard of Clairvaux Alèthe Bernard of Clairvaux "[',' ' the' ' great' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Bernard' ' of' ' Clair' 'v' 'aux' ',' ' the' ' great' '\n']" ", the great
+
+ The name of the mother of Bernard of Clair v aux , the great
+" False the year St Bernard of Clairvaux began preaching for 7 [' the', ' year', ' St', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1471 317 Name of mother of x -1 Name of mother of Bernard of Clairvaux Alèthe Bernard of Clairvaux "[',' ' the' ' great' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Bernard' ' of' ' Clair' 'v' 'aux' ',' ' the' ' great' '\n']" ", the great
+
+ The name of the mother of Bernard of Clair v aux , the great
+" False contemporaries like Bernard of Clairvaux for being more in love 6 [' contemporaries', ' like', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1472 317 Name of mother of x -1 Name of mother of Bernard of Clairvaux Alèthe Bernard of Clairvaux "[',' ' the' ' great' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Bernard' ' of' ' Clair' 'v' 'aux' ',' ' the' ' great' '\n']" ", the great
+
+ The name of the mother of Bernard of Clair v aux , the great
+" False days (e.g., St. Bernard of Clairvaux on August 19 in the 12 [' days', ' (', 'e', '.', 'g', '.,', ' St', '.', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1473 317 Name of mother of x -1 Name of mother of Bernard of Clairvaux Alèthe Bernard of Clairvaux "[',' ' the' ' great' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Bernard' ' of' ' Clair' 'v' 'aux' ',' ' the' ' great' '\n']" ", the great
+
+ The name of the mother of Bernard of Clair v aux , the great
+" False and was attended by Bernard of Clairvaux and Suger of St 8 [' and', ' was', ' attended', ' by', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1474 317 Name of mother of x -1 Name of mother of Bernard of Clairvaux Alèthe Bernard of Clairvaux "[',' ' the' ' great' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Bernard' ' of' ' Clair' 'v' 'aux' ',' ' the' ' great' '\n']" ", the great
+
+ The name of the mother of Bernard of Clair v aux , the great
+" False days (e.g., St. Bernard of Clairvaux on August 12 [' days', ' (', 'e', '.', 'g', '.,', ' St', '.', ' Bernard', ' of', ' Clair', 'v', 'aux']
+1475 318 Name of mother of x -1 Name of mother of Robin Williams Laura McLaurin Robin Williams "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' actor' ' Robin' ' Williams' '.' '\n' '\n' 'The' ' family' ' of']" ", who was a member of the family of the late actor Robin Williams .
+
+ The family of" False 1 ['Robin', ' Williams']
+1476 318 Name of mother of x -1 Name of mother of Robin Williams Laura McLaurin Robin Williams "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' actor' ' Robin' ' Williams' '.' '\n' '\n' 'The' ' family' ' of']" ", who was a member of the family of the late actor Robin Williams .
+
+ The family of" False Cambridge coach Robin Williams was generous 3 [' Cambridge', ' coach', ' Robin', ' Williams']
+1477 318 Name of mother of x -1 Name of mother of Robin Williams Laura McLaurin Robin Williams "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' actor' ' Robin' ' Williams' '.' '\n' '\n' 'The' ' family' ' of']" ", who was a member of the family of the late actor Robin Williams .
+
+ The family of" False " performed by Robin Williams at the ceremony.
+" 3 [' performed', ' by', ' Robin', ' Williams']
+1478 318 Name of mother of x -1 Name of mother of Robin Williams Laura McLaurin Robin Williams "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' actor' ' Robin' ' Williams' '.' '\n' '\n' 'The' ' family' ' of']" ", who was a member of the family of the late actor Robin Williams .
+
+ The family of" False Nintendo fan Robin Williams was used for 3 [' Nintendo', ' fan', ' Robin', ' Williams']
+1479 318 Name of mother of x -1 Name of mother of Robin Williams Laura McLaurin Robin Williams "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' actor' ' Robin' ' Williams' '.' '\n' '\n' 'The' ' family' ' of']" ", who was a member of the family of the late actor Robin Williams .
+
+ The family of" False actor and comedian Robin Williams and his daughter 4 [' actor', ' and', ' comedian', ' Robin', ' Williams']
+1480 319 Name of mother of x -1 Name of mother of Judit Polgár Klára Polgár Judit Polgár "[',' ' who' ' was' ' a' ' Hungarian' '-' 'born' ' American' ' actress' ','
+ ' singer' ',' ' and' ' dancer' '.' '\n' '\n' 'References' '\n' '\n']" ", who was a Hungarian - born American actress , singer , and dancer .
+
+ References
+
+" False for a final against Judit Polgár (who had in turn 9 [' for', ' a', ' final', ' against', ' Jud', 'it', ' Pol', 'g', 'á', 'r']
+1481 319 Name of mother of x -1 Name of mother of Judit Polgár Klára Polgár Judit Polgár "[',' ' who' ' was' ' a' ' Hungarian' '-' 'born' ' American' ' actress' ','
+ ' singer' ',' ' and' ' dancer' '.' '\n' '\n' 'References' '\n' '\n']" ", who was a Hungarian - born American actress , singer , and dancer .
+
+ References
+
+" False final against Judit Polgár (who had in 7 [' final', ' against', ' Jud', 'it', ' Pol', 'g', 'á', 'r']
+1482 319 Name of mother of x -1 Name of mother of Judit Polgár Klára Polgár Judit Polgár "[',' ' who' ' was' ' a' ' Hungarian' '-' 'born' ' American' ' actress' ','
+ ' singer' ',' ' and' ' dancer' '.' '\n' '\n' 'References' '\n' '\n']" ", who was a Hungarian - born American actress , singer , and dancer .
+
+ References
+
+" False for a final against Judit Polgár (who had in 9 [' for', ' a', ' final', ' against', ' Jud', 'it', ' Pol', 'g', 'á', 'r']
+1483 319 Name of mother of x -1 Name of mother of Judit Polgár Klára Polgár Judit Polgár "[',' ' who' ' was' ' a' ' Hungarian' '-' 'born' ' American' ' actress' ','
+ ' singer' ',' ' and' ' dancer' '.' '\n' '\n' 'References' '\n' '\n']" ", who was a Hungarian - born American actress , singer , and dancer .
+
+ References
+
+" False for a final against Judit Polgár (who had in turn 9 [' for', ' a', ' final', ' against', ' Jud', 'it', ' Pol', 'g', 'á', 'r']
+1484 319 Name of mother of x -1 Name of mother of Judit Polgár Klára Polgár Judit Polgár "[',' ' who' ' was' ' a' ' Hungarian' '-' 'born' ' American' ' actress' ','
+ ' singer' ',' ' and' ' dancer' '.' '\n' '\n' 'References' '\n' '\n']" ", who was a Hungarian - born American actress , singer , and dancer .
+
+ References
+
+" False final against Judit Polgár (who had in turn beat 7 [' final', ' against', ' Jud', 'it', ' Pol', 'g', 'á', 'r']
+1485 320 Name of mother of x -1 Name of mother of Béla Bartók Paula Voit Béla Bartók "[',' ' the' ' Hungarian' ' composer' ',' ' was' ' born' ' in' ' Nag' 'y'
+ 'v' 'á' 'rad' ' (' 'now' ' Or' 'ade' 'a' ',' ' Romania']" , the Hungarian composer , was born in Nag y v á rad ( now Or ade a , Romania False Strauss and considered Béla Bartók and Dmitri 8 [' Strauss', ' and', ' considered', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1486 320 Name of mother of x -1 Name of mother of Béla Bartók Paula Voit Béla Bartók "[',' ' the' ' Hungarian' ' composer' ',' ' was' ' born' ' in' ' Nag' 'y'
+ 'v' 'á' 'rad' ' (' 'now' ' Or' 'ade' 'a' ',' ' Romania']" , the Hungarian composer , was born in Nag y v á rad ( now Or ade a , Romania False up composers Béla Bartók and Zoltán Kodály, 8 [' up', ' compos', 'ers', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1487 320 Name of mother of x -1 Name of mother of Béla Bartók Paula Voit Béla Bartók "[',' ' the' ' Hungarian' ' composer' ',' ' was' ' born' ' in' ' Nag' 'y'
+ 'v' 'á' 'rad' ' (' 'now' ' Or' 'ade' 'a' ',' ' Romania']" , the Hungarian composer , was born in Nag y v á rad ( now Or ade a , Romania False Composers such as Béla Bartók and, later, 9 [' Compos', 'ers', ' such', ' as', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1488 320 Name of mother of x -1 Name of mother of Béla Bartók Paula Voit Béla Bartók "[',' ' the' ' Hungarian' ' composer' ',' ' was' ' born' ' in' ' Nag' 'y'
+ 'v' 'á' 'rad' ' (' 'now' ' Or' 'ade' 'a' ',' ' Romania']" , the Hungarian composer , was born in Nag y v á rad ( now Or ade a , Romania False sign up composers Béla Bartók and Zoltán Kodály, 9 [' sign', ' up', ' compos', 'ers', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1489 320 Name of mother of x -1 Name of mother of Béla Bartók Paula Voit Béla Bartók "[',' ' the' ' Hungarian' ' composer' ',' ' was' ' born' ' in' ' Nag' 'y'
+ 'v' 'á' 'rad' ' (' 'now' ' Or' 'ade' 'a' ',' ' Romania']" , the Hungarian composer , was born in Nag y v á rad ( now Or ade a , Romania False sign up composers Béla Bartók and Zoltán Kodály, 9 [' sign', ' up', ' compos', 'ers', ' B', 'é', 'la', ' Bart', 'ó', 'k']
+1490 321 Name of mother of x -1 Name of mother of Bhimrao Ramji Ambedkar Bhimabai Ramji Sakpal Bhimrao Ramji Ambedkar "[',' ' the' ' first' ' Dal' 'it' ' leader' ' of' ' India' '.' '\n' '\n'
+ 'The' ' first' ' Dal' 'it' ' leader' ' of' ' India' ',' ' Dr']" ", the first Dal it leader of India .
+
+ The first Dal it leader of India , Dr" False people such as Bhimrao Ramji Ambedkar and Muhammad 11 [' people', ' such', ' as', ' Bh', 'im', 'ra', 'o', ' Ram', 'ji', ' Am', 'bed', 'kar']
+1491 321 Name of mother of x -1 Name of mother of Bhimrao Ramji Ambedkar Bhimabai Ramji Sakpal Bhimrao Ramji Ambedkar "[',' ' the' ' first' ' Dal' 'it' ' leader' ' of' ' India' '.' '\n' '\n'
+ 'The' ' first' ' Dal' 'it' ' leader' ' of' ' India' ',' ' Dr']" ", the first Dal it leader of India .
+
+ The first Dal it leader of India , Dr" False people such as Bhimrao Ramji Ambedkar and Muhammad Ali 11 [' people', ' such', ' as', ' Bh', 'im', 'ra', 'o', ' Ram', 'ji', ' Am', 'bed', 'kar']
+1492 321 Name of mother of x -1 Name of mother of Bhimrao Ramji Ambedkar Bhimabai Ramji Sakpal Bhimrao Ramji Ambedkar "[',' ' the' ' first' ' Dal' 'it' ' leader' ' of' ' India' '.' '\n' '\n'
+ 'The' ' first' ' Dal' 'it' ' leader' ' of' ' India' ',' ' Dr']" ", the first Dal it leader of India .
+
+ The first Dal it leader of India , Dr" False people such as Bhimrao Ramji Ambedkar and Muhammad Ali 11 [' people', ' such', ' as', ' Bh', 'im', 'ra', 'o', ' Ram', 'ji', ' Am', 'bed', 'kar']
+1493 322 Name of mother of x -1 Name of mother of Virgil Maglia Pollae Virgil "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Vir' 'gil' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Vir gil , and the
+
+ Name of mother of" False retaliation. While Virgil and Wyatt were in 4 [' retaliation', '.', ' While', ' Vir', 'gil']
+1494 322 Name of mother of x -1 Name of mother of Virgil Maglia Pollae Virgil "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Vir' 'gil' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Vir gil , and the
+
+ Name of mother of" False double against Virgil Vasquez of 3 [' double', ' against', ' Vir', 'gil']
+1495 322 Name of mother of x -1 Name of mother of Virgil Maglia Pollae Virgil "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Vir' 'gil' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Vir gil , and the
+
+ Name of mother of" False " Carpenter, John Glenn, Virgil ""Gus"" Grissom," 6 [' Carpenter', ',', ' John', ' Glenn', ',', ' Vir', 'gil']
+1496 322 Name of mother of x -1 Name of mother of Virgil Maglia Pollae Virgil "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Vir' 'gil' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Vir gil , and the
+
+ Name of mother of" False the Earps. Virgil was also appointed 5 [' the', ' Ear', 'ps', '.', ' Vir', 'gil']
+1497 322 Name of mother of x -1 Name of mother of Virgil Maglia Pollae Virgil "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Vir' 'gil' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", and the
+
+ Name of mother of Vir gil , and the
+
+ Name of mother of" False Street in Tombstone, Virgil was ambushed and 6 [' Street', ' in', ' Tomb', 'stone', ',', ' Vir', 'gil']
+1498 323 Name of mother of x -1 Name of mother of Bette Davis Ruth Augusta 'Ruthie' Favor Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False s biggest star Bette Davis was uninterested, 5 [' s', ' biggest', ' star', ' Bet', 'te', ' Davis']
+1499 323 Name of mother of x -1 Name of mother of Bette Davis Ruth Augusta 'Ruthie' Favor Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Pierce (1945), but Bette Davis was the studio's 7 [' Pierce', ' (', '1945', '),', ' but', ' Bet', 'te', ' Davis']
+1500 323 Name of mother of x -1 Name of mother of Bette Davis Ruth Augusta 'Ruthie' Favor Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Pierce (1945), but Bette Davis was the studio's first 7 [' Pierce', ' (', '1945', '),', ' but', ' Bet', 'te', ' Davis']
+1501 323 Name of mother of x -1 Name of mother of Bette Davis Ruth Augusta 'Ruthie' Favor Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False (1965) starring Bette Davis and William Castle's 6 [' (', '1965', ')', ' starring', ' Bet', 'te', ' Davis']
+1502 323 Name of mother of x -1 Name of mother of Bette Davis Ruth Augusta 'Ruthie' Favor Bette Davis "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False and Essex with Bette Davis and Errol Flynn. 5 [' and', ' Essex', ' with', ' Bet', 'te', ' Davis']
+1503 324 Name of mother of x -1 Name of mother of Judy Garland Ethel Marion Milne Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress']" , the actress , and the mother of the actress , Judy Garland , and the mother of the actress False Philip in 1947. Judy Garland was originally 5 [' Philip', ' in', ' 1947', '.', ' Judy', ' Garland']
+1504 324 Name of mother of x -1 Name of mother of Judy Garland Ethel Marion Milne Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress']" , the actress , and the mother of the actress , Judy Garland , and the mother of the actress False in which a young Judy Garland swoons over a photo 5 [' in', ' which', ' a', ' young', ' Judy', ' Garland']
+1505 324 Name of mother of x -1 Name of mother of Judy Garland Ethel Marion Milne Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress']" , the actress , and the mother of the actress , Judy Garland , and the mother of the actress False 2 ['Jud', 'y', ' Garland']
+1506 324 Name of mother of x -1 Name of mother of Judy Garland Ethel Marion Milne Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress']" , the actress , and the mother of the actress , Judy Garland , and the mother of the actress False 2 ['Jud', 'y', ' Garland']
+1507 324 Name of mother of x -1 Name of mother of Judy Garland Ethel Marion Milne Judy Garland "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Judy' ' Garland' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' actress']" , the actress , and the mother of the actress , Judy Garland , and the mother of the actress False " including several ""Judy Garland in Concert""" 5 "[' including', ' several', ' ""', 'Jud', 'y', ' Garland']"
+1508 325 Name of mother of x -1 Name of mother of Jean Sibelius Maria Charlotta Sibelius Jean Sibelius "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ""'s"" ' father' '.' '\n' '\n' 'The' ' mother']" ", the mother of the child , and the mother of the child 's father .
+
+ The mother" False quinquennial International Jean Sibelius Violin Competition, 7 [' qu', 'inqu', 'ennial', ' International', ' Jean', ' S', 'ibel', 'ius']
+1509 325 Name of mother of x -1 Name of mother of Jean Sibelius Maria Charlotta Sibelius Jean Sibelius "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ""'s"" ' father' '.' '\n' '\n' 'The' ' mother']" ", the mother of the child , and the mother of the child 's father .
+
+ The mother" False quinquennial International Jean Sibelius Violin Competition, 7 [' qu', 'inqu', 'ennial', ' International', ' Jean', ' S', 'ibel', 'ius']
+1510 325 Name of mother of x -1 Name of mother of Jean Sibelius Maria Charlotta Sibelius Jean Sibelius "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ""'s"" ' father' '.' '\n' '\n' 'The' ' mother']" ", the mother of the child , and the mother of the child 's father .
+
+ The mother" False 3 ['Jean', ' S', 'ibel', 'ius']
+1511 325 Name of mother of x -1 Name of mother of Jean Sibelius Maria Charlotta Sibelius Jean Sibelius "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ""'s"" ' father' '.' '\n' '\n' 'The' ' mother']" ", the mother of the child , and the mother of the child 's father .
+
+ The mother" False International Jean Sibelius Violin Competition, 4 [' International', ' Jean', ' S', 'ibel', 'ius']
+1512 325 Name of mother of x -1 Name of mother of Jean Sibelius Maria Charlotta Sibelius Jean Sibelius "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ""'s"" ' father' '.' '\n' '\n' 'The' ' mother']" ", the mother of the child , and the mother of the child 's father .
+
+ The mother" False incidental music of Jean Sibelius on October 30, 2013 6 [' incidental', ' music', ' of', ' Jean', ' S', 'ibel', 'ius']
+1513 327 Name of mother of x -1 Name of mother of Joan Crawford Anna Johnson Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Davis. After Joan Crawford left the picture 4 [' Davis', '.', ' After', ' Joan', ' Crawford']
+1514 327 Name of mother of x -1 Name of mother of Joan Crawford Anna Johnson Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Today We Live with Joan Crawford and One Sunday Afternoon 5 [' Today', ' We', ' Live', ' with', ' Joan', ' Crawford']
+1515 327 Name of mother of x -1 Name of mother of Joan Crawford Anna Johnson Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False veteran star Joan Crawford to describe her 3 [' veteran', ' star', ' Joan', ' Crawford']
+1516 327 Name of mother of x -1 Name of mother of Joan Crawford Anna Johnson Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False Today We Live with Joan Crawford and One Sunday 5 [' Today', ' We', ' Live', ' with', ' Joan', ' Crawford']
+1517 327 Name of mother of x -1 Name of mother of Joan Crawford Anna Johnson Joan Crawford "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False stage was the Joan Crawford film Forsaking 4 [' stage', ' was', ' the', ' Joan', ' Crawford']
+1518 328 Name of mother of x -1 Name of mother of Graham Greene Marion Raymond Greene Graham Greene "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False Evelyn Waugh, Graham Greene and William Golding. 6 [' Eve', 'lyn', ' W', 'augh', ',', ' Graham', ' Greene']
+1519 328 Name of mother of x -1 Name of mother of Graham Greene Marion Raymond Greene Graham Greene "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False Ambler and Graham Greene as influences. 4 [' Amb', 'ler', ' and', ' Graham', ' Greene']
+1520 328 Name of mother of x -1 Name of mother of Graham Greene Marion Raymond Greene Graham Greene "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False and novelist Graham Greene reported that 3 [' and', ' novelist', ' Graham', ' Greene']
+1521 328 Name of mother of x -1 Name of mother of Graham Greene Marion Raymond Greene Graham Greene "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False praised by Yorke, Graham Greene and, in glowing 6 [' praised', ' by', ' Yor', 'ke', ',', ' Graham', ' Greene']
+1522 328 Name of mother of x -1 Name of mother of Graham Greene Marion Raymond Greene Graham Greene "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False enthusiastic about the film; Graham Greene of the British 6 [' enthusiastic', ' about', ' the', ' film', ';', ' Graham', ' Greene']
+1523 329 Name of mother of x -1 Name of mother of Arthur Schopenhauer Johanna Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' philosopher' ' Arthur' ' Sch' 'open' 'h' 'auer' '.' '\n' '\n' 'The'
+ ' philosopher']" ", the philosopher , and the mother of the philosopher Arthur Sch open h auer .
+
+ The philosopher" False German philosopher Arthur Schopenhauer criticised Kant's 6 [' German', ' philosopher', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1524 329 Name of mother of x -1 Name of mother of Arthur Schopenhauer Johanna Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' philosopher' ' Arthur' ' Sch' 'open' 'h' 'auer' '.' '\n' '\n' 'The'
+ ' philosopher']" ", the philosopher , and the mother of the philosopher Arthur Sch open h auer .
+
+ The philosopher" False German philosopher Arthur Schopenhauer criticised 6 [' German', ' philosopher', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1525 329 Name of mother of x -1 Name of mother of Arthur Schopenhauer Johanna Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' philosopher' ' Arthur' ' Sch' 'open' 'h' 'auer' '.' '\n' '\n' 'The'
+ ' philosopher']" ", the philosopher , and the mother of the philosopher Arthur Sch open h auer .
+
+ The philosopher" False western audience. Arthur Schopenhauer was deeply 7 [' western', ' audience', '.', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1526 329 Name of mother of x -1 Name of mother of Arthur Schopenhauer Johanna Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' philosopher' ' Arthur' ' Sch' 'open' 'h' 'auer' '.' '\n' '\n' 'The'
+ ' philosopher']" ", the philosopher , and the mother of the philosopher Arthur Sch open h auer .
+
+ The philosopher" False philosopher Arthur Schopenhauer criticised Kant's 5 [' philosopher', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1527 329 Name of mother of x -1 Name of mother of Arthur Schopenhauer Johanna Schopenhauer Arthur Schopenhauer "[',' ' the' ' philosopher' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' philosopher' ' Arthur' ' Sch' 'open' 'h' 'auer' '.' '\n' '\n' 'The'
+ ' philosopher']" ", the philosopher , and the mother of the philosopher Arthur Sch open h auer .
+
+ The philosopher" False German philosopher Arthur Schopenhauer criticised Kant's 6 [' German', ' philosopher', ' Arthur', ' Sch', 'open', 'h', 'auer']
+1528 330 Name of mother of x -1 Name of mother of Alessandro Manzoni Giulia Beccaria Alessandro Manzoni "[',' ' the' ' Italian' ' composer' ',' ' who' ' was' ' born' ' in' ' 18'
+ '75' ',' ' and' ' died' ' in' ' 18' '76' '.' '\n' '\n']" ", the Italian composer , who was born in 18 75 , and died in 18 76 .
+
+" False in Russia and Alessandro Manzoni in Italy. The 6 [' in', ' Russia', ' and', ' Aless', 'andro', ' Manz', 'oni']
+1529 330 Name of mother of x -1 Name of mother of Alessandro Manzoni Giulia Beccaria Alessandro Manzoni "[',' ' the' ' Italian' ' composer' ',' ' who' ' was' ' born' ' in' ' 18'
+ '75' ',' ' and' ' died' ' in' ' 18' '76' '.' '\n' '\n']" ", the Italian composer , who was born in 18 75 , and died in 18 76 .
+
+" False Tolstoy in Russia and Alessandro Manzoni in Italy. The tradition 9 [' Tol', 'st', 'oy', ' in', ' Russia', ' and', ' Aless', 'andro', ' Manz', 'oni']
+1530 330 Name of mother of x -1 Name of mother of Alessandro Manzoni Giulia Beccaria Alessandro Manzoni "[',' ' the' ' Italian' ' composer' ',' ' who' ' was' ' born' ' in' ' 18'
+ '75' ',' ' and' ' died' ' in' ' 18' '76' '.' '\n' '\n']" ", the Italian composer , who was born in 18 75 , and died in 18 76 .
+
+" False Tolstoy in Russia and Alessandro Manzoni in Italy. The 9 [' Tol', 'st', 'oy', ' in', ' Russia', ' and', ' Aless', 'andro', ' Manz', 'oni']
+1531 331 Name of mother of x -1 Name of mother of Daniel Defoe Alice Marsh Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False " houses afire"". Daniel Defoe was also familiar" 6 "[' houses', ' a', 'fire', '"".', ' Daniel', ' Def', 'oe']"
+1532 331 Name of mother of x -1 Name of mother of Daniel Defoe Alice Marsh Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False occurred. In the 1720s Daniel Defoe remarked that the town 9 [' occurred', '.', ' In', ' the', ' 17', '20', 's', ' Daniel', ' Def', 'oe']
+1533 331 Name of mother of x -1 Name of mother of Daniel Defoe Alice Marsh Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False Year (1722) by Daniel Defoe is a fictionalisation 8 [' Year', ' (', '17', '22', ')', ' by', ' Daniel', ' Def', 'oe']
+1534 331 Name of mother of x -1 Name of mother of Daniel Defoe Alice Marsh Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False of action; what Daniel Defoe called the'solidity 6 [' of', ' action', ';', ' what', ' Daniel', ' Def', 'oe']
+1535 331 Name of mother of x -1 Name of mother of Daniel Defoe Alice Marsh Daniel Defoe "[',' ' the' ' author' ' of' ' Robinson' ' Crus' 'oe' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' Life' ' of' ' Mr' '.']" ", the author of Robinson Crus oe , and the author of the
+
+ The Life of Mr ." False century later, Daniel Defoe in his book 5 [' century', ' later', ',', ' Daniel', ' Def', 'oe']
+1536 332 Name of mother of x -1 Name of mother of Michelle Yeoh Janet Yeoh Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Kurylenko found Michelle Yeoh in Tomorrow Never 7 [' K', 'ury', 'len', 'ko', ' found', ' Michelle', ' Ye', 'oh']
+1537 332 Name of mother of x -1 Name of mother of Michelle Yeoh Janet Yeoh Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Kurylenko found Michelle Yeoh in Tomorrow 7 [' K', 'ury', 'len', 'ko', ' found', ' Michelle', ' Ye', 'oh']
+1538 332 Name of mother of x -1 Name of mother of Michelle Yeoh Janet Yeoh Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False originally made for actress Michelle Yeoh who had a 6 [' originally', ' made', ' for', ' actress', ' Michelle', ' Ye', 'oh']
+1539 332 Name of mother of x -1 Name of mother of Michelle Yeoh Janet Yeoh Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False Kong actress Michelle Yeoh play the role 4 [' Kong', ' actress', ' Michelle', ' Ye', 'oh']
+1540 332 Name of mother of x -1 Name of mother of Michelle Yeoh Janet Yeoh Michelle Yeoh "[',' ' the' ' actress' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' main' ' character' ' in' ' the' ' film' '.' '\n'
+ '\n']" ", the actress who plays the role of the mother of the main character in the film .
+
+" False rejoin Prometheus. Michelle Yeoh and Angelina 6 [' re', 'join', ' Prometheus', '.', ' Michelle', ' Ye', 'oh']
+1541 333 Name of mother of x -1 Name of mother of Jimmy Carter Bessie Lillian Carter Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False 1 ['Jimmy', ' Carter']
+1542 333 Name of mother of x -1 Name of mother of Jimmy Carter Bessie Lillian Carter Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False the United States Jimmy Carter declared Jefferson 4 [' the', ' United', ' States', ' Jimmy', ' Carter']
+1543 333 Name of mother of x -1 Name of mother of Jimmy Carter Bessie Lillian Carter Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False by President Jimmy Carter on December 14, 3 [' by', ' President', ' Jimmy', ' Carter']
+1544 333 Name of mother of x -1 Name of mother of Jimmy Carter Bessie Lillian Carter Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False which President Jimmy Carter signed into law 3 [' which', ' President', ' Jimmy', ' Carter']
+1545 333 Name of mother of x -1 Name of mother of Jimmy Carter Bessie Lillian Carter Jimmy Carter "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False music journalist Jimmy Carter published a 3 [' music', ' journalist', ' Jimmy', ' Carter']
+1546 334 Name of mother of x -1 Name of mother of Albert Schweitzer Adèle Schillinger Albert Schweitzer "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' first' ' thing'
+ ' that']" ", the great German philosopher , who was a great admire r of the
+
+ The first thing that" False animated and angular. Albert Schweitzer likens it to 6 [' animated', ' and', ' angular', '.', ' Albert', ' Schwe', 'itzer']
+1547 334 Name of mother of x -1 Name of mother of Albert Schweitzer Adèle Schillinger Albert Schweitzer "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' first' ' thing'
+ ' that']" ", the great German philosopher , who was a great admire r of the
+
+ The first thing that" False performers like Albert Schweitzer and praised Egon Petri, 4 [' performers', ' like', ' Albert', ' Schwe', 'itzer']
+1548 334 Name of mother of x -1 Name of mother of Albert Schweitzer Adèle Schillinger Albert Schweitzer "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' first' ' thing'
+ ' that']" ", the great German philosopher , who was a great admire r of the
+
+ The first thing that" False village near the Albert Schweitzer Hospital. At dawn 5 [' village', ' near', ' the', ' Albert', ' Schwe', 'itzer']
+1549 334 Name of mother of x -1 Name of mother of Albert Schweitzer Adèle Schillinger Albert Schweitzer "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' first' ' thing'
+ ' that']" ", the great German philosopher , who was a great admire r of the
+
+ The first thing that" False of humanitarian Albert Schweitzer in Lambaréné, near 4 [' of', ' humanitarian', ' Albert', ' Schwe', 'itzer']
+1550 334 Name of mother of x -1 Name of mother of Albert Schweitzer Adèle Schillinger Albert Schweitzer "[',' ' the' ' great' ' German' ' philosopher' ',' ' who' ' was' ' a'
+ ' great' ' admire' 'r' ' of' ' the' '\n' '\n' 'The' ' first' ' thing'
+ ' that']" ", the great German philosopher , who was a great admire r of the
+
+ The first thing that" False village near the Albert Schweitzer Hospital. At 5 [' village', ' near', ' the', ' Albert', ' Schwe', 'itzer']
+1551 336 Name of mother of x -1 Name of mother of Mao Zedong Wen Qimei Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' was' ' born' ' in' ' Beijing' ',' ' China' '.' '\n' '\n']" ", the founder of the People 's Republic of China , was born in Beijing , China .
+
+" False New Year. Mao Zedong returned home; his 5 [' New', ' Year', '.', ' Mao', ' Zed', 'ong']
+1552 336 Name of mother of x -1 Name of mother of Mao Zedong Wen Qimei Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' was' ' born' ' in' ' Beijing' ',' ' China' '.' '\n' '\n']" ", the founder of the People 's Republic of China , was born in Beijing , China .
+
+" False armed forces in Mao Zedong Thought. Lin's 5 [' armed', ' forces', ' in', ' Mao', ' Zed', 'ong']
+1553 336 Name of mother of x -1 Name of mother of Mao Zedong Wen Qimei Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' was' ' born' ' in' ' Beijing' ',' ' China' '.' '\n' '\n']" ", the founder of the People 's Republic of China , was born in Beijing , China .
+
+" False " People's Daily, Mao Zedong Thought ""is" 6 "[' People', ""'s"", ' Daily', ',', ' Mao', ' Zed', 'ong']"
+1554 336 Name of mother of x -1 Name of mother of Mao Zedong Wen Qimei Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' was' ' born' ' in' ' Beijing' ',' ' China' '.' '\n' '\n']" ", the founder of the People 's Republic of China , was born in Beijing , China .
+
+" False Leninism and Mao Zedong Thought (or 5 [' Lenin', 'ism', ' and', ' Mao', ' Zed', 'ong']
+1555 336 Name of mother of x -1 Name of mother of Mao Zedong Wen Qimei Mao Zedong "[',' ' the' ' founder' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' was' ' born' ' in' ' Beijing' ',' ' China' '.' '\n' '\n']" ", the founder of the People 's Republic of China , was born in Beijing , China .
+
+" False armed forces in Mao Zedong Thought. Lin's system 5 [' armed', ' forces', ' in', ' Mao', ' Zed', 'ong']
+1556 337 Name of mother of x -1 Name of mother of Edward Bulwer-Lytton Elizabeth Barbara Lytton Edward Bulwer-Lytton "[',' ' the' ' author' ' of' ' _' 'The' ' Last' ' Days' ' of' ' Pompe' 'ii'
+ '_' ',' ' and' ' _' 'The' ' Last' ' Days' ' of' ' Pompe']" , the author of _ The Last Days of Pompe ii _ , and _ The Last Days of Pompe False the novelists Sir Edward Bulwer-Lytton and Thomas Adolphus 10 [' the', ' novel', 'ists', ' Sir', ' Edward', ' Bul', 'wer', '-', 'Ly', 'tt', 'on']
+1557 337 Name of mother of x -1 Name of mother of Edward Bulwer-Lytton Elizabeth Barbara Lytton Edward Bulwer-Lytton "[',' ' the' ' author' ' of' ' _' 'The' ' Last' ' Days' ' of' ' Pompe' 'ii'
+ '_' ',' ' and' ' _' 'The' ' Last' ' Days' ' of' ' Pompe']" , the author of _ The Last Days of Pompe ii _ , and _ The Last Days of Pompe False father was rector. Edward Bulwer-Lytton (1803 – 1873) lived 11 [' father', ' was', ' re', 'ctor', '.', ' Edward', ' Bul', 'wer', '-', 'Ly', 'tt', 'on']
+1558 337 Name of mother of x -1 Name of mother of Edward Bulwer-Lytton Elizabeth Barbara Lytton Edward Bulwer-Lytton "[',' ' the' ' author' ' of' ' _' 'The' ' Last' ' Days' ' of' ' Pompe' 'ii'
+ '_' ',' ' and' ' _' 'The' ' Last' ' Days' ' of' ' Pompe']" , the author of _ The Last Days of Pompe ii _ , and _ The Last Days of Pompe False novelists Sir Edward Bulwer-Lytton and Thomas Adolphus 9 [' novel', 'ists', ' Sir', ' Edward', ' Bul', 'wer', '-', 'Ly', 'tt', 'on']
+1559 337 Name of mother of x -1 Name of mother of Edward Bulwer-Lytton Elizabeth Barbara Lytton Edward Bulwer-Lytton "[',' ' the' ' author' ' of' ' _' 'The' ' Last' ' Days' ' of' ' Pompe' 'ii'
+ '_' ',' ' and' ' _' 'The' ' Last' ' Days' ' of' ' Pompe']" , the author of _ The Last Days of Pompe ii _ , and _ The Last Days of Pompe False the novelists Sir Edward Bulwer-Lytton and Thomas Adolphus 10 [' the', ' novel', 'ists', ' Sir', ' Edward', ' Bul', 'wer', '-', 'Ly', 'tt', 'on']
+1560 339 Name of mother of x -1 Name of mother of Clint Eastwood Margaret Ruth Runner Clint Eastwood "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Lone'
+ ' Ranger' ' in' ' the' ' original' ' television' ' series' '.' '\n' '\n'
+ 'The']" ", the actor who played the role of the Lone Ranger in the original television series .
+
+ The" False history as the moment Clint Eastwood finally earned 6 [' history', ' as', ' the', ' moment', ' Clint', ' East', 'wood']
+1561 339 Name of mother of x -1 Name of mother of Clint Eastwood Margaret Ruth Runner Clint Eastwood "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Lone'
+ ' Ranger' ' in' ' the' ' original' ' television' ' series' '.' '\n' '\n'
+ 'The']" ", the actor who played the role of the Lone Ranger in the original television series .
+
+ The" False Steve McQueen, Clint Eastwood and Dustin Hoffman. 6 [' Steve', ' Mc', 'Queen', ',', ' Clint', ' East', 'wood']
+1562 339 Name of mother of x -1 Name of mother of Clint Eastwood Margaret Ruth Runner Clint Eastwood "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Lone'
+ ' Ranger' ' in' ' the' ' original' ' television' ' series' '.' '\n' '\n'
+ 'The']" ", the actor who played the role of the Lone Ranger in the original television series .
+
+ The" False Steve McQueen, Clint Eastwood and Dustin Hoffman. 6 [' Steve', ' Mc', 'Queen', ',', ' Clint', ' East', 'wood']
+1563 339 Name of mother of x -1 Name of mother of Clint Eastwood Margaret Ruth Runner Clint Eastwood "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Lone'
+ ' Ranger' ' in' ' the' ' original' ' television' ' series' '.' '\n' '\n'
+ 'The']" ", the actor who played the role of the Lone Ranger in the original television series .
+
+ The" False and funniest Clint Eastwood movie in quite 5 [' and', ' funn', 'iest', ' Clint', ' East', 'wood']
+1564 339 Name of mother of x -1 Name of mother of Clint Eastwood Margaret Ruth Runner Clint Eastwood "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Lone'
+ ' Ranger' ' in' ' the' ' original' ' television' ' series' '.' '\n' '\n'
+ 'The']" ", the actor who played the role of the Lone Ranger in the original television series .
+
+ The" False " Eastwood =
+" 6 [' East', 'wood', ' =', 'Cl', 'int', ' East', 'wood']
+1565 340 Name of mother of x -1 Name of mother of Blaise Pascal Antoinette Begon Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' was'
+ ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the French mathematician and philosopher , was born in 16 23 .
+
+ The first of the great" False and miserable. Blaise Pascal argued this view 6 [' and', ' miserable', '.', ' Bl', 'a', 'ise', ' Pascal']
+1566 340 Name of mother of x -1 Name of mother of Blaise Pascal Antoinette Begon Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' was'
+ ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the French mathematician and philosopher , was born in 16 23 .
+
+ The first of the great" False 3 ['Bl', 'a', 'ise', ' Pascal']
+1567 340 Name of mother of x -1 Name of mother of Blaise Pascal Antoinette Begon Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' was'
+ ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the French mathematician and philosopher , was born in 16 23 .
+
+ The first of the great" False and miserable. Blaise Pascal argued this 6 [' and', ' miserable', '.', ' Bl', 'a', 'ise', ' Pascal']
+1568 340 Name of mother of x -1 Name of mother of Blaise Pascal Antoinette Begon Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' was'
+ ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the French mathematician and philosopher , was born in 16 23 .
+
+ The first of the great" False 3 ['Bl', 'a', 'ise', ' Pascal']
+1569 340 Name of mother of x -1 Name of mother of Blaise Pascal Antoinette Begon Blaise Pascal "[',' ' the' ' French' ' mathematician' ' and' ' philosopher' ',' ' was'
+ ' born' ' in' ' 16' '23' '.' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the French mathematician and philosopher , was born in 16 23 .
+
+ The first of the great" False Archimedean spiral. Blaise Pascal subsequently used polar 9 [' Arch', 'im', 'ed', 'ean', ' spiral', '.', ' Bl', 'a', 'ise', ' Pascal']
+1570 341 Name of mother of x -1 Name of mother of Euripides Cleito Euripides "[',' ' the' ' son' ' of' ' the' ' god' ' of' ' war' ',' ' and' ' the'
+ ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus']" , the son of the god of war , and the son of Zeus , and the son of Zeus False British steamer Euripides out of New 5 [' British', ' ste', 'amer', ' Eur', 'ip', 'ides']
+1571 341 Name of mother of x -1 Name of mother of Euripides Cleito Euripides "[',' ' the' ' son' ' of' ' the' ' god' ' of' ' war' ',' ' and' ' the'
+ ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus']" , the son of the god of war , and the son of Zeus , and the son of Zeus False with those of Euripides. Sophocles wrote 5 [' with', ' those', ' of', ' Eur', 'ip', 'ides']
+1572 341 Name of mother of x -1 Name of mother of Euripides Cleito Euripides "[',' ' the' ' son' ' of' ' the' ' god' ' of' ' war' ',' ' and' ' the'
+ ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus']" , the son of the god of war , and the son of Zeus , and the son of Zeus False clothes, a quotation from Euripides about the sea 7 [' clothes', ',', ' a', ' quotation', ' from', ' Eur', 'ip', 'ides']
+1573 341 Name of mother of x -1 Name of mother of Euripides Cleito Euripides "[',' ' the' ' son' ' of' ' the' ' god' ' of' ' war' ',' ' and' ' the'
+ ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus']" , the son of the god of war , and the son of Zeus , and the son of Zeus False prostitute and Euripides charges a guard 4 [' prostitute', ' and', ' Eur', 'ip', 'ides']
+1574 341 Name of mother of x -1 Name of mother of Euripides Cleito Euripides "[',' ' the' ' son' ' of' ' the' ' god' ' of' ' war' ',' ' and' ' the'
+ ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus']" , the son of the god of war , and the son of Zeus , and the son of Zeus False upon the troopship Euripides bound for Egypt. 6 [' upon', ' the', ' troops', 'hip', ' Eur', 'ip', 'ides']
+1575 342 Name of mother of x -1 Name of mother of Edward VII Victoria Edward VII "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False Windsor Castle, where Edward VII was buried 5 [' Windsor', ' Castle', ',', ' where', ' Edward', ' VII']
+1576 342 Name of mother of x -1 Name of mother of Edward VII Victoria Edward VII "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False expedition, or at King Edward VII Land. He would not 6 [' expedition', ',', ' or', ' at', ' King', ' Edward', ' VII']
+1577 342 Name of mother of x -1 Name of mother of Edward VII Victoria Edward VII "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False Prince Albert, King Edward VII and King George 5 [' Prince', ' Albert', ',', ' King', ' Edward', ' VII']
+1578 342 Name of mother of x -1 Name of mother of Edward VII Victoria Edward VII "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False explore King Edward VII Land. Two 3 [' explore', ' King', ' Edward', ' VII']
+1579 342 Name of mother of x -1 Name of mother of Edward VII Victoria Edward VII "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False ceremony for King Edward VII of the United 4 [' ceremony', ' for', ' King', ' Edward', ' VII']
+1580 344 Name of mother of x -1 Name of mother of Nancy Sinatra Nancy Barbato Nancy Sinatra "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " Sinatra and Nancy Sinatra duet, ""Somethin'" 5 [' Sin', 'atra', ' and', ' Nancy', ' Sin', 'atra']
+1581 344 Name of mother of x -1 Name of mother of Nancy Sinatra Nancy Barbato Nancy Sinatra "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False on a cover of Nancy Sinatra and Lee Hazlewood's 6 [' on', ' a', ' cover', ' of', ' Nancy', ' Sin', 'atra']
+1582 344 Name of mother of x -1 Name of mother of Nancy Sinatra Nancy Barbato Nancy Sinatra "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False first wife, Nancy Sinatra (née Barbato; born 5 [' first', ' wife', ',', ' Nancy', ' Sin', 'atra']
+1583 344 Name of mother of x -1 Name of mother of Nancy Sinatra Nancy Barbato Nancy Sinatra "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " Sinatra version ==
+" 7 [' Sin', 'atra', ' version', ' ==', 'N', 'ancy', ' Sin', 'atra']
+1584 344 Name of mother of x -1 Name of mother of Nancy Sinatra Nancy Barbato Nancy Sinatra "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False first wife, Nancy Sinatra (née Barbato; 5 [' first', ' wife', ',', ' Nancy', ' Sin', 'atra']
+1585 345 Name of mother of x -1 Name of mother of Jodie Foster Brandy Almond Jodie Foster "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Foster' ""'s""]" ", who is a very good friend of mine .
+
+ I have been a fan of Foster 's" False unclear if Jodie Foster (Clarice Starling) 5 [' unclear', ' if', ' J', 'od', 'ie', ' Foster']
+1586 345 Name of mother of x -1 Name of mother of Jodie Foster Brandy Almond Jodie Foster "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Foster' ""'s""]" ", who is a very good friend of mine .
+
+ I have been a fan of Foster 's" False Washington, Owen, Jodie Foster and Chiwetel 7 [' Washington', ',', ' Owen', ',', ' J', 'od', 'ie', ' Foster']
+1587 345 Name of mother of x -1 Name of mother of Jodie Foster Brandy Almond Jodie Foster "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Foster' ""'s""]" ", who is a very good friend of mine .
+
+ I have been a fan of Foster 's" False Washington, Owen, Jodie Foster and Chiwetel 7 [' Washington', ',', ' Owen', ',', ' J', 'od', 'ie', ' Foster']
+1588 345 Name of mother of x -1 Name of mother of Jodie Foster Brandy Almond Jodie Foster "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Foster' ""'s""]" ", who is a very good friend of mine .
+
+ I have been a fan of Foster 's" False 3 ['J', 'od', 'ie', ' Foster']
+1589 345 Name of mother of x -1 Name of mother of Jodie Foster Brandy Almond Jodie Foster "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Foster' ""'s""]" ", who is a very good friend of mine .
+
+ I have been a fan of Foster 's" False American actress Jodie Foster has been substituted 5 [' American', ' actress', ' J', 'od', 'ie', ' Foster']
+1590 346 Name of mother of x -1 Name of mother of Nicholas II of Russia Maria Feodorovna (Dagmar of Denmark) Nicholas II of Russia "[',' ' the' ' Ts' 'ar' 'ina' ' Alexandra' ',' ' and' ' the' ' Ts' 'ar'
+ ""'s"" ' mother' ',' ' the' ' Dow' 'ager' ' Empress' ' Marie' ',']" , the Ts ar ina Alexandra , and the Ts ar 's mother , the Dow ager Empress Marie , False daughter of Tsar Nicholas II of Russia and Tsarina 7 [' daughter', ' of', ' Ts', 'ar', ' Nicholas', ' II', ' of', ' Russia']
+1591 346 Name of mother of x -1 Name of mother of Nicholas II of Russia Maria Feodorovna (Dagmar of Denmark) Nicholas II of Russia "[',' ' the' ' Ts' 'ar' 'ina' ' Alexandra' ',' ' and' ' the' ' Ts' 'ar'
+ ""'s"" ' mother' ',' ' the' ' Dow' 'ager' ' Empress' ' Marie' ',']" , the Ts ar ina Alexandra , and the Ts ar 's mother , the Dow ager Empress Marie , False meetings with Tsar Nicholas II of Russia in Kronstadt. She returned 7 [' meetings', ' with', ' Ts', 'ar', ' Nicholas', ' II', ' of', ' Russia']
+1592 346 Name of mother of x -1 Name of mother of Nicholas II of Russia Maria Feodorovna (Dagmar of Denmark) Nicholas II of Russia "[',' ' the' ' Ts' 'ar' 'ina' ' Alexandra' ',' ' and' ' the' ' Ts' 'ar'
+ ""'s"" ' mother' ',' ' the' ' Dow' 'ager' ' Empress' ' Marie' ',']" , the Ts ar ina Alexandra , and the Ts ar 's mother , the Dow ager Empress Marie , False died and her nephew Nicholas II of Russia became Tsar. Alexandra's 7 [' died', ' and', ' her', ' nephew', ' Nicholas', ' II', ' of', ' Russia']
+1593 346 Name of mother of x -1 Name of mother of Nicholas II of Russia Maria Feodorovna (Dagmar of Denmark) Nicholas II of Russia "[',' ' the' ' Ts' 'ar' 'ina' ' Alexandra' ',' ' and' ' the' ' Ts' 'ar'
+ ""'s"" ' mother' ',' ' the' ' Dow' 'ager' ' Empress' ' Marie' ',']" , the Ts ar ina Alexandra , and the Ts ar 's mother , the Dow ager Empress Marie , False conferred by Tsar Nicholas II of Russia to Lasker, Capablanca, 7 [' conferred', ' by', ' Ts', 'ar', ' Nicholas', ' II', ' of', ' Russia']
+1594 346 Name of mother of x -1 Name of mother of Nicholas II of Russia Maria Feodorovna (Dagmar of Denmark) Nicholas II of Russia "[',' ' the' ' Ts' 'ar' 'ina' ' Alexandra' ',' ' and' ' the' ' Ts' 'ar'
+ ""'s"" ' mother' ',' ' the' ' Dow' 'ager' ' Empress' ' Marie' ',']" , the Ts ar ina Alexandra , and the Ts ar 's mother , the Dow ager Empress Marie , False Wilhelm II and Czar Nicholas II of Russia visited the 8 [' Wilhelm', ' II', ' and', ' C', 'zar', ' Nicholas', ' II', ' of', ' Russia']
+1595 347 Name of mother of x -1 Name of mother of Bernie Sanders Dorothy Sanders Bernie Sanders "[',' ' the' ' former' ' Democratic' ' presidential' ' candidate' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'I' ' am' ' deeply'
+ ' saddened']" ", the former Democratic presidential candidate , said in a statement .
+
+ � � I am deeply saddened" False chaired by U.S. Sen. Bernie Sanders (I-VT) regarding farm 9 [' chaired', ' by', ' U', '.', 'S', '.', ' Sen', '.', ' Bernie', ' Sanders']
+1596 347 Name of mother of x -1 Name of mother of Bernie Sanders Dorothy Sanders Bernie Sanders "[',' ' the' ' former' ' Democratic' ' presidential' ' candidate' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'I' ' am' ' deeply'
+ ' saddened']" ", the former Democratic presidential candidate , said in a statement .
+
+ � � I am deeply saddened" False support for Bernie Sanders among young 3 [' support', ' for', ' Bernie', ' Sanders']
+1597 347 Name of mother of x -1 Name of mother of Bernie Sanders Dorothy Sanders Bernie Sanders "[',' ' the' ' former' ' Democratic' ' presidential' ' candidate' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'I' ' am' ' deeply'
+ ' saddened']" ", the former Democratic presidential candidate , said in a statement .
+
+ � � I am deeply saddened" False 1 ['Bernie', ' Sanders']
+1598 347 Name of mother of x -1 Name of mother of Bernie Sanders Dorothy Sanders Bernie Sanders "[',' ' the' ' former' ' Democratic' ' presidential' ' candidate' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'I' ' am' ' deeply'
+ ' saddened']" ", the former Democratic presidential candidate , said in a statement .
+
+ � � I am deeply saddened" False Democratic debate with Bernie Sanders on February 4 [' Democratic', ' debate', ' with', ' Bernie', ' Sanders']
+1599 347 Name of mother of x -1 Name of mother of Bernie Sanders Dorothy Sanders Bernie Sanders "[',' ' the' ' former' ' Democratic' ' presidential' ' candidate' ','
+ ' said' ' in' ' a' ' statement' '.' '\n' '\n' '�' '�' 'I' ' am' ' deeply'
+ ' saddened']" ", the former Democratic presidential candidate , said in a statement .
+
+ � � I am deeply saddened" False broad support for Bernie Sanders among young Democratic 4 [' broad', ' support', ' for', ' Bernie', ' Sanders']
+1600 348 Name of mother of x -1 Name of mother of Prosper Mérimée Anne Louise Moreau Prosper Mérimée "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '03'
+ ',' ' and' ' died' ' in' ' 18' '71' '.' '\n' '\n']" ", the French writer , who was born in 18 03 , and died in 18 71 .
+
+" False and by the time Prosper Mérimée saw it in 1835 7 [' and', ' by', ' the', ' time', ' Prosper', ' Mé', 'rim', 'ée']
+1601 348 Name of mother of x -1 Name of mother of Prosper Mérimée Anne Louise Moreau Prosper Mérimée "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '03'
+ ',' ' and' ' died' ' in' ' 18' '71' '.' '\n' '\n']" ", the French writer , who was born in 18 03 , and died in 18 71 .
+
+" False by the time Prosper Mérimée saw it in 1835 6 [' by', ' the', ' time', ' Prosper', ' Mé', 'rim', 'ée']
+1602 348 Name of mother of x -1 Name of mother of Prosper Mérimée Anne Louise Moreau Prosper Mérimée "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '03'
+ ',' ' and' ' died' ' in' ' 18' '71' '.' '\n' '\n']" ", the French writer , who was born in 18 03 , and died in 18 71 .
+
+" False deteriorate and by the time Prosper Mérimée saw it in 1835 9 [' deterior', 'ate', ' and', ' by', ' the', ' time', ' Prosper', ' Mé', 'rim', 'ée']
+1603 348 Name of mother of x -1 Name of mother of Prosper Mérimée Anne Louise Moreau Prosper Mérimée "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '03'
+ ',' ' and' ' died' ' in' ' 18' '71' '.' '\n' '\n']" ", the French writer , who was born in 18 03 , and died in 18 71 .
+
+" False French writer Prosper Mérimée first proposed 5 [' French', ' writer', ' Prosper', ' Mé', 'rim', 'ée']
+1604 348 Name of mother of x -1 Name of mother of Prosper Mérimée Anne Louise Moreau Prosper Mérimée "[',' ' the' ' French' ' writer' ',' ' who' ' was' ' born' ' in' ' 18' '03'
+ ',' ' and' ' died' ' in' ' 18' '71' '.' '\n' '\n']" ", the French writer , who was born in 18 03 , and died in 18 71 .
+
+" False 4 ['Pros', 'per', ' Mé', 'rim', 'ée']
+1605 349 Name of mother of x -1 Name of mother of Margaret Thatcher Beatrice Ethel Stephenson Margaret Thatcher "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False and the Pope, Margaret Thatcher and Tony Benn, peace 5 [' and', ' the', ' Pope', ',', ' Margaret', ' Thatcher']
+1606 349 Name of mother of x -1 Name of mother of Margaret Thatcher Beatrice Ethel Stephenson Margaret Thatcher "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False was under Margaret Thatcher who was breaking 3 [' was', ' under', ' Margaret', ' Thatcher']
+1607 349 Name of mother of x -1 Name of mother of Margaret Thatcher Beatrice Ethel Stephenson Margaret Thatcher "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False 43 seats, and Margaret Thatcher became prime 5 [' 43', ' seats', ',', ' and', ' Margaret', ' Thatcher']
+1608 349 Name of mother of x -1 Name of mother of Margaret Thatcher Beatrice Ethel Stephenson Margaret Thatcher "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False Conservative Leader Margaret Thatcher in 2013, Miliband 3 [' Conservative', ' Leader', ' Margaret', ' Thatcher']
+1609 349 Name of mother of x -1 Name of mother of Margaret Thatcher Beatrice Ethel Stephenson Margaret Thatcher "[',' ' the' ' Queen' ' of' ' England' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',' ' and' ' the']" , the Queen of England , and the Queen of England , and the Queen of England , and the False 2 ['Marg', 'aret', ' Thatcher']
+1610 350 Name of mother of x -1 Name of mother of Fidel Castro Lina Ruz González Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' people' '.'
+ '\n' '\n' 'The' ' Cuban' ' people' ' are' ' not' ' only' ' the' ' most']" ", the Cuban leader , and the Cuban people .
+
+ The Cuban people are not only the most" False world, after Fidel Castro of Cuba and Kim 4 [' world', ',', ' after', ' Fidel', ' Castro']
+1611 350 Name of mother of x -1 Name of mother of Fidel Castro Lina Ruz González Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' people' '.'
+ '\n' '\n' 'The' ' Cuban' ' people' ' are' ' not' ' only' ' the' ' most']" ", the Cuban leader , and the Cuban people .
+
+ The Cuban people are not only the most" False exiles linked to Fidel Castro through the July 5 [' ex', 'iles', ' linked', ' to', ' Fidel', ' Castro']
+1612 350 Name of mother of x -1 Name of mother of Fidel Castro Lina Ruz González Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' people' '.'
+ '\n' '\n' 'The' ' Cuban' ' people' ' are' ' not' ' only' ' the' ' most']" ", the Cuban leader , and the Cuban people .
+
+ The Cuban people are not only the most" False Cuban President Fidel Castro described the 3 [' Cuban', ' President', ' Fidel', ' Castro']
+1613 350 Name of mother of x -1 Name of mother of Fidel Castro Lina Ruz González Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' people' '.'
+ '\n' '\n' 'The' ' Cuban' ' people' ' are' ' not' ' only' ' the' ' most']" ", the Cuban leader , and the Cuban people .
+
+ The Cuban people are not only the most" False 2 ['F', 'idel', ' Castro']
+1614 350 Name of mother of x -1 Name of mother of Fidel Castro Lina Ruz González Fidel Castro "[',' ' the' ' Cuban' ' leader' ',' ' and' ' the' ' Cuban' ' people' '.'
+ '\n' '\n' 'The' ' Cuban' ' people' ' are' ' not' ' only' ' the' ' most']" ", the Cuban leader , and the Cuban people .
+
+ The Cuban people are not only the most" False Revolution 6 [' Rev', 'ol', 'ut', 'io', 'F', 'idel', ' Castro']
+1615 351 Name of mother of x -1 Name of mother of Stefan Zweig Ida Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " Collection
+" 5 [' Collection', 'Ste', 'fan', ' Z', 'we', 'ig']
+1616 351 Name of mother of x -1 Name of mother of Stefan Zweig Ida Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False Royal Game by Stefan Zweig and Vladimir Nabokov's 6 [' Royal', ' Game', ' by', ' Stefan', ' Z', 'we', 'ig']
+1617 351 Name of mother of x -1 Name of mother of Stefan Zweig Ida Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False The Royal Game by Stefan Zweig and Vladimir 7 [' The', ' Royal', ' Game', ' by', ' Stefan', ' Z', 'we', 'ig']
+1618 351 Name of mother of x -1 Name of mother of Stefan Zweig Ida Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " Collection
+" 5 [' Collection', 'Ste', 'fan', ' Z', 'we', 'ig']
+1619 351 Name of mother of x -1 Name of mother of Stefan Zweig Ida Zweig Stefan Zweig "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " Howl. Austrian writer Stefan Zweig remarked, ""If" 8 [' How', 'l', '.', ' Austrian', ' writer', ' Stefan', ' Z', 'we', 'ig']
+1620 352 Name of mother of x -1 Name of mother of François-René de Chateaubriand Apolline Jeanne Suzanne de Bédée François-René de Chateaubriand "[',' ' the' ' author' ' of' ' the' ' _' 'G' 'é' 'nie' ' du' ' Christian'
+ 'ism' 'e' '_' ',' ' and' ' of' ' the' ' _' 'M']" , the author of the _ G é nie du Christian ism e _ , and of the _ M False 12 ['Fran', 'ç', 'ois', '-', 'Ren', 'é', ' de', ' Ch', 'ate', 'a', 'ub', 'ri', 'and']
+1621 352 Name of mother of x -1 Name of mother of François-René de Chateaubriand Apolline Jeanne Suzanne de Bédée François-René de Chateaubriand "[',' ' the' ' author' ' of' ' the' ' _' 'G' 'é' 'nie' ' du' ' Christian'
+ 'ism' 'e' '_' ',' ' and' ' of' ' the' ' _' 'M']" , the author of the _ G é nie du Christian ism e _ , and of the _ M False centenary of the death of François-René de Chateaubriand and she visited 16 [' cent', 'enary', ' of', ' the', ' death', ' of', ' François', '-', 'Ren', 'é', ' de', ' Ch', 'ate', 'a', 'ub', 'ri', 'and']
+1622 352 Name of mother of x -1 Name of mother of François-René de Chateaubriand Apolline Jeanne Suzanne de Bédée François-René de Chateaubriand "[',' ' the' ' author' ' of' ' the' ' _' 'G' 'é' 'nie' ' du' ' Christian'
+ 'ism' 'e' '_' ',' ' and' ' of' ' the' ' _' 'M']" , the author of the _ G é nie du Christian ism e _ , and of the _ M False centenary of the death of François-René de Chateaubriand and she visited 16 [' cent', 'enary', ' of', ' the', ' death', ' of', ' François', '-', 'Ren', 'é', ' de', ' Ch', 'ate', 'a', 'ub', 'ri', 'and']
+1623 354 Name of mother of x -1 Name of mother of Richard Feynman Lucille Feynman Richard Feynman "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' the' ' atomic' ' bomb' '.' '\n' '\n' 'The' ' father'
+ ' of']" ", the father of modern physics , and the father of the atomic bomb .
+
+ The father of" False revisited his role as Richard Feynman in Peter Parnell's 8 [' revis', 'ited', ' his', ' role', ' as', ' Richard', ' Fe', 'yn', 'man']
+1624 354 Name of mother of x -1 Name of mother of Richard Feynman Lucille Feynman Richard Feynman "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' the' ' atomic' ' bomb' '.' '\n' '\n' 'The' ' father'
+ ' of']" ", the father of modern physics , and the father of the atomic bomb .
+
+ The father of" False " Schwinger and Richard Feynman in the late 1940s.
+" 6 [' Schw', 'inger', ' and', ' Richard', ' Fe', 'yn', 'man']
+1625 354 Name of mother of x -1 Name of mother of Richard Feynman Lucille Feynman Richard Feynman "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' the' ' atomic' ' bomb' '.' '\n' '\n' 'The' ' father'
+ ' of']" ", the father of modern physics , and the father of the atomic bomb .
+
+ The father of" False bongo drum, as Richard Feynman played bongos. 8 [' b', 'ongo', ' drum', ',', ' as', ' Richard', ' Fe', 'yn', 'man']
+1626 354 Name of mother of x -1 Name of mother of Richard Feynman Lucille Feynman Richard Feynman "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' the' ' atomic' ' bomb' '.' '\n' '\n' 'The' ' father'
+ ' of']" ", the father of modern physics , and the father of the atomic bomb .
+
+ The father of" False theory. Developed by Richard Feynman and others in 8 [' theory', '.', ' Develop', 'ed', ' by', ' Richard', ' Fe', 'yn', 'man']
+1627 354 Name of mother of x -1 Name of mother of Richard Feynman Lucille Feynman Richard Feynman "[',' ' the' ' father' ' of' ' modern' ' physics' ',' ' and' ' the'
+ ' father' ' of' ' the' ' atomic' ' bomb' '.' '\n' '\n' 'The' ' father'
+ ' of']" ", the father of modern physics , and the father of the atomic bomb .
+
+ The father of" False " quantum vacuum. Richard Feynman called it ""the" 6 [' quantum', ' vacuum', '.', ' Richard', ' Fe', 'yn', 'man']
+1628 355 Name of mother of x -1 Name of mother of Carl Friedrich Gauss Dorthea Benze Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' was' ' born' ' in' ' 17' '77' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' Ga' 'uss' ' family' ' to']" ", the mathematician , was born in 17 77 .
+
+ The first of the Ga uss family to" False when, in 1832, Carl Friedrich Gauss used it, the 9 [' when', ',', ' in', ' 18', '32', ',', ' Carl', ' Friedrich', ' Ga', 'uss']
+1629 355 Name of mother of x -1 Name of mother of Carl Friedrich Gauss Dorthea Benze Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' was' ' born' ' in' ' 17' '77' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' Ga' 'uss' ' family' ' to']" ", the mathematician , was born in 17 77 .
+
+ The first of the Ga uss family to" False system when, in 1832, Carl Friedrich Gauss used it, the centimetre 10 [' system', ' when', ',', ' in', ' 18', '32', ',', ' Carl', ' Friedrich', ' Ga', 'uss']
+1630 355 Name of mother of x -1 Name of mother of Carl Friedrich Gauss Dorthea Benze Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' was' ' born' ' in' ' 17' '77' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' Ga' 'uss' ' family' ' to']" ", the mathematician , was born in 17 77 .
+
+ The first of the Ga uss family to" False respectively. In the 1830s Carl Friedrich Gauss laid the foundations 9 [' respectively', '.', ' In', ' the', ' 1830', 's', ' Carl', ' Friedrich', ' Ga', 'uss']
+1631 355 Name of mother of x -1 Name of mother of Carl Friedrich Gauss Dorthea Benze Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' was' ' born' ' in' ' 17' '77' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' Ga' 'uss' ' family' ' to']" ", the mathematician , was born in 17 77 .
+
+ The first of the Ga uss family to" False system when, in 1832, Carl Friedrich Gauss used it, the centimetre 10 [' system', ' when', ',', ' in', ' 18', '32', ',', ' Carl', ' Friedrich', ' Ga', 'uss']
+1632 355 Name of mother of x -1 Name of mother of Carl Friedrich Gauss Dorthea Benze Carl Friedrich Gauss "[',' ' the' ' mathematician' ',' ' was' ' born' ' in' ' 17' '77' '.' '\n'
+ '\n' 'The' ' first' ' of' ' the' ' Ga' 'uss' ' family' ' to']" ", the mathematician , was born in 17 77 .
+
+ The first of the Ga uss family to" False system when, in 1832, Carl Friedrich Gauss used it, the centimetre 10 [' system', ' when', ',', ' in', ' 18', '32', ',', ' Carl', ' Friedrich', ' Ga', 'uss']
+1633 356 Name of mother of x -1 Name of mother of Bill Gates Mary Maxwell Gates Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' founder'
+ ' of' ' the' ' Bill' ' and' ' Mel' 'inda' ' Gates' ' Foundation' ','
+ ' and' ' the']" , the founder of Microsoft , and the founder of the Bill and Mel inda Gates Foundation , and the False " on people like Bill Gates and Sam Walton""." 4 [' on', ' people', ' like', ' Bill', ' Gates']
+1634 356 Name of mother of x -1 Name of mother of Bill Gates Mary Maxwell Gates Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' founder'
+ ' of' ' the' ' Bill' ' and' ' Mel' 'inda' ' Gates' ' Foundation' ','
+ ' and' ' the']" , the founder of Microsoft , and the founder of the Bill and Mel inda Gates Foundation , and the False " later dismantled by Bill Gates and his goons.
+" 4 [' later', ' dismantled', ' by', ' Bill', ' Gates']
+1635 356 Name of mother of x -1 Name of mother of Bill Gates Mary Maxwell Gates Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' founder'
+ ' of' ' the' ' Bill' ' and' ' Mel' 'inda' ' Gates' ' Foundation' ','
+ ' and' ' the']" , the founder of Microsoft , and the founder of the Bill and Mel inda Gates Foundation , and the False It is owned by Bill Gates and is displayed 5 [' It', ' is', ' owned', ' by', ' Bill', ' Gates']
+1636 356 Name of mother of x -1 Name of mother of Bill Gates Mary Maxwell Gates Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' founder'
+ ' of' ' the' ' Bill' ' and' ' Mel' 'inda' ' Gates' ' Foundation' ','
+ ' and' ' the']" , the founder of Microsoft , and the founder of the Bill and Mel inda Gates Foundation , and the False Melinda and Bill Gates Foundation, saw 4 [' Mel', 'inda', ' and', ' Bill', ' Gates']
+1637 356 Name of mother of x -1 Name of mother of Bill Gates Mary Maxwell Gates Bill Gates "[',' ' the' ' founder' ' of' ' Microsoft' ',' ' and' ' the' ' founder'
+ ' of' ' the' ' Bill' ' and' ' Mel' 'inda' ' Gates' ' Foundation' ','
+ ' and' ' the']" , the founder of Microsoft , and the founder of the Bill and Mel inda Gates Foundation , and the False negotiated with Bill Gates for Windows 1.0. 3 [' negotiated', ' with', ' Bill', ' Gates']
+1638 357 Name of mother of x -1 Name of mother of Mariah Carey Patricia Hickey Carey Mariah Carey "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False and produced by Mariah Carey and Danja, with 5 [' and', ' produced', ' by', ' Mar', 'iah', ' Carey']
+1639 357 Name of mother of x -1 Name of mother of Mariah Carey Patricia Hickey Carey Mariah Carey "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False Janet Jackson and Mariah Carey to sell over 300,000 5 [' Janet', ' Jackson', ' and', ' Mar', 'iah', ' Carey']
+1640 357 Name of mother of x -1 Name of mother of Mariah Carey Patricia Hickey Carey Mariah Carey "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False " that week - Mariah Carey and Westlife's ""Against" 5 [' that', ' week', ' -', ' Mar', 'iah', ' Carey']
+1641 357 Name of mother of x -1 Name of mother of Mariah Carey Patricia Hickey Carey Mariah Carey "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False Company titled Mariah Carey Homecoming 4 [' Company', ' titled', ' Mar', 'iah', ' Carey']
+1642 357 Name of mother of x -1 Name of mother of Mariah Carey Patricia Hickey Carey Mariah Carey "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the mother of the bride , and the mother of the groom .
+
+ The wedding was held" False first time alongside Mariah Carey on October 5 [' first', ' time', ' alongside', ' Mar', 'iah', ' Carey']
+1643 358 Name of mother of x -1 Name of mother of Percy Bysshe Shelley Elizabeth Pilford Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Shelley' ','
+ ' the' ' novelist' ',' ' and' ' their' ' daughter' ',' ' Claire' ',']" , the poet , and his wife , Mary Shelley , the novelist , and their daughter , Claire , False contemporary, poet Percy Bysshe Shelley (1792 – 1822) 7 [' contemporary', ',', ' poet', ' Percy', ' By', 'ss', 'he', ' Shelley']
+1644 358 Name of mother of x -1 Name of mother of Percy Bysshe Shelley Elizabeth Pilford Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Shelley' ','
+ ' the' ' novelist' ',' ' and' ' their' ' daughter' ',' ' Claire' ',']" , the poet , and his wife , Mary Shelley , the novelist , and their daughter , Claire , False contemporary, poet Percy Bysshe Shelley (1792 – 1822) 7 [' contemporary', ',', ' poet', ' Percy', ' By', 'ss', 'he', ' Shelley']
+1645 358 Name of mother of x -1 Name of mother of Percy Bysshe Shelley Elizabeth Pilford Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Shelley' ','
+ ' the' ' novelist' ',' ' and' ' their' ' daughter' ',' ' Claire' ',']" , the poet , and his wife , Mary Shelley , the novelist , and their daughter , Claire , False 1816, Mary Shelley and Percy Bysshe Shelley toured the Alps and 10 [' 18', '16', ',', ' Mary', ' Shelley', ' and', ' Percy', ' By', 'ss', 'he', ' Shelley']
+1646 358 Name of mother of x -1 Name of mother of Percy Bysshe Shelley Elizabeth Pilford Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Shelley' ','
+ ' the' ' novelist' ',' ' and' ' their' ' daughter' ',' ' Claire' ',']" , the poet , and his wife , Mary Shelley , the novelist , and their daughter , Claire , False her husband Percy Bysshe Shelley had lived in Italy 6 [' her', ' husband', ' Percy', ' By', 'ss', 'he', ' Shelley']
+1647 358 Name of mother of x -1 Name of mother of Percy Bysshe Shelley Elizabeth Pilford Percy Bysshe Shelley "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Mary' ' Shelley' ','
+ ' the' ' novelist' ',' ' and' ' their' ' daughter' ',' ' Claire' ',']" , the poet , and his wife , Mary Shelley , the novelist , and their daughter , Claire , False 5 ['Per', 'cy', ' By', 'ss', 'he', ' Shelley']
+1648 359 Name of mother of x -1 Name of mother of Norman Mailer Fanny Schneider Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " Monroe's contemporary Norman Mailer wrote that ""Marilyn" 5 "[' Monroe', ""'s"", ' contemporary', ' Norman', ' Mail', 'er']"
+1649 359 Name of mother of x -1 Name of mother of Norman Mailer Fanny Schneider Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False truth that, as Norman Mailer put it in one of 6 [' truth', ' that', ',', ' as', ' Norman', ' Mail', 'er']
+1650 359 Name of mother of x -1 Name of mother of Norman Mailer Fanny Schneider Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " Monroe's contemporary Norman Mailer wrote that ""Marilyn" 5 "[' Monroe', ""'s"", ' contemporary', ' Norman', ' Mail', 'er']"
+1651 359 Name of mother of x -1 Name of mother of Norman Mailer Fanny Schneider Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " of the earth."" Norman Mailer called her the ""Robespierre" 6 "[' of', ' the', ' earth', '.""', ' Norman', ' Mail', 'er']"
+1652 359 Name of mother of x -1 Name of mother of Norman Mailer Fanny Schneider Norman Mailer "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False ranking at No. 46. Norman Mailer described 8 [' ranking', ' at', ' No', '.', ' 46', '.', ' Norman', ' Mail', 'er']
+1653 360 Name of mother of x -1 Name of mother of John Constable Ann Watts John Constable "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False English painter John Constable visited Hadleigh in 3 [' English', ' painter', ' John', ' Constable']
+1654 360 Name of mother of x -1 Name of mother of John Constable Ann Watts John Constable "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False – 1851) and John Constable (1776 – 1837) 6 [' –', ' 18', '51', ')', ' and', ' John', ' Constable']
+1655 360 Name of mother of x -1 Name of mother of John Constable Ann Watts John Constable "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False British artists John Constable and J. M. W. 3 [' British', ' artists', ' John', ' Constable']
+1656 360 Name of mother of x -1 Name of mother of John Constable Ann Watts John Constable "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False soundly defeated John Constable by eighteen 4 [' sound', 'ly', ' defeated', ' John', ' Constable']
+1657 360 Name of mother of x -1 Name of mother of John Constable Ann Watts John Constable "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False Painters such as John Constable and Jean-Baptiste-Camille 5 [' Pain', 'ters', ' such', ' as', ' John', ' Constable']
+1658 361 Name of mother of x -1 Name of mother of George Gershwin Rose Bruskina George Gershwin "[',' ' the' ' composer' ',' ' and' ' the' ' composer' ' of' ' the' ' song'
+ ' ""' 'I' ' Got' ' Rhythm' '""' ' (' '19' '17' ')' ' and']" ", the composer , and the composer of the song "" I Got Rhythm "" ( 19 17 ) and" False " Lu Walters on the George Gershwin aria ""Summertime""." 8 [' Lu', ' Walters', ' on', ' the', ' George', ' G', 'ers', 'h', 'win']
+1659 361 Name of mother of x -1 Name of mother of George Gershwin Rose Bruskina George Gershwin "[',' ' the' ' composer' ',' ' and' ' the' ' composer' ' of' ' the' ' song'
+ ' ""' 'I' ' Got' ' Rhythm' '""' ' (' '19' '17' ')' ' and']" ", the composer , and the composer of the song "" I Got Rhythm "" ( 19 17 ) and" False songwriters like George Gershwin or Cole Porter, 7 [' song', 'writers', ' like', ' George', ' G', 'ers', 'h', 'win']
+1660 361 Name of mother of x -1 Name of mother of George Gershwin Rose Bruskina George Gershwin "[',' ' the' ' composer' ',' ' and' ' the' ' composer' ' of' ' the' ' song'
+ ' ""' 'I' ' Got' ' Rhythm' '""' ' (' '19' '17' ')' ' and']" ", the composer , and the composer of the song "" I Got Rhythm "" ( 19 17 ) and" False Pulitzer, as composer George Gershwin had not been recognized 8 [' Pulitzer', ',', ' as', ' composer', ' George', ' G', 'ers', 'h', 'win']
+1661 361 Name of mother of x -1 Name of mother of George Gershwin Rose Bruskina George Gershwin "[',' ' the' ' composer' ',' ' and' ' the' ' composer' ' of' ' the' ' song'
+ ' ""' 'I' ' Got' ' Rhythm' '""' ' (' '19' '17' ')' ' and']" ", the composer , and the composer of the song "" I Got Rhythm "" ( 19 17 ) and" False Show Boat in 1927. George Gershwin was perhaps 9 [' Show', ' Boat', ' in', ' 1927', '.', ' George', ' G', 'ers', 'h', 'win']
+1662 361 Name of mother of x -1 Name of mother of George Gershwin Rose Bruskina George Gershwin "[',' ' the' ' composer' ',' ' and' ' the' ' composer' ' of' ' the' ' song'
+ ' ""' 'I' ' Got' ' Rhythm' '""' ' (' '19' '17' ')' ' and']" ", the composer , and the composer of the song "" I Got Rhythm "" ( 19 17 ) and" False with songwriters like George Gershwin or Cole Porter, he 8 [' with', ' song', 'writers', ' like', ' George', ' G', 'ers', 'h', 'win']
+1663 362 Name of mother of x -1 Name of mother of Justin Bieber Pattie Mallette Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ' beautiful'
+ ' children' ',' ' a' ' daughter' ' and' ' a' ' son' '.' ' I' ' am']" ".
+
+ I am a mother of two beautiful children , a daughter and a son . I am" False celebrity endorsements. Justin Bieber for example, was paid 4 [' celebrity', ' endorsements', '.', ' Justin', ' Bieber']
+1664 362 Name of mother of x -1 Name of mother of Justin Bieber Pattie Mallette Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ' beautiful'
+ ' children' ',' ' a' ' daughter' ' and' ' a' ' son' '.' ' I' ' am']" ".
+
+ I am a mother of two beautiful children , a daughter and a son . I am" False unofficially altered Justin Bieber song served as 5 [' un', 'offic', 'ially', ' altered', ' Justin', ' Bieber']
+1665 362 Name of mother of x -1 Name of mother of Justin Bieber Pattie Mallette Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ' beautiful'
+ ' children' ',' ' a' ' daughter' ' and' ' a' ' son' '.' ' I' ' am']" ".
+
+ I am a mother of two beautiful children , a daughter and a son . I am" False " bouquet and other Justin Bieber merchandise.
+" 5 [' bou', 'quet', ' and', ' other', ' Justin', ' Bieber']
+1666 362 Name of mother of x -1 Name of mother of Justin Bieber Pattie Mallette Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ' beautiful'
+ ' children' ',' ' a' ' daughter' ' and' ' a' ' son' '.' ' I' ' am']" ".
+
+ I am a mother of two beautiful children , a daughter and a son . I am" False Montana and Justin Bieber at recording studio 3 [' Montana', ' and', ' Justin', ' Bieber']
+1667 362 Name of mother of x -1 Name of mother of Justin Bieber Pattie Mallette Justin Bieber "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ' beautiful'
+ ' children' ',' ' a' ' daughter' ' and' ' a' ' son' '.' ' I' ' am']" ".
+
+ I am a mother of two beautiful children , a daughter and a son . I am" False throws with Justin Bieber chants. The 2013 3 [' throws', ' with', ' Justin', ' Bieber']
+1668 363 Name of mother of x -1 Name of mother of Selena Gomez Mandy Teefey Selena Gomez "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' huge' ' fan' ' of' ' Sel' 'ena']" ", who is a very good friend of mine .
+
+ I am a huge fan of Sel ena" False American pop band Selena Gomez & the Scene performed 5 [' American', ' pop', ' band', ' Sel', 'ena', ' Gomez']
+1669 363 Name of mother of x -1 Name of mother of Selena Gomez Mandy Teefey Selena Gomez "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' huge' ' fan' ' of' ' Sel' 'ena']" ", who is a very good friend of mine .
+
+ I am a huge fan of Sel ena" False sections and lyrics. Selena Gomez & the Scene performed 6 [' sections', ' and', ' lyrics', '.', ' Sel', 'ena', ' Gomez']
+1670 363 Name of mother of x -1 Name of mother of Selena Gomez Mandy Teefey Selena Gomez "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' huge' ' fan' ' of' ' Sel' 'ena']" ", who is a very good friend of mine .
+
+ I am a huge fan of Sel ena" False artists, including Selena Gomez & the Scene, 5 [' artists', ',', ' including', ' Sel', 'ena', ' Gomez']
+1671 363 Name of mother of x -1 Name of mother of Selena Gomez Mandy Teefey Selena Gomez "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' huge' ' fan' ' of' ' Sel' 'ena']" ", who is a very good friend of mine .
+
+ I am a huge fan of Sel ena" False American entertainer Selena Gomez (who was featured on 5 [' American', ' entertain', 'er', ' Sel', 'ena', ' Gomez']
+1672 363 Name of mother of x -1 Name of mother of Selena Gomez Mandy Teefey Selena Gomez "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' huge' ' fan' ' of' ' Sel' 'ena']" ", who is a very good friend of mine .
+
+ I am a huge fan of Sel ena" False Walker and Selena Gomez & the Scene, and 4 [' Walker', ' and', ' Sel', 'ena', ' Gomez']
+1673 364 Name of mother of x -1 Name of mother of Ludwig Wittgenstein Leopoldine 'Poldy' Kalmus Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' philosophy' '.' '\n' '\n' 'The' ' philosopher' ' Ludwig' ' Witt' 'gen'
+ 'stein' ' was']" ", the philosopher , and the father of modern philosophy .
+
+ The philosopher Ludwig Witt gen stein was" False successfully lobbied for Ludwig Wittgenstein to be allowed residency 6 [' successfully', ' lobbied', ' for', ' Ludwig', ' Witt', 'gen', 'stein']
+1674 364 Name of mother of x -1 Name of mother of Ludwig Wittgenstein Leopoldine 'Poldy' Kalmus Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' philosophy' '.' '\n' '\n' 'The' ' philosopher' ' Ludwig' ' Witt' 'gen'
+ 'stein' ' was']" ", the philosopher , and the father of modern philosophy .
+
+ The philosopher Ludwig Witt gen stein was" False fulfils this function. Ludwig Wittgenstein also proposed 8 [' fulf', 'ils', ' this', ' function', '.', ' Ludwig', ' Witt', 'gen', 'stein']
+1675 364 Name of mother of x -1 Name of mother of Ludwig Wittgenstein Leopoldine 'Poldy' Kalmus Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' philosophy' '.' '\n' '\n' 'The' ' philosopher' ' Ludwig' ' Witt' 'gen'
+ 'stein' ' was']" ", the philosopher , and the father of modern philosophy .
+
+ The philosopher Ludwig Witt gen stein was" False philosophers Ludwig Wittgenstein and Hans-Georg 4 [' philosophers', ' Ludwig', ' Witt', 'gen', 'stein']
+1676 364 Name of mother of x -1 Name of mother of Ludwig Wittgenstein Leopoldine 'Poldy' Kalmus Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' philosophy' '.' '\n' '\n' 'The' ' philosopher' ' Ludwig' ' Witt' 'gen'
+ 'stein' ' was']" ", the philosopher , and the father of modern philosophy .
+
+ The philosopher Ludwig Witt gen stein was" False fulfils this function. Ludwig Wittgenstein also proposed that 8 [' fulf', 'ils', ' this', ' function', '.', ' Ludwig', ' Witt', 'gen', 'stein']
+1677 364 Name of mother of x -1 Name of mother of Ludwig Wittgenstein Leopoldine 'Poldy' Kalmus Ludwig Wittgenstein "[',' ' the' ' philosopher' ',' ' and' ' the' ' father' ' of' ' modern'
+ ' philosophy' '.' '\n' '\n' 'The' ' philosopher' ' Ludwig' ' Witt' 'gen'
+ 'stein' ' was']" ", the philosopher , and the father of modern philosophy .
+
+ The philosopher Ludwig Witt gen stein was" False lectures given in 1939 by Ludwig Wittgenstein about the foundations 8 [' lectures', ' given', ' in', ' 1939', ' by', ' Ludwig', ' Witt', 'gen', 'stein']
+1678 365 Name of mother of x -1 Name of mother of Arthur Wellesley, 1st Duke of Wellington Anne Wellesley, Countess of Mornington Arthur Wellesley, 1st Duke of Wellington "[',' ' 1' 'st' ' Marqu' 'ess' ' of' ' Wellington' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Wellington' ',' ' 1' 'st' ' Earl' ' of' ' Wellington']" , 1 st Marqu ess of Wellington , 1 st V isc ount Wellington , 1 st Earl of Wellington False " Battle record of Arthur Wellesley, 1st Duke of Wellington =
+" 11 [' Battle', ' record', ' of', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1679 365 Name of mother of x -1 Name of mother of Arthur Wellesley, 1st Duke of Wellington Anne Wellesley, Countess of Mornington Arthur Wellesley, 1st Duke of Wellington "[',' ' 1' 'st' ' Marqu' 'ess' ' of' ' Wellington' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Wellington' ',' ' 1' 'st' ' Earl' ' of' ' Wellington']" , 1 st Marqu ess of Wellington , 1 st V isc ount Wellington , 1 st Earl of Wellington False support of the Arthur Wellesley, 1st Duke of Wellington (who was also UK's 11 [' support', ' of', ' the', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1680 365 Name of mother of x -1 Name of mother of Arthur Wellesley, 1st Duke of Wellington Anne Wellesley, Countess of Mornington Arthur Wellesley, 1st Duke of Wellington "[',' ' 1' 'st' ' Marqu' 'ess' ' of' ' Wellington' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Wellington' ',' ' 1' 'st' ' Earl' ' of' ' Wellington']" , 1 st Marqu ess of Wellington , 1 st V isc ount Wellington , 1 st Earl of Wellington False " = Battle record of Arthur Wellesley, 1st Duke of Wellington =
+" 12 [' =', ' Battle', ' record', ' of', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1681 365 Name of mother of x -1 Name of mother of Arthur Wellesley, 1st Duke of Wellington Anne Wellesley, Countess of Mornington Arthur Wellesley, 1st Duke of Wellington "[',' ' 1' 'st' ' Marqu' 'ess' ' of' ' Wellington' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Wellington' ',' ' 1' 'st' ' Earl' ' of' ' Wellington']" , 1 st Marqu ess of Wellington , 1 st V isc ount Wellington , 1 st Earl of Wellington False " Battle record of Arthur Wellesley, 1st Duke of Wellington =
+" 11 [' Battle', ' record', ' of', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1682 365 Name of mother of x -1 Name of mother of Arthur Wellesley, 1st Duke of Wellington Anne Wellesley, Countess of Mornington Arthur Wellesley, 1st Duke of Wellington "[',' ' 1' 'st' ' Marqu' 'ess' ' of' ' Wellington' ',' ' 1' 'st' ' V' 'isc'
+ 'ount' ' Wellington' ',' ' 1' 'st' ' Earl' ' of' ' Wellington']" , 1 st Marqu ess of Wellington , 1 st V isc ount Wellington , 1 st Earl of Wellington False support of the Arthur Wellesley, 1st Duke of Wellington (who was also 11 [' support', ' of', ' the', ' Arthur', ' Well', 'esley', ',', ' 1', 'st', ' Duke', ' of', ' Wellington']
+1683 366 Name of mother of x -1 Name of mother of John Everett Millais Mary Emily Evamy John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False Rossetti and John Everett Millais were leaders. Prominent 7 [' Ross', 'etti', ' and', ' John', ' Everett', ' M', 'illa', 'is']
+1684 366 Name of mother of x -1 Name of mother of John Everett Millais Mary Emily Evamy John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False a friend of John Everett Millais and he subsequently 7 [' a', ' friend', ' of', ' John', ' Everett', ' M', 'illa', 'is']
+1685 366 Name of mother of x -1 Name of mother of John Everett Millais Mary Emily Evamy John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False " early paintings of John Everett Millais and ""the wonderful" 7 [' early', ' paintings', ' of', ' John', ' Everett', ' M', 'illa', 'is']
+1686 366 Name of mother of x -1 Name of mother of John Everett Millais Mary Emily Evamy John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False Browning, Lord Tennyson, John Everett Millais and Henry James. 11 [' Brown', 'ing', ',', ' Lord', ' Tenn', 'yson', ',', ' John', ' Everett', ' M', 'illa', 'is']
+1687 366 Name of mother of x -1 Name of mother of John Everett Millais Mary Emily Evamy John Everett Millais "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' their' ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',']" , the painter , and his wife , the artist , and their daughter , the artist 's daughter , False (1860), by John Everett Millais was inspired 9 [' (', '18', '60', '),', ' by', ' John', ' Everett', ' M', 'illa', 'is']
+1688 367 Name of mother of x -1 Name of mother of Thomas Alva Edison Nancy Elliott Thomas Alva Edison "[',' ' the' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the inventor of the phon ograph , the
+ " False technology created by Thomas Alva Edison in 1930. The overhead 6 [' technology', ' created', ' by', ' Thomas', ' Al', 'va', ' Edison']
+1689 367 Name of mother of x -1 Name of mother of Thomas Alva Edison Nancy Elliott Thomas Alva Edison "[',' ' the' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the inventor of the phon ograph , the
+ " False technology created by Thomas Alva Edison in 1930. The overhead 6 [' technology', ' created', ' by', ' Thomas', ' Al', 'va', ' Edison']
+1690 367 Name of mother of x -1 Name of mother of Thomas Alva Edison Nancy Elliott Thomas Alva Edison "[',' ' the' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the inventor of the phon ograph , the
+ " False technology created by Thomas Alva Edison in 1930. The 6 [' technology', ' created', ' by', ' Thomas', ' Al', 'va', ' Edison']
+1691 367 Name of mother of x -1 Name of mother of Thomas Alva Edison Nancy Elliott Thomas Alva Edison "[',' ' the' ' inventor' ' of' ' the' ' phon' 'ograph' ',' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the inventor of the phon ograph , the
+ " False passes near the Thomas Alva Edison Memorial Tower 6 [' passes', ' near', ' the', ' Thomas', ' Al', 'va', ' Edison']
+1692 368 Name of mother of x -1 Name of mother of Lauren Bacall Natalie Alberta Bacall Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False Claudette Colbert, Lauren Bacall and Mildred Natwick 6 [' Claud', 'ette', ' Colbert', ',', ' Lauren', ' Bac', 'all']
+1693 368 Name of mother of x -1 Name of mother of Lauren Bacall Natalie Alberta Bacall Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False " cinematographers who lit Lauren Bacall and Grace Kelly.""
+" 7 [' cinem', 'at', 'ographers', ' who', ' lit', ' Lauren', ' Bac', 'all']
+1694 368 Name of mother of x -1 Name of mother of Lauren Bacall Natalie Alberta Bacall Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False a reference to Lauren Bacall in the 1944 film To 5 [' a', ' reference', ' to', ' Lauren', ' Bac', 'all']
+1695 368 Name of mother of x -1 Name of mother of Lauren Bacall Natalie Alberta Bacall Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False Claudette Colbert, Lauren Bacall and Mildred Natwick 6 [' Claud', 'ette', ' Colbert', ',', ' Lauren', ' Bac', 'all']
+1696 368 Name of mother of x -1 Name of mother of Lauren Bacall Natalie Alberta Bacall Lauren Bacall "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' '.' '\n' '\n' 'The']" ", the actress who played the role of the mother of the bride in the film .
+
+ The" False off engagements to Lauren Bacall in 1958, and Juliet 5 [' off', ' engagements', ' to', ' Lauren', ' Bac', 'all']
+1697 369 Name of mother of x -1 Name of mother of Carl Jung Emilie Preiswerk Carl Jung "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of psycho analysis , and the father of modern psychology .
+
+ The first thing that" False 1 ['Carl', ' Jung']
+1698 369 Name of mother of x -1 Name of mother of Carl Jung Emilie Preiswerk Carl Jung "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of psycho analysis , and the father of modern psychology .
+
+ The first thing that" False influenced both Carl Jung and Joseph Campbell. 3 [' influenced', ' both', ' Carl', ' Jung']
+1699 369 Name of mother of x -1 Name of mother of Carl Jung Emilie Preiswerk Carl Jung "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of psycho analysis , and the father of modern psychology .
+
+ The first thing that" False psychoanalysts Carl Jung and Sigmund Freud. 5 [' psych', 'oan', 'aly', 'sts', ' Carl', ' Jung']
+1700 369 Name of mother of x -1 Name of mother of Carl Jung Emilie Preiswerk Carl Jung "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of psycho analysis , and the father of modern psychology .
+
+ The first thing that" False cultures influenced both Carl Jung and Joseph Campbell. 4 [' cultures', ' influenced', ' both', ' Carl', ' Jung']
+1701 369 Name of mother of x -1 Name of mother of Carl Jung Emilie Preiswerk Carl Jung "[',' ' the' ' father' ' of' ' psycho' 'analysis' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' psychology' '.' '\n' '\n' 'The' ' first'
+ ' thing' ' that']" ", the father of psycho analysis , and the father of modern psychology .
+
+ The first thing that" False 1 ['Carl', ' Jung']
+1702 370 Name of mother of x -1 Name of mother of Franz Joseph I of Austria Princess Sophie of Bavaria Franz Joseph I of Austria "[',' ' the' ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of'
+ ' Austria' ',' ' the' ' Empress' ' of' ' Austria' ',' ' and' ' the'
+ ' Empress']" , the Emperor of Austria , and the Empress of Austria , the Empress of Austria , and the Empress False acquired by Emperor Franz Joseph I of Austria for 50,000 7 [' acquired', ' by', ' Emperor', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1703 370 Name of mother of x -1 Name of mother of Franz Joseph I of Austria Princess Sophie of Bavaria Franz Joseph I of Austria "[',' ' the' ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of'
+ ' Austria' ',' ' the' ' Empress' ' of' ' Austria' ',' ' and' ' the'
+ ' Empress']" , the Emperor of Austria , and the Empress of Austria , the Empress of Austria , and the Empress False acquired by Emperor Franz Joseph I of Austria for 50,000 franks. 7 [' acquired', ' by', ' Emperor', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1704 370 Name of mother of x -1 Name of mother of Franz Joseph I of Austria Princess Sophie of Bavaria Franz Joseph I of Austria "[',' ' the' ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of'
+ ' Austria' ',' ' the' ' Empress' ' of' ' Austria' ',' ' and' ' the'
+ ' Empress']" , the Emperor of Austria , and the Empress of Austria , the Empress of Austria , and the Empress False presumptive to the Emperor Franz Joseph I of Austria in 1914 after the 8 [' presumptive', ' to', ' the', ' Emperor', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1705 370 Name of mother of x -1 Name of mother of Franz Joseph I of Austria Princess Sophie of Bavaria Franz Joseph I of Austria "[',' ' the' ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of'
+ ' Austria' ',' ' the' ' Empress' ' of' ' Austria' ',' ' and' ' the'
+ ' Empress']" , the Emperor of Austria , and the Empress of Austria , the Empress of Austria , and the Empress False acquired by Emperor Franz Joseph I of Austria for 50,000 franks. 7 [' acquired', ' by', ' Emperor', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1706 370 Name of mother of x -1 Name of mother of Franz Joseph I of Austria Princess Sophie of Bavaria Franz Joseph I of Austria "[',' ' the' ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of'
+ ' Austria' ',' ' the' ' Empress' ' of' ' Austria' ',' ' and' ' the'
+ ' Empress']" , the Emperor of Austria , and the Empress of Austria , the Empress of Austria , and the Empress False Napoleon II of France, Franz Joseph I of Austria and Maximilian 9 [' Napoleon', ' II', ' of', ' France', ',', ' Franz', ' Joseph', ' I', ' of', ' Austria']
+1707 371 Name of mother of x -1 Name of mother of Catherine II of Russia Joanna Elisabeth of Holstein-Gottorp Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' and' ' the'
+ ' Empress' ' Elizabeth' '.' '\n' '\n' 'The' ' Empress' ' Catherine' ' II'
+ ' was' ' the']" ", the Empress Catherine the Great , and the Empress Elizabeth .
+
+ The Empress Catherine II was the" False Lubowla), Catherine II of Russia and her advisor 7 [' Lub', 'owl', 'a', '),', ' Catherine', ' II', ' of', ' Russia']
+1708 371 Name of mother of x -1 Name of mother of Catherine II of Russia Joanna Elisabeth of Holstein-Gottorp Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' and' ' the'
+ ' Empress' ' Elizabeth' '.' '\n' '\n' 'The' ' Empress' ' Catherine' ' II'
+ ' was' ' the']" ", the Empress Catherine the Great , and the Empress Elizabeth .
+
+ The Empress Catherine II was the" False II of Prussia and Catherine II of Russia would do it 8 [' II', ' of', ' Pr', 'ussia', ' and', ' Catherine', ' II', ' of', ' Russia']
+1709 371 Name of mother of x -1 Name of mother of Catherine II of Russia Joanna Elisabeth of Holstein-Gottorp Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' and' ' the'
+ ' Empress' ' Elizabeth' '.' '\n' '\n' 'The' ' Empress' ' Catherine' ' II'
+ ' was' ' the']" ", the Empress Catherine the Great , and the Empress Elizabeth .
+
+ The Empress Catherine II was the" False its completion as Catherine II of Russia had offered him a place 6 [' its', ' completion', ' as', ' Catherine', ' II', ' of', ' Russia']
+1710 371 Name of mother of x -1 Name of mother of Catherine II of Russia Joanna Elisabeth of Holstein-Gottorp Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' and' ' the'
+ ' Empress' ' Elizabeth' '.' '\n' '\n' 'The' ' Empress' ' Catherine' ' II'
+ ' was' ' the']" ", the Empress Catherine the Great , and the Empress Elizabeth .
+
+ The Empress Catherine II was the" False 4 ['C', 'atherine', ' II', ' of', ' Russia']
+1711 371 Name of mother of x -1 Name of mother of Catherine II of Russia Joanna Elisabeth of Holstein-Gottorp Catherine II of Russia "[',' ' the' ' Empress' ' Catherine' ' the' ' Great' ',' ' and' ' the'
+ ' Empress' ' Elizabeth' '.' '\n' '\n' 'The' ' Empress' ' Catherine' ' II'
+ ' was' ' the']" ", the Empress Catherine the Great , and the Empress Elizabeth .
+
+ The Empress Catherine II was the" False people. For example, Catherine II of Russia sent him seeds from 8 [' people', '.', ' For', ' example', ',', ' Catherine', ' II', ' of', ' Russia']
+1712 372 Name of mother of x -1 Name of mother of Theodore von Kármán Helene Konn Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' '.' '\n' '\n'
+ 'The' 'odore' ' von' ' K' 'á' 'rm' 'án' ' was' ' born']" ", the first American to fly a plane .
+
+ The odore von K á rm án was born" False doctoral advisor Theodore von Kármán saw more promise 7 [' doctoral', ' advisor', ' Theodore', ' von', ' K', 'á', 'rm', 'án']
+1713 372 Name of mother of x -1 Name of mother of Theodore von Kármán Helene Konn Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' '.' '\n' '\n'
+ 'The' 'odore' ' von' ' K' 'á' 'rm' 'án' ' was' ' born']" ", the first American to fly a plane .
+
+ The odore von K á rm án was born" False km (62 mi), as Theodore von Kármán calculated, a vehicle 11 [' km', ' (', '62', ' mi', '),', ' as', ' Theodore', ' von', ' K', 'á', 'rm', 'án']
+1714 372 Name of mother of x -1 Name of mother of Theodore von Kármán Helene Konn Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' '.' '\n' '\n'
+ 'The' 'odore' ' von' ' K' 'á' 'rm' 'án' ' was' ' born']" ", the first American to fly a plane .
+
+ The odore von K á rm án was born" False that included Theodore von Kármán (b. 1881), George 7 [' that', ' included', ' Theodore', ' von', ' K', 'á', 'rm', 'án']
+1715 372 Name of mother of x -1 Name of mother of Theodore von Kármán Helene Konn Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' '.' '\n' '\n'
+ 'The' 'odore' ' von' ' K' 'á' 'rm' 'án' ' was' ' born']" ", the first American to fly a plane .
+
+ The odore von K á rm án was born" False " student), ""Ka"" for Theodore von Kármán (a Privatdozent)," 11 "[' student', '),', ' ""', 'Ka', '""', ' for', ' Theodore', ' von', ' K', 'á', 'rm', 'án']"
+1716 372 Name of mother of x -1 Name of mother of Theodore von Kármán Helene Konn Theodore von Kármán "[',' ' the' ' first' ' American' ' to' ' fly' ' a' ' plane' '.' '\n' '\n'
+ 'The' 'odore' ' von' ' K' 'á' 'rm' 'án' ' was' ' born']" ", the first American to fly a plane .
+
+ The odore von K á rm án was born" False Malina's doctoral advisor Theodore von Kármán saw more promise 10 "[' Mal', 'ina', ""'s"", ' doctoral', ' advisor', ' Theodore', ' von', ' K', 'á', 'rm', 'án']"
+1717 373 Name of mother of x -1 Name of mother of Woodrow Wilson Janet Woodrow Sumersimpson Woodrow Wilson "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States'
+ ' of']" ", the
+
+ President of the United States , and the
+
+ President of the United States of" False by President Woodrow Wilson and Secretary of the 4 [' by', ' President', ' Wood', 'row', ' Wilson']
+1718 373 Name of mother of x -1 Name of mother of Woodrow Wilson Janet Woodrow Sumersimpson Woodrow Wilson "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States'
+ ' of']" ", the
+
+ President of the United States , and the
+
+ President of the United States of" False Turkey. President Woodrow Wilson agreed to act 5 [' Turkey', '.', ' President', ' Wood', 'row', ' Wilson']
+1719 373 Name of mother of x -1 Name of mother of Woodrow Wilson Janet Woodrow Sumersimpson Woodrow Wilson "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States'
+ ' of']" ", the
+
+ President of the United States , and the
+
+ President of the United States of" False States. When President Woodrow Wilson arrived at Brest 6 [' States', '.', ' When', ' President', ' Wood', 'row', ' Wilson']
+1720 373 Name of mother of x -1 Name of mother of Woodrow Wilson Janet Woodrow Sumersimpson Woodrow Wilson "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States'
+ ' of']" ", the
+
+ President of the United States , and the
+
+ President of the United States of" False ballet. President Woodrow Wilson refused to miss any 5 [' ballet', '.', ' President', ' Wood', 'row', ' Wilson']
+1721 373 Name of mother of x -1 Name of mother of Woodrow Wilson Janet Woodrow Sumersimpson Woodrow Wilson "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States'
+ ' of']" ", the
+
+ President of the United States , and the
+
+ President of the United States of" False Regent's Fellowship, Woodrow Wilson Fellowship 7 "[' Reg', 'ent', ""'s"", ' Fellowship', ',', ' Wood', 'row', ' Wilson']"
+1722 374 Name of mother of x -1 Name of mother of Adam Mickiewicz Barbara Mickiewicz, née Majewska Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False of sociology at Adam Mickiewicz University where he 5 [' of', ' sociology', ' at', ' Adam', ' Mick', 'iewicz']
+1723 374 Name of mother of x -1 Name of mother of Adam Mickiewicz Barbara Mickiewicz, née Majewska Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False Cultural Studies from Adam Mickiewicz University, Poznań, 5 [' Cultural', ' Studies', ' from', ' Adam', ' Mick', 'iewicz']
+1724 374 Name of mother of x -1 Name of mother of Adam Mickiewicz Barbara Mickiewicz, née Majewska Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False culture as poet Adam Mickiewicz and painter Jan 5 [' culture', ' as', ' poet', ' Adam', ' Mick', 'iewicz']
+1725 374 Name of mother of x -1 Name of mother of Adam Mickiewicz Barbara Mickiewicz, née Majewska Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False Cultural Studies from Adam Mickiewicz University, 5 [' Cultural', ' Studies', ' from', ' Adam', ' Mick', 'iewicz']
+1726 374 Name of mother of x -1 Name of mother of Adam Mickiewicz Barbara Mickiewicz, née Majewska Adam Mickiewicz "[',' ' the' ' Polish' ' poet' ',' ' was' ' born' ' in' ' the' ' village'
+ ' of' ' W' 'ola' ',' ' near' ' Warsaw' ',' ' in' ' 17' '98']" , the Polish poet , was born in the village of W ola , near Warsaw , in 17 98 False " Adam Mickiewicz =
+" 2 [' Adam', ' Mick', 'iewicz']
+1727 375 Name of mother of x -1 Name of mother of Pablo Neruda Rosa Neftalí Basoalto Opazo Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False politician Pablo Neruda organized the immigration 3 [' politician', ' Pablo', ' Ner', 'uda']
+1728 375 Name of mother of x -1 Name of mother of Pablo Neruda Rosa Neftalí Basoalto Opazo Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False others, the Chilean poet Pablo Neruda is also associated 7 [' others', ',', ' the', ' Chilean', ' poet', ' Pablo', ' Ner', 'uda']
+1729 375 Name of mother of x -1 Name of mother of Pablo Neruda Rosa Neftalí Basoalto Opazo Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False Guatemala by Chilean poet Pablo Neruda highlighted Neruda's 6 [' Guatemala', ' by', ' Chilean', ' poet', ' Pablo', ' Ner', 'uda']
+1730 375 Name of mother of x -1 Name of mother of Pablo Neruda Rosa Neftalí Basoalto Opazo Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False others, the Chilean poet Pablo Neruda is also associated 7 [' others', ',', ' the', ' Chilean', ' poet', ' Pablo', ' Ner', 'uda']
+1731 375 Name of mother of x -1 Name of mother of Pablo Neruda Rosa Neftalí Basoalto Opazo Pablo Neruda "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False and politician Pablo Neruda organized the immigration 4 [' and', ' politician', ' Pablo', ' Ner', 'uda']
+1732 376 Name of mother of x -1 Name of mother of Henry David Thoreau Cynthia Dunbar Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' _' 'W' 'ald']" , the author of Wald en , and the author of the book that inspired the movie _ W ald False " systematic form."". Henry David Thoreau (1817 – 1862)" 8 "[' systematic', ' form', '.""', '.', ' Henry', ' David', ' Th', 'ore', 'au']"
+1733 376 Name of mother of x -1 Name of mother of Henry David Thoreau Cynthia Dunbar Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' _' 'W' 'ald']" , the author of Wald en , and the author of the book that inspired the movie _ W ald False example arose when Henry David Thoreau the author of Walden 7 [' example', ' arose', ' when', ' Henry', ' David', ' Th', 'ore', 'au']
+1734 376 Name of mother of x -1 Name of mother of Henry David Thoreau Cynthia Dunbar Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' _' 'W' 'ald']" , the author of Wald en , and the author of the book that inspired the movie _ W ald False Allan Poe, and Henry David Thoreau established a distinctive 8 [' Allan', ' Poe', ',', ' and', ' Henry', ' David', ' Th', 'ore', 'au']
+1735 376 Name of mother of x -1 Name of mother of Henry David Thoreau Cynthia Dunbar Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' _' 'W' 'ald']" , the author of Wald en , and the author of the book that inspired the movie _ W ald False example arose when Henry David Thoreau the author of Walden 7 [' example', ' arose', ' when', ' Henry', ' David', ' Th', 'ore', 'au']
+1736 376 Name of mother of x -1 Name of mother of Henry David Thoreau Cynthia Dunbar Henry David Thoreau "[',' ' the' ' author' ' of' ' Wald' 'en' ',' ' and' ' the' ' author' ' of'
+ ' the' ' book' ' that' ' inspired' ' the' ' movie' ' _' 'W' 'ald']" , the author of Wald en , and the author of the book that inspired the movie _ W ald False Edgar Allan Poe, and Henry David Thoreau established a distinctive 9 [' Edgar', ' Allan', ' Poe', ',', ' and', ' Henry', ' David', ' Th', 'ore', 'au']
+1737 377 Name of mother of x -1 Name of mother of Heracles Alkmene Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False labour of Greek hero Heracles - slaying the horses 5 [' labour', ' of', ' Greek', ' hero', ' Her', 'acles']
+1738 377 Name of mother of x -1 Name of mother of Heracles Alkmene Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False and, possibly, Heracles of Macedon from 5 [' and', ',', ' possibly', ',', ' Her', 'acles']
+1739 377 Name of mother of x -1 Name of mother of Heracles Alkmene Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False accidentally killing Heracles after he had completed 3 [' accidentally', ' killing', ' Her', 'acles']
+1740 377 Name of mother of x -1 Name of mother of Heracles Alkmene Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False accidentally killing Heracles after he had 3 [' accidentally', ' killing', ' Her', 'acles']
+1741 377 Name of mother of x -1 Name of mother of Heracles Alkmene Heracles "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False sacked by Heracles in a brief allusion 3 [' sacked', ' by', ' Her', 'acles']
+1742 378 Name of mother of x -1 Name of mother of Robert De Niro Virginia Admiral Robert De Niro "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great' ' actor'
+ '.' '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the']" ", who is a great actor , and a great actor .
+
+ I have a feeling that the" False Mickey Rourke, Robert De Niro and Lisa Bonet. The 7 [' Mickey', ' R', 'ourke', ',', ' Robert', ' De', ' N', 'iro']
+1743 378 Name of mother of x -1 Name of mother of Robert De Niro Virginia Admiral Robert De Niro "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great' ' actor'
+ '.' '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the']" ", who is a great actor , and a great actor .
+
+ I have a feeling that the" False reviewing the script, Robert De Niro was actually the person 7 [' reviewing', ' the', ' script', ',', ' Robert', ' De', ' N', 'iro']
+1744 378 Name of mother of x -1 Name of mother of Robert De Niro Virginia Admiral Robert De Niro "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great' ' actor'
+ '.' '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the']" ", who is a great actor , and a great actor .
+
+ I have a feeling that the" False with the likes of Robert De Niro and Sarah Jessica 7 [' with', ' the', ' likes', ' of', ' Robert', ' De', ' N', 'iro']
+1745 378 Name of mother of x -1 Name of mother of Robert De Niro Virginia Admiral Robert De Niro "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great' ' actor'
+ '.' '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the']" ", who is a great actor , and a great actor .
+
+ I have a feeling that the" False Bradley Cooper and Robert De Niro film, Limitless. 6 [' Bradley', ' Cooper', ' and', ' Robert', ' De', ' N', 'iro']
+1746 378 Name of mother of x -1 Name of mother of Robert De Niro Virginia Admiral Robert De Niro "[',' ' who' ' is' ' a' ' great' ' actor' ',' ' and' ' a' ' great' ' actor'
+ '.' '\n' '\n' 'I' ' have' ' a' ' feeling' ' that' ' the']" ", who is a great actor , and a great actor .
+
+ I have a feeling that the" False Damon joined Robert De Niro in The Good Shepherd 5 [' Damon', ' joined', ' Robert', ' De', ' N', 'iro']
+1747 379 Name of mother of x -1 Name of mother of Kanye West Donda West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' are' ' expecting' ' their' ' first'
+ ' child']" , the rapper , and his wife Kim Kardashian , the reality TV star , are expecting their first child False " Produced by 88-Keys, Kanye West and Mike Dean
+" 8 [' Produ', 'ced', ' by', ' 88', '-', 'Keys', ',', ' Kanye', ' West']
+1748 379 Name of mother of x -1 Name of mother of Kanye West Donda West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' are' ' expecting' ' their' ' first'
+ ' child']" , the rapper , and his wife Kim Kardashian , the reality TV star , are expecting their first child False and fellow rapper Kanye West on his 2008 record 4 [' and', ' fellow', ' rapper', ' Kanye', ' West']
+1749 379 Name of mother of x -1 Name of mother of Kanye West Donda West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' are' ' expecting' ' their' ' first'
+ ' child']" , the rapper , and his wife Kim Kardashian , the reality TV star , are expecting their first child False Tour by Jay-Z and Kanye West and noted 7 [' Tour', ' by', ' Jay', '-', 'Z', ' and', ' Kanye', ' West']
+1750 379 Name of mother of x -1 Name of mother of Kanye West Donda West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' are' ' expecting' ' their' ' first'
+ ' child']" , the rapper , and his wife Kim Kardashian , the reality TV star , are expecting their first child False joke about Kanye West not understanding 3 [' joke', ' about', ' Kanye', ' West']
+1751 379 Name of mother of x -1 Name of mother of Kanye West Donda West Kanye West "[',' ' the' ' rapper' ',' ' and' ' his' ' wife' ' Kim' ' Kardashian' ','
+ ' the' ' reality' ' TV' ' star' ',' ' are' ' expecting' ' their' ' first'
+ ' child']" , the rapper , and his wife Kim Kardashian , the reality TV star , are expecting their first child False insulted Jay-Z and Kanye West along with musicians 6 [' insulted', ' Jay', '-', 'Z', ' and', ' Kanye', ' West']
+1752 380 Name of mother of x -1 Name of mother of Mary Cassatt Katherine Kelso Cassatt Mary Cassatt "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 1870, Sartain met Mary Cassatt in Philadelphia 8 [' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+1753 380 Name of mother of x -1 Name of mother of Mary Cassatt Katherine Kelso Cassatt Mary Cassatt "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 1870, Sartain met Mary Cassatt in Philadelphia 8 [' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+1754 380 Name of mother of x -1 Name of mother of Mary Cassatt Katherine Kelso Cassatt Mary Cassatt "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 1870, Sartain met Mary Cassatt in Philadelphia 8 [' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+1755 380 Name of mother of x -1 Name of mother of Mary Cassatt Katherine Kelso Cassatt Mary Cassatt "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False In 1870, Sartain met Mary Cassatt in Philadelphia 9 [' In', ' 1870', ',', ' S', 'art', 'ain', ' met', ' Mary', ' Cass', 'att']
+1756 381 Name of mother of x -1 Name of mother of Jacques Chirac Marie-Louise Valette Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False then-mayor of Paris Jacques Chirac and placed a full-sized 9 [' then', '-', 'may', 'or', ' of', ' Paris', ' Jacques', ' Ch', 'ir', 'ac']
+1757 381 Name of mother of x -1 Name of mother of Jacques Chirac Marie-Louise Valette Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False President Jacques Chirac nor Prime 4 [' President', ' Jacques', ' Ch', 'ir', 'ac']
+1758 381 Name of mother of x -1 Name of mother of Jacques Chirac Marie-Louise Valette Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False then-mayor of Paris Jacques Chirac and placed a full-sized 9 [' then', '-', 'may', 'or', ' of', ' Paris', ' Jacques', ' Ch', 'ir', 'ac']
+1759 381 Name of mother of x -1 Name of mother of Jacques Chirac Marie-Louise Valette Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False cited French president Jacques Chirac as a witness; 6 [' cited', ' French', ' president', ' Jacques', ' Ch', 'ir', 'ac']
+1760 381 Name of mother of x -1 Name of mother of Jacques Chirac Marie-Louise Valette Jacques Chirac "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False by president Jacques Chirac during his official 5 [' by', ' president', ' Jacques', ' Ch', 'ir', 'ac']
+1761 382 Name of mother of x -1 Name of mother of Herman Melville Maria Gansevoort Melvill Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False 3 ['H', 'erman', ' Mel', 'ville']
+1762 382 Name of mother of x -1 Name of mother of Herman Melville Maria Gansevoort Melvill Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False of Adams, and Herman Melville alluded to the case 6 [' of', ' Adams', ',', ' and', ' Herman', ' Mel', 'ville']
+1763 382 Name of mother of x -1 Name of mother of Herman Melville Maria Gansevoort Melvill Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False in Pierre, Herman Melville focuses on a 5 [' in', ' Pierre', ',', ' Herman', ' Mel', 'ville']
+1764 382 Name of mother of x -1 Name of mother of Herman Melville Maria Gansevoort Melvill Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False Hemisphere. Indeed, Herman Melville mentions it and 6 [' Hemisphere', '.', ' Indeed', ',', ' Herman', ' Mel', 'ville']
+1765 382 Name of mother of x -1 Name of mother of Herman Melville Maria Gansevoort Melvill Herman Melville "[',' ' the' ' author' ' of' ' Mob' 'y' '-' 'Dick' ',' ' and' ' the'
+ ' author' ' of' ' the' '\n' '\n' 'The' ' author' ' of' ' Mob']" ", the author of Mob y - Dick , and the author of the
+
+ The author of Mob" False Battle-Pieces publication, Herman Melville penned a poem about 8 [' Battle', '-', 'Pie', 'ces', ' publication', ',', ' Herman', ' Mel', 'ville']
+1766 383 Name of mother of x -1 Name of mother of Georgia O'Keeffe Ida Ten Eyck Totto Georgia O'Keeffe "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' New' ' York' ' City'
+ ' family' ',' ' and' ' the' ' daughter' ' of' ' a' ' wealthy' ' New'
+ ' York' ' City']" , the daughter of a wealthy New York City family , and the daughter of a wealthy New York City False van Gogh and Georgia O'Keeffe on her paintings and 9 "[' van', ' Go', 'gh', ' and', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+1767 383 Name of mother of x -1 Name of mother of Georgia O'Keeffe Ida Ten Eyck Totto Georgia O'Keeffe "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' New' ' York' ' City'
+ ' family' ',' ' and' ' the' ' daughter' ' of' ' a' ' wealthy' ' New'
+ ' York' ' City']" , the daughter of a wealthy New York City family , and the daughter of a wealthy New York City False Vincent van Gogh and Georgia O'Keeffe on her paintings 10 "[' Vincent', ' van', ' Go', 'gh', ' and', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+1768 383 Name of mother of x -1 Name of mother of Georgia O'Keeffe Ida Ten Eyck Totto Georgia O'Keeffe "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' New' ' York' ' City'
+ ' family' ',' ' and' ' the' ' daughter' ' of' ' a' ' wealthy' ' New'
+ ' York' ' City']" , the daughter of a wealthy New York City family , and the daughter of a wealthy New York City False van Gogh and Georgia O'Keeffe on her paintings 9 "[' van', ' Go', 'gh', ' and', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+1769 383 Name of mother of x -1 Name of mother of Georgia O'Keeffe Ida Ten Eyck Totto Georgia O'Keeffe "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' New' ' York' ' City'
+ ' family' ',' ' and' ' the' ' daughter' ' of' ' a' ' wealthy' ' New'
+ ' York' ' City']" , the daughter of a wealthy New York City family , and the daughter of a wealthy New York City False including works by Georgia O'Keeffe and George Inness. 8 "[' including', ' works', ' by', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+1770 383 Name of mother of x -1 Name of mother of Georgia O'Keeffe Ida Ten Eyck Totto Georgia O'Keeffe "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' New' ' York' ' City'
+ ' family' ',' ' and' ' the' ' daughter' ' of' ' a' ' wealthy' ' New'
+ ' York' ' City']" , the daughter of a wealthy New York City family , and the daughter of a wealthy New York City False including works by Georgia O'Keeffe and George 8 "[' including', ' works', ' by', ' Georgia', ' O', ""'"", 'K', 'ee', 'ffe']"
+1771 384 Name of mother of x -1 Name of mother of Quincy Jones Sarah Frances Wells Quincy Jones "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' late' ' Quincy' ' Jones' ',' ' Jr' '.' ' and' ' his' ' wife' ',' ' the']" , the first of the three children of the late Quincy Jones , Jr . and his wife , the False Orchestra, with Quincy Jones conducting. Sinatra 4 [' Orchestra', ',', ' with', ' Quincy', ' Jones']
+1772 384 Name of mother of x -1 Name of mother of Quincy Jones Sarah Frances Wells Quincy Jones "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' late' ' Quincy' ' Jones' ',' ' Jr' '.' ' and' ' his' ' wife' ',' ' the']" , the first of the three children of the late Quincy Jones , Jr . and his wife , the False All songs produced by Quincy Jones and co-produced by 5 [' All', ' songs', ' produced', ' by', ' Quincy', ' Jones']
+1773 384 Name of mother of x -1 Name of mother of Quincy Jones Sarah Frances Wells Quincy Jones "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' late' ' Quincy' ' Jones' ',' ' Jr' '.' ' and' ' his' ' wife' ',' ' the']" , the first of the three children of the late Quincy Jones , Jr . and his wife , the False Phillinganes. Quincy Jones passed on 5 [' Ph', 'illing', 'anes', '.', ' Quincy', ' Jones']
+1774 384 Name of mother of x -1 Name of mother of Quincy Jones Sarah Frances Wells Quincy Jones "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' late' ' Quincy' ' Jones' ',' ' Jr' '.' ' and' ' his' ' wife' ',' ' the']" , the first of the three children of the late Quincy Jones , Jr . and his wife , the False Michael Jackson and Quincy Jones went back 4 [' Michael', ' Jackson', ' and', ' Quincy', ' Jones']
+1775 384 Name of mother of x -1 Name of mother of Quincy Jones Sarah Frances Wells Quincy Jones "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' late' ' Quincy' ' Jones' ',' ' Jr' '.' ' and' ' his' ' wife' ',' ' the']" , the first of the three children of the late Quincy Jones , Jr . and his wife , the False " that at times Quincy Jones may ""depersonalize" 4 [' that', ' at', ' times', ' Quincy', ' Jones']
+1776 385 Name of mother of x -1 Name of mother of Jean Racine Jeanne Sconin Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' in' ' Paris'
+ ' in' ' 16' '39' '.' ' He' ' was' ' the' ' son' ' of']" , the French dram at ist , was born in Paris in 16 39 . He was the son of False Bérénice, a play by Jean Racine (1670) which focuses 10 [' B', 'é', 'ré', 'nice', ',', ' a', ' play', ' by', ' Jean', ' Rac', 'ine']
+1777 385 Name of mother of x -1 Name of mother of Jean Racine Jeanne Sconin Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' in' ' Paris'
+ ' in' ' 16' '39' '.' ' He' ' was' ' the' ' son' ' of']" , the French dram at ist , was born in Paris in 16 39 . He was the son of False Pierre Corneille and Jean Racine and through them the 7 [' Pierre', ' Cor', 'ne', 'ille', ' and', ' Jean', ' Rac', 'ine']
+1778 385 Name of mother of x -1 Name of mother of Jean Racine Jeanne Sconin Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' in' ' Paris'
+ ' in' ' 16' '39' '.' ' He' ' was' ' the' ' son' ' of']" , the French dram at ist , was born in Paris in 16 39 . He was the son of False Corneille and Jean Racine and through them the 6 [' Cor', 'ne', 'ille', ' and', ' Jean', ' Rac', 'ine']
+1779 385 Name of mother of x -1 Name of mother of Jean Racine Jeanne Sconin Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' in' ' Paris'
+ ' in' ' 16' '39' '.' ' He' ' was' ' the' ' son' ' of']" , the French dram at ist , was born in Paris in 16 39 . He was the son of False Corneille and Jean Racine and through them 6 [' Cor', 'ne', 'ille', ' and', ' Jean', ' Rac', 'ine']
+1780 385 Name of mother of x -1 Name of mother of Jean Racine Jeanne Sconin Jean Racine "[',' ' the' ' French' ' dram' 'at' 'ist' ',' ' was' ' born' ' in' ' Paris'
+ ' in' ' 16' '39' '.' ' He' ' was' ' the' ' son' ' of']" , the French dram at ist , was born in Paris in 16 39 . He was the son of False libretto, based on Jean Racine ’ s Athalie and 8 [' lib', 'rett', 'o', ',', ' based', ' on', ' Jean', ' Rac', 'ine']
+1781 386 Name of mother of x -1 Name of mother of Johannes Kepler Katharina Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the astronomer , and the
+ " False mathematical astronomer Johannes Kepler modelled the universe 3 [' mathematical', ' astronomer', ' Johannes', ' Kepler']
+1782 386 Name of mother of x -1 Name of mother of Johannes Kepler Katharina Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the astronomer , and the
+ " False " Johannes Kepler =
+" 1 [' Johannes', ' Kepler']
+1783 386 Name of mother of x -1 Name of mother of Johannes Kepler Katharina Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the astronomer , and the
+ " False Harmonices Mundi, Johannes Kepler first applied the 6 [' Harmon', 'ices', ' Mund', 'i', ',', ' Johannes', ' Kepler']
+1784 386 Name of mother of x -1 Name of mother of Johannes Kepler Katharina Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the astronomer , and the
+ " False also says that Johannes Kepler came to his conclusions 4 [' also', ' says', ' that', ' Johannes', ' Kepler']
+1785 386 Name of mother of x -1 Name of mother of Johannes Kepler Katharina Kepler Johannes Kepler "[',' ' the' ' astronomer' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the astronomer , and the
+ " False but the theories of Johannes Kepler and William 5 [' but', ' the', ' theories', ' of', ' Johannes', ' Kepler']
+1786 387 Name of mother of x -1 Name of mother of Carl Maria von Weber Genovefa von Weber Carl Maria von Weber "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Carl' ' Maria' ' von' ' Weber' ',']" ", the composer , and his wife , the
+
+ Name of mother of Carl Maria von Weber ," False and Romantic eras. Carl Maria von Weber and Felix Mendelssohn 7 [' and', ' Romantic', ' eras', '.', ' Carl', ' Maria', ' von', ' Weber']
+1787 387 Name of mother of x -1 Name of mother of Carl Maria von Weber Genovefa von Weber Carl Maria von Weber "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Carl' ' Maria' ' von' ' Weber' ',']" ", the composer , and his wife , the
+
+ Name of mother of Carl Maria von Weber ," False and Romantic eras. Carl Maria von Weber and Felix Mendelssohn 7 [' and', ' Romantic', ' eras', '.', ' Carl', ' Maria', ' von', ' Weber']
+1788 388 Name of mother of x -1 Name of mother of Che Guevara Celia de la Serna Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' had' ' been' ' a'
+ ' guerrilla' ' fighter' ',' ' a' ' revolutionary' ',' ' a' ' Marxist' ','
+ ' a']" , the revolutionary , the man who had been a guerrilla fighter , a revolutionary , a Marxist , a False Jinnah, Kim Il-Sung, Che Guevara and several Presidential 13 [' Jinn', 'ah', ',', ' Kim', ' Il', '-', 'S', 'ung', ',', ' Che', ' G', 'ue', 'v', 'ara']
+1789 388 Name of mother of x -1 Name of mother of Che Guevara Celia de la Serna Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' had' ' been' ' a'
+ ' guerrilla' ' fighter' ',' ' a' ' revolutionary' ',' ' a' ' Marxist' ','
+ ' a']" , the revolutionary , the man who had been a guerrilla fighter , a revolutionary , a Marxist , a False cooperation with the Che Guevara Publishing Project, 7 [' cooperation', ' with', ' the', ' Che', ' G', 'ue', 'v', 'ara']
+1790 388 Name of mother of x -1 Name of mother of Che Guevara Celia de la Serna Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' had' ' been' ' a'
+ ' guerrilla' ' fighter' ',' ' a' ' revolutionary' ',' ' a' ' Marxist' ','
+ ' a']" , the revolutionary , the man who had been a guerrilla fighter , a revolutionary , a Marxist , a False support with Che Guevara stating on 15 August 6 [' support', ' with', ' Che', ' G', 'ue', 'v', 'ara']
+1791 388 Name of mother of x -1 Name of mother of Che Guevara Celia de la Serna Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' had' ' been' ' a'
+ ' guerrilla' ' fighter' ',' ' a' ' revolutionary' ',' ' a' ' Marxist' ','
+ ' a']" , the revolutionary , the man who had been a guerrilla fighter , a revolutionary , a Marxist , a False advisors led by Che Guevara to advise 7 [' advisors', ' led', ' by', ' Che', ' G', 'ue', 'v', 'ara']
+1792 388 Name of mother of x -1 Name of mother of Che Guevara Celia de la Serna Che Guevara "[',' ' the' ' revolutionary' ',' ' the' ' man' ' who' ' had' ' been' ' a'
+ ' guerrilla' ' fighter' ',' ' a' ' revolutionary' ',' ' a' ' Marxist' ','
+ ' a']" , the revolutionary , the man who had been a guerrilla fighter , a revolutionary , a Marxist , a False principles of Che Guevara are very important 6 [' principles', ' of', ' Che', ' G', 'ue', 'v', 'ara']
+1793 389 Name of mother of x -1 Name of mother of Demi Moore Virginia King Demi Moore "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' a' ' son' ',' ' Dylan' ',' ' who' ' is' ' a']" , who is a former Miss America , and the mother of a son , Dylan , who is a False 2 ['Dem', 'i', ' Moore']
+1794 389 Name of mother of x -1 Name of mother of Demi Moore Virginia King Demi Moore "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' a' ' son' ',' ' Dylan' ',' ' who' ' is' ' a']" , who is a former Miss America , and the mother of a son , Dylan , who is a False 30 A.M. and Demi Moore slept that night 8 [' 30', ' A', '.', 'M', '.', ' and', ' Dem', 'i', ' Moore']
+1795 389 Name of mother of x -1 Name of mother of Demi Moore Virginia King Demi Moore "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' a' ' son' ',' ' Dylan' ',' ' who' ' is' ' a']" , who is a former Miss America , and the mother of a son , Dylan , who is a False Suit cover of Demi Moore in a body painting 5 [' Suit', ' cover', ' of', ' Dem', 'i', ' Moore']
+1796 389 Name of mother of x -1 Name of mother of Demi Moore Virginia King Demi Moore "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' a' ' son' ',' ' Dylan' ',' ' who' ' is' ' a']" , who is a former Miss America , and the mother of a son , Dylan , who is a False " Demi Moore =
+" 2 [' Dem', 'i', ' Moore']
+1797 389 Name of mother of x -1 Name of mother of Demi Moore Virginia King Demi Moore "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' a' ' son' ',' ' Dylan' ',' ' who' ' is' ' a']" , who is a former Miss America , and the mother of a son , Dylan , who is a False and actress Demi Moore discussed ways 4 [' and', ' actress', ' Dem', 'i', ' Moore']
+1798 390 Name of mother of x -1 Name of mother of Pliny the Elder Marcella Pliny the Elder "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' sons' ' of' ' the'
+ ' Emperor' ' Augustus' '.' '\n' '\n' 'The' ' first' ' of' ' these' ',']" ", and the other two are the sons of the Emperor Augustus .
+
+ The first of these ," False Animals and by Pliny the Elder in his Natural 6 [' Animals', ' and', ' by', ' Pl', 'iny', ' the', ' Elder']
+1799 390 Name of mother of x -1 Name of mother of Pliny the Elder Marcella Pliny the Elder "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' sons' ' of' ' the'
+ ' Emperor' ' Augustus' '.' '\n' '\n' 'The' ' first' ' of' ' these' ',']" ", and the other two are the sons of the Emperor Augustus .
+
+ The first of these ," False Latin such as Pliny the Elder (Gaius Plinius 6 [' Latin', ' such', ' as', ' Pl', 'iny', ' the', ' Elder']
+1800 390 Name of mother of x -1 Name of mother of Pliny the Elder Marcella Pliny the Elder "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' sons' ' of' ' the'
+ ' Emperor' ' Augustus' '.' '\n' '\n' 'The' ' first' ' of' ' these' ',']" ", and the other two are the sons of the Emperor Augustus .
+
+ The first of these ," False naturalist Pliny the Elder wrote about 5 [' natural', 'ist', ' Pl', 'iny', ' the', ' Elder']
+1801 390 Name of mother of x -1 Name of mother of Pliny the Elder Marcella Pliny the Elder "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' sons' ' of' ' the'
+ ' Emperor' ' Augustus' '.' '\n' '\n' 'The' ' first' ' of' ' these' ',']" ", and the other two are the sons of the Emperor Augustus .
+
+ The first of these ," False rescue her. Pliny the Elder claimed that these 6 [' rescue', ' her', '.', ' Pl', 'iny', ' the', ' Elder']
+1802 390 Name of mother of x -1 Name of mother of Pliny the Elder Marcella Pliny the Elder "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' sons' ' of' ' the'
+ ' Emperor' ' Augustus' '.' '\n' '\n' 'The' ' first' ' of' ' these' ',']" ", and the other two are the sons of the Emperor Augustus .
+
+ The first of these ," False " that avoids light""). Pliny the Elder recorded the" 7 "[' that', ' avoids', ' light', '"").', ' Pl', 'iny', ' the', ' Elder']"
+1803 391 Name of mother of x -1 Name of mother of Napoleon III Hortense de Beauharnais Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' to'
+ ' be' ' tr' 'ifled' ' with' ',' ' was' ' not' ' to']" ".
+
+ The Emperor , who was not a man to be tr ifled with , was not to" False Nazareth and Acre, Napoleon III of France presented 7 [' Naz', 'areth', ' and', ' Ac', 're', ',', ' Napoleon', ' III']
+1804 391 Name of mother of x -1 Name of mother of Napoleon III Hortense de Beauharnais Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' to'
+ ' be' ' tr' 'ifled' ' with' ',' ' was' ' not' ' to']" ".
+
+ The Emperor , who was not a man to be tr ifled with , was not to" False the French emperor Napoleon III declared war on 15 4 [' the', ' French', ' emperor', ' Napoleon', ' III']
+1805 391 Name of mother of x -1 Name of mother of Napoleon III Hortense de Beauharnais Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' to'
+ ' be' ' tr' 'ifled' ' with' ',' ' was' ' not' ' to']" ".
+
+ The Emperor , who was not a man to be tr ifled with , was not to" False Second French Empire of Napoleon III; the emperor and 5 [' Second', ' French', ' Empire', ' of', ' Napoleon', ' III']
+1806 391 Name of mother of x -1 Name of mother of Napoleon III Hortense de Beauharnais Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' to'
+ ' be' ' tr' 'ifled' ' with' ',' ' was' ' not' ' to']" ".
+
+ The Emperor , who was not a man to be tr ifled with , was not to" False Nazareth and Acre, Napoleon III of France 7 [' Naz', 'areth', ' and', ' Ac', 're', ',', ' Napoleon', ' III']
+1807 391 Name of mother of x -1 Name of mother of Napoleon III Hortense de Beauharnais Napoleon III "['.' '\n' '\n' 'The' ' Emperor' ',' ' who' ' was' ' not' ' a' ' man' ' to'
+ ' be' ' tr' 'ifled' ' with' ',' ' was' ' not' ' to']" ".
+
+ The Emperor , who was not a man to be tr ifled with , was not to" False Empire of Napoleon III; the emperor and 3 [' Empire', ' of', ' Napoleon', ' III']
+1808 393 Name of mother of x -1 Name of mother of Gregory I Saint Silvia Gregory I "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False letter from Pope Gregory I known as the Epistola 4 [' letter', ' from', ' Pope', ' Gregory', ' I']
+1809 393 Name of mother of x -1 Name of mother of Gregory I Saint Silvia Gregory I "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False background that Pope Gregory I decided to send 4 [' background', ' that', ' Pope', ' Gregory', ' I']
+1810 393 Name of mother of x -1 Name of mother of Gregory I Saint Silvia Gregory I "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False letters, Pope Gregory I called him an abbot, 4 [' letters', ',', ' Pope', ' Gregory', ' I']
+1811 393 Name of mother of x -1 Name of mother of Gregory I Saint Silvia Gregory I "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False 595, when Pope Gregory I decided to send 6 [' 5', '95', ',', ' when', ' Pope', ' Gregory', ' I']
+1812 393 Name of mother of x -1 Name of mother of Gregory I Saint Silvia Gregory I "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False reply by Pope Gregory I to questions posed 4 [' reply', ' by', ' Pope', ' Gregory', ' I']
+1813 394 Name of mother of x -1 Name of mother of Richard Strauss Josephine Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' his' ' children' ',' ' and' ' the' ' mother' ' of' ' his'
+ ' children']" , the composer , and his wife , the mother of his children , and the mother of his children False written in 1896 by Richard Strauss and very popular 5 [' written', ' in', ' 1896', ' by', ' Richard', ' Strauss']
+1814 394 Name of mother of x -1 Name of mother of Richard Strauss Josephine Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' his' ' children' ',' ' and' ' the' ' mother' ' of' ' his'
+ ' children']" , the composer , and his wife , the mother of his children , and the mother of his children False " under the baton of Richard Strauss at the Queen's Hall.
+" 6 [' under', ' the', ' bat', 'on', ' of', ' Richard', ' Strauss']
+1815 394 Name of mother of x -1 Name of mother of Richard Strauss Josephine Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' his' ' children' ',' ' and' ' the' ' mother' ' of' ' his'
+ ' children']" , the composer , and his wife , the mother of his children , and the mother of his children False Brahms. Works by Richard Strauss featured almost as 6 [' Brah', 'ms', '.', ' Works', ' by', ' Richard', ' Strauss']
+1816 394 Name of mother of x -1 Name of mother of Richard Strauss Josephine Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' his' ' children' ',' ' and' ' the' ' mother' ' of' ' his'
+ ' children']" , the composer , and his wife , the mother of his children , and the mother of his children False Kodály, Prokofiev, Richard Strauss and Stravinsky. 10 [' Kod', 'á', 'ly', ',', ' Pro', 'k', 'of', 'iev', ',', ' Richard', ' Strauss']
+1817 394 Name of mother of x -1 Name of mother of Richard Strauss Josephine Strauss Richard Strauss "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' mother'
+ ' of' ' his' ' children' ',' ' and' ' the' ' mother' ' of' ' his'
+ ' children']" , the composer , and his wife , the mother of his children , and the mother of his children False Composer-conductors included Richard Strauss and Anton Webern. 7 [' Compos', 'er', '-', 'conduct', 'ors', ' included', ' Richard', ' Strauss']
+1818 395 Name of mother of x -1 Name of mother of Edward Burne-Jones Elizabeth Coley Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Edward']" ", the painter , and his wife , the artist , and the
+
+ Name of mother of Edward" False Pre-Raphaelite artists Edward Burne-Jones and Dante Gabriel Rossetti 10 [' Pre', '-', 'R', 'aphael', 'ite', ' artists', ' Edward', ' Burn', 'e', '-', 'Jones']
+1819 395 Name of mother of x -1 Name of mother of Edward Burne-Jones Elizabeth Coley Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Edward']" ", the painter , and his wife , the artist , and the
+
+ Name of mother of Edward" False paintings of Edward Burne-Jones and the Pre-Raphaelite 6 [' paintings', ' of', ' Edward', ' Burn', 'e', '-', 'Jones']
+1820 395 Name of mother of x -1 Name of mother of Edward Burne-Jones Elizabeth Coley Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Edward']" ", the painter , and his wife , the artist , and the
+
+ Name of mother of Edward" False movement such as Edward Burne-Jones as designers, and 7 [' movement', ' such', ' as', ' Edward', ' Burn', 'e', '-', 'Jones']
+1821 395 Name of mother of x -1 Name of mother of Edward Burne-Jones Elizabeth Coley Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Edward']" ", the painter , and his wife , the artist , and the
+
+ Name of mother of Edward" False Pre-Raphaelite artists Edward Burne-Jones and Dante Gabriel 10 [' Pre', '-', 'R', 'aphael', 'ite', ' artists', ' Edward', ' Burn', 'e', '-', 'Jones']
+1822 395 Name of mother of x -1 Name of mother of Edward Burne-Jones Elizabeth Coley Edward Burne-Jones "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ','
+ ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Edward']" ", the painter , and his wife , the artist , and the
+
+ Name of mother of Edward" False He commissioned Edward Burne-Jones of James Powell & 6 [' He', ' commissioned', ' Edward', ' Burn', 'e', '-', 'Jones']
+1823 396 Name of mother of x -1 Name of mother of Augustus Atia Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False granted 139, and Augustus O. Stanley granted 4 [' granted', ' 139', ',', ' and', ' Augustus']
+1824 396 Name of mother of x -1 Name of mother of Augustus Atia Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False the title of Augustus. This displeased 3 [' the', ' title', ' of', ' Augustus']
+1825 396 Name of mother of x -1 Name of mother of Augustus Atia Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False the Forum of Augustus with its Temple 3 [' the', ' Forum', ' of', ' Augustus']
+1826 396 Name of mother of x -1 Name of mother of Augustus Atia Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False particularly for metalwork. Augustus Siebe, the pioneer 5 [' particularly', ' for', ' metal', 'work', '.', ' Augustus']
+1827 396 Name of mother of x -1 Name of mother of Augustus Atia Augustus "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Augustus' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Augustus , and the son of the Emperor Augustus , and the son of False " == Change to Augustus ==
+" 3 [' ==', ' Change', ' to', ' Augustus']
+1828 397 Name of mother of x -1 Name of mother of Fred Astaire Ann Astaire Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False " Band Wagon, starring Fred Astaire and Cyd Charisse.
+" 7 [' Band', ' W', 'agon', ',', ' starring', ' Fred', ' Ast', 'aire']
+1829 397 Name of mother of x -1 Name of mother of Fred Astaire Ann Astaire Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False for choreographing Fred Astaire and Cyd Charisse 6 [' for', ' chore', 'ograp', 'hing', ' Fred', ' Ast', 'aire']
+1830 397 Name of mother of x -1 Name of mother of Fred Astaire Ann Astaire Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False Charisse and Fred Astaire in the 1953 5 [' Char', 'isse', ' and', ' Fred', ' Ast', 'aire']
+1831 397 Name of mother of x -1 Name of mother of Fred Astaire Ann Astaire Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False Frank (1992) and the Fred Astaire tribute Steppin 8 [' Frank', ' (', '1992', ')', ' and', ' the', ' Fred', ' Ast', 'aire']
+1832 397 Name of mother of x -1 Name of mother of Fred Astaire Ann Astaire Fred Astaire "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False comedy film in which Fred Astaire plays an American 6 [' comedy', ' film', ' in', ' which', ' Fred', ' Ast', 'aire']
+1833 398 Name of mother of x -1 Name of mother of W. H. Auden Constance Rosalie Bicknell W. H. Auden "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' E' 'ileen' ',' ' who'
+ ' was' ' a' ' painter' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , E ileen , who was a painter .
+
+ The house" False Complete Works of W. H. Auden are indicated 8 [' Complete', ' Works', ' of', ' W', '.', ' H', '.', ' Aud', 'en']
+1834 398 Name of mother of x -1 Name of mother of W. H. Auden Constance Rosalie Bicknell W. H. Auden "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' E' 'ileen' ',' ' who'
+ ' was' ' a' ' painter' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , E ileen , who was a painter .
+
+ The house" False of Yeats's work; W. H. Auden called it the 11 "[' of', ' Ye', 'ats', ""'s"", ' work', ';', ' W', '.', ' H', '.', ' Aud', 'en']"
+1835 398 Name of mother of x -1 Name of mother of W. H. Auden Constance Rosalie Bicknell W. H. Auden "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' E' 'ileen' ',' ' who'
+ ' was' ' a' ' painter' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , E ileen , who was a painter .
+
+ The house" False 20th century, W. H. Auden once called Wagner 9 [' 20', 'th', ' century', ',', ' W', '.', ' H', '.', ' Aud', 'en']
+1836 398 Name of mother of x -1 Name of mother of W. H. Auden Constance Rosalie Bicknell W. H. Auden "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' E' 'ileen' ',' ' who'
+ ' was' ' a' ' painter' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , E ileen , who was a painter .
+
+ The house" False 5 ['W', '.', ' H', '.', ' Aud', 'en']
+1837 398 Name of mother of x -1 Name of mother of W. H. Auden Constance Rosalie Bicknell W. H. Auden "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' E' 'ileen' ',' ' who'
+ ' was' ' a' ' painter' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , E ileen , who was a painter .
+
+ The house" False the 20th century, W. H. Auden once called Wagner 10 [' the', ' 20', 'th', ' century', ',', ' W', '.', ' H', '.', ' Aud', 'en']
+1838 399 Name of mother of x -1 Name of mother of Nathaniel Hawthorne Elizabeth Clarke Hathorne Nathaniel Hawthorne "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Nathaniel'
+ ' Hawth' 'orne' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Nathaniel Hawth orne , and the
+
+ Name of mother" False when critics like Nathaniel Hawthorne and H. L. Mencken 5 [' when', ' critics', ' like', ' Nathaniel', ' Hawth', 'orne']
+1839 399 Name of mother of x -1 Name of mother of Nathaniel Hawthorne Elizabeth Clarke Hathorne Nathaniel Hawthorne "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Nathaniel'
+ ' Hawth' 'orne' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Nathaniel Hawth orne , and the
+
+ Name of mother" False century, when critics like Nathaniel Hawthorne and H. L. Mencken 7 [' century', ',', ' when', ' critics', ' like', ' Nathaniel', ' Hawth', 'orne']
+1840 399 Name of mother of x -1 Name of mother of Nathaniel Hawthorne Elizabeth Clarke Hathorne Nathaniel Hawthorne "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Nathaniel'
+ ' Hawth' 'orne' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Nathaniel Hawth orne , and the
+
+ Name of mother" False called Fruitlands. Nathaniel Hawthorne was a founding member 6 [' called', ' Fruit', 'lands', '.', ' Nathaniel', ' Hawth', 'orne']
+1841 399 Name of mother of x -1 Name of mother of Nathaniel Hawthorne Elizabeth Clarke Hathorne Nathaniel Hawthorne "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Nathaniel'
+ ' Hawth' 'orne' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Nathaniel Hawth orne , and the
+
+ Name of mother" False to Congress) and Nathaniel Hawthorne (the author), 6 [' to', ' Congress', ')', ' and', ' Nathaniel', ' Hawth', 'orne']
+1842 399 Name of mother of x -1 Name of mother of Nathaniel Hawthorne Elizabeth Clarke Hathorne Nathaniel Hawthorne "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Nathaniel'
+ ' Hawth' 'orne' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Nathaniel Hawth orne , and the
+
+ Name of mother" False called Fruitlands. Nathaniel Hawthorne was a founding 6 [' called', ' Fruit', 'lands', '.', ' Nathaniel', ' Hawth', 'orne']
+1843 400 Name of mother of x -1 Name of mother of Sigourney Weaver Elizabeth Inglis Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False crane. While Sigourney Weaver was inside the 5 [' crane', '.', ' While', ' Sig', 'ourney', ' Weaver']
+1844 400 Name of mother of x -1 Name of mother of Sigourney Weaver Elizabeth Inglis Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False " strength begun by Sigourney Weaver in Alien"", but" 5 [' strength', ' begun', ' by', ' Sig', 'ourney', ' Weaver']
+1845 400 Name of mother of x -1 Name of mother of Sigourney Weaver Elizabeth Inglis Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False crane. While Sigourney Weaver was inside the 5 [' crane', '.', ' While', ' Sig', 'ourney', ' Weaver']
+1846 400 Name of mother of x -1 Name of mother of Sigourney Weaver Elizabeth Inglis Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False script for Alien 5, but Sigourney Weaver was not interested 8 [' script', ' for', ' Alien', ' 5', ',', ' but', ' Sig', 'ourney', ' Weaver']
+1847 400 Name of mother of x -1 Name of mother of Sigourney Weaver Elizabeth Inglis Sigourney Weaver "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' Sig' 'ourney']" ", who is a very good friend of mine .
+
+ I have been a fan of Sig ourney" False was confirmed that Sigourney Weaver would have a role in 5 [' was', ' confirmed', ' that', ' Sig', 'ourney', ' Weaver']
+1848 401 Name of mother of x -1 Name of mother of Tupac Shakur Afeni Shakur Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' born' ' in' ' Brooklyn' ',' ' New'
+ ' York' ',' ' and' ' was' ' raised' ' in' ' the' ' Bronx' '.' ' He']" , the rapper , was born in Brooklyn , New York , and was raised in the Bronx . He False B.I.G. and Tupac Shakur and the group's 10 [' B', '.', 'I', '.', 'G', '.', ' and', ' Tup', 'ac', ' Shak', 'ur']
+1849 401 Name of mother of x -1 Name of mother of Tupac Shakur Afeni Shakur Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' born' ' in' ' Brooklyn' ',' ' New'
+ ' York' ',' ' and' ' was' ' raised' ' in' ' the' ' Bronx' '.' ' He']" , the rapper , was born in Brooklyn , New York , and was raised in the Bronx . He False " to sample the 1996 Tupac Shakur song ""Me and" 7 [' to', ' sample', ' the', ' 1996', ' Tup', 'ac', ' Shak', 'ur']
+1850 401 Name of mother of x -1 Name of mother of Tupac Shakur Afeni Shakur Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' born' ' in' ' Brooklyn' ',' ' New'
+ ' York' ',' ' and' ' was' ' raised' ' in' ' the' ' Bronx' '.' ' He']" , the rapper , was born in Brooklyn , New York , and was raised in the Bronx . He False Notorious B.I.G. and Tupac Shakur in conjunction 12 [' Not', 'orious', ' B', '.', 'I', '.', 'G', '.', ' and', ' Tup', 'ac', ' Shak', 'ur']
+1851 401 Name of mother of x -1 Name of mother of Tupac Shakur Afeni Shakur Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' born' ' in' ' Brooklyn' ',' ' New'
+ ' York' ',' ' and' ' was' ' raised' ' in' ' the' ' Bronx' '.' ' He']" , the rapper , was born in Brooklyn , New York , and was raised in the Bronx . He False 4 ['T', 'up', 'ac', ' Shak', 'ur']
+1852 401 Name of mother of x -1 Name of mother of Tupac Shakur Afeni Shakur Tupac Shakur "[',' ' the' ' rapper' ',' ' was' ' born' ' in' ' Brooklyn' ',' ' New'
+ ' York' ',' ' and' ' was' ' raised' ' in' ' the' ' Bronx' '.' ' He']" , the rapper , was born in Brooklyn , New York , and was raised in the Bronx . He False Wonder, Tonic, Prince, Tupac Shakur and Marvin Gaye 10 [' Wonder', ',', ' T', 'onic', ',', ' Prince', ',', ' Tup', 'ac', ' Shak', 'ur']
+1853 403 Name of mother of x -1 Name of mother of Leonardo DiCaprio Irmelin DiCaprio Leonardo DiCaprio "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' his' ' work']" ", who is a very good friend of mine .
+
+ I am a big fan of his work" False Sycamore Grove. Leonardo DiCaprio was Romeo and Claire 8 [' Sy', 'cam', 'ore', ' Grove', '.', ' Leonardo', ' Di', 'Cap', 'rio']
+1854 403 Name of mother of x -1 Name of mother of Leonardo DiCaprio Irmelin DiCaprio Leonardo DiCaprio "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' his' ' work']" ", who is a very good friend of mine .
+
+ I am a big fan of his work" False starred a young Leonardo DiCaprio as Streep's character's 6 [' starred', ' a', ' young', ' Leonardo', ' Di', 'Cap', 'rio']
+1855 403 Name of mother of x -1 Name of mother of Leonardo DiCaprio Irmelin DiCaprio Leonardo DiCaprio "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' his' ' work']" ", who is a very good friend of mine .
+
+ I am a big fan of his work" False " DiCaprio =
+" 8 [' Di', 'Cap', 'rio', ' =', 'Leon', 'ardo', ' Di', 'Cap', 'rio']
+1856 403 Name of mother of x -1 Name of mother of Leonardo DiCaprio Irmelin DiCaprio Leonardo DiCaprio "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' his' ' work']" ", who is a very good friend of mine .
+
+ I am a big fan of his work" False " the second most for Leonardo DiCaprio — behind Titanic.
+" 7 [' the', ' second', ' most', ' for', ' Leonardo', ' Di', 'Cap', 'rio']
+1857 403 Name of mother of x -1 Name of mother of Leonardo DiCaprio Irmelin DiCaprio Leonardo DiCaprio "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' his' ' work']" ", who is a very good friend of mine .
+
+ I am a big fan of his work" False " Kid"" Herod
+" 7 "[' Kid', '""', ' Herod', 'Leon', 'ardo', ' Di', 'Cap', 'rio']"
+1858 404 Name of mother of x -1 Name of mother of Laurence Olivier Agnes Louise Crookenden Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False and the 25th Laurence Olivier Awards in 2001 for 6 [' and', ' the', ' 25', 'th', ' Laure', 'nce', ' Olivier']
+1859 404 Name of mother of x -1 Name of mother of Laurence Olivier Agnes Louise Crookenden Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False " ==
+" 4 [' ==', 'La', 'ure', 'nce', ' Olivier']
+1860 404 Name of mother of x -1 Name of mother of Laurence Olivier Agnes Louise Crookenden Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False " hosted the Laurence Olivier Awards.
+" 4 [' hosted', ' the', ' Laure', 'nce', ' Olivier']
+1861 404 Name of mother of x -1 Name of mother of Laurence Olivier Agnes Louise Crookenden Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False long-running Laurence Olivier Award – nominated 5 [' long', '-', 'running', ' Laure', 'nce', ' Olivier']
+1862 404 Name of mother of x -1 Name of mother of Laurence Olivier Agnes Louise Crookenden Laurence Olivier "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Viv' 'ien' ' Leigh'
+ ',' ' who' ' was' ' the' ' first' ' to' ' be' ' cast' ' in']" , the actor , and his wife , Viv ien Leigh , who was the first to be cast in False Shepherd earned a Laurence Olivier Award for Best 5 [' Shepherd', ' earned', ' a', ' Laure', 'nce', ' Olivier']
+1863 405 Name of mother of x -1 Name of mother of Ringo Starr Elsie Starkey Ringo Starr "[',' ' the' ' drummer' ' for' ' the' ' Beatles' ',' ' and' ' the'
+ ' drummer' ' for' ' the' ' Beatles' ',' ' Ring' 'o' ' Starr' ',' ' was'
+ ' born']" , the drummer for the Beatles , and the drummer for the Beatles , Ring o Starr , was born False Russell, Jim Horn, Ringo Starr and Jim Keltner. 7 [' Russell', ',', ' Jim', ' Horn', ',', ' Ring', 'o', ' Starr']
+1864 405 Name of mother of x -1 Name of mother of Ringo Starr Elsie Starkey Ringo Starr "[',' ' the' ' drummer' ' for' ' the' ' Beatles' ',' ' and' ' the'
+ ' drummer' ' for' ' the' ' Beatles' ',' ' Ring' 'o' ' Starr' ',' ' was'
+ ' born']" , the drummer for the Beatles , and the drummer for the Beatles , Ring o Starr , was born False Former Beatle Ringo Starr defended the song's 5 [' Former', ' Beat', 'le', ' Ring', 'o', ' Starr']
+1865 405 Name of mother of x -1 Name of mother of Ringo Starr Elsie Starkey Ringo Starr "[',' ' the' ' drummer' ' for' ' the' ' Beatles' ',' ' and' ' the'
+ ' drummer' ' for' ' the' ' Beatles' ',' ' Ring' 'o' ' Starr' ',' ' was'
+ ' born']" , the drummer for the Beatles , and the drummer for the Beatles , Ring o Starr , was born False and drummer Ringo Starr replaced Best, 4 [' and', ' drummer', ' Ring', 'o', ' Starr']
+1866 405 Name of mother of x -1 Name of mother of Ringo Starr Elsie Starkey Ringo Starr "[',' ' the' ' drummer' ' for' ' the' ' Beatles' ',' ' and' ' the'
+ ' drummer' ' for' ' the' ' Beatles' ',' ' Ring' 'o' ' Starr' ',' ' was'
+ ' born']" , the drummer for the Beatles , and the drummer for the Beatles , Ring o Starr , was born False Gibson SG, and Ringo Starr used Ludwig drums. 6 [' Gibson', ' SG', ',', ' and', ' Ring', 'o', ' Starr']
+1867 405 Name of mother of x -1 Name of mother of Ringo Starr Elsie Starkey Ringo Starr "[',' ' the' ' drummer' ' for' ' the' ' Beatles' ',' ' and' ' the'
+ ' drummer' ' for' ' the' ' Beatles' ',' ' Ring' 'o' ' Starr' ',' ' was'
+ ' born']" , the drummer for the Beatles , and the drummer for the Beatles , Ring o Starr , was born False " drums, handclaps
+" 7 [' drums', ',', ' hand', 'cl', 'aps', 'R', 'ingo', ' Starr']
+1868 406 Name of mother of x -1 Name of mother of Phil Collins June Strange Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' the' ' groom' ""'s"" ' father' ',' ' and' ' the' ' groom' ""'s""]" , the father of the bride , and the groom , the groom 's father , and the groom 's False " Collins – drums (1985)
+" 7 [' Collins', ' –', ' drums', ' (', '1985', ')', 'Phil', ' Collins']
+1869 406 Name of mother of x -1 Name of mother of Phil Collins June Strange Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' the' ' groom' ""'s"" ' father' ',' ' and' ' the' ' groom' ""'s""]" , the father of the bride , and the groom , the groom 's father , and the groom 's False 120 artists with Phil Collins topping the 4 [' 120', ' artists', ' with', ' Phil', ' Collins']
+1870 406 Name of mother of x -1 Name of mother of Phil Collins June Strange Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' the' ' groom' ""'s"" ' father' ',' ' and' ' the' ' groom' ""'s""]" , the father of the bride , and the groom , the groom 's father , and the groom 's False score by pop star Phil Collins resulted in significant 5 [' score', ' by', ' pop', ' star', ' Phil', ' Collins']
+1871 406 Name of mother of x -1 Name of mother of Phil Collins June Strange Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' the' ' groom' ""'s"" ' father' ',' ' and' ' the' ' groom' ""'s""]" , the father of the bride , and the groom , the groom 's father , and the groom 's False time later. The Phil Collins Big Band played this 5 [' time', ' later', '.', ' The', ' Phil', ' Collins']
+1872 406 Name of mother of x -1 Name of mother of Phil Collins June Strange Phil Collins "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom' ','
+ ' the' ' groom' ""'s"" ' father' ',' ' and' ' the' ' groom' ""'s""]" , the father of the bride , and the groom , the groom 's father , and the groom 's False the works by Phil Collins and Lionel 4 [' the', ' works', ' by', ' Phil', ' Collins']
+1873 407 Name of mother of x -1 Name of mother of Gustav Mahler Marie Herrmann Gustav Mahler "[',' ' the' ' great' ' composer' ',' ' and' ' the' ' great' ' composer'
+ ' of' ' the' ' future' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' heard']" ", the great composer , and the great composer of the future .
+
+ The first time I heard" False Aires in July 1893; Gustav Mahler conducted the opera 7 [' Aires', ' in', ' July', ' 1893', ';', ' Gustav', ' Mah', 'ler']
+1874 407 Name of mother of x -1 Name of mother of Gustav Mahler Marie Herrmann Gustav Mahler "[',' ' the' ' great' ' composer' ',' ' and' ' the' ' great' ' composer'
+ ' of' ' the' ' future' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' heard']" ", the great composer , and the great composer of the future .
+
+ The first time I heard" False exception occurred and Gustav Mahler was named director 5 [' exception', ' occurred', ' and', ' Gustav', ' Mah', 'ler']
+1875 407 Name of mother of x -1 Name of mother of Gustav Mahler Marie Herrmann Gustav Mahler "[',' ' the' ' great' ' composer' ',' ' and' ' the' ' great' ' composer'
+ ' of' ' the' ' future' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' heard']" ", the great composer , and the great composer of the future .
+
+ The first time I heard" False International Gustav Mahler Institute was 3 [' International', ' Gustav', ' Mah', 'ler']
+1876 407 Name of mother of x -1 Name of mother of Gustav Mahler Marie Herrmann Gustav Mahler "[',' ' the' ' great' ' composer' ',' ' and' ' the' ' great' ' composer'
+ ' of' ' the' ' future' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' heard']" ", the great composer , and the great composer of the future .
+
+ The first time I heard" False E-flat major by Gustav Mahler is one of the 7 [' E', '-', 'flat', ' major', ' by', ' Gustav', ' Mah', 'ler']
+1877 407 Name of mother of x -1 Name of mother of Gustav Mahler Marie Herrmann Gustav Mahler "[',' ' the' ' great' ' composer' ',' ' and' ' the' ' great' ' composer'
+ ' of' ' the' ' future' '.' '\n' '\n' 'The' ' first' ' time' ' I' ' heard']" ", the great composer , and the great composer of the future .
+
+ The first time I heard" False 4 ['G', 'ust', 'av', ' Mah', 'ler']
+1878 408 Name of mother of x -1 Name of mother of John Updike Linda Grace Hoyer John Updike "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False American art critic John Updike suggests that 6 [' American', ' art', ' critic', ' John', ' Up', 'd', 'ike']
+1879 408 Name of mother of x -1 Name of mother of John Updike Linda Grace Hoyer John Updike "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False this time, John Updike noticed his 6 [' this', ' time', ',', ' John', ' Up', 'd', 'ike']
+1880 408 Name of mother of x -1 Name of mother of John Updike Linda Grace Hoyer John Updike "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False 3 ['John', ' Up', 'd', 'ike']
+1881 408 Name of mother of x -1 Name of mother of John Updike Linda Grace Hoyer John Updike "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False Prize-winning novelist John Updike attested that 7 [' Prize', '-', 'winning', ' novelist', ' John', ' Up', 'd', 'ike']
+1882 408 Name of mother of x -1 Name of mother of John Updike Linda Grace Hoyer John Updike "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False US writer John Updike expressed his 5 [' US', ' writer', ' John', ' Up', 'd', 'ike']
+1883 410 Name of mother of x -1 Name of mother of Roald Dahl Sofie Magdalene Hesselberg Roald Dahl "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False screenplay written by Roald Dahl and Ken Hughes. 5 [' screenplay', ' written', ' by', ' Ro', 'ald', ' Dahl']
+1884 410 Name of mother of x -1 Name of mother of Roald Dahl Sofie Magdalene Hesselberg Roald Dahl "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False invisible lift in Roald Dahl Plass and follows them 5 [' invisible', ' lift', ' in', ' Ro', 'ald', ' Dahl']
+1885 410 Name of mother of x -1 Name of mother of Roald Dahl Sofie Magdalene Hesselberg Roald Dahl "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False screenplay by Roald Dahl and Ken Hughes; 4 [' screenplay', ' by', ' Ro', 'ald', ' Dahl']
+1886 410 Name of mother of x -1 Name of mother of Roald Dahl Sofie Magdalene Hesselberg Roald Dahl "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False children's author Roald Dahl (1916 – 90) lived 5 "[' children', ""'s"", ' author', ' Ro', 'ald', ' Dahl']"
+1887 410 Name of mother of x -1 Name of mother of Roald Dahl Sofie Magdalene Hesselberg Roald Dahl "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False wedding was recorded at Roald Dahl Plass on 25 6 [' wedding', ' was', ' recorded', ' at', ' Ro', 'ald', ' Dahl']
+1888 411 Name of mother of x -1 Name of mother of Arthur Miller Augusta Barnett Arthur Miller "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False 1 ['Arthur', ' Miller']
+1889 411 Name of mother of x -1 Name of mother of Arthur Miller Augusta Barnett Arthur Miller "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False discussion. Arthur Miller considered 3 [' discussion', '.', ' Arthur', ' Miller']
+1890 411 Name of mother of x -1 Name of mother of Arthur Miller Augusta Barnett Arthur Miller "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False home of playwright Arthur Miller where he was 5 [' home', ' of', ' play', 'wright', ' Arthur', ' Miller']
+1891 411 Name of mother of x -1 Name of mother of Arthur Miller Augusta Barnett Arthur Miller "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False the discussion. Arthur Miller considered 4 [' the', ' discussion', '.', ' Arthur', ' Miller']
+1892 411 Name of mother of x -1 Name of mother of Arthur Miller Augusta Barnett Arthur Miller "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False American playwright Arthur Miller to Turkey in 1985 4 [' American', ' play', 'wright', ' Arthur', ' Miller']
+1893 412 Name of mother of x -1 Name of mother of Anne Frank Edith Frank-Holländer Anne Frank "[',' ' the' ' daughter' ' of' ' a' ' Jewish' ' family' ',' ' was' ' born'
+ ' in' ' Amsterdam' ' in' ' the' ' Netherlands' ' in' ' 1929' '.' ' She'
+ ' was']" , the daughter of a Jewish family , was born in Amsterdam in the Netherlands in 1929 . She was False such as The Diary of Anne Frank (1959) or the 6 [' such', ' as', ' The', ' Diary', ' of', ' Anne', ' Frank']
+1894 412 Name of mother of x -1 Name of mother of Anne Frank Edith Frank-Holländer Anne Frank "[',' ' the' ' daughter' ' of' ' a' ' Jewish' ' family' ',' ' was' ' born'
+ ' in' ' Amsterdam' ' in' ' the' ' Netherlands' ' in' ' 1929' '.' ' She'
+ ' was']" , the daughter of a Jewish family , was born in Amsterdam in the Netherlands in 1929 . She was False " tributes to Anne Frank and Ruby Bridges.
+" 4 [' t', 'ributes', ' to', ' Anne', ' Frank']
+1895 412 Name of mother of x -1 Name of mother of Anne Frank Edith Frank-Holländer Anne Frank "[',' ' the' ' daughter' ' of' ' a' ' Jewish' ' family' ',' ' was' ' born'
+ ' in' ' Amsterdam' ' in' ' the' ' Netherlands' ' in' ' 1929' '.' ' She'
+ ' was']" , the daughter of a Jewish family , was born in Amsterdam in the Netherlands in 1929 . She was False Geiringer-Markovits, set up the Anne Frank Fonds as a charitable 12 [' Ge', 'iring', 'er', '-', 'Mark', 'ov', 'its', ',', ' set', ' up', ' the', ' Anne', ' Frank']
+1896 412 Name of mother of x -1 Name of mother of Anne Frank Edith Frank-Holländer Anne Frank "[',' ' the' ' daughter' ' of' ' a' ' Jewish' ' family' ',' ' was' ' born'
+ ' in' ' Amsterdam' ' in' ' the' ' Netherlands' ' in' ' 1929' '.' ' She'
+ ' was']" , the daughter of a Jewish family , was born in Amsterdam in the Netherlands in 1929 . She was False The Diary of Anne Frank, which was a critical 4 [' The', ' Diary', ' of', ' Anne', ' Frank']
+1897 412 Name of mother of x -1 Name of mother of Anne Frank Edith Frank-Holländer Anne Frank "[',' ' the' ' daughter' ' of' ' a' ' Jewish' ' family' ',' ' was' ' born'
+ ' in' ' Amsterdam' ' in' ' the' ' Netherlands' ' in' ' 1929' '.' ' She'
+ ' was']" , the daughter of a Jewish family , was born in Amsterdam in the Netherlands in 1929 . She was False June 1942, Anne Frank received a book she 4 [' June', ' 1942', ',', ' Anne', ' Frank']
+1898 413 Name of mother of x -1 Name of mother of George III of Great Britain Princess Augusta of Saxe-Gotha George III of Great Britain "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' George' ' III' ' of' ' Great' ' Britain' ',' ' and']" ", and the
+
+ The
+
+ Name of the father of George III of Great Britain , and" False allowance from King George III of Great Britain and sent some 7 [' allowance', ' from', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+1899 413 Name of mother of x -1 Name of mother of George III of Great Britain Princess Augusta of Saxe-Gotha George III of Great Britain "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' George' ' III' ' of' ' Great' ' Britain' ',' ' and']" ", and the
+
+ The
+
+ Name of the father of George III of Great Britain , and" False about the fact King George III of Great Britain was a major shareholder. 8 [' about', ' the', ' fact', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+1900 413 Name of mother of x -1 Name of mother of George III of Great Britain Princess Augusta of Saxe-Gotha George III of Great Britain "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' George' ' III' ' of' ' Great' ' Britain' ',' ' and']" ", and the
+
+ The
+
+ Name of the father of George III of Great Britain , and" False about the fact King George III of Great Britain was a major shareholder. 8 [' about', ' the', ' fact', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+1901 413 Name of mother of x -1 Name of mother of George III of Great Britain Princess Augusta of Saxe-Gotha George III of Great Britain "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' George' ' III' ' of' ' Great' ' Britain' ',' ' and']" ", and the
+
+ The
+
+ Name of the father of George III of Great Britain , and" False allowance from King George III of Great Britain and sent some money 7 [' allowance', ' from', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+1902 413 Name of mother of x -1 Name of mother of George III of Great Britain Princess Augusta of Saxe-Gotha George III of Great Britain "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' George' ' III' ' of' ' Great' ' Britain' ',' ' and']" ", and the
+
+ The
+
+ Name of the father of George III of Great Britain , and" False allowance from King George III of Great Britain and sent some 7 [' allowance', ' from', ' King', ' George', ' III', ' of', ' Great', ' Britain']
+1903 415 Name of mother of x -1 Name of mother of Thomas Henry Huxley Rachel Withers Thomas Henry Huxley "[',' ' the' ' great' ' biologist' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' Darwin' ""'s"" ' work' ',' ' and' ' who' ' was' ' a']" , the great biologist , who was a great admire r of Darwin 's work , and who was a False " ""Darwin's Bulldog"" Thomas Henry Huxley spent the last" 11 "[' ""', 'Dar', 'win', ""'s"", ' Bull', 'dog', '""', ' Thomas', ' Henry', ' H', 'ux', 'ley']"
+1904 415 Name of mother of x -1 Name of mother of Thomas Henry Huxley Rachel Withers Thomas Henry Huxley "[',' ' the' ' great' ' biologist' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' Darwin' ""'s"" ' work' ',' ' and' ' who' ' was' ' a']" , the great biologist , who was a great admire r of Darwin 's work , and who was a False 4 ['Thomas', ' Henry', ' H', 'ux', 'ley']
+1905 415 Name of mother of x -1 Name of mother of Thomas Henry Huxley Rachel Withers Thomas Henry Huxley "[',' ' the' ' great' ' biologist' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' Darwin' ""'s"" ' work' ',' ' and' ' who' ' was' ' a']" , the great biologist , who was a great admire r of Darwin 's work , and who was a False young friend Thomas Henry Huxley was firmly against 6 [' young', ' friend', ' Thomas', ' Henry', ' H', 'ux', 'ley']
+1906 415 Name of mother of x -1 Name of mother of Thomas Henry Huxley Rachel Withers Thomas Henry Huxley "[',' ' the' ' great' ' biologist' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' Darwin' ""'s"" ' work' ',' ' and' ' who' ' was' ' a']" , the great biologist , who was a great admire r of Darwin 's work , and who was a False 4 ['Thomas', ' Henry', ' H', 'ux', 'ley']
+1907 415 Name of mother of x -1 Name of mother of Thomas Henry Huxley Rachel Withers Thomas Henry Huxley "[',' ' the' ' great' ' biologist' ',' ' who' ' was' ' a' ' great'
+ ' admire' 'r' ' of' ' Darwin' ""'s"" ' work' ',' ' and' ' who' ' was' ' a']" , the great biologist , who was a great admire r of Darwin 's work , and who was a False evolution between Thomas Henry Huxley and Richard Owen, 6 [' evolution', ' between', ' Thomas', ' Henry', ' H', 'ux', 'ley']
+1908 416 Name of mother of x -1 Name of mother of Ralph Waldo Emerson Ruth Haskins Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False " Transcendentalist Ralph Waldo Emerson reacted to ""The" 7 [' Trans', 'cend', 'ental', 'ist', ' Ralph', ' Wald', 'o', ' Emerson']
+1909 416 Name of mother of x -1 Name of mother of Ralph Waldo Emerson Ruth Haskins Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False the works, Ralph Waldo Emerson expressed this 6 [' the', ' works', ',', ' Ralph', ' Wald', 'o', ' Emerson']
+1910 416 Name of mother of x -1 Name of mother of Ralph Waldo Emerson Ruth Haskins Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Shakespeare and John Milton. Ralph Waldo Emerson noted that, though 8 [' Shakespeare', ' and', ' John', ' Milton', '.', ' Ralph', ' Wald', 'o', ' Emerson']
+1911 416 Name of mother of x -1 Name of mother of Ralph Waldo Emerson Ruth Haskins Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False author of the works, Ralph Waldo Emerson expressed this disjunction 8 [' author', ' of', ' the', ' works', ',', ' Ralph', ' Wald', 'o', ' Emerson']
+1912 416 Name of mother of x -1 Name of mother of Ralph Waldo Emerson Ruth Haskins Ralph Waldo Emerson "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Transcendentalist Ralph Waldo Emerson reacted to 7 [' Trans', 'cend', 'ental', 'ist', ' Ralph', ' Wald', 'o', ' Emerson']
+1913 417 Name of mother of x -1 Name of mother of Bing Crosby Catherine Helen Harrigan Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' the' ' other' ' one' ' is'
+ ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' famous' ' singer'
+ '.']" , the famous singer , and the other one is the name of the father of the famous singer . False worked on the first Bing Crosby and Bob Hope 5 [' worked', ' on', ' the', ' first', ' Bing', ' Crosby']
+1914 417 Name of mother of x -1 Name of mother of Bing Crosby Catherine Helen Harrigan Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' the' ' other' ' one' ' is'
+ ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' famous' ' singer'
+ '.']" , the famous singer , and the other one is the name of the father of the famous singer . False Perry Como credited Bing Crosby for influencing 5 [' Perry', ' Com', 'o', ' credited', ' Bing', ' Crosby']
+1915 417 Name of mother of x -1 Name of mother of Bing Crosby Catherine Helen Harrigan Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' the' ' other' ' one' ' is'
+ ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' famous' ' singer'
+ '.']" , the famous singer , and the other one is the name of the father of the famous singer . False Hutton Building, Bing Crosby Theater), and Chicago 5 [' H', 'utton', ' Building', ',', ' Bing', ' Crosby']
+1916 417 Name of mother of x -1 Name of mother of Bing Crosby Catherine Helen Harrigan Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' the' ' other' ' one' ' is'
+ ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' famous' ' singer'
+ '.']" , the famous singer , and the other one is the name of the father of the famous singer . False which featured Bing Crosby on vocals. 3 [' which', ' featured', ' Bing', ' Crosby']
+1917 417 Name of mother of x -1 Name of mother of Bing Crosby Catherine Helen Harrigan Bing Crosby "[',' ' the' ' famous' ' singer' ',' ' and' ' the' ' other' ' one' ' is'
+ ' the' ' name' ' of' ' the' ' father' ' of' ' the' ' famous' ' singer'
+ '.']" , the famous singer , and the other one is the name of the father of the famous singer . False Bob Hope and Bing Crosby each served as 4 [' Bob', ' Hope', ' and', ' Bing', ' Crosby']
+1918 418 Name of mother of x -1 Name of mother of Washington Irving Sarah Sanders Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book']" ", the author of the "" S ket ch - book ,"" and the "" S ket ch - book" False American author Washington Irving produced several 3 [' American', ' author', ' Washington', ' Irving']
+1919 418 Name of mother of x -1 Name of mother of Washington Irving Sarah Sanders Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book']" ", the author of the "" S ket ch - book ,"" and the "" S ket ch - book" False 1 ['Washington', ' Irving']
+1920 418 Name of mother of x -1 Name of mother of Washington Irving Sarah Sanders Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book']" ", the author of the "" S ket ch - book ,"" and the "" S ket ch - book" False American Civil War. Washington Irving used it as part of 5 [' American', ' Civil', ' War', '.', ' Washington', ' Irving']
+1921 418 Name of mother of x -1 Name of mother of Washington Irving Sarah Sanders Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book']" ", the author of the "" S ket ch - book ,"" and the "" S ket ch - book" False allusion to the Washington Irving character Ichabod 5 [' all', 'usion', ' to', ' the', ' Washington', ' Irving']
+1922 418 Name of mother of x -1 Name of mother of Washington Irving Sarah Sanders Washington Irving "[',' ' the' ' author' ' of' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book' ',""'
+ ' and' ' the' ' ""' 'S' 'ket' 'ch' '-' 'book']" ", the author of the "" S ket ch - book ,"" and the "" S ket ch - book" False 1 ['Washington', ' Irving']
+1923 419 Name of mother of x -1 Name of mother of Mel Brooks Kate Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False " Has Two Mommies"". Mel Brooks has a cameo appearance" 7 "[' Has', ' Two', ' M', 'omm', 'ies', '"".', ' Mel', ' Brooks']"
+1924 419 Name of mother of x -1 Name of mother of Mel Brooks Kate Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False was taken from the Mel Brooks movie The Producers, 5 [' was', ' taken', ' from', ' the', ' Mel', ' Brooks']
+1925 419 Name of mother of x -1 Name of mother of Mel Brooks Kate Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False taking it from the Mel Brooks film The Producers. 5 [' taking', ' it', ' from', ' the', ' Mel', ' Brooks']
+1926 419 Name of mother of x -1 Name of mother of Mel Brooks Kate Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False worked for Mel Brooks and Brooksfilms 3 [' worked', ' for', ' Mel', ' Brooks']
+1927 419 Name of mother of x -1 Name of mother of Mel Brooks Kate Kaminsky Mel Brooks "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Anne' ' B' 'anc' 'ro'
+ 'ft' ',' ' who' ' was' ' nominated' ' for' ' an' ' Oscar']" , the actor , and his wife , Anne B anc ro ft , who was nominated for an Oscar False taking it from the Mel Brooks film The Producers. 5 [' taking', ' it', ' from', ' the', ' Mel', ' Brooks']
+1928 420 Name of mother of x -1 Name of mother of Ezra Pound Isabel Weston Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False " Stein, James Joyce, and Ezra Pound who ""could help" 7 [' Stein', ',', ' James', ' Joyce', ',', ' and', ' Ezra', ' Pound']
+1929 420 Name of mother of x -1 Name of mother of Ezra Pound Isabel Weston Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Lowell and Ezra Pound found inspiration in 3 [' Lowell', ' and', ' Ezra', ' Pound']
+1930 420 Name of mother of x -1 Name of mother of Ezra Pound Isabel Weston Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False also here that Ezra Pound brought him to the 4 [' also', ' here', ' that', ' Ezra', ' Pound']
+1931 420 Name of mother of x -1 Name of mother of Ezra Pound Isabel Weston Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Modernist poet Ezra Pound called Whitman 4 [' Modern', 'ist', ' poet', ' Ezra', ' Pound']
+1932 420 Name of mother of x -1 Name of mother of Ezra Pound Isabel Weston Ezra Pound "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False factual 1970 Life of Ezra Pound – although the 5 [' factual', ' 1970', ' Life', ' of', ' Ezra', ' Pound']
+1933 421 Name of mother of x -1 Name of mother of Mick Jagger Eva Ensley Mary Scutts Mick Jagger "[',' ' the' ' rock' ' star' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' rock' ' star' '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' rock']" ", the rock star , and the mother of the rock star .
+
+ The mother of the rock" False Morrison, Ronnie Wood, Mick Jagger and Roger Daltrey. 7 [' Morrison', ',', ' Ronnie', ' Wood', ',', ' Mick', ' J', 'agger']
+1934 421 Name of mother of x -1 Name of mother of Mick Jagger Eva Ensley Mary Scutts Mick Jagger "[',' ' the' ' rock' ' star' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' rock' ' star' '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' rock']" ", the rock star , and the mother of the rock star .
+
+ The mother of the rock" False already gathered. Mick Jagger read a short 5 [' already', ' gathered', '.', ' Mick', ' J', 'agger']
+1935 421 Name of mother of x -1 Name of mother of Mick Jagger Eva Ensley Mary Scutts Mick Jagger "[',' ' the' ' rock' ' star' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' rock' ' star' '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' rock']" ", the rock star , and the mother of the rock star .
+
+ The mother of the rock" False Led Zeppelin, Mick Jagger and Keith Richards 6 [' Led', ' Ze', 'ppelin', ',', ' Mick', ' J', 'agger']
+1936 421 Name of mother of x -1 Name of mother of Mick Jagger Eva Ensley Mary Scutts Mick Jagger "[',' ' the' ' rock' ' star' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' rock' ' star' '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' rock']" ", the rock star , and the mother of the rock star .
+
+ The mother of the rock" False 3 ['M', 'ick', ' J', 'agger']
+1937 421 Name of mother of x -1 Name of mother of Mick Jagger Eva Ensley Mary Scutts Mick Jagger "[',' ' the' ' rock' ' star' ',' ' and' ' the' ' mother' ' of' ' the'
+ ' rock' ' star' '.' '\n' '\n' 'The' ' mother' ' of' ' the' ' rock']" ", the rock star , and the mother of the rock star .
+
+ The mother of the rock" False Munich, and rock star Mick Jagger along with his wife 7 [' Munich', ',', ' and', ' rock', ' star', ' Mick', ' J', 'agger']
+1938 423 Name of mother of x -1 Name of mother of Sandra Bullock Helga Meyer Sandra Bullock "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Blind']" , the actress who played the role of the mother of the bride in the movie � � The Blind False " producers on the Sandra Bullock film, The Proposal.
+" 5 [' producers', ' on', ' the', ' Sandra', ' Bull', 'ock']
+1939 423 Name of mother of x -1 Name of mother of Sandra Bullock Helga Meyer Sandra Bullock "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Blind']" , the actress who played the role of the mother of the bride in the movie � � The Blind False then approached Sandra Bullock for the role. 4 [' then', ' approached', ' Sandra', ' Bull', 'ock']
+1940 423 Name of mother of x -1 Name of mother of Sandra Bullock Helga Meyer Sandra Bullock "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Blind']" , the actress who played the role of the mother of the bride in the movie � � The Blind False Bont and McCormick. Sandra Bullock stars in the film, 8 [' B', 'ont', ' and', ' McCorm', 'ick', '.', ' Sandra', ' Bull', 'ock']
+1941 423 Name of mother of x -1 Name of mother of Sandra Bullock Helga Meyer Sandra Bullock "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Blind']" , the actress who played the role of the mother of the bride in the movie � � The Blind False Proposal, starring Sandra Bullock and Ryan Reynolds 6 [' Pro', 'posal', ',', ' starring', ' Sandra', ' Bull', 'ock']
+1942 423 Name of mother of x -1 Name of mother of Sandra Bullock Helga Meyer Sandra Bullock "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'The'
+ ' Blind']" , the actress who played the role of the mother of the bride in the movie � � The Blind False and McCormick. Sandra Bullock stars in the 6 [' and', ' McCorm', 'ick', '.', ' Sandra', ' Bull', 'ock']
+1943 424 Name of mother of x -1 Name of mother of Marcel Proust Jeanne-Clémence Proust Marcel Proust "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' '�' '�' ' la'
+ ' rec' 'her' 'che' ' du' ' tem' 'ps' ' per' 'du' '_']" , the author of the famous novel _ � � la rec her che du tem ps per du _ False 4 ['Mar', 'cel', ' P', 'rou', 'st']
+1944 424 Name of mother of x -1 Name of mother of Marcel Proust Jeanne-Clémence Proust Marcel Proust "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' '�' '�' ' la'
+ ' rec' 'her' 'che' ' du' ' tem' 'ps' ' per' 'du' '_']" , the author of the famous novel _ � � la rec her che du tem ps per du _ False have included Marcel Proust and André Gide; 5 [' have', ' included', ' Marcel', ' P', 'rou', 'st']
+1945 424 Name of mother of x -1 Name of mother of Marcel Proust Jeanne-Clémence Proust Marcel Proust "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' '�' '�' ' la'
+ ' rec' 'her' 'che' ' du' ' tem' 'ps' ' per' 'du' '_']" , the author of the famous novel _ � � la rec her che du tem ps per du _ False 4 ['Mar', 'cel', ' P', 'rou', 'st']
+1946 424 Name of mother of x -1 Name of mother of Marcel Proust Jeanne-Clémence Proust Marcel Proust "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' '�' '�' ' la'
+ ' rec' 'her' 'che' ' du' ' tem' 'ps' ' per' 'du' '_']" , the author of the famous novel _ � � la rec her che du tem ps per du _ False " which French novelist Marcel Proust later termed ""retrospective" 6 [' which', ' French', ' novelist', ' Marcel', ' P', 'rou', 'st']
+1947 424 Name of mother of x -1 Name of mother of Marcel Proust Jeanne-Clémence Proust Marcel Proust "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' _' '�' '�' ' la'
+ ' rec' 'her' 'che' ' du' ' tem' 'ps' ' per' 'du' '_']" , the author of the famous novel _ � � la rec her che du tem ps per du _ False care for it. Marcel Proust never attended 7 [' care', ' for', ' it', '.', ' Marcel', ' P', 'rou', 'st']
+1948 426 Name of mother of x -1 Name of mother of Sting Audrey Cowell Sting "['ray' ',' ' the' ' cat' ',' ' and' ' the' ' cat' ""'s"" ' name' ' is'
+ ' Sting' 'ray' '.' '\n' '\n' 'S' 'ting' 'ray' ' is']" "ray , the cat , and the cat 's name is Sting ray .
+
+ S ting ray is" False Samoa Joe, Rhino, Sting, and Jeff Jarrett 5 [' Samoa', ' Joe', ',', ' Rhino', ',', ' Sting']
+1949 426 Name of mother of x -1 Name of mother of Sting Audrey Cowell Sting "['ray' ',' ' the' ' cat' ',' ' and' ' the' ' cat' ""'s"" ' name' ' is'
+ ' Sting' 'ray' '.' '\n' '\n' 'S' 'ting' 'ray' ' is']" "ray , the cat , and the cat 's name is Sting ray .
+
+ S ting ray is" False " torpedoes, Sting Ray torpedoes
+" 3 [' torped', 'oes', ',', ' Sting']
+1950 426 Name of mother of x -1 Name of mother of Sting Audrey Cowell Sting "['ray' ',' ' the' ' cat' ',' ' and' ' the' ' cat' ""'s"" ' name' ' is'
+ ' Sting' 'ray' '.' '\n' '\n' 'S' 'ting' 'ray' ' is']" "ray , the cat , and the cat 's name is Sting ray .
+
+ S ting ray is" False Tonic, Filter, Sting and Aswad had 5 [' T', 'onic', ',', ' Filter', ',', ' Sting']
+1951 426 Name of mother of x -1 Name of mother of Sting Audrey Cowell Sting "['ray' ',' ' the' ' cat' ',' ' and' ' the' ' cat' ""'s"" ' name' ' is'
+ ' Sting' 'ray' '.' '\n' '\n' 'S' 'ting' 'ray' ' is']" "ray , the cat , and the cat 's name is Sting ray .
+
+ S ting ray is" False " treatment ===
+" 3 [' treatment', ' ===', 'S', 'ting']
+1952 426 Name of mother of x -1 Name of mother of Sting Audrey Cowell Sting "['ray' ',' ' the' ' cat' ',' ' and' ' the' ' cat' ""'s"" ' name' ' is'
+ ' Sting' 'ray' '.' '\n' '\n' 'S' 'ting' 'ray' ' is']" "ray , the cat , and the cat 's name is Sting ray .
+
+ S ting ray is" False after interference from Sting and various security 3 [' after', ' interference', ' from', ' Sting']
+1953 427 Name of mother of x -1 Name of mother of Elizabeth I of England Anne Boleyn Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' daughter' ' of' ' a'
+ ' king' ',' ' and' ' I' ' am' ' the' ' mother' ' of' ' a']" ", and the
+
+ I am the daughter of a king , and I am the mother of a" False the 16th century. Elizabeth I of England received a wristwatch 8 [' the', ' 16', 'th', ' century', '.', ' Elizabeth', ' I', ' of', ' England']
+1954 427 Name of mother of x -1 Name of mother of Elizabeth I of England Anne Boleyn Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' daughter' ' of' ' a'
+ ' king' ',' ' and' ' I' ' am' ' the' ' mother' ' of' ' a']" ", and the
+
+ I am the daughter of a king , and I am the mother of a" False December 1600, Queen Elizabeth I of England (1533 – 1603) granted 7 [' December', ' 1600', ',', ' Queen', ' Elizabeth', ' I', ' of', ' England']
+1955 427 Name of mother of x -1 Name of mother of Elizabeth I of England Anne Boleyn Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' daughter' ' of' ' a'
+ ' king' ',' ' and' ' I' ' am' ' the' ' mother' ' of' ' a']" ", and the
+
+ I am the daughter of a king , and I am the mother of a" False with Spain. Queen Elizabeth I of England chose to support 7 [' with', ' Spain', '.', ' Queen', ' Elizabeth', ' I', ' of', ' England']
+1956 427 Name of mother of x -1 Name of mother of Elizabeth I of England Anne Boleyn Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' daughter' ' of' ' a'
+ ' king' ',' ' and' ' I' ' am' ' the' ' mother' ' of' ' a']" ", and the
+
+ I am the daughter of a king , and I am the mother of a" False 1585, Queen Elizabeth I of England took the United Provinces 7 [' 15', '85', ',', ' Queen', ' Elizabeth', ' I', ' of', ' England']
+1957 427 Name of mother of x -1 Name of mother of Elizabeth I of England Anne Boleyn Elizabeth I of England "[',' ' and' ' the' '\n' '\n' 'I' ' am' ' the' ' daughter' ' of' ' a'
+ ' king' ',' ' and' ' I' ' am' ' the' ' mother' ' of' ' a']" ", and the
+
+ I am the daughter of a king , and I am the mother of a" False " England =
+" 5 [' England', ' =', 'Elizabeth', ' I', ' of', ' England']
+1958 428 Name of mother of x -1 Name of mother of Willy Brandt Martha Frahm Willy Brandt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False Indeed, since Willy Brandt had become chancellor 6 [' Indeed', ',', ' since', ' Will', 'y', ' Brand', 't']
+1959 428 Name of mother of x -1 Name of mother of Willy Brandt Martha Frahm Willy Brandt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False Chancellor Willy Brandt had paid a visit to 4 [' Chancellor', ' Will', 'y', ' Brand', 't']
+1960 428 Name of mother of x -1 Name of mother of Willy Brandt Martha Frahm Willy Brandt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False Hartley Shawcross. Willy Brandt was a strong 8 [' Hart', 'ley', ' Shaw', 'cross', '.', ' Will', 'y', ' Brand', 't']
+1961 428 Name of mother of x -1 Name of mother of Willy Brandt Martha Frahm Willy Brandt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False Britain and Willy Brandt in Germany, but to 5 [' Britain', ' and', ' Will', 'y', ' Brand', 't']
+1962 428 Name of mother of x -1 Name of mother of Willy Brandt Martha Frahm Willy Brandt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German' ' people'
+ '.' '\n' '\n' 'The' ' German' ' people' ' are' ' not' ' the' ' only'
+ ' ones']" ", the German chancellor , and the German people .
+
+ The German people are not the only ones" False as Chancellor Willy Brandt had paid a visit to 5 [' as', ' Chancellor', ' Will', 'y', ' Brand', 't']
+1963 429 Name of mother of x -1 Name of mother of Grace Kelly Margaret Majer Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' ' She' ' was' ' a'
+ ' beautiful']" , the daughter of a wealthy family , and the daughter of a wealthy family . She was a beautiful False Stewart gave actress Grace Kelly in the 1954 film 4 [' Stewart', ' gave', ' actress', ' Grace', ' Kelly']
+1964 429 Name of mother of x -1 Name of mother of Grace Kelly Margaret Majer Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' ' She' ' was' ' a'
+ ' beautiful']" , the daughter of a wealthy family , and the daughter of a wealthy family . She was a beautiful False company. He dated Grace Kelly and Joan Bennett 5 [' company', '.', ' He', ' dated', ' Grace', ' Kelly']
+1965 429 Name of mother of x -1 Name of mother of Grace Kelly Margaret Majer Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' ' She' ' was' ' a'
+ ' beautiful']" , the daughter of a wealthy family , and the daughter of a wealthy family . She was a beautiful False was won, however, by Grace Kelly for The Country 7 [' was', ' won', ',', ' however', ',', ' by', ' Grace', ' Kelly']
+1966 429 Name of mother of x -1 Name of mother of Grace Kelly Margaret Majer Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' ' She' ' was' ' a'
+ ' beautiful']" , the daughter of a wealthy family , and the daughter of a wealthy family . She was a beautiful False Noon (1952) with Grace Kelly for United 7 [' Noon', ' (', '19', '52', ')', ' with', ' Grace', ' Kelly']
+1967 429 Name of mother of x -1 Name of mother of Grace Kelly Margaret Majer Grace Kelly "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ' family' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' wealthy' ' family' '.' ' She' ' was' ' a'
+ ' beautiful']" , the daughter of a wealthy family , and the daughter of a wealthy family . She was a beautiful False (1952) with Grace Kelly for United Artists. 6 [' (', '19', '52', ')', ' with', ' Grace', ' Kelly']
+1968 430 Name of mother of x -1 Name of mother of Dennis Hopper Marjorie Mae Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' first' ' of' ' the' ' Ho' 'pper'
+ ' children' ',' ' was' ' a' ' man' ' of' ' the' ' cloth' '.' ' He']" , the father of the first of the Ho pper children , was a man of the cloth . He False departure, Des McAnuff, Dennis Hopper and Milčo Mančevski 9 [' departure', ',', ' Des', ' Mc', 'An', 'uff', ',', ' Dennis', ' Ho', 'pper']
+1969 430 Name of mother of x -1 Name of mother of Dennis Hopper Marjorie Mae Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' first' ' of' ' the' ' Ho' 'pper'
+ ' children' ',' ' was' ' a' ' man' ' of' ' the' ' cloth' '.' ' He']" , the father of the first of the Ho pper children , was a man of the cloth . He False Peter Fonda, Dennis Hopper (Director) and Jack 6 [' Peter', ' F', 'onda', ',', ' Dennis', ' Ho', 'pper']
+1970 430 Name of mother of x -1 Name of mother of Dennis Hopper Marjorie Mae Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' first' ' of' ' the' ' Ho' 'pper'
+ ' children' ',' ' was' ' a' ' man' ' of' ' the' ' cloth' '.' ' He']" , the father of the first of the Ho pper children , was a man of the cloth . He False success for Warhol. Dennis Hopper was the first 7 [' success', ' for', ' War', 'hol', '.', ' Dennis', ' Ho', 'pper']
+1971 430 Name of mother of x -1 Name of mother of Dennis Hopper Marjorie Mae Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' first' ' of' ' the' ' Ho' 'pper'
+ ' children' ',' ' was' ' a' ' man' ' of' ' the' ' cloth' '.' ' He']" , the father of the first of the Ho pper children , was a man of the cloth . He False departure, Des McAnuff, Dennis Hopper and Milčo 9 [' departure', ',', ' Des', ' Mc', 'An', 'uff', ',', ' Dennis', ' Ho', 'pper']
+1972 430 Name of mother of x -1 Name of mother of Dennis Hopper Marjorie Mae Hopper Dennis Hopper "[',' ' the' ' father' ' of' ' the' ' first' ' of' ' the' ' Ho' 'pper'
+ ' children' ',' ' was' ' a' ' man' ' of' ' the' ' cloth' '.' ' He']" , the father of the first of the Ho pper children , was a man of the cloth . He False 3 ['D', 'ennis', ' Ho', 'pper']
+1973 431 Name of mother of x -1 Name of mother of Harriet Beecher Stowe Roxana Foote Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False Uncle Tom's Cabin by Harriet Beecher Stowe was the best-selling 10 "[' Uncle', ' Tom', ""'s"", ' Cabin', ' by', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']"
+1974 431 Name of mother of x -1 Name of mother of Harriet Beecher Stowe Roxana Foote Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False Tom's Cabin by Harriet Beecher Stowe was the best-selling 9 "[' Tom', ""'s"", ' Cabin', ' by', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']"
+1975 431 Name of mother of x -1 Name of mother of Harriet Beecher Stowe Roxana Foote Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False guide written by Harriet Beecher Stowe about her 8 [' guide', ' written', ' by', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']
+1976 431 Name of mother of x -1 Name of mother of Harriet Beecher Stowe Roxana Foote Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False however, famed author Harriet Beecher Stowe lived near Jacksonville 9 [' however', ',', ' famed', ' author', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']
+1977 431 Name of mother of x -1 Name of mother of Harriet Beecher Stowe Roxana Foote Beecher Harriet Beecher Stowe "[',' ' the' ' author' ' of' ' Uncle' ' Tom' ""'s"" ' Cabin' ',' ' and'
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of Uncle Tom 's Cabin , and the
+ " False freedom in Canada. Harriet Beecher Stowe was living 9 [' freedom', ' in', ' Canada', '.', ' Harriet', ' Be', 'ec', 'her', ' St', 'owe']
+1978 432 Name of mother of x -1 Name of mother of Geoffrey Chaucer Agnes Copton Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's mother , and the poet 's father , and the poet 's False of Ovid's material. Geoffrey Chaucer recounted the 8 "[' of', ' O', 'vid', ""'s"", ' material', '.', ' Geoffrey', ' Chau', 'cer']"
+1979 432 Name of mother of x -1 Name of mother of Geoffrey Chaucer Agnes Copton Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's mother , and the poet 's father , and the poet 's False importance, from Geoffrey Chaucer to Beatrix Potter, 5 [' importance', ',', ' from', ' Geoffrey', ' Chau', 'cer']
+1980 432 Name of mother of x -1 Name of mother of Geoffrey Chaucer Agnes Copton Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's mother , and the poet 's father , and the poet 's False The work of Geoffrey Chaucer from the 1370s 5 [' The', ' work', ' of', ' Geoffrey', ' Chau', 'cer']
+1981 432 Name of mother of x -1 Name of mother of Geoffrey Chaucer Agnes Copton Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's mother , and the poet 's father , and the poet 's False technical treatise by Geoffrey Chaucer (modern Antarctic 6 [' technical', ' treat', 'ise', ' by', ' Geoffrey', ' Chau', 'cer']
+1982 432 Name of mother of x -1 Name of mother of Geoffrey Chaucer Agnes Copton Geoffrey Chaucer "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' father' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's mother , and the poet 's father , and the poet 's False 14th-century Italy, Geoffrey Chaucer (d. 1400) and William 8 [' 14', 'th', '-', 'century', ' Italy', ',', ' Geoffrey', ' Chau', 'cer']
+1983 433 Name of mother of x -1 Name of mother of John Steinbeck Olive Hamilton John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Nobel Prize-winner John Steinbeck wrote about 6 [' Nobel', ' Prize', '-', 'winner', ' John', ' Stein', 'beck']
+1984 433 Name of mother of x -1 Name of mother of John Steinbeck Olive Hamilton John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Ernest Hemingway and John Steinbeck are often named 7 [' Ernest', ' Hem', 'ing', 'way', ' and', ' John', ' Stein', 'beck']
+1985 433 Name of mother of x -1 Name of mother of John Steinbeck Olive Hamilton John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Hemingway and John Steinbeck are often named among 6 [' Hem', 'ing', 'way', ' and', ' John', ' Stein', 'beck']
+1986 433 Name of mother of x -1 Name of mother of John Steinbeck Olive Hamilton John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Robert Frost, John Steinbeck and E.M. Forster. The 5 [' Robert', ' Frost', ',', ' John', ' Stein', 'beck']
+1987 433 Name of mother of x -1 Name of mother of John Steinbeck Olive Hamilton John Steinbeck "[""'s"" ' _' 'The' ' G' 'rap' 'es' ' of' ' Wrath' '_' ',' ' and' ' the' ' _'
+ 'New' ' York' ' Times' '_' ' called' ' it' ' ""']" "'s _ The G rap es of Wrath _ , and the _ New York Times _ called it """ False Nobel Prize-winner John Steinbeck wrote about 6 [' Nobel', ' Prize', '-', 'winner', ' John', ' Stein', 'beck']
+1988 434 Name of mother of x -1 Name of mother of Marie Antoinette Maria Theresa of Austria Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',']" ", the Queen of France , and the
+
+ Queen of England , and the Queen of England ," False General Lafayette, Marie Antoinette and Louis XVI, and 6 [' General', ' Lafayette', ',', ' Marie', ' Ant', 'oin', 'ette']
+1989 434 Name of mother of x -1 Name of mother of Marie Antoinette Maria Theresa of Austria Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',']" ", the Queen of France , and the
+
+ Queen of England , and the Queen of England ," False " John Diefenbaker as Marie Antoinette saying ""Let" 10 [' John', ' D', 'ief', 'en', 'b', 'aker', ' as', ' Marie', ' Ant', 'oin', 'ette']
+1990 434 Name of mother of x -1 Name of mother of Marie Antoinette Maria Theresa of Austria Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',']" ", the Queen of France , and the
+
+ Queen of England , and the Queen of England ," False reception room, the Marie Antoinette Suite near 7 [' reception', ' room', ',', ' the', ' Marie', ' Ant', 'oin', 'ette']
+1991 434 Name of mother of x -1 Name of mother of Marie Antoinette Maria Theresa of Austria Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',']" ", the Queen of France , and the
+
+ Queen of England , and the Queen of England ," False begun, Louis XVI and Marie Antoinette had been tried 8 [' begun', ',', ' Louis', ' XVI', ' and', ' Marie', ' Ant', 'oin', 'ette']
+1992 434 Name of mother of x -1 Name of mother of Marie Antoinette Maria Theresa of Austria Marie Antoinette "[',' ' the' ' Queen' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'Queen'
+ ' of' ' England' ',' ' and' ' the' ' Queen' ' of' ' England' ',']" ", the Queen of France , and the
+
+ Queen of England , and the Queen of England ," False set-piece of Louis XVI and Marie Antoinette forced from their 10 [' set', '-', 'piece', ' of', ' Louis', ' XVI', ' and', ' Marie', ' Ant', 'oin', 'ette']
+1993 435 Name of mother of x -1 Name of mother of James VI and I Mary, Queen of Scots James VI and I "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False granddaughter of James VI and I through his 5 [' granddaughter', ' of', ' James', ' VI', ' and', ' I']
+1994 435 Name of mother of x -1 Name of mother of James VI and I Mary, Queen of Scots James VI and I "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False auspices of King James VI and I the Authorised 7 [' ausp', 'ices', ' of', ' King', ' James', ' VI', ' and', ' I']
+1995 435 Name of mother of x -1 Name of mother of James VI and I Mary, Queen of Scots James VI and I "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False " James VI and I =
+" 3 [' James', ' VI', ' and', ' I']
+1996 435 Name of mother of x -1 Name of mother of James VI and I Mary, Queen of Scots James VI and I "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False granddaughter of James VI and I through his eldest 5 [' granddaughter', ' of', ' James', ' VI', ' and', ' I']
+1997 435 Name of mother of x -1 Name of mother of James VI and I Mary, Queen of Scots James VI and I "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False auspices of King James VI and I the Authorised King 7 [' ausp', 'ices', ' of', ' King', ' James', ' VI', ' and', ' I']
+1998 436 Name of mother of x -1 Name of mother of Muhammad Aminah Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False as Babe Ruth and Muhammad Ali. Jordan placed 4 [' as', ' Babe', ' Ruth', ' and', ' Muhammad']
+1999 436 Name of mother of x -1 Name of mother of Muhammad Aminah Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False microfinance, such as Muhammad Yunus of Grameen 6 [' micro', 'f', 'inance', ',', ' such', ' as', ' Muhammad']
+2000 436 Name of mother of x -1 Name of mother of Muhammad Aminah Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False farewell hajj of Muhammad. During which 4 [' farewell', ' ha', 'jj', ' of', ' Muhammad']
+2001 436 Name of mother of x -1 Name of mother of Muhammad Aminah Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False acquainted with Nur Muhammad Taraki, a communist. 3 [' acquainted', ' with', ' Nur', ' Muhammad']
+2002 436 Name of mother of x -1 Name of mother of Muhammad Aminah Muhammad "[' Ali' ',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ','
+ ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great']" Ali , the great boxer , the great boxer , the great boxer , the great boxer , the great False Sultan Ala ad-Din Muhammad succumbed to disease 6 [' Sultan', ' Ala', ' ad', '-', 'D', 'in', ' Muhammad']
+2003 437 Name of mother of x -1 Name of mother of Samuel Taylor Coleridge Anne Bowden Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False Hertford in 1680. Samuel Taylor Coleridge (1772 – 1834) 10 [' Hert', 'ford', ' in', ' 16', '80', '.', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2004 437 Name of mother of x -1 Name of mother of Samuel Taylor Coleridge Anne Bowden Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False Wordsworth and Samuel Taylor Coleridge that was both Wordsworth's 7 [' Word', 'sworth', ' and', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2005 437 Name of mother of x -1 Name of mother of Samuel Taylor Coleridge Anne Bowden Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False ministry, the poet Samuel Taylor Coleridge occasionally preached 8 [' ministry', ',', ' the', ' poet', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2006 437 Name of mother of x -1 Name of mother of Samuel Taylor Coleridge Anne Bowden Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False poem written by Samuel Taylor Coleridge in 1795 and 7 [' poem', ' written', ' by', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2007 437 Name of mother of x -1 Name of mother of Samuel Taylor Coleridge Anne Bowden Samuel Taylor Coleridge "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 17' '72' ',' ' and' ' died'
+ ' in' ' 18' '34' '.' '\n' '\n' 'The' ' following']" ", the poet , was born in 17 72 , and died in 18 34 .
+
+ The following" False heaped upon her by Samuel Taylor Coleridge and William 9 [' he', 'aped', ' upon', ' her', ' by', ' Samuel', ' Taylor', ' Col', 'er', 'idge']
+2008 438 Name of mother of x -1 Name of mother of Tennessee Williams Edwina Estelle Dakin Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False had a quote from Tennessee Williams taped on my 5 [' had', ' a', ' quote', ' from', ' Tennessee', ' Williams']
+2009 438 Name of mother of x -1 Name of mother of Tennessee Williams Edwina Estelle Dakin Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False I had a quote from Tennessee Williams taped on my 6 [' I', ' had', ' a', ' quote', ' from', ' Tennessee', ' Williams']
+2010 438 Name of mother of x -1 Name of mother of Tennessee Williams Edwina Estelle Dakin Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False over playwright Tennessee Williams and their experience 4 [' over', ' play', 'wright', ' Tennessee', ' Williams']
+2011 438 Name of mother of x -1 Name of mother of Tennessee Williams Edwina Estelle Dakin Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False Arthur Miller and Tennessee Williams as influences 4 [' Arthur', ' Miller', ' and', ' Tennessee', ' Williams']
+2012 438 Name of mother of x -1 Name of mother of Tennessee Williams Edwina Estelle Dakin Tennessee Williams "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False great-grandnephew of Tennessee Williams on his father's side. 8 [' great', '-', 'grand', 'n', 'ep', 'hew', ' of', ' Tennessee', ' Williams']
+2013 440 Name of mother of x -1 Name of mother of Friedrich Engels Elisabeth Franziska Mauritia van Haar Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' to' ' be' ' done' ' is' ' to' ' get' ' rid' ' of']" ", the father of the proletariat .
+
+ The first thing to be done is to get rid of" False collaborated with Friedrich Engels on his work 3 [' collaborated', ' with', ' Friedrich', ' Engels']
+2014 440 Name of mother of x -1 Name of mother of Friedrich Engels Elisabeth Franziska Mauritia van Haar Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' to' ' be' ' done' ' is' ' to' ' get' ' rid' ' of']" ", the father of the proletariat .
+
+ The first thing to be done is to get rid of" False Marianne Weber and Friedrich Engels as founders 5 [' Marian', 'ne', ' Weber', ' and', ' Friedrich', ' Engels']
+2015 440 Name of mother of x -1 Name of mother of Friedrich Engels Elisabeth Franziska Mauritia van Haar Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' to' ' be' ' done' ' is' ' to' ' get' ' rid' ' of']" ", the father of the proletariat .
+
+ The first thing to be done is to get rid of" False introduction, Friedrich Engels described the 3 [' introduction', ',', ' Friedrich', ' Engels']
+2016 440 Name of mother of x -1 Name of mother of Friedrich Engels Elisabeth Franziska Mauritia van Haar Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' to' ' be' ' done' ' is' ' to' ' get' ' rid' ' of']" ", the father of the proletariat .
+
+ The first thing to be done is to get rid of" False International, about whom Friedrich Engels once had written, 5 [' International', ',', ' about', ' whom', ' Friedrich', ' Engels']
+2017 440 Name of mother of x -1 Name of mother of Friedrich Engels Elisabeth Franziska Mauritia van Haar Friedrich Engels "[',' ' the' ' father' ' of' ' the' ' proletariat' '.' '\n' '\n' 'The'
+ ' first' ' thing' ' to' ' be' ' done' ' is' ' to' ' get' ' rid' ' of']" ", the father of the proletariat .
+
+ The first thing to be done is to get rid of" False safety protection. Friedrich Engels in his The 4 [' safety', ' protection', '.', ' Friedrich', ' Engels']
+2018 441 Name of mother of x -1 Name of mother of George V Alexandra of Denmark George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' cinema' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the cinema ." False Tenacious. King George V did not participate 5 [' Ten', 'acious', '.', ' King', ' George', ' V']
+2019 441 Name of mother of x -1 Name of mother of George V Alexandra of Denmark George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' cinema' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the cinema ." False last ship of the King George V class, was laid 6 [' last', ' ship', ' of', ' the', ' King', ' George', ' V']
+2020 441 Name of mother of x -1 Name of mother of George V Alexandra of Denmark George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' cinema' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the cinema ." False the Coronation of George V of the United 6 [' the', ' Cor', 'on', 'ation', ' of', ' George', ' V']
+2021 441 Name of mother of x -1 Name of mother of George V Alexandra of Denmark George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' cinema' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the cinema ." False battleship class: the King George V class. The calibre 7 [' battles', 'hip', ' class', ':', ' the', ' King', ' George', ' V']
+2022 441 Name of mother of x -1 Name of mother of George V Alexandra of Denmark George V "['.' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I'
+ ' was' ' in' ' the' ' audience' ' at' ' the' ' cinema' '.']" ".
+
+ The first time I saw the movie , I was in the audience at the cinema ." False Buckingham Palace by King George V on 15 September 1916. 5 [' Buckingham', ' Palace', ' by', ' King', ' George', ' V']
+2023 442 Name of mother of x -1 Name of mother of Constantine the Great Helena Augusta Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Constantine' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Constantine the Great , and the name of the" False 4 ['Con', 'stant', 'ine', ' the', ' Great']
+2024 442 Name of mother of x -1 Name of mother of Constantine the Great Helena Augusta Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Constantine' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Constantine the Great , and the name of the" False was carried by Constantine the Great during the founding 5 [' was', ' carried', ' by', ' Constantine', ' the', ' Great']
+2025 442 Name of mother of x -1 Name of mother of Constantine the Great Helena Augusta Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Constantine' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Constantine the Great , and the name of the" False Empire, under Emperor Constantine the Great (reigned 306 6 [' Empire', ',', ' under', ' Emperor', ' Constantine', ' the', ' Great']
+2026 442 Name of mother of x -1 Name of mother of Constantine the Great Helena Augusta Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Constantine' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Constantine the Great , and the name of the" False Constantine I known as Constantine the Great — another early Roman 6 [' Constantine', ' I', ' known', ' as', ' Constantine', ' the', ' Great']
+2027 442 Name of mother of x -1 Name of mother of Constantine the Great Helena Augusta Constantine the Great "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Constantine' ' the' ' Great' ',' ' and' ' the' ' name' ' of' ' the']" ", and the
+
+ The name of the mother of Constantine the Great , and the name of the" False 4 ['Con', 'stant', 'ine', ' the', ' Great']
+2028 443 Name of mother of x -1 Name of mother of Tenzin Gyatso Diki Tsering Tenzin Gyatso "[',' ' the' ' Dalai' ' Lama' ',' ' the' ' 14' 'th' ' Dalai' ' Lama' ','
+ ' the' ' Dalai' ' Lama' ',' ' the' ' 14' 'th' ' Dalai' ' Lama']" , the Dalai Lama , the 14 th Dalai Lama , the Dalai Lama , the 14 th Dalai Lama False Dalai Lama, Tenzin Gyatso in Einsiedeln, 7 [' Dalai', ' Lama', ',', ' Ten', 'zin', ' Gy', 'at', 'so']
+2029 444 Name of mother of x -1 Name of mother of Tomáš Garrigue Masaryk Terezie Masaryková Tomáš Garrigue Masaryk "[',' ' the' ' Czech' 'oslov' 'ak' 'ian' ' states' 'man' ' and'
+ ' philosopher' ',' ' who' ' was' ' the' ' first' ' president' ' of'
+ ' Czech' 'oslov' 'akia']" , the Czech oslov ak ian states man and philosopher , who was the first president of Czech oslov akia False Karamchand Gandhi of India, Tomáš Garrigue Masaryk of Czechoslovakia, 15 [' Kar', 'am', 'ch', 'and', ' Gandhi', ' of', ' India', ',', ' Tom', 'á', 'š', ' Garr', 'igue', ' Mas', 'ary', 'k']
+2030 444 Name of mother of x -1 Name of mother of Tomáš Garrigue Masaryk Terezie Masaryková Tomáš Garrigue Masaryk "[',' ' the' ' Czech' 'oslov' 'ak' 'ian' ' states' 'man' ' and'
+ ' philosopher' ',' ' who' ' was' ' the' ' first' ' president' ' of'
+ ' Czech' 'oslov' 'akia']" , the Czech oslov ak ian states man and philosopher , who was the first president of Czech oslov akia False Gandhi of India, Tomáš Garrigue Masaryk of Czechoslovakia, 11 [' Gandhi', ' of', ' India', ',', ' Tom', 'á', 'š', ' Garr', 'igue', ' Mas', 'ary', 'k']
+2031 445 Name of mother of x -1 Name of mother of Francisco Franco Pilar Bahamonde y Pardo de Andrade Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False and especially of Francisco Franco (1939 – 1975), 4 [' and', ' especially', ' of', ' Francisco', ' Franco']
+2032 445 Name of mother of x -1 Name of mother of Francisco Franco Pilar Bahamonde y Pardo de Andrade Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False Gil-Robles and General Francisco Franco had approached 7 [' Gil', '-', 'Rob', 'les', ' and', ' General', ' Francisco', ' Franco']
+2033 445 Name of mother of x -1 Name of mother of Francisco Franco Pilar Bahamonde y Pardo de Andrade Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False Spain. General Francisco Franco was put in informal 4 [' Spain', '.', ' General', ' Francisco', ' Franco']
+2034 445 Name of mother of x -1 Name of mother of Francisco Franco Pilar Bahamonde y Pardo de Andrade Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False the governments of Francisco Franco of Spain and 4 [' the', ' governments', ' of', ' Francisco', ' Franco']
+2035 445 Name of mother of x -1 Name of mother of Francisco Franco Pilar Bahamonde y Pardo de Andrade Francisco Franco "[',' ' the' ' Spanish' ' dictator' ',' ' and' ' the' ' Spanish' ' Civil'
+ ' War' '.' '\n' '\n' 'The' ' Spanish' ' Civil' ' War' ' was' ' a' ' war']" ", the Spanish dictator , and the Spanish Civil War .
+
+ The Spanish Civil War was a war" False spring, but General Francisco Franco hesitated until 5 [' spring', ',', ' but', ' General', ' Francisco', ' Franco']
+2036 446 Name of mother of x -1 Name of mother of Georges Bizet Aimée Léopoldine Joséphine Delsarte Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' ' great' ' composer' ',' ' and']" , the French composer , was born in Paris in 18 38 . He was a great composer , and False work, the composer Georges Bizet was disappointed 8 [' work', ',', ' the', ' composer', ' Georg', 'es', ' B', 'iz', 'et']
+2037 446 Name of mother of x -1 Name of mother of Georges Bizet Aimée Léopoldine Joséphine Delsarte Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' ' great' ' composer' ',' ' and']" , the French composer , was born in Paris in 18 38 . He was a great composer , and False established a fund for a Georges Bizet prize, to be awarded 9 [' established', ' a', ' fund', ' for', ' a', ' Georg', 'es', ' B', 'iz', 'et']
+2038 446 Name of mother of x -1 Name of mother of Georges Bizet Aimée Léopoldine Joséphine Delsarte Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' ' great' ' composer' ',' ' and']" , the French composer , was born in Paris in 18 38 . He was a great composer , and False " Georges Bizet =
+" 4 [' Georg', 'es', ' B', 'iz', 'et']
+2039 446 Name of mother of x -1 Name of mother of Georges Bizet Aimée Léopoldine Joséphine Delsarte Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' ' great' ' composer' ',' ' and']" , the French composer , was born in Paris in 18 38 . He was a great composer , and False 4 ['Georg', 'es', ' B', 'iz', 'et']
+2040 446 Name of mother of x -1 Name of mother of Georges Bizet Aimée Léopoldine Joséphine Delsarte Georges Bizet "[',' ' the' ' French' ' composer' ',' ' was' ' born' ' in' ' Paris' ' in'
+ ' 18' '38' '.' ' He' ' was' ' a' ' great' ' composer' ',' ' and']" , the French composer , was born in Paris in 18 38 . He was a great composer , and False established a fund for a Georges Bizet prize, to be awarded 9 [' established', ' a', ' fund', ' for', ' a', ' Georg', 'es', ' B', 'iz', 'et']
+2041 449 Name of mother of x -1 Name of mother of John Locke Agnes Keene John Locke "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False and neighbour of John Locke a philosopher 4 [' and', ' neighbour', ' of', ' John', ' Locke']
+2042 449 Name of mother of x -1 Name of mother of John Locke Agnes Keene John Locke "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False only stopped when John Locke (Terry O 'Quinn) 4 [' only', ' stopped', ' when', ' John', ' Locke']
+2043 449 Name of mother of x -1 Name of mother of John Locke Agnes Keene John Locke "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False Newton sent to John Locke in which he disputed 4 [' Newton', ' sent', ' to', ' John', ' Locke']
+2044 449 Name of mother of x -1 Name of mother of John Locke Agnes Keene John Locke "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False the system, with John Locke writing a 5 [' the', ' system', ',', ' with', ' John', ' Locke']
+2045 449 Name of mother of x -1 Name of mother of John Locke Agnes Keene John Locke "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False encounters John Locke (Terry O 'Quinn) at 2 [' encounters', ' John', ' Locke']
+2046 450 Name of mother of x -1 Name of mother of William Makepeace Thackeray Anne Becher William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works' ' of']" ", the author of Vanity Fair , and the
+
+ The following is a list of the works of" False (1847 / 8) by William Makepeace Thackeray opens at Miss Pinkerton's 12 [' (', '18', '47', ' /', ' 8', ')', ' by', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2047 450 Name of mother of x -1 Name of mother of William Makepeace Thackeray Anne Becher William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works' ' of']" ", the author of Vanity Fair , and the
+
+ The following is a list of the works of" False Fair (1847 / 8) by William Makepeace Thackeray opens at Miss 13 [' Fair', ' (', '18', '47', ' /', ' 8', ')', ' by', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2048 450 Name of mother of x -1 Name of mother of William Makepeace Thackeray Anne Becher William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works' ' of']" ", the author of Vanity Fair , and the
+
+ The following is a list of the works of" False the novelist William Makepeace Thackeray visited Dublin 7 [' the', ' novelist', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2049 450 Name of mother of x -1 Name of mother of William Makepeace Thackeray Anne Becher William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works' ' of']" ", the author of Vanity Fair , and the
+
+ The following is a list of the works of" False June of that year William Makepeace Thackeray (under the pen name 9 [' June', ' of', ' that', ' year', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2050 450 Name of mother of x -1 Name of mother of William Makepeace Thackeray Anne Becher William Makepeace Thackeray "[',' ' the' ' author' ' of' ' Vanity' ' Fair' ',' ' and' ' the' '\n' '\n'
+ 'The' ' following' ' is' ' a' ' list' ' of' ' the' ' works' ' of']" ", the author of Vanity Fair , and the
+
+ The following is a list of the works of" False " 1842 the writer William Makepeace Thackeray noted Nelson ""upon" 9 [' 18', '42', ' the', ' writer', ' William', ' Make', 'peace', ' Th', 'acker', 'ay']
+2051 451 Name of mother of x -1 Name of mother of Alfred Tennyson Elizabeth Fytch Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False one to attend. Alfred Tennyson contributed 6 [' one', ' to', ' attend', '.', ' Alfred', ' Tenn', 'yson']
+2052 451 Name of mother of x -1 Name of mother of Alfred Tennyson Elizabeth Fytch Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False " Gladstone as well as Alfred Tennyson and Francis Parkman.
+" 7 [' Glad', 'stone', ' as', ' well', ' as', ' Alfred', ' Tenn', 'yson']
+2053 451 Name of mother of x -1 Name of mother of Alfred Tennyson Elizabeth Fytch Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False " Gladstone as well as Alfred Tennyson and Francis Parkman.
+" 7 [' Glad', 'stone', ' as', ' well', ' as', ' Alfred', ' Tenn', 'yson']
+2054 451 Name of mother of x -1 Name of mother of Alfred Tennyson Elizabeth Fytch Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False to attend. Alfred Tennyson contributed 5 [' to', ' attend', '.', ' Alfred', ' Tenn', 'yson']
+2055 451 Name of mother of x -1 Name of mother of Alfred Tennyson Elizabeth Fytch Alfred Tennyson "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '09' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' poet']" ", the poet , was born in 18 09 , and died in 18 92 .
+
+ The poet" False one to attend. Alfred Tennyson contributed 6 [' one', ' to', ' attend', '.', ' Alfred', ' Tenn', 'yson']
+2056 452 Name of mother of x -1 Name of mother of Benjamin Netanyahu Zila Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False Israeli Prime Minister Benjamin Netanyahu has rejected a proposal 4 [' Israeli', ' Prime', ' Minister', ' Benjamin', ' Netanyahu']
+2057 452 Name of mother of x -1 Name of mother of Benjamin Netanyahu Zila Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False Israel's Prime Minister Benjamin Netanyahu called the 5 "[' Israel', ""'s"", ' Prime', ' Minister', ' Benjamin', ' Netanyahu']"
+2058 452 Name of mother of x -1 Name of mother of Benjamin Netanyahu Zila Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False Prime Minister Benjamin Netanyahu called the 3 [' Prime', ' Minister', ' Benjamin', ' Netanyahu']
+2059 452 Name of mother of x -1 Name of mother of Benjamin Netanyahu Zila Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False Prime Minister Benjamin Netanyahu called Russia an 3 [' Prime', ' Minister', ' Benjamin', ' Netanyahu']
+2060 452 Name of mother of x -1 Name of mother of Benjamin Netanyahu Zila Netanyahu Benjamin Netanyahu "[',' ' the' ' Israeli' ' Prime' ' Minister' ',' ' and' ' the' ' Israeli'
+ ' Prime' ' Minister' ',' ' Benjamin' ' Netanyahu' ',' ' is' ' a' ' liar'
+ '.' '\n']" ", the Israeli Prime Minister , and the Israeli Prime Minister , Benjamin Netanyahu , is a liar .
+" False as a signal Benjamin Netanyahu was sending 4 [' as', ' a', ' signal', ' Benjamin', ' Netanyahu']
+2061 453 Name of mother of x -1 Name of mother of Deborah Kerr Kathleen Rose Smale Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 2 ['Deb', 'orah', ' Kerr']
+2062 453 Name of mother of x -1 Name of mother of Deborah Kerr Kathleen Rose Smale Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Peck as Scott and Deborah Kerr as Graham. The book 5 [' Peck', ' as', ' Scott', ' and', ' Deborah', ' Kerr']
+2063 453 Name of mother of x -1 Name of mother of Deborah Kerr Kathleen Rose Smale Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Robarts, co-starring with Deborah Kerr and Diana 9 [' Rob', 'arts', ',', ' co', '-', 'star', 'ring', ' with', ' Deborah', ' Kerr']
+2064 453 Name of mother of x -1 Name of mother of Deborah Kerr Kathleen Rose Smale Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 1960. Cary Grant and Deborah Kerr play the earl 6 [' 1960', '.', ' Cary', ' Grant', ' and', ' Deborah', ' Kerr']
+2065 453 Name of mother of x -1 Name of mother of Deborah Kerr Kathleen Rose Smale Deborah Kerr "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Mitchum, Jean Simmons and Deborah Kerr in The Grass 7 [' Mitch', 'um', ',', ' Jean', ' Simmons', ' and', ' Deborah', ' Kerr']
+2066 454 Name of mother of x -1 Name of mother of Romy Schneider Magda Schneider Romy Schneider "[',' ' who' ' was' ' a' ' German' ' actress' ' who' ' was' ' born' ' in'
+ ' Berlin' ',' ' Germany' '.' ' She' ' was' ' the' ' daughter' ' of' ' a']" , who was a German actress who was born in Berlin , Germany . She was the daughter of a False awarded France's Romy Schneider and Jean Gabin Prizes 5 "[' awarded', ' France', ""'s"", ' Rom', 'y', ' Schneider']"
+2067 454 Name of mother of x -1 Name of mother of Romy Schneider Magda Schneider Romy Schneider "[',' ' who' ' was' ' a' ' German' ' actress' ' who' ' was' ' born' ' in'
+ ' Berlin' ',' ' Germany' '.' ' She' ' was' ' the' ' daughter' ' of' ' a']" , who was a German actress who was born in Berlin , Germany . She was the daughter of a False were awarded France's Romy Schneider and Jean Gabin 6 "[' were', ' awarded', ' France', ""'s"", ' Rom', 'y', ' Schneider']"
+2068 455 Name of mother of x -1 Name of mother of Katharine Hepburn Katharine Martha Houghton Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Kath' 'arine' ' Hep' 'burn' ',' ' and' ' the' ' mother' ' of']" , the actress , and the mother of the actress , Kath arine Hep burn , and the mother of False " Avenue was renamed ""Katharine Hepburn Place"". Three" 8 "[' Avenue', ' was', ' renamed', ' ""', 'K', 'ath', 'arine', ' Hep', 'burn']"
+2069 455 Name of mother of x -1 Name of mother of Katharine Hepburn Katharine Martha Houghton Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Kath' 'arine' ' Hep' 'burn' ',' ' and' ' the' ' mother' ' of']" , the actress , and the mother of the actress , Kath arine Hep burn , and the mother of False with actress Katharine Hepburn that received 5 [' with', ' actress', ' Kath', 'arine', ' Hep', 'burn']
+2070 455 Name of mother of x -1 Name of mother of Katharine Hepburn Katharine Martha Houghton Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Kath' 'arine' ' Hep' 'burn' ',' ' and' ' the' ' mother' ' of']" , the actress , and the mother of the actress , Kath arine Hep burn , and the mother of False Tracy and Katharine Hepburn two-some who married 5 [' Tracy', ' and', ' Kath', 'arine', ' Hep', 'burn']
+2071 455 Name of mother of x -1 Name of mother of Katharine Hepburn Katharine Martha Houghton Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Kath' 'arine' ' Hep' 'burn' ',' ' and' ' the' ' mother' ' of']" , the actress , and the mother of the actress , Kath arine Hep burn , and the mother of False with actress Katharine Hepburn that received 5 [' with', ' actress', ' Kath', 'arine', ' Hep', 'burn']
+2072 455 Name of mother of x -1 Name of mother of Katharine Hepburn Katharine Martha Houghton Hepburn Katharine Hepburn "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' Kath' 'arine' ' Hep' 'burn' ',' ' and' ' the' ' mother' ' of']" , the actress , and the mother of the actress , Kath arine Hep burn , and the mother of False Divorcement, opposite Katharine Hepburn in her screen debut. 8 [' Div', 'orce', 'ment', ',', ' opposite', ' Kath', 'arine', ' Hep', 'burn']
+2073 456 Name of mother of x -1 Name of mother of Robert Graves Amalie von Ranke Robert Graves "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False Greek Myths (1955), Robert Graves views Oenopion as 8 [' Greek', ' My', 'ths', ' (', '19', '55', '),', ' Robert', ' Graves']
+2074 456 Name of mother of x -1 Name of mother of Robert Graves Amalie von Ranke Robert Graves "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False the direction of Robert Graves and worked as an 4 [' the', ' direction', ' of', ' Robert', ' Graves']
+2075 456 Name of mother of x -1 Name of mother of Robert Graves Amalie von Ranke Robert Graves "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False classicist Robert Graves and Italian ethnobotanist 3 [' classic', 'ist', ' Robert', ' Graves']
+2076 456 Name of mother of x -1 Name of mother of Robert Graves Amalie von Ranke Robert Graves "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False published by his friend Robert Graves and his older brother 5 [' published', ' by', ' his', ' friend', ' Robert', ' Graves']
+2077 456 Name of mother of x -1 Name of mother of Robert Graves Amalie von Ranke Robert Graves "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " than sound. Robert Graves recalled ""I" 4 [' than', ' sound', '.', ' Robert', ' Graves']
+2078 457 Name of mother of x -1 Name of mother of David Hume Katherine Falconer David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in' ' 17' '11'
+ '.']" , the father of modern philosophy , and the father of modern science , was born in 17 11 . False Scottish philosopher David Hume (1711 – 76). This 3 [' Scottish', ' philosopher', ' David', ' Hume']
+2079 457 Name of mother of x -1 Name of mother of David Hume Katherine Falconer David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in' ' 17' '11'
+ '.']" , the father of modern philosophy , and the father of modern science , was born in 17 11 . False in Scotland was David Hume (1711 – 76) 4 [' in', ' Scotland', ' was', ' David', ' Hume']
+2080 457 Name of mother of x -1 Name of mother of David Hume Katherine Falconer David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in' ' 17' '11'
+ '.']" , the father of modern philosophy , and the father of modern science , was born in 17 11 . False philosopher David Hume contended that 2 [' philosopher', ' David', ' Hume']
+2081 457 Name of mother of x -1 Name of mother of David Hume Katherine Falconer David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in' ' 17' '11'
+ '.']" , the father of modern philosophy , and the father of modern science , was born in 17 11 . False (1570? – 1611) and David Hume of Godscroft 10 [' (', '15', '70', '?', ' –', ' 16', '11', ')', ' and', ' David', ' Hume']
+2082 457 Name of mother of x -1 Name of mother of David Hume Katherine Falconer David Hume "[',' ' the' ' father' ' of' ' modern' ' philosophy' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' science' ',' ' was' ' born' ' in' ' 17' '11'
+ '.']" , the father of modern philosophy , and the father of modern science , was born in 17 11 . False The philosopher David Hume developed a skeptical 3 [' The', ' philosopher', ' David', ' Hume']
+2083 458 Name of mother of x -1 Name of mother of Benjamin West Sarah Pearson Benjamin West "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Benjamin' ' West' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Benjamin West , the painter , and the" False third portrait is by Benjamin West and was painted 5 [' third', ' portrait', ' is', ' by', ' Benjamin', ' West']
+2084 458 Name of mother of x -1 Name of mother of Benjamin West Sarah Pearson Benjamin West "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Benjamin' ' West' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Benjamin West , the painter , and the" False American 3 [' America', 'Ben', 'jamin', ' West']
+2085 458 Name of mother of x -1 Name of mother of Benjamin West Sarah Pearson Benjamin West "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Benjamin' ' West' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Benjamin West , the painter , and the" False appeared in paintings by Benjamin West and Arthur William 5 [' appeared', ' in', ' paintings', ' by', ' Benjamin', ' West']
+2086 458 Name of mother of x -1 Name of mother of Benjamin West Sarah Pearson Benjamin West "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Benjamin' ' West' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Benjamin West , the painter , and the" False in paintings by Benjamin West and Arthur William 4 [' in', ' paintings', ' by', ' Benjamin', ' West']
+2087 458 Name of mother of x -1 Name of mother of Benjamin West Sarah Pearson Benjamin West "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Benjamin' ' West' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Benjamin West , the painter , and the" False " contributed, such as Benjamin West and Henry Fuseli.
+" 5 [' contributed', ',', ' such', ' as', ' Benjamin', ' West']
+2088 459 Name of mother of x -1 Name of mother of Philipp Melanchthon Barbara Reuter Philipp Melanchthon "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Philipp' ' Mel'
+ 'anch' 'th' 'on' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Philipp Mel anch th on , the
+
+ Name of mother" False with Luther and Philipp Melanchthon arriving shortly thereafter. 7 [' with', ' Luther', ' and', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2089 459 Name of mother of x -1 Name of mother of Philipp Melanchthon Barbara Reuter Philipp Melanchthon "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Philipp' ' Mel'
+ 'anch' 'th' 'on' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Philipp Mel anch th on , the
+
+ Name of mother" False Luther and Philipp Melanchthon arriving shortly 6 [' Luther', ' and', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2090 459 Name of mother of x -1 Name of mother of Philipp Melanchthon Barbara Reuter Philipp Melanchthon "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Philipp' ' Mel'
+ 'anch' 'th' 'on' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Philipp Mel anch th on , the
+
+ Name of mother" False and signed by Philipp Melanchthon and others, much 7 [' and', ' signed', ' by', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2091 459 Name of mother of x -1 Name of mother of Philipp Melanchthon Barbara Reuter Philipp Melanchthon "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Philipp' ' Mel'
+ 'anch' 'th' 'on' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Philipp Mel anch th on , the
+
+ Name of mother" False with Luther and Philipp Melanchthon arriving shortly thereafter. 7 [' with', ' Luther', ' and', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2092 459 Name of mother of x -1 Name of mother of Philipp Melanchthon Barbara Reuter Philipp Melanchthon "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Philipp' ' Mel'
+ 'anch' 'th' 'on' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Philipp Mel anch th on , the
+
+ Name of mother" False and signed by Philipp Melanchthon and others, 7 [' and', ' signed', ' by', ' Philipp', ' Mel', 'anch', 'th', 'on']
+2093 460 Name of mother of x -1 Name of mother of Martin Scorsese Catherine Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False April 2002, Martin Scorsese was involved with 6 [' April', ' 2002', ',', ' Martin', ' Sc', 'ors', 'ese']
+2094 460 Name of mother of x -1 Name of mother of Martin Scorsese Catherine Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False " ""Bad"" — directed by Martin Scorsese — Jackson began" 9 "[' ""', 'Bad', '""', ' —', ' directed', ' by', ' Martin', ' Sc', 'ors', 'ese']"
+2095 460 Name of mother of x -1 Name of mother of Martin Scorsese Catherine Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False this version are Martin Scorsese Presents the 6 [' this', ' version', ' are', ' Martin', ' Sc', 'ors', 'ese']
+2096 460 Name of mother of x -1 Name of mother of Martin Scorsese Catherine Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False idol Fred Astaire. Martin Scorsese presented the 8 [' idol', ' Fred', ' Ast', 'aire', '.', ' Martin', ' Sc', 'ors', 'ese']
+2097 460 Name of mother of x -1 Name of mother of Martin Scorsese Catherine Scorsese Martin Scorsese "[""'s"" ' ""' 'The' ' Wolf' ' of' ' Wall' ' Street' '""' ' and' ' ""' 'The'
+ ' Wolf' ' of' ' Wall' ' Street' '""' '\n' '\n' 'The' ' Wolf']" "'s "" The Wolf of Wall Street "" and "" The Wolf of Wall Street ""
+
+ The Wolf" False film of 1999. Martin Scorsese ranked it as his second 7 [' film', ' of', ' 1999', '.', ' Martin', ' Sc', 'ors', 'ese']
+2098 461 Name of mother of x -1 Name of mother of Giacomo Leopardi Adelaide Antici Leopardi Giacomo Leopardi "[',' ' the' ' father' ' of' ' modern' ' Italian' ' literature' ',' ' was'
+ ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' is' ' now']" ", the father of modern Italian literature , was born in this house .
+
+ The house is now" False when the poet Giacomo Leopardi had visited and met 7 [' when', ' the', ' poet', ' Gi', 'ac', 'omo', ' Leopard', 'i']
+2099 461 Name of mother of x -1 Name of mother of Giacomo Leopardi Adelaide Antici Leopardi Giacomo Leopardi "[',' ' the' ' father' ' of' ' modern' ' Italian' ' literature' ',' ' was'
+ ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' is' ' now']" ", the father of modern Italian literature , was born in this house .
+
+ The house is now" False Rabindranath Tagore, Giacomo Leopardi and pursued academic 11 [' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore', ',', ' Gi', 'ac', 'omo', ' Leopard', 'i']
+2100 461 Name of mother of x -1 Name of mother of Giacomo Leopardi Adelaide Antici Leopardi Giacomo Leopardi "[',' ' the' ' father' ' of' ' modern' ' Italian' ' literature' ',' ' was'
+ ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' is' ' now']" ", the father of modern Italian literature , was born in this house .
+
+ The house is now" False Dante Aligheri, and Giacomo Leopardi (the latter poet had 11 [' Dante', ' Al', 'ig', 'her', 'i', ',', ' and', ' Gi', 'ac', 'omo', ' Leopard', 'i']
+2101 461 Name of mother of x -1 Name of mother of Giacomo Leopardi Adelaide Antici Leopardi Giacomo Leopardi "[',' ' the' ' father' ' of' ' modern' ' Italian' ' literature' ',' ' was'
+ ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' is' ' now']" ", the father of modern Italian literature , was born in this house .
+
+ The house is now" False Rabindranath Tagore, Giacomo Leopardi and pursued academic 11 [' Rab', 'ind', 'ran', 'ath', ' Tag', 'ore', ',', ' Gi', 'ac', 'omo', ' Leopard', 'i']
+2102 461 Name of mother of x -1 Name of mother of Giacomo Leopardi Adelaide Antici Leopardi Giacomo Leopardi "[',' ' the' ' father' ' of' ' modern' ' Italian' ' literature' ',' ' was'
+ ' born' ' in' ' this' ' house' '.' '\n' '\n' 'The' ' house' ' is' ' now']" ", the father of modern Italian literature , was born in this house .
+
+ The house is now" False when the poet Giacomo Leopardi had visited and 7 [' when', ' the', ' poet', ' Gi', 'ac', 'omo', ' Leopard', 'i']
+2103 462 Name of mother of x -1 Name of mother of Al Gore Pauline LaFon Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' the' ' former' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former vice president of the United States , and the former president of the United States , and False but is distracted by Al Gore measuring a window. 5 [' but', ' is', ' distracted', ' by', ' Al', ' Gore']
+2104 462 Name of mother of x -1 Name of mother of Al Gore Pauline LaFon Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' the' ' former' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former vice president of the United States , and the former president of the United States , and False defeated Democrat Al Gore in the 2000 3 [' defeated', ' Democrat', ' Al', ' Gore']
+2105 462 Name of mother of x -1 Name of mother of Al Gore Pauline LaFon Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' the' ' former' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former vice president of the United States , and the former president of the United States , and False " the end, when Al Gore says ""Quiet!" 5 [' the', ' end', ',', ' when', ' Al', ' Gore']
+2106 462 Name of mother of x -1 Name of mother of Al Gore Pauline LaFon Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' the' ' former' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former vice president of the United States , and the former president of the United States , and False Mondale, Dan Quayle, Al Gore and Dick Cheney, along 9 [' Mond', 'ale', ',', ' Dan', ' Qu', 'ay', 'le', ',', ' Al', ' Gore']
+2107 462 Name of mother of x -1 Name of mother of Al Gore Pauline LaFon Gore Al Gore "[',' ' the' ' former' ' vice' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and' ' the' ' former' ' president' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former vice president of the United States , and the former president of the United States , and False married to Quagmire; Al Gore is now President; Chris, 7 [' married', ' to', ' Qu', 'ag', 'mire', ';', ' Al', ' Gore']
+2108 463 Name of mother of x -1 Name of mother of John Cale Margaret Davies John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom'
+ ',' ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the']" o , the father of the bride , and the groom , and the groom 's mother , and the False release, Reed and John Cale included a song 6 [' release', ',', ' Reed', ' and', ' John', ' C', 'ale']
+2109 463 Name of mother of x -1 Name of mother of John Cale Margaret Davies John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom'
+ ',' ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the']" o , the father of the bride , and the groom , and the groom 's mother , and the False Underground member John Cale make cameo appearances 4 [' Underground', ' member', ' John', ' C', 'ale']
+2110 463 Name of mother of x -1 Name of mother of John Cale Margaret Davies John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom'
+ ',' ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the']" o , the father of the bride , and the groom , and the groom 's mother , and the False Wade, Baha Men, and John Cale (covering Leonard 9 [' Wade', ',', ' B', 'aha', ' Men', ',', ' and', ' John', ' C', 'ale']
+2111 463 Name of mother of x -1 Name of mother of John Cale Margaret Davies John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom'
+ ',' ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the']" o , the father of the bride , and the groom , and the groom 's mother , and the False " Morrison's studio, musician John Cale reported, ""Morrison" 7 "[' Morrison', ""'s"", ' studio', ',', ' musician', ' John', ' C', 'ale']"
+2112 463 Name of mother of x -1 Name of mother of John Cale Margaret Davies John Cale "['o' ',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' groom'
+ ',' ' and' ' the' ' groom' ""'s"" ' mother' ',' ' and' ' the']" o , the father of the bride , and the groom , and the groom 's mother , and the False Underground member John Cale make cameo appearances 4 [' Underground', ' member', ' John', ' C', 'ale']
+2113 464 Name of mother of x -1 Name of mother of Elon Musk Maye Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False thrust fluctuations. Elon Musk reported that this 4 [' thrust', ' fluctuations', '.', ' Elon', ' Musk']
+2114 464 Name of mother of x -1 Name of mother of Elon Musk Maye Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False thrust fluctuations. Elon Musk reported that this 4 [' thrust', ' fluctuations', '.', ' Elon', ' Musk']
+2115 464 Name of mother of x -1 Name of mother of Elon Musk Maye Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False the launch site. Elon Musk first publicly referred 5 [' the', ' launch', ' site', '.', ' Elon', ' Musk']
+2116 464 Name of mother of x -1 Name of mother of Elon Musk Maye Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False thrust fluctuations. Elon Musk reported that 4 [' thrust', ' fluctuations', '.', ' Elon', ' Musk']
+2117 464 Name of mother of x -1 Name of mother of Elon Musk Maye Musk Elon Musk "[',' ' the' ' founder' ' of' ' SpaceX' ',' ' Elon' ' Musk' ',' ' is' ' a'
+ ' man' ' who' ' has' ' been' ' a' ' pioneer' ' in' ' the' ' space']" , the founder of SpaceX , Elon Musk , is a man who has been a pioneer in the space False future. In January 2016 Elon Musk estimated 6 [' future', '.', ' In', ' January', ' 2016', ' Elon', ' Musk']
+2118 465 Name of mother of x -1 Name of mother of Fridtjof Nansen Baronesse Adelaide Johanne Thekla Isidore Wedel Jarlsberg Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False discovery inspired Fridtjof Nansen to mount his 8 [' discovery', ' inspired', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2119 465 Name of mother of x -1 Name of mother of Fridtjof Nansen Baronesse Adelaide Johanne Thekla Isidore Wedel Jarlsberg Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False discovery inspired Fridtjof Nansen to mount his 8 [' discovery', ' inspired', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2120 465 Name of mother of x -1 Name of mother of Fridtjof Nansen Baronesse Adelaide Johanne Thekla Isidore Wedel Jarlsberg Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False 6 ['F', 'rid', 't', 'j', 'of', ' N', 'ansen']
+2121 465 Name of mother of x -1 Name of mother of Fridtjof Nansen Baronesse Adelaide Johanne Thekla Isidore Wedel Jarlsberg Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False discovery inspired Fridtjof Nansen to mount his 8 [' discovery', ' inspired', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2122 465 Name of mother of x -1 Name of mother of Fridtjof Nansen Baronesse Adelaide Johanne Thekla Isidore Wedel Jarlsberg Fridtjof Nansen "[',' ' the' ' Norwegian' ' explorer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' cross' ' the' ' Arctic' ' Ocean' ' in' ' a' ' balloon' '.' '\n'
+ '\n']" ", the Norwegian explorer , who was the first to cross the Arctic Ocean in a balloon .
+
+" False Norwegian explorer Fridtjof Nansen and his team crossed 8 [' Norwegian', ' explorer', ' Fr', 'id', 't', 'j', 'of', ' N', 'ansen']
+2123 466 Name of mother of x -1 Name of mother of Brad Pitt Jane Etta Hillhouse Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False issues. In 2002, Brad Pitt was attached to 6 [' issues', '.', ' In', ' 2002', ',', ' Brad', ' Pitt']
+2124 466 Name of mother of x -1 Name of mother of Brad Pitt Jane Etta Hillhouse Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False Inception. Both Brad Pitt and Will Smith 5 [' In', 'ception', '.', ' Both', ' Brad', ' Pitt']
+2125 466 Name of mother of x -1 Name of mother of Brad Pitt Jane Etta Hillhouse Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False 1 ['Brad', ' Pitt']
+2126 466 Name of mother of x -1 Name of mother of Brad Pitt Jane Etta Hillhouse Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False Hollywood actors like Brad Pitt and Tom Cruise, 4 [' Hollywood', ' actors', ' like', ' Brad', ' Pitt']
+2127 466 Name of mother of x -1 Name of mother of Brad Pitt Jane Etta Hillhouse Brad Pitt "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Angel' 'ina' ' Jol'
+ 'ie' ',' ' are' ' expecting' ' their' ' first' ' child' ' together' '.']" , the actor , and his wife , Angel ina Jol ie , are expecting their first child together . False Inception. Both Brad Pitt and Will Smith were 5 [' In', 'ception', '.', ' Both', ' Brad', ' Pitt']
+2128 467 Name of mother of x -1 Name of mother of Johnny Depp Betty Sue Wells Johnny Depp "[',' ' the' ' actor' ' who' ' plays' ' the' ' title' ' character' ' in'
+ ' the' ' film' ',' ' is' ' a' ' former' ' child' ' actor' ' who' ' has'
+ ' been']" , the actor who plays the title character in the film , is a former child actor who has been False scripts, and talked to Johnny Depp about the possibility 7 [' scripts', ',', ' and', ' talked', ' to', ' Johnny', ' De', 'pp']
+2129 467 Name of mother of x -1 Name of mother of Johnny Depp Betty Sue Wells Johnny Depp "[',' ' the' ' actor' ' who' ' plays' ' the' ' title' ' character' ' in'
+ ' the' ' film' ',' ' is' ' a' ' former' ' child' ' actor' ' who' ' has'
+ ' been']" , the actor who plays the title character in the film , is a former child actor who has been False director cast Johnny Depp to replace Stiller 4 [' director', ' cast', ' Johnny', ' De', 'pp']
+2130 467 Name of mother of x -1 Name of mother of Johnny Depp Betty Sue Wells Johnny Depp "[',' ' the' ' actor' ' who' ' plays' ' the' ' title' ' character' ' in'
+ ' the' ' film' ',' ' is' ' a' ' former' ' child' ' actor' ' who' ' has'
+ ' been']" , the actor who plays the title character in the film , is a former child actor who has been False portrayed once again by Johnny Depp coming May 26, 6 [' portrayed', ' once', ' again', ' by', ' Johnny', ' De', 'pp']
+2131 467 Name of mother of x -1 Name of mother of Johnny Depp Betty Sue Wells Johnny Depp "[',' ' the' ' actor' ' who' ' plays' ' the' ' title' ' character' ' in'
+ ' the' ' film' ',' ' is' ' a' ' former' ' child' ' actor' ' who' ' has'
+ ' been']" , the actor who plays the title character in the film , is a former child actor who has been False Atlantis SquarePantis, Johnny Depp as the voice 7 [' Atlantis', ' Square', 'P', 'antis', ',', ' Johnny', ' De', 'pp']
+2132 467 Name of mother of x -1 Name of mother of Johnny Depp Betty Sue Wells Johnny Depp "[',' ' the' ' actor' ' who' ' plays' ' the' ' title' ' character' ' in'
+ ' the' ' film' ',' ' is' ' a' ' former' ' child' ' actor' ' who' ' has'
+ ' been']" , the actor who plays the title character in the film , is a former child actor who has been False " Edward Scissorhands
+" 7 [' Edward', ' Sc', 'iss', 'or', 'hands', 'Johnny', ' De', 'pp']
+2133 468 Name of mother of x -1 Name of mother of Socrates Phaenarete Socrates "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False 432 BC, where Socrates was said to 4 [' 432', ' BC', ',', ' where', ' Socrates']
+2134 468 Name of mother of x -1 Name of mother of Socrates Phaenarete Socrates "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False Pericles, and quotes Socrates as claiming 5 [' Per', 'icles', ',', ' and', ' quotes', ' Socrates']
+2135 468 Name of mother of x -1 Name of mother of Socrates Phaenarete Socrates "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False life, friends of Socrates brought their wives 4 [' life', ',', ' friends', ' of', ' Socrates']
+2136 468 Name of mother of x -1 Name of mother of Socrates Phaenarete Socrates "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False an allusion to Socrates and his manner 4 [' an', ' all', 'usion', ' to', ' Socrates']
+2137 468 Name of mother of x -1 Name of mother of Socrates Phaenarete Socrates "[',' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the' ' son' ' of' ' Zeus'
+ ',' ' and' ' the' ' son' ' of' ' Zeus' ',' ' and' ' the']" , the son of Zeus , and the son of Zeus , and the son of Zeus , and the False timid lawyer, Socrates Poole (Christian 3 [' timid', ' lawyer', ',', ' Socrates']
+2138 469 Name of mother of x -1 Name of mother of Richard Nixon Hannah Milhous Nixon Richard Nixon "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False In 1974, President Richard Nixon stopped at 5 [' In', ' 1974', ',', ' President', ' Richard', ' Nixon']
+2139 469 Name of mother of x -1 Name of mother of Richard Nixon Hannah Milhous Nixon Richard Nixon "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False earlier, President Richard Nixon nominated Representative 4 [' earlier', ',', ' President', ' Richard', ' Nixon']
+2140 469 Name of mother of x -1 Name of mother of Richard Nixon Hannah Milhous Nixon Richard Nixon "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False presidential candidate Richard Nixon in 1960. Although 3 [' presidential', ' candidate', ' Richard', ' Nixon']
+2141 469 Name of mother of x -1 Name of mother of Richard Nixon Hannah Milhous Nixon Richard Nixon "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False " former U.S. President Richard Nixon because: ""there's" 7 [' former', ' U', '.', 'S', '.', ' President', ' Richard', ' Nixon']
+2142 469 Name of mother of x -1 Name of mother of Richard Nixon Hannah Milhous Nixon Richard Nixon "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False leaving Vice President Richard Nixon to conduct the 4 [' leaving', ' Vice', ' President', ' Richard', ' Nixon']
+2143 470 Name of mother of x -1 Name of mother of Seamus Heaney Margaret Kathleen McCann Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Marie' ',' ' who'
+ ' was' ' a' ' poet' 'ess' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , Marie , who was a poet ess .
+
+ The house" False Nobel laureate Seamus Heaney makes a less 5 [' Nobel', ' laureate', ' Se', 'amus', ' He', 'aney']
+2144 470 Name of mother of x -1 Name of mother of Seamus Heaney Margaret Kathleen McCann Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Marie' ',' ' who'
+ ' was' ' a' ' poet' 'ess' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , Marie , who was a poet ess .
+
+ The house" False by the Irish poet Seamus Heaney in the introduction 7 [' by', ' the', ' Irish', ' poet', ' Se', 'amus', ' He', 'aney']
+2145 470 Name of mother of x -1 Name of mother of Seamus Heaney Margaret Kathleen McCann Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Marie' ',' ' who'
+ ' was' ' a' ' poet' 'ess' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , Marie , who was a poet ess .
+
+ The house" False Christopher Ricks and Seamus Heaney looked at the poems, 7 [' Christopher', ' R', 'icks', ' and', ' Se', 'amus', ' He', 'aney']
+2146 470 Name of mother of x -1 Name of mother of Seamus Heaney Margaret Kathleen McCann Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Marie' ',' ' who'
+ ' was' ' a' ' poet' 'ess' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , Marie , who was a poet ess .
+
+ The house" False New York. When Seamus Heaney gave an Oxford lecture 7 [' New', ' York', '.', ' When', ' Se', 'amus', ' He', 'aney']
+2147 470 Name of mother of x -1 Name of mother of Seamus Heaney Margaret Kathleen McCann Seamus Heaney "[',' ' the' ' poet' ',' ' and' ' his' ' wife' ',' ' Marie' ',' ' who'
+ ' was' ' a' ' poet' 'ess' '.' '\n' '\n' 'The' ' house']" ", the poet , and his wife , Marie , who was a poet ess .
+
+ The house" False Nobel laureate Seamus Heaney makes a less 5 [' Nobel', ' laureate', ' Se', 'amus', ' He', 'aney']
+2148 471 Name of mother of x -1 Name of mother of M. C. Escher Sara Adriana Gleichman M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of mother of M . C . Esc her , the" False mathematics, such as M. C. Escher (inspired by 9 [' mathematics', ',', ' such', ' as', ' M', '.', ' C', '.', ' Esc', 'her']
+2149 471 Name of mother of x -1 Name of mother of M. C. Escher Sara Adriana Gleichman M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of mother of M . C . Esc her , the" False mathematics, such as M. C. Escher (inspired by H. 9 [' mathematics', ',', ' such', ' as', ' M', '.', ' C', '.', ' Esc', 'her']
+2150 471 Name of mother of x -1 Name of mother of M. C. Escher Sara Adriana Gleichman M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of mother of M . C . Esc her , the" False the work of M. C. Escher often made use 8 [' the', ' work', ' of', ' M', '.', ' C', '.', ' Esc', 'her']
+2151 471 Name of mother of x -1 Name of mother of M. C. Escher Sara Adriana Gleichman M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of mother of M . C . Esc her , the" False graphic artist M. C. Escher made intensive 7 [' graphic', ' artist', ' M', '.', ' C', '.', ' Esc', 'her']
+2152 471 Name of mother of x -1 Name of mother of M. C. Escher Sara Adriana Gleichman M. C. Escher "[',' ' the' ' artist' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' M' '.' ' C' '.' ' Esc' 'her' ',' ' the']" ", the artist , and the
+
+ Name of mother of M . C . Esc her , the" False compared by critics to M. C. Escher drawings and Echochrome. 9 [' compared', ' by', ' critics', ' to', ' M', '.', ' C', '.', ' Esc', 'her']
+2153 473 Name of mother of x -1 Name of mother of Max Weber Helene Weber Max Weber "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False seven children of Max Weber Sr., a wealthy 4 [' seven', ' children', ' of', ' Max', ' Weber']
+2154 473 Name of mother of x -1 Name of mother of Max Weber Helene Weber Max Weber "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False children of Max Weber Sr., a wealthy 3 [' children', ' of', ' Max', ' Weber']
+2155 473 Name of mother of x -1 Name of mother of Max Weber Helene Weber Max Weber "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False wrote in 1956 that Max Weber was the only economist 5 [' wrote', ' in', ' 1956', ' that', ' Max', ' Weber']
+2156 473 Name of mother of x -1 Name of mother of Max Weber Helene Weber Max Weber "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False 1 ['Max', ' Weber']
+2157 473 Name of mother of x -1 Name of mother of Max Weber Helene Weber Max Weber "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False 1 ['Max', ' Weber']
+2158 474 Name of mother of x -1 Name of mother of Justin Timberlake Lynn Bomar Justin Timberlake "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Timbaland was on tour with Justin Timberlake to promote Timberlake's 9 [' Tim', 'bal', 'and', ' was', ' on', ' tour', ' with', ' Justin', ' Timber', 'lake']
+2159 474 Name of mother of x -1 Name of mother of Justin Timberlake Lynn Bomar Justin Timberlake "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False song was written by Justin Timberlake and The Neptunes, 6 [' song', ' was', ' written', ' by', ' Justin', ' Timber', 'lake']
+2160 474 Name of mother of x -1 Name of mother of Justin Timberlake Lynn Bomar Justin Timberlake "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False chemistry between Justin Timberlake and Mila Kunis 4 [' chemistry', ' between', ' Justin', ' Timber', 'lake']
+2161 474 Name of mother of x -1 Name of mother of Justin Timberlake Lynn Bomar Justin Timberlake "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False 2 ['Justin', ' Timber', 'lake']
+2162 474 Name of mother of x -1 Name of mother of Justin Timberlake Lynn Bomar Justin Timberlake "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Justified earned Justin Timberlake three American 5 [' Just', 'ified', ' earned', ' Justin', ' Timber', 'lake']
+2163 475 Name of mother of x -1 Name of mother of Pius II Vittoria Forteguerri Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False been banned by Pope Pius II in a conflict over 6 [' been', ' banned', ' by', ' Pope', ' P', 'ius', ' II']
+2164 475 Name of mother of x -1 Name of mother of Pius II Vittoria Forteguerri Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False sent by Pope Pius II in 1463, at the 5 [' sent', ' by', ' Pope', ' P', 'ius', ' II']
+2165 475 Name of mother of x -1 Name of mother of Pius II Vittoria Forteguerri Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False planned by Pope Pius II with Skanderbeg 5 [' planned', ' by', ' Pope', ' P', 'ius', ' II']
+2166 475 Name of mother of x -1 Name of mother of Pius II Vittoria Forteguerri Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False sent by Pope Pius II in 1463, at the 5 [' sent', ' by', ' Pope', ' P', 'ius', ' II']
+2167 475 Name of mother of x -1 Name of mother of Pius II Vittoria Forteguerri Pius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False wrote to Pope Pius II to request that 5 [' wrote', ' to', ' Pope', ' P', 'ius', ' II']
+2168 479 Name of mother of x -1 Name of mother of Jean de La Fontaine Françoise Pidoux Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False classical antiquity. Jean de La Fontaine began his 7 [' classical', ' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2169 479 Name of mother of x -1 Name of mother of Jean de La Fontaine Françoise Pidoux Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False classical antiquity. Jean de La Fontaine began his collection 7 [' classical', ' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2170 479 Name of mother of x -1 Name of mother of Jean de La Fontaine Françoise Pidoux Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False classical antiquity. Jean de La Fontaine began his collection 7 [' classical', ' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2171 479 Name of mother of x -1 Name of mother of Jean de La Fontaine Françoise Pidoux Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False antiquity. Jean de La Fontaine began his collection 6 [' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2172 479 Name of mother of x -1 Name of mother of Jean de La Fontaine Françoise Pidoux Jean de La Fontaine "[',' ' the' ' author' ' of' ' the' ' _' 'F' 'ables' '_' ',' ' and' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the author of the _ F ables _ , and the
+ " False classical antiquity. Jean de La Fontaine began his collection 7 [' classical', ' antiquity', '.', ' Jean', ' de', ' La', ' Font', 'aine']
+2173 480 Name of mother of x -1 Name of mother of Apollo Leto Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False " ""Godspeed the crew of Apollo 17.""
+" 6 "[' ""', 'God', 'speed', ' the', ' crew', ' of', ' Apollo']"
+2174 480 Name of mother of x -1 Name of mother of Apollo Leto Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False depicting the Apollo 11 landing when 2 [' depicting', ' the', ' Apollo']
+2175 480 Name of mother of x -1 Name of mother of Apollo Leto Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False 1 ['Ap', 'ollo']
+2176 480 Name of mother of x -1 Name of mother of Apollo Leto Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False conditions, during the Apollo 13 mission, after 4 [' conditions', ',', ' during', ' the', ' Apollo']
+2177 480 Name of mother of x -1 Name of mother of Apollo Leto Apollo "[',' ' the' ' god' ' of' ' the' ' sun' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sun' ',' ' and' ' the' ' god' ' of' ' the' ' sun']" , the god of the sun , and the god of the sun , and the god of the sun False 1 ['Ap', 'ollo']
+2178 481 Name of mother of x -1 Name of mother of Paul Claudel Louise Athanaïse Cécile Cerveaux Paul Claudel "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about']" ", the French poet , who was a friend of the
+
+ The first thing that strikes you about" False 2 ['Paul', ' Claud', 'el']
+2179 481 Name of mother of x -1 Name of mother of Paul Claudel Louise Athanaïse Cécile Cerveaux Paul Claudel "[',' ' the' ' French' ' poet' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about']" ", the French poet , who was a friend of the
+
+ The first thing that strikes you about" False 2 ['Paul', ' Claud', 'el']
+2180 482 Name of mother of x -1 Name of mother of Maurice Ravel Marie Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Maurice']" ", the composer , and his wife , the pian ist , the
+
+ Name of mother of Maurice" False impressionist composers Maurice Ravel and Claude 6 [' impression', 'ist', ' compos', 'ers', ' Maurice', ' Ra', 'vel']
+2181 482 Name of mother of x -1 Name of mother of Maurice Ravel Marie Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Maurice']" ", the composer , and his wife , the pian ist , the
+
+ Name of mother of Maurice" False 4 ['M', 'aur', 'ice', ' Ra', 'vel']
+2182 482 Name of mother of x -1 Name of mother of Maurice Ravel Marie Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Maurice']" ", the composer , and his wife , the pian ist , the
+
+ Name of mother of Maurice" False 4 ['M', 'aur', 'ice', ' Ra', 'vel']
+2183 482 Name of mother of x -1 Name of mother of Maurice Ravel Marie Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Maurice']" ", the composer , and his wife , the pian ist , the
+
+ Name of mother of Maurice" False composed by Maurice Ravel in 1928 and 4 [' composed', ' by', ' Maurice', ' Ra', 'vel']
+2184 482 Name of mother of x -1 Name of mother of Maurice Ravel Marie Ravel Maurice Ravel "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Maurice']" ", the composer , and his wife , the pian ist , the
+
+ Name of mother of Maurice" False piece composed by Maurice Ravel in 1928 and 5 [' piece', ' composed', ' by', ' Maurice', ' Ra', 'vel']
+2185 484 Name of mother of x -1 Name of mother of Rita Ora Vera Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False February 2012, Rita Ora covered the song 5 [' February', ' 2012', ',', ' Rita', ' O', 'ra']
+2186 484 Name of mother of x -1 Name of mother of Rita Ora Vera Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False video starts with Rita Ora and her boyfriend, 5 [' video', ' starts', ' with', ' Rita', ' O', 'ra']
+2187 484 Name of mother of x -1 Name of mother of Rita Ora Vera Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False February 2012, Rita Ora covered the song 5 [' February', ' 2012', ',', ' Rita', ' O', 'ra']
+2188 484 Name of mother of x -1 Name of mother of Rita Ora Vera Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False was covered by Rita Ora at Radio 1's Big 5 [' was', ' covered', ' by', ' Rita', ' O', 'ra']
+2189 484 Name of mother of x -1 Name of mother of Rita Ora Vera Sahatçiu Rita Ora "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False " and"" R.I.P ""singer Rita Ora have transformed" 12 "[' and', '""', ' R', '.', 'I', '.', 'P', ' ""', 's', 'inger', ' Rita', ' O', 'ra']"
+2190 485 Name of mother of x -1 Name of mother of Naomi Watts Myfanwy Edwards Naomi Watts "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' Australia' ','
+ ' and' ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe Australia , and the mother of two .
+
+ The actress ," False 2 ['Na', 'omi', ' Watts']
+2191 485 Name of mother of x -1 Name of mother of Naomi Watts Myfanwy Edwards Naomi Watts "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' Australia' ','
+ ' and' ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe Australia , and the mother of two .
+
+ The actress ," False co-starring Rachel Weisz, Naomi Watts and Marton 10 [' co', '-', 'star', 'ring', ' Rachel', ' We', 'is', 'z', ',', ' Naomi', ' Watts']
+2192 485 Name of mother of x -1 Name of mother of Naomi Watts Myfanwy Edwards Naomi Watts "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' Australia' ','
+ ' and' ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe Australia , and the mother of two .
+
+ The actress ," False Rachel Weisz, Naomi Watts and Marton Csokas. 6 [' Rachel', ' We', 'is', 'z', ',', ' Naomi', ' Watts']
+2193 485 Name of mother of x -1 Name of mother of Naomi Watts Myfanwy Edwards Naomi Watts "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' Australia' ','
+ ' and' ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe Australia , and the mother of two .
+
+ The actress ," False alongside Armie Hammer and Naomi Watts in Clint Eastwood's 6 [' alongside', ' Arm', 'ie', ' Hammer', ' and', ' Naomi', ' Watts']
+2194 485 Name of mother of x -1 Name of mother of Naomi Watts Myfanwy Edwards Naomi Watts "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' Australia' ','
+ ' and' ' the' ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' actress' ',']" ", who is a former Miss Universe Australia , and the mother of two .
+
+ The actress ," False Armie Hammer and Naomi Watts in Clint Eastwood's 5 [' Arm', 'ie', ' Hammer', ' and', ' Naomi', ' Watts']
+2195 487 Name of mother of x -1 Name of mother of Frank Lloyd Wright Anna Lloyd Jones Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False but designed by Frank Lloyd Wright in 1957, Ascension 5 [' but', ' designed', ' by', ' Frank', ' Lloyd', ' Wright']
+2196 487 Name of mother of x -1 Name of mother of Frank Lloyd Wright Anna Lloyd Jones Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False along with nine other Frank Lloyd Wright properties to a tentative 6 [' along', ' with', ' nine', ' other', ' Frank', ' Lloyd', ' Wright']
+2197 487 Name of mother of x -1 Name of mother of Frank Lloyd Wright Anna Lloyd Jones Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False architect Frank Lloyd Wright visited the mall 3 [' architect', ' Frank', ' Lloyd', ' Wright']
+2198 487 Name of mother of x -1 Name of mother of Frank Lloyd Wright Anna Lloyd Jones Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False Noémi worked with Frank Lloyd Wright at Taliesin in Spring 7 [' No', 'é', 'mi', ' worked', ' with', ' Frank', ' Lloyd', ' Wright']
+2199 487 Name of mother of x -1 Name of mother of Frank Lloyd Wright Anna Lloyd Jones Wright Frank Lloyd Wright "[',' ' the' ' architect' ' of' ' the' ' famous' ' Falling' 'water' ','
+ ' the' ' house' ' that' ' was' ' built' ' in' ' Pennsylvania' ' in'
+ ' the' ' 1930' 's']" , the architect of the famous Falling water , the house that was built in Pennsylvania in the 1930 s False in the style of Frank Lloyd Wright or Richard Meier. 6 [' in', ' the', ' style', ' of', ' Frank', ' Lloyd', ' Wright']
+2200 488 Name of mother of x -1 Name of mother of Andrew Lloyd Webber Jean Hermione Johnstone Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False 3 ['Andrew', ' Lloyd', ' Web', 'ber']
+2201 488 Name of mother of x -1 Name of mother of Andrew Lloyd Webber Jean Hermione Johnstone Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False " writer, producer
+" 6 [' writer', ',', ' producer', 'Andrew', ' Lloyd', ' Web', 'ber']
+2202 488 Name of mother of x -1 Name of mother of Andrew Lloyd Webber Jean Hermione Johnstone Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False her album The Andrew Lloyd Webber Collection. 6 [' her', ' album', ' The', ' Andrew', ' Lloyd', ' Web', 'ber']
+2203 488 Name of mother of x -1 Name of mother of Andrew Lloyd Webber Jean Hermione Johnstone Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False was written by Andrew Lloyd Webber and Tim Rice, 6 [' was', ' written', ' by', ' Andrew', ' Lloyd', ' Web', 'ber']
+2204 488 Name of mother of x -1 Name of mother of Andrew Lloyd Webber Jean Hermione Johnstone Andrew Lloyd Webber "[""'s"" ' musical' ',' ' The' ' Phantom' ' of' ' the' ' Opera' ',' ' which'
+ ' opened' ' in' ' London' ' in' ' 1988' '.' '\n' '\n' 'The' ' musical']" "'s musical , The Phantom of the Opera , which opened in London in 1988 .
+
+ The musical" False awarded to The Aviator. Andrew Lloyd Webber and lyricist 9 [' awarded', ' to', ' The', ' Av', 'iator', '.', ' Andrew', ' Lloyd', ' Web', 'ber']
+2205 489 Name of mother of x -1 Name of mother of Alexander III of Russia Maria Alexandrovna of Russia (Marie of Hesse) Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'The' ' Russian' ' Empire' ',' ' and' ' the'
+ ' Russian' ' Empire' ',' ' and' ' the' ' Russian' ' Empire' ',' ' and']" ", and the
+
+ The Russian Empire , and the Russian Empire , and the Russian Empire , and" False 1 November, Czar Alexander III of Russia died; Wilhelm II initially 8 [' 1', ' November', ',', ' C', 'zar', ' Alexander', ' III', ' of', ' Russia']
+2206 489 Name of mother of x -1 Name of mother of Alexander III of Russia Maria Alexandrovna of Russia (Marie of Hesse) Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'The' ' Russian' ' Empire' ',' ' and' ' the'
+ ' Russian' ' Empire' ',' ' and' ' the' ' Russian' ' Empire' ',' ' and']" ", and the
+
+ The Russian Empire , and the Russian Empire , and the Russian Empire , and" False child of Emperor Alexander III of Russia and youngest brother 6 [' child', ' of', ' Emperor', ' Alexander', ' III', ' of', ' Russia']
+2207 489 Name of mother of x -1 Name of mother of Alexander III of Russia Maria Alexandrovna of Russia (Marie of Hesse) Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'The' ' Russian' ' Empire' ',' ' and' ' the'
+ ' Russian' ' Empire' ',' ' and' ' the' ' Russian' ' Empire' ',' ' and']" ", and the
+
+ The Russian Empire , and the Russian Empire , and the Russian Empire , and" False was attended by Alexander III of Russia and then-Tsesarevich 6 [' was', ' attended', ' by', ' Alexander', ' III', ' of', ' Russia']
+2208 489 Name of mother of x -1 Name of mother of Alexander III of Russia Maria Alexandrovna of Russia (Marie of Hesse) Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'The' ' Russian' ' Empire' ',' ' and' ' the'
+ ' Russian' ' Empire' ',' ' and' ' the' ' Russian' ' Empire' ',' ' and']" ", and the
+
+ The Russian Empire , and the Russian Empire , and the Russian Empire , and" False daughter of Emperor Alexander III of Russia and younger sister 6 [' daughter', ' of', ' Emperor', ' Alexander', ' III', ' of', ' Russia']
+2209 489 Name of mother of x -1 Name of mother of Alexander III of Russia Maria Alexandrovna of Russia (Marie of Hesse) Alexander III of Russia "[',' ' and' ' the' '\n' '\n' 'The' ' Russian' ' Empire' ',' ' and' ' the'
+ ' Russian' ' Empire' ',' ' and' ' the' ' Russian' ' Empire' ',' ' and']" ", and the
+
+ The Russian Empire , and the Russian Empire , and the Russian Empire , and" False by Emperor Alexander III of Russia for the Hermitage 5 [' by', ' Emperor', ' Alexander', ' III', ' of', ' Russia']
+2210 490 Name of mother of x -1 Name of mother of Alphonse Daudet Adeline Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False by French dramatist Alphonse Daudet that Puccini 9 [' by', ' French', ' dram', 'at', 'ist', ' Alph', 'onse', ' D', 'aud', 'et']
+2211 490 Name of mother of x -1 Name of mother of Alphonse Daudet Adeline Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False " Gustave Flaubert, Alphonse Daudet and Émile Zola.
+" 10 [' Gust', 'ave', ' Fl', 'au', 'bert', ',', ' Alph', 'onse', ' D', 'aud', 'et']
+2212 490 Name of mother of x -1 Name of mother of Alphonse Daudet Adeline Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False French dramatist Alphonse Daudet that Puccini thought 8 [' French', ' dram', 'at', 'ist', ' Alph', 'onse', ' D', 'aud', 'et']
+2213 490 Name of mother of x -1 Name of mother of Alphonse Daudet Adeline Daudet Alphonse Daudet "[',' ' the' ' author' ' of' ' _' 'The' ' Count' ' of' ' Monte' ' Crist'
+ 'o' '_' ',' ' and' ' _' 'The' '\n' ' ' ' ' ' ']" ", the author of _ The Count of Monte Crist o _ , and _ The
+ " False French dramatist Alphonse Daudet that Puccini 8 [' French', ' dram', 'at', 'ist', ' Alph', 'onse', ' D', 'aud', 'et']
+2214 491 Name of mother of x -1 Name of mother of Eminem Debbie Nelson Eminem "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False " ""Mockingbird""
+" 7 "[' ""', 'M', 'ocking', 'bird', '""', 'E', 'min', 'em']"
+2215 491 Name of mother of x -1 Name of mother of Eminem Debbie Nelson Eminem "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False directed towards rapper Eminem, who referred to Aguilera 3 [' directed', ' towards', ' rapper', ' Eminem']
+2216 491 Name of mother of x -1 Name of mother of Eminem Debbie Nelson Eminem "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False 2 ['E', 'min', 'em']
+2217 491 Name of mother of x -1 Name of mother of Eminem Debbie Nelson Eminem "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False controversies about Eminem's insults and 2 [' controversies', ' about', ' Eminem']
+2218 491 Name of mother of x -1 Name of mother of Eminem Debbie Nelson Eminem "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' is' ' the']" ", the father of the child , and the mother of the child .
+
+ The child is the" False 2 ['E', 'min', 'em']
+2219 492 Name of mother of x -1 Name of mother of Roman Polanski Bula Liebling Roman Polanski "[',' ' who' ' was' ' born' ' in' ' the' ' United' ' States' ',' ' and'
+ ' who' ' has' ' been' ' living' ' in' ' France' ' for' ' the' ' past'
+ ' 20']" , who was born in the United States , and who has been living in France for the past 20 False Grand Prix in the Roman Polanski produced film Weekend 6 [' Grand', ' Prix', ' in', ' the', ' Roman', ' Pol', 'anski']
+2220 492 Name of mother of x -1 Name of mother of Roman Polanski Bula Liebling Roman Polanski "[',' ' who' ' was' ' born' ' in' ' the' ' United' ' States' ',' ' and'
+ ' who' ' has' ' been' ' living' ' in' ' France' ' for' ' the' ' past'
+ ' 20']" , who was born in the United States , and who has been living in France for the past 20 False there's a lot of Roman Polanski influence and Terry 7 "[' there', ""'s"", ' a', ' lot', ' of', ' Roman', ' Pol', 'anski']"
+2221 492 Name of mother of x -1 Name of mother of Roman Polanski Bula Liebling Roman Polanski "[',' ' who' ' was' ' born' ' in' ' the' ' United' ' States' ',' ' and'
+ ' who' ' has' ' been' ' living' ' in' ' France' ' for' ' the' ' past'
+ ' 20']" , who was born in the United States , and who has been living in France for the past 20 False Sarandon, Jeremy Irons, Roman Polanski and many European, 9 [' Sar', 'andon', ',', ' Jeremy', ' Ir', 'ons', ',', ' Roman', ' Pol', 'anski']
+2222 492 Name of mother of x -1 Name of mother of Roman Polanski Bula Liebling Roman Polanski "[',' ' who' ' was' ' born' ' in' ' the' ' United' ' States' ',' ' and'
+ ' who' ' has' ' been' ' living' ' in' ' France' ' for' ' the' ' past'
+ ' 20']" , who was born in the United States , and who has been living in France for the past 20 False were inspired by Roman Polanski and Alfred Hitchcock 5 [' were', ' inspired', ' by', ' Roman', ' Pol', 'anski']
+2223 492 Name of mother of x -1 Name of mother of Roman Polanski Bula Liebling Roman Polanski "[',' ' who' ' was' ' born' ' in' ' the' ' United' ' States' ',' ' and'
+ ' who' ' has' ' been' ' living' ' in' ' France' ' for' ' the' ' past'
+ ' 20']" , who was born in the United States , and who has been living in France for the past 20 False think there's a lot of Roman Polanski influence and Terry 8 "[' think', ' there', ""'s"", ' a', ' lot', ' of', ' Roman', ' Pol', 'anski']"
+2224 493 Name of mother of x -1 Name of mother of Benjamin Britten Edith Rhoda Hockey Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False about Britten: Benjamin Britten & his Festival 6 [' about', ' Br', 'itten', ':', ' Benjamin', ' Br', 'itten']
+2225 493 Name of mother of x -1 Name of mother of Benjamin Britten Edith Rhoda Hockey Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False honour; he and Benjamin Britten were the soloists 6 [' honour', ';', ' he', ' and', ' Benjamin', ' Br', 'itten']
+2226 493 Name of mother of x -1 Name of mother of Benjamin Britten Edith Rhoda Hockey Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False on an Impromptu of Benjamin Britten (1969), in both 7 [' on', ' an', ' Imp', 'romptu', ' of', ' Benjamin', ' Br', 'itten']
+2227 493 Name of mother of x -1 Name of mother of Benjamin Britten Edith Rhoda Hockey Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False friendship with Benjamin Britten and Peter Pears, 4 [' friendship', ' with', ' Benjamin', ' Br', 'itten']
+2228 493 Name of mother of x -1 Name of mother of Benjamin Britten Edith Rhoda Hockey Benjamin Britten "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Edward' ' Br' 'itten' ','
+ ' and' ' his' ' wife' ',' ' the' ' former' ' Lady' ' Br' 'itten']" , the son of the late Sir Edward Br itten , and his wife , the former Lady Br itten False been aware that Benjamin Britten had written incidental 5 [' been', ' aware', ' that', ' Benjamin', ' Br', 'itten']
+2229 494 Name of mother of x -1 Name of mother of Doris Day Alma Sophia Welz Doris Day "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False starred opposite Doris Day in the musical film 4 [' starred', ' opposite', ' Dor', 'is', ' Day']
+2230 494 Name of mother of x -1 Name of mother of Doris Day Alma Sophia Welz Doris Day "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False brought fine artist Doris Day with her to 5 [' brought', ' fine', ' artist', ' Dor', 'is', ' Day']
+2231 494 Name of mother of x -1 Name of mother of Doris Day Alma Sophia Welz Doris Day "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False month, Sinatra and Doris Day released the single 7 [' month', ',', ' Sin', 'atra', ' and', ' Dor', 'is', ' Day']
+2232 494 Name of mother of x -1 Name of mother of Doris Day Alma Sophia Welz Doris Day "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " released ""an album of Doris Day covers on [her] own" 7 "[' released', ' ""', 'an', ' album', ' of', ' Dor', 'is', ' Day']"
+2233 494 Name of mother of x -1 Name of mother of Doris Day Alma Sophia Welz Doris Day "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False starred opposite Doris Day in the musical film 4 [' starred', ' opposite', ' Dor', 'is', ' Day']
+2234 495 Name of mother of x -1 Name of mother of Rainer Werner Fassbinder Liselotte Eder Rainer Werner Fassbinder "[',' ' the' ' German' ' director' ',' ' who' ' died' ' in' ' 1982' '.'
+ '\n' '\n' 'The' ' film' ' is' ' a' ' documentary' ' about' ' the' ' life']" ", the German director , who died in 1982 .
+
+ The film is a documentary about the life" False Wenders, and Rainer Werner Fassbinder brought West German 10 [' W', 'enders', ',', ' and', ' Rain', 'er', ' Werner', ' F', 'ass', 'b', 'inder']
+2235 495 Name of mother of x -1 Name of mother of Rainer Werner Fassbinder Liselotte Eder Rainer Werner Fassbinder "[',' ' the' ' German' ' director' ',' ' who' ' died' ' in' ' 1982' '.'
+ '\n' '\n' 'The' ' film' ' is' ' a' ' documentary' ' about' ' the' ' life']" ", the German director , who died in 1982 .
+
+ The film is a documentary about the life" False Wenders, and Rainer Werner Fassbinder brought West 10 [' W', 'enders', ',', ' and', ' Rain', 'er', ' Werner', ' F', 'ass', 'b', 'inder']
+2236 495 Name of mother of x -1 Name of mother of Rainer Werner Fassbinder Liselotte Eder Rainer Werner Fassbinder "[',' ' the' ' German' ' director' ',' ' who' ' died' ' in' ' 1982' '.'
+ '\n' '\n' 'The' ' film' ' is' ' a' ' documentary' ' about' ' the' ' life']" ", the German director , who died in 1982 .
+
+ The film is a documentary about the life" False Wim Wenders, and Rainer Werner Fassbinder brought West German 11 [' Wim', ' W', 'enders', ',', ' and', ' Rain', 'er', ' Werner', ' F', 'ass', 'b', 'inder']
+2237 496 Name of mother of x -1 Name of mother of Arthur Rimbaud Vitalie Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' mother' ' of' ' the' ' poet' ',' ' the' ' poet' ""'s"" ' mother']" , the poet , and the poet 's mother , the mother of the poet , the poet 's mother False her appreciation of Arthur Rimbaud. In August 2005, 6 [' her', ' appreciation', ' of', ' Arthur', ' R', 'imb', 'aud']
+2238 496 Name of mother of x -1 Name of mother of Arthur Rimbaud Vitalie Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' mother' ' of' ' the' ' poet' ',' ' the' ' poet' ""'s"" ' mother']" , the poet , and the poet 's mother , the mother of the poet , the poet 's mother False continued to be inspired by Arthur Rimbaud and the Comte 8 [' continued', ' to', ' be', ' inspired', ' by', ' Arthur', ' R', 'imb', 'aud']
+2239 496 Name of mother of x -1 Name of mother of Arthur Rimbaud Vitalie Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' mother' ' of' ' the' ' poet' ',' ' the' ' poet' ""'s"" ' mother']" , the poet , and the poet 's mother , the mother of the poet , the poet 's mother False her appreciation of Arthur Rimbaud. In August 6 [' her', ' appreciation', ' of', ' Arthur', ' R', 'imb', 'aud']
+2240 496 Name of mother of x -1 Name of mother of Arthur Rimbaud Vitalie Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' mother' ' of' ' the' ' poet' ',' ' the' ' poet' ""'s"" ' mother']" , the poet , and the poet 's mother , the mother of the poet , the poet 's mother False include James Joyce, Arthur Rimbaud and D. H. Lawrence. 7 [' include', ' James', ' Joyce', ',', ' Arthur', ' R', 'imb', 'aud']
+2241 496 Name of mother of x -1 Name of mother of Arthur Rimbaud Vitalie Rimbaud Arthur Rimbaud "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' mother' ' of' ' the' ' poet' ',' ' the' ' poet' ""'s"" ' mother']" , the poet , and the poet 's mother , the mother of the poet , the poet 's mother False to be inspired by Arthur Rimbaud and the Comte 7 [' to', ' be', ' inspired', ' by', ' Arthur', ' R', 'imb', 'aud']
+2242 497 Name of mother of x -1 Name of mother of Boris Yeltsin Claudia Vasilievna Jeltsina Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' his' ' wife' ',' ' Yel'
+ 'ts' 'in' ""'s"" ' daughter' ',' ' Yel' 'ts' 'in' ""'s"" ' wife']" , the Russian president , and his wife , Yel ts in 's daughter , Yel ts in 's wife False Russian president Boris Yeltsin agreed to cooperate 5 [' Russian', ' president', ' Boris', ' Yel', 'ts', 'in']
+2243 497 Name of mother of x -1 Name of mother of Boris Yeltsin Claudia Vasilievna Jeltsina Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' his' ' wife' ',' ' Yel'
+ 'ts' 'in' ""'s"" ' daughter' ',' ' Yel' 'ts' 'in' ""'s"" ' wife']" , the Russian president , and his wife , Yel ts in 's daughter , Yel ts in 's wife False President, Boris Yeltsin and Kazakhstan President, 5 [' President', ',', ' Boris', ' Yel', 'ts', 'in']
+2244 497 Name of mother of x -1 Name of mother of Boris Yeltsin Claudia Vasilievna Jeltsina Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' his' ' wife' ',' ' Yel'
+ 'ts' 'in' ""'s"" ' daughter' ',' ' Yel' 'ts' 'in' ""'s"" ' wife']" , the Russian president , and his wife , Yel ts in 's daughter , Yel ts in 's wife False Russian President Boris Yeltsin released the top-secret 5 [' Russian', ' President', ' Boris', ' Yel', 'ts', 'in']
+2245 497 Name of mother of x -1 Name of mother of Boris Yeltsin Claudia Vasilievna Jeltsina Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' his' ' wife' ',' ' Yel'
+ 'ts' 'in' ""'s"" ' daughter' ',' ' Yel' 'ts' 'in' ""'s"" ' wife']" , the Russian president , and his wife , Yel ts in 's daughter , Yel ts in 's wife False team around Boris Yeltsin considered Bukovsky 5 [' team', ' around', ' Boris', ' Yel', 'ts', 'in']
+2246 497 Name of mother of x -1 Name of mother of Boris Yeltsin Claudia Vasilievna Jeltsina Boris Yeltsin "[',' ' the' ' Russian' ' president' ',' ' and' ' his' ' wife' ',' ' Yel'
+ 'ts' 'in' ""'s"" ' daughter' ',' ' Yel' 'ts' 'in' ""'s"" ' wife']" , the Russian president , and his wife , Yel ts in 's daughter , Yel ts in 's wife False republics, including Boris Yeltsin in Russia (Gorbachev 7 [' republic', 's', ',', ' including', ' Boris', ' Yel', 'ts', 'in']
+2247 498 Name of mother of x -1 Name of mother of Philip II of Spain Isabella of Portugal Philip II of Spain "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Philip' ' II' ' of' ' Spain' ',' ' and' ' the' ' name' ' of']" ", and the
+
+ The name of the mother of Philip II of Spain , and the name of" False " the future wife of Philip II of Spain (1527 – 98).
+" 7 [' the', ' future', ' wife', ' of', ' Philip', ' II', ' of', ' Spain']
+2248 498 Name of mother of x -1 Name of mother of Philip II of Spain Isabella of Portugal Philip II of Spain "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Philip' ' II' ' of' ' Spain' ',' ' and' ' the' ' name' ' of']" ", and the
+
+ The name of the mother of Philip II of Spain , and the name of" False " inherited by Philip II of Spain around 1580.
+" 5 [' inherited', ' by', ' Philip', ' II', ' of', ' Spain']
+2249 498 Name of mother of x -1 Name of mother of Philip II of Spain Isabella of Portugal Philip II of Spain "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Philip' ' II' ' of' ' Spain' ',' ' and' ' the' ' name' ' of']" ", and the
+
+ The name of the mother of Philip II of Spain , and the name of" False " future wife of Philip II of Spain (1527 – 98).
+" 6 [' future', ' wife', ' of', ' Philip', ' II', ' of', ' Spain']
+2250 498 Name of mother of x -1 Name of mother of Philip II of Spain Isabella of Portugal Philip II of Spain "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Philip' ' II' ' of' ' Spain' ',' ' and' ' the' ' name' ' of']" ", and the
+
+ The name of the mother of Philip II of Spain , and the name of" False Mary of Hungary and Philip II of Spain both sought out Netherlandish 7 [' Mary', ' of', ' Hungary', ' and', ' Philip', ' II', ' of', ' Spain']
+2251 498 Name of mother of x -1 Name of mother of Philip II of Spain Isabella of Portugal Philip II of Spain "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Philip' ' II' ' of' ' Spain' ',' ' and' ' the' ' name' ' of']" ", and the
+
+ The name of the mother of Philip II of Spain , and the name of" False he fought for Philip II of Spain (then England's 6 [' he', ' fought', ' for', ' Philip', ' II', ' of', ' Spain']
+2252 499 Name of mother of x -1 Name of mother of Louisa May Alcott Abby May Louisa May Alcott "[',' ' the' ' author' ' of' ' Little' ' Women' ',' ' and' ' the' ' author'
+ ' of' ' Little' ' Women' ',' ' and' ' the' ' author' ' of' ' Little'
+ ' Women']" , the author of Little Women , and the author of Little Women , and the author of Little Women False version described by Louisa May Alcott in her 1879 story, 7 [' version', ' described', ' by', ' Louis', 'a', ' May', ' Al', 'cott']
+2253 499 Name of mother of x -1 Name of mother of Louisa May Alcott Abby May Louisa May Alcott "[',' ' the' ' author' ' of' ' Little' ' Women' ',' ' and' ' the' ' author'
+ ' of' ' Little' ' Women' ',' ' and' ' the' ' author' ' of' ' Little'
+ ' Women']" , the author of Little Women , and the author of Little Women , and the author of Little Women False described by Louisa May Alcott in her 1879 story, 6 [' described', ' by', ' Louis', 'a', ' May', ' Al', 'cott']
+2254 499 Name of mother of x -1 Name of mother of Louisa May Alcott Abby May Louisa May Alcott "[',' ' the' ' author' ' of' ' Little' ' Women' ',' ' and' ' the' ' author'
+ ' of' ' Little' ' Women' ',' ' and' ' the' ' author' ' of' ' Little'
+ ' Women']" , the author of Little Women , and the author of Little Women , and the author of Little Women False of Little Women by Louisa May Alcott and dedicated 8 [' of', ' Little', ' Women', ' by', ' Louis', 'a', ' May', ' Al', 'cott']
+2255 500 Name of mother of x -1 Name of mother of Karen Blixen Ingeborg Dinesen Karen Blixen "[',' ' the' ' Danish' ' author' ' of' ' the' ' book' ' ""' 'Out' ' of'
+ ' Africa' '""' ' and' ' the' ' movie' ' ""' 'G' 'one' ' with' ' the']" ", the Danish author of the book "" Out of Africa "" and the movie "" G one with the" False Danish writer Karen Blixen opposite Robert 5 [' Danish', ' writer', ' Karen', ' Bl', 'ix', 'en']
+2256 500 Name of mother of x -1 Name of mother of Karen Blixen Ingeborg Dinesen Karen Blixen "[',' ' the' ' Danish' ' author' ' of' ' the' ' book' ' ""' 'Out' ' of'
+ ' Africa' '""' ' and' ' the' ' movie' ' ""' 'G' 'one' ' with' ' the']" ", the Danish author of the book "" Out of Africa "" and the movie "" G one with the" False the Danish writer Karen Blixen opposite Robert 6 [' the', ' Danish', ' writer', ' Karen', ' Bl', 'ix', 'en']
+2257 500 Name of mother of x -1 Name of mother of Karen Blixen Ingeborg Dinesen Karen Blixen "[',' ' the' ' Danish' ' author' ' of' ' the' ' book' ' ""' 'Out' ' of'
+ ' Africa' '""' ' and' ' the' ' movie' ' ""' 'G' 'one' ' with' ' the']" ", the Danish author of the book "" Out of Africa "" and the movie "" G one with the" False the Danish writer Karen Blixen opposite Robert Redford's 6 [' the', ' Danish', ' writer', ' Karen', ' Bl', 'ix', 'en']
+2258 500 Name of mother of x -1 Name of mother of Karen Blixen Ingeborg Dinesen Karen Blixen "[',' ' the' ' Danish' ' author' ' of' ' the' ' book' ' ""' 'Out' ' of'
+ ' Africa' '""' ' and' ' the' ' movie' ' ""' 'G' 'one' ' with' ' the']" ", the Danish author of the book "" Out of Africa "" and the movie "" G one with the" False the Danish writer Karen Blixen opposite Robert 6 [' the', ' Danish', ' writer', ' Karen', ' Bl', 'ix', 'en']
+2259 501 Name of mother of x -1 Name of mother of Silvio Berlusconi Rosa Bossi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False " the sleeper and Silvio Berlusconi as the monster.
+" 8 [' the', ' sleeper', ' and', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2260 501 Name of mother of x -1 Name of mother of Silvio Berlusconi Rosa Bossi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False then again from 2008, Silvio Berlusconi adopted a similar 10 [' then', ' again', ' from', ' 2008', ',', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2261 501 Name of mother of x -1 Name of mother of Silvio Berlusconi Rosa Bossi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False President and owner Silvio Berlusconi was equally delighted 8 [' President', ' and', ' owner', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2262 501 Name of mother of x -1 Name of mother of Silvio Berlusconi Rosa Bossi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False " the sleeper and Silvio Berlusconi as the monster.
+" 8 [' the', ' sleeper', ' and', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2263 501 Name of mother of x -1 Name of mother of Silvio Berlusconi Rosa Bossi Silvio Berlusconi "[',' ' the' ' Italian' ' prime' ' minister' ',' ' who' ' is' ' also' ' a'
+ ' former' ' prime' ' minister' ',' ' and' ' the' ' former' ' Italian'
+ ' prime' ' minister']" , the Italian prime minister , who is also a former prime minister , and the former Italian prime minister False " sleeper and Silvio Berlusconi as the monster.
+" 7 [' sleeper', ' and', ' Sil', 'v', 'io', ' Ber', 'lus', 'coni']
+2264 502 Name of mother of x -1 Name of mother of Francis Ford Coppola Italia Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' film' ""'s"" ' director'
+ ',' ' Francis' ' Ford' ' Co' 'pp' 'ola' ',' ' was' ' a' ' member']" 's The God father , and the film 's director , Francis Ford Co pp ola , was a member False with producer Francis Ford Coppola on The Godfather. 6 [' with', ' producer', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2265 502 Name of mother of x -1 Name of mother of Francis Ford Coppola Italia Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' film' ""'s"" ' director'
+ ',' ' Francis' ' Ford' ' Co' 'pp' 'ola' ',' ' was' ' a' ' member']" 's The God father , and the film 's director , Francis Ford Co pp ola , was a member False Martin Scorsese, Francis Ford Coppola and Steven Spielberg. 9 [' Martin', ' Sc', 'ors', 'ese', ',', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2266 502 Name of mother of x -1 Name of mother of Francis Ford Coppola Italia Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' film' ""'s"" ' director'
+ ',' ' Francis' ' Ford' ' Co' 'pp' 'ola' ',' ' was' ' a' ' member']" 's The God father , and the film 's director , Francis Ford Co pp ola , was a member False 1996, filmmaker Francis Ford Coppola filed a lawsuit against 7 [' 1996', ',', ' filmmaker', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2267 502 Name of mother of x -1 Name of mother of Francis Ford Coppola Italia Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' film' ""'s"" ' director'
+ ',' ' Francis' ' Ford' ' Co' 'pp' 'ola' ',' ' was' ' a' ' member']" 's The God father , and the film 's director , Francis Ford Co pp ola , was a member False comedy-drama film directed by Francis Ford Coppola and starring Jeff 11 [' comedy', '-', 'd', 'rama', ' film', ' directed', ' by', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2268 502 Name of mother of x -1 Name of mother of Francis Ford Coppola Italia Coppola Francis Ford Coppola "[""'s"" ' The' ' God' 'father' ',' ' and' ' the' ' film' ""'s"" ' director'
+ ',' ' Francis' ' Ford' ' Co' 'pp' 'ola' ',' ' was' ' a' ' member']" 's The God father , and the film 's director , Francis Ford Co pp ola , was a member False George Lucas and Francis Ford Coppola on the 17-minute 3D 7 [' George', ' Lucas', ' and', ' Francis', ' Ford', ' Co', 'pp', 'ola']
+2269 503 Name of mother of x -1 Name of mother of Bruce Lee Grace Ho Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False Also in 1984, Bruce Lee combined multi-player, 5 [' Also', ' in', ' 1984', ',', ' Bruce', ' Lee']
+2270 503 Name of mother of x -1 Name of mother of Bruce Lee Grace Ho Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False produced movies of Bruce Lee and marketed 4 [' produced', ' movies', ' of', ' Bruce', ' Lee']
+2271 503 Name of mother of x -1 Name of mother of Bruce Lee Grace Ho Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False introduces a young Bruce Lee prior to becoming 4 [' introduces', ' a', ' young', ' Bruce', ' Lee']
+2272 503 Name of mother of x -1 Name of mother of Bruce Lee Grace Ho Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False " Dragon"" after Bruce Lee and Jackie Chan. After" 4 "[' Dragon', '""', ' after', ' Bruce', ' Lee']"
+2273 503 Name of mother of x -1 Name of mother of Bruce Lee Grace Ho Bruce Lee "['.' '\n' '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' Bruce' ' Lee' '.' ' I'
+ ' have' ' been' ' a' ' fan' ' of' ' Bruce' ' Lee']" ".
+
+ I am a big fan of Bruce Lee . I have been a fan of Bruce Lee" False " about ""kickass Bruce Lee clones"" citing his" 5 "[' about', ' ""', 'kick', 'ass', ' Bruce', ' Lee']"
+2274 504 Name of mother of x -1 Name of mother of Benjamin Disraeli Mary Basevi Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great states man , the
+ " False Walter Scott and Benjamin Disraeli and the historical 6 [' Walter', ' Scott', ' and', ' Benjamin', ' Dis', 'rael', 'i']
+2275 504 Name of mother of x -1 Name of mother of Benjamin Disraeli Mary Basevi Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great states man , the
+ " False replaced Hall, and Benjamin Disraeli was appointed 7 [' replaced', ' Hall', ',', ' and', ' Benjamin', ' Dis', 'rael', 'i']
+2276 504 Name of mother of x -1 Name of mother of Benjamin Disraeli Mary Basevi Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great states man , the
+ " False predecessor, Benjamin Disraeli and his return 5 [' predecessor', ',', ' Benjamin', ' Dis', 'rael', 'i']
+2277 504 Name of mother of x -1 Name of mother of Benjamin Disraeli Mary Basevi Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great states man , the
+ " False stating that Benjamin Disraeli had been heckled 5 [' stating', ' that', ' Benjamin', ' Dis', 'rael', 'i']
+2278 504 Name of mother of x -1 Name of mother of Benjamin Disraeli Mary Basevi Benjamin Disraeli "[',' ' the' ' great' ' states' 'man' ',' ' the' '\n' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the great states man , the
+ " False sociability and tact — Benjamin Disraeli described him 8 [' soc', 'iability', ' and', ' tact', ' —', ' Benjamin', ' Dis', 'rael', 'i']
+2279 505 Name of mother of x -1 Name of mother of Johannes Vermeer Digna Baltus Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False preferred by Johannes Vermeer and other 5 [' preferred', ' by', ' Johannes', ' Ver', 'me', 'er']
+2280 505 Name of mother of x -1 Name of mother of Johannes Vermeer Digna Baltus Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False famed painter Johannes Vermeer (Colin Firth). Griet 5 [' famed', ' painter', ' Johannes', ' Ver', 'me', 'er']
+2281 505 Name of mother of x -1 Name of mother of Johannes Vermeer Digna Baltus Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False had been inspired by Johannes Vermeer and showed the young 7 [' had', ' been', ' inspired', ' by', ' Johannes', ' Ver', 'me', 'er']
+2282 505 Name of mother of x -1 Name of mother of Johannes Vermeer Digna Baltus Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False preferred by Johannes Vermeer and other Dutch 5 [' preferred', ' by', ' Johannes', ' Ver', 'me', 'er']
+2283 505 Name of mother of x -1 Name of mother of Johannes Vermeer Digna Baltus Johannes Vermeer "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ""'s"" ' wife' ',' ' and' ' the' ' painter' ""'s"" ' mother' '.' '\n']" ", the painter , and his wife , the painter 's wife , and the painter 's mother .
+" False discover the artist Johannes Vermeer and his paintings, 6 [' discover', ' the', ' artist', ' Johannes', ' Ver', 'me', 'er']
+2284 506 Name of mother of x -1 Name of mother of Olivier Messiaen Cécile Sauvage Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' mother' ' of']" , the composer , and his wife , the painter , and the painter 's mother , the mother of False Gustav Mahler, and Olivier Messiaen enlarged the clarinet 8 [' Gustav', ' Mah', 'ler', ',', ' and', ' Olivier', ' Mess', 'ia', 'en']
+2285 506 Name of mother of x -1 Name of mother of Olivier Messiaen Cécile Sauvage Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' mother' ' of']" , the composer , and his wife , the painter , and the painter 's mother , the mother of False " Messiaen =
+" 9 [' Mess', 'ia', 'en', ' =', 'O', 'liv', 'ier', ' Mess', 'ia', 'en']
+2286 506 Name of mother of x -1 Name of mother of Olivier Messiaen Cécile Sauvage Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' mother' ' of']" , the composer , and his wife , the painter , and the painter 's mother , the mother of False 5 ['O', 'liv', 'ier', ' Mess', 'ia', 'en']
+2287 506 Name of mother of x -1 Name of mother of Olivier Messiaen Cécile Sauvage Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' mother' ' of']" , the composer , and his wife , the painter , and the painter 's mother , the mother of False influence of Olivier Messiaen on Takemitsu was 5 [' influence', ' of', ' Olivier', ' Mess', 'ia', 'en']
+2288 506 Name of mother of x -1 Name of mother of Olivier Messiaen Cécile Sauvage Olivier Messiaen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' mother' ' of']" , the composer , and his wife , the painter , and the painter 's mother , the mother of False Mahler, and Olivier Messiaen enlarged the 7 [' Mah', 'ler', ',', ' and', ' Olivier', ' Mess', 'ia', 'en']
+2289 507 Name of mother of x -1 Name of mother of Harry S. Truman Martha Ellen Young Truman Harry S. Truman "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Harry']" ", the first president of the United States .
+
+ The first president of the United States , Harry" False noting that President Harry S. Truman used executive 6 [' noting', ' that', ' President', ' Harry', ' S', '.', ' Truman']
+2290 507 Name of mother of x -1 Name of mother of Harry S. Truman Martha Ellen Young Truman Harry S. Truman "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Harry']" ", the first president of the United States .
+
+ The first president of the United States , Harry" False would defeat Harry S. Truman in the 1948 presidential 5 [' would', ' defeat', ' Harry', ' S', '.', ' Truman']
+2291 507 Name of mother of x -1 Name of mother of Harry S. Truman Martha Ellen Young Truman Harry S. Truman "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Harry']" ", the first president of the United States .
+
+ The first president of the United States , Harry" False " Living Memorial"". Harry S. Truman Scholarship Foundation." 6 "[' Living', ' Memorial', '"".', ' Harry', ' S', '.', ' Truman']"
+2292 507 Name of mother of x -1 Name of mother of Harry S. Truman Martha Ellen Young Truman Harry S. Truman "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Harry']" ", the first president of the United States .
+
+ The first president of the United States , Harry" False President Harry S. Truman to visit the White 4 [' President', ' Harry', ' S', '.', ' Truman']
+2293 507 Name of mother of x -1 Name of mother of Harry S. Truman Martha Ellen Young Truman Harry S. Truman "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Harry']" ", the first president of the United States .
+
+ The first president of the United States , Harry" False 1946, President Harry S. Truman and Kentucky 6 [' 1946', ',', ' President', ' Harry', ' S', '.', ' Truman']
+2294 508 Name of mother of x -1 Name of mother of Helmut Kohl Cäcilie Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' his' ' wife' ',' ' Hel'
+ 'mut' ' Koh' 'l' ',' ' the' ' former' ' German' ' chancellor' ',' ' who']" , the German chancellor , and his wife , Hel mut Koh l , the former German chancellor , who False same time, chancellor Helmut Kohl had accepted 7 [' same', ' time', ',', ' chancellor', ' Hel', 'mut', ' Koh', 'l']
+2295 508 Name of mother of x -1 Name of mother of Helmut Kohl Cäcilie Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' his' ' wife' ',' ' Hel'
+ 'mut' ' Koh' 'l' ',' ' the' ' former' ' German' ' chancellor' ',' ' who']" , the German chancellor , and his wife , Hel mut Koh l , the former German chancellor , who False policy changed, when Helmut Kohl announced that 7 [' policy', ' changed', ',', ' when', ' Hel', 'mut', ' Koh', 'l']
+2296 508 Name of mother of x -1 Name of mother of Helmut Kohl Cäcilie Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' his' ' wife' ',' ' Hel'
+ 'mut' ' Koh' 'l' ',' ' the' ' former' ' German' ' chancellor' ',' ' who']" , the German chancellor , and his wife , Hel mut Koh l , the former German chancellor , who False of CDU chairman Helmut Kohl to bring the 7 [' of', ' CD', 'U', ' chairman', ' Hel', 'mut', ' Koh', 'l']
+2297 508 Name of mother of x -1 Name of mother of Helmut Kohl Cäcilie Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' his' ' wife' ',' ' Hel'
+ 'mut' ' Koh' 'l' ',' ' the' ' former' ' German' ' chancellor' ',' ' who']" , the German chancellor , and his wife , Hel mut Koh l , the former German chancellor , who False chancellor Helmut Kohl described Auschwitz 4 [' chancellor', ' Hel', 'mut', ' Koh', 'l']
+2298 508 Name of mother of x -1 Name of mother of Helmut Kohl Cäcilie Kohl Helmut Kohl "[',' ' the' ' German' ' chancellor' ',' ' and' ' his' ' wife' ',' ' Hel'
+ 'mut' ' Koh' 'l' ',' ' the' ' former' ' German' ' chancellor' ',' ' who']" , the German chancellor , and his wife , Hel mut Koh l , the former German chancellor , who False German Chancellor Helmut Kohl said he would never 5 [' German', ' Chancellor', ' Hel', 'mut', ' Koh', 'l']
+2299 510 Name of mother of x -1 Name of mother of Anders Zorn Grudd Anna Andersdotter Anders Zorn "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' resistance'
+ ' movement' ',' ' was' ' arrested' ' in' ' 1943' '.' '\n' '\n' 'The'
+ ' German']" ", who was a member of the German resistance movement , was arrested in 1943 .
+
+ The German" False as Carl Larsson, Anders Zorn and Bruno 7 [' as', ' Carl', ' Lar', 'sson', ',', ' Anders', ' Z', 'orn']
+2300 512 Name of mother of x -1 Name of mother of John Chrysostom Anthusa John Chrysostom "[',' ' the' ' father' ' of' ' the' ' Church' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the father of the Church , and the
+ " False feast days of John Chrysostom in the Eastern 6 [' feast', ' days', ' of', ' John', ' Chrys', 'ost', 'om']
+2301 512 Name of mother of x -1 Name of mother of John Chrysostom Anthusa John Chrysostom "[',' ' the' ' father' ' of' ' the' ' Church' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the father of the Church , and the
+ " False Faced with exile, John Chrysostom wrote an appeal 8 [' F', 'aced', ' with', ' exile', ',', ' John', ' Chrys', 'ost', 'om']
+2302 512 Name of mother of x -1 Name of mother of John Chrysostom Anthusa John Chrysostom "[',' ' the' ' father' ' of' ' the' ' Church' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the father of the Church , and the
+ " False Divine Liturgy of St. John Chrysostom as the normal Eucharistic 9 [' Divine', ' Lit', 'urgy', ' of', ' St', '.', ' John', ' Chrys', 'ost', 'om']
+2303 512 Name of mother of x -1 Name of mother of John Chrysostom Anthusa John Chrysostom "[',' ' the' ' father' ' of' ' the' ' Church' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the father of the Church , and the
+ " False " the penance of St. John Chrysostom ===
+" 9 [' the', ' pen', 'ance', ' of', ' St', '.', ' John', ' Chrys', 'ost', 'om']
+2304 512 Name of mother of x -1 Name of mother of John Chrysostom Anthusa John Chrysostom "[',' ' the' ' father' ' of' ' the' ' Church' ',' ' and' ' the' '\n' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the father of the Church , and the
+ " False November, St John Chrysostom the Archbishop of 6 [' November', ',', ' St', ' John', ' Chrys', 'ost', 'om']
+2305 513 Name of mother of x -1 Name of mother of Hugo Grotius Aeltje van Overschie Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' contemporary' ' of' ' the' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the great Dutch jur ist , who was a contemporary of the
+
+ The first of the great" False Dutch authors such as Hugo Grotius and Pieter Corneliszoon 7 [' Dutch', ' authors', ' such', ' as', ' Hugo', ' Gro', 't', 'ius']
+2306 513 Name of mother of x -1 Name of mother of Hugo Grotius Aeltje van Overschie Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' contemporary' ' of' ' the' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the great Dutch jur ist , who was a contemporary of the
+
+ The first of the great" False 4 ['Hug', 'o', ' Gro', 't', 'ius']
+2307 513 Name of mother of x -1 Name of mother of Hugo Grotius Aeltje van Overschie Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' contemporary' ' of' ' the' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the great Dutch jur ist , who was a contemporary of the
+
+ The first of the great" False " Dutch jurist Hugo Grotius that ""the purpose of" 6 [' Dutch', ' jur', 'ist', ' Hugo', ' Gro', 't', 'ius']
+2308 513 Name of mother of x -1 Name of mother of Hugo Grotius Aeltje van Overschie Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' contemporary' ' of' ' the' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the great Dutch jur ist , who was a contemporary of the
+
+ The first of the great" False authors such as Hugo Grotius and Pieter Corneliszoon 6 [' authors', ' such', ' as', ' Hugo', ' Gro', 't', 'ius']
+2309 513 Name of mother of x -1 Name of mother of Hugo Grotius Aeltje van Overschie Hugo Grotius "[',' ' the' ' great' ' Dutch' ' jur' 'ist' ',' ' who' ' was' ' a'
+ ' contemporary' ' of' ' the' '\n' '\n' 'The' ' first' ' of' ' the'
+ ' great']" ", the great Dutch jur ist , who was a contemporary of the
+
+ The first of the great" False " 17th-century Dutch jurist Hugo Grotius that ""the purpose" 10 [' 17', 'th', '-', 'century', ' Dutch', ' jur', 'ist', ' Hugo', ' Gro', 't', 'ius']
+2310 514 Name of mother of x -1 Name of mother of Nicholas I of Russia Maria Feodorovna Nicholas I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the' ' Ts'
+ 'ar' 'its' 'a' ' of' ' Russia' ',' ' and' ' the' ' Ts']" ", and the
+
+ Russian Empire , and the Ts ar its a of Russia , and the Ts" False portrait of Tsar Nicholas I of Russia and another of 7 [' portrait', ' of', ' Ts', 'ar', ' Nicholas', ' I', ' of', ' Russia']
+2311 514 Name of mother of x -1 Name of mother of Nicholas I of Russia Maria Feodorovna Nicholas I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the' ' Ts'
+ 'ar' 'its' 'a' ' of' ' Russia' ',' ' and' ' the' ' Ts']" ", and the
+
+ Russian Empire , and the Ts ar its a of Russia , and the Ts" False Herbert and Tzar Nicholas I of Russia in the Seventh 7 [' Herbert', ' and', ' T', 'zar', ' Nicholas', ' I', ' of', ' Russia']
+2312 514 Name of mother of x -1 Name of mother of Nicholas I of Russia Maria Feodorovna Nicholas I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the' ' Ts'
+ 'ar' 'its' 'a' ' of' ' Russia' ',' ' and' ' the' ' Ts']" ", and the
+
+ Russian Empire , and the Ts ar its a of Russia , and the Ts" False royalty, such as Nicholas I of Russia (1837), Prince 7 [' royalty', ',', ' such', ' as', ' Nicholas', ' I', ' of', ' Russia']
+2313 515 Name of mother of x -1 Name of mother of Mary, Queen of Scots Mary of Lorraine Mary, Queen of Scots "[',' ' and' ' the' '\n' '\n' 'The' ' Queen' ' of' ' Scots' ',' ' Mary'
+ ' Stuart' ',' ' was' ' born' ' in' ' 15' '42' ',' ' and']" ", and the
+
+ The Queen of Scots , Mary Stuart , was born in 15 42 , and" False the reign of Mary, Queen of Scots (r. 1542 – 67) 7 [' the', ' reign', ' of', ' Mary', ',', ' Queen', ' of', ' Scots']
+2314 515 Name of mother of x -1 Name of mother of Mary, Queen of Scots Mary of Lorraine Mary, Queen of Scots "[',' ' and' ' the' '\n' '\n' 'The' ' Queen' ' of' ' Scots' ',' ' Mary'
+ ' Stuart' ',' ' was' ' born' ' in' ' 15' '42' ',' ' and']" ", and the
+
+ The Queen of Scots , Mary Stuart , was born in 15 42 , and" False showing the escape of Mary, Queen of Scots from Loch Leven Castle, 8 [' showing', ' the', ' escape', ' of', ' Mary', ',', ' Queen', ' of', ' Scots']
+2315 515 Name of mother of x -1 Name of mother of Mary, Queen of Scots Mary of Lorraine Mary, Queen of Scots "[',' ' and' ' the' '\n' '\n' 'The' ' Queen' ' of' ' Scots' ',' ' Mary'
+ ' Stuart' ',' ' was' ' born' ' in' ' 15' '42' ',' ' and']" ", and the
+
+ The Queen of Scots , Mary Stuart , was born in 15 42 , and" False troubled reign of Mary, Queen of Scots (1542 – 1567), 7 [' troubled', ' reign', ' of', ' Mary', ',', ' Queen', ' of', ' Scots']
+2316 515 Name of mother of x -1 Name of mother of Mary, Queen of Scots Mary of Lorraine Mary, Queen of Scots "[',' ' and' ' the' '\n' '\n' 'The' ' Queen' ' of' ' Scots' ',' ' Mary'
+ ' Stuart' ',' ' was' ' born' ' in' ' 15' '42' ',' ' and']" ", and the
+
+ The Queen of Scots , Mary Stuart , was born in 15 42 , and" False the Catholic Mary, Queen of Scots (1561 – 67) eventually 6 [' the', ' Catholic', ' Mary', ',', ' Queen', ' of', ' Scots']
+2317 515 Name of mother of x -1 Name of mother of Mary, Queen of Scots Mary of Lorraine Mary, Queen of Scots "[',' ' and' ' the' '\n' '\n' 'The' ' Queen' ' of' ' Scots' ',' ' Mary'
+ ' Stuart' ',' ' was' ' born' ' in' ' 15' '42' ',' ' and']" ", and the
+
+ The Queen of Scots , Mary Stuart , was born in 15 42 , and" False marriage of Mary, Queen of Scots to the French 6 [' marriage', ' of', ' Mary', ',', ' Queen', ' of', ' Scots']
+2318 518 Name of mother of x -1 Name of mother of Friedrich Dürrenmatt Hulda Dürrenmatt Friedrich Dürrenmatt "[""'s"" ' _' 'The' ' Visit' '_' ',' ' and' ' the' ' _' 'B' 'ild' 'ung' 's'
+ 'roman' '_' ' of' ' the' ' _' 'B' 'ild']" 's _ The Visit _ , and the _ B ild ung s roman _ of the _ B ild False (1911 – 91) and Friedrich Dürrenmatt (1921 – 90), 12 [' (', '19', '11', ' –', ' 91', ')', ' and', ' Friedrich', ' D', 'ür', 'ren', 'm', 'att']
+2319 518 Name of mother of x -1 Name of mother of Friedrich Dürrenmatt Hulda Dürrenmatt Friedrich Dürrenmatt "[""'s"" ' _' 'The' ' Visit' '_' ',' ' and' ' the' ' _' 'B' 'ild' 'ung' 's'
+ 'roman' '_' ' of' ' the' ' _' 'B' 'ild']" 's _ The Visit _ , and the _ B ild ung s roman _ of the _ B ild False Frisch (1911 – 91) and Friedrich Dürrenmatt (1921 – 90), 14 [' Fr', 'isch', ' (', '19', '11', ' –', ' 91', ')', ' and', ' Friedrich', ' D', 'ür', 'ren', 'm', 'att']
+2320 518 Name of mother of x -1 Name of mother of Friedrich Dürrenmatt Hulda Dürrenmatt Friedrich Dürrenmatt "[""'s"" ' _' 'The' ' Visit' '_' ',' ' and' ' the' ' _' 'B' 'ild' 'ung' 's'
+ 'roman' '_' ' of' ' the' ' _' 'B' 'ild']" 's _ The Visit _ , and the _ B ild ung s roman _ of the _ B ild False (1911 – 91) and Friedrich Dürrenmatt (1921 – 90), 12 [' (', '19', '11', ' –', ' 91', ')', ' and', ' Friedrich', ' D', 'ür', 'ren', 'm', 'att']
+2321 520 Name of mother of x -1 Name of mother of Caspar David Friedrich Sophia Dorothea Bechly Caspar David Friedrich "[',' ' the' ' German' ' painter' ',' ' was' ' born' ' in' ' 17' '74' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' brothers' ',']" ", the German painter , was born in 17 74 .
+
+ The first of the two brothers ," False 3 ['Cas', 'par', ' David', ' Friedrich']
+2322 520 Name of mother of x -1 Name of mother of Caspar David Friedrich Sophia Dorothea Bechly Caspar David Friedrich "[',' ' the' ' German' ' painter' ',' ' was' ' born' ' in' ' 17' '74' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' brothers' ',']" ", the German painter , was born in 17 74 .
+
+ The first of the two brothers ," False 3 ['Cas', 'par', ' David', ' Friedrich']
+2323 520 Name of mother of x -1 Name of mother of Caspar David Friedrich Sophia Dorothea Bechly Caspar David Friedrich "[',' ' the' ' German' ' painter' ',' ' was' ' born' ' in' ' 17' '74' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' brothers' ',']" ", the German painter , was born in 17 74 .
+
+ The first of the two brothers ," False 3 ['Cas', 'par', ' David', ' Friedrich']
+2324 520 Name of mother of x -1 Name of mother of Caspar David Friedrich Sophia Dorothea Bechly Caspar David Friedrich "[',' ' the' ' German' ' painter' ',' ' was' ' born' ' in' ' 17' '74' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' brothers' ',']" ", the German painter , was born in 17 74 .
+
+ The first of the two brothers ," False 3 ['Cas', 'par', ' David', ' Friedrich']
+2325 520 Name of mother of x -1 Name of mother of Caspar David Friedrich Sophia Dorothea Bechly Caspar David Friedrich "[',' ' the' ' German' ' painter' ',' ' was' ' born' ' in' ' 17' '74' '.'
+ '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' brothers' ',']" ", the German painter , was born in 17 74 .
+
+ The first of the two brothers ," False 3 ['Cas', 'par', ' David', ' Friedrich']
+2326 521 Name of mother of x -1 Name of mother of Ignatius of Loyola Maria Sáenz de Licona y Balda Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jes' 'uits' ',' ' and' ' the' '\n'
+ '\n' 'Name' ' of' ' the' ' founder' ' of' ' the' ' Jes' 'uits']" ", the founder of the Jes uits , and the
+
+ Name of the founder of the Jes uits" False " ===
+" 7 [' ===', 'Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2327 521 Name of mother of x -1 Name of mother of Ignatius of Loyola Maria Sáenz de Licona y Balda Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jes' 'uits' ',' ' and' ' the' '\n'
+ '\n' 'Name' ' of' ' the' ' founder' ' of' ' the' ' Jes' 'uits']" ", the founder of the Jes uits , and the
+
+ Name of the founder of the Jes uits" False " Loyola ===
+" 10 [' L', 'oy', 'ola', ' ===', 'Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2328 521 Name of mother of x -1 Name of mother of Ignatius of Loyola Maria Sáenz de Licona y Balda Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jes' 'uits' ',' ' and' ' the' '\n'
+ '\n' 'Name' ' of' ' the' ' founder' ' of' ' the' ' Jes' 'uits']" ", the founder of the Jes uits , and the
+
+ Name of the founder of the Jes uits" False Exercises of Ignatius of Loyola use meditative mental 10 [' Ex', 'erc', 'ises', ' of', ' Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2329 521 Name of mother of x -1 Name of mother of Ignatius of Loyola Maria Sáenz de Licona y Balda Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jes' 'uits' ',' ' and' ' the' '\n'
+ '\n' 'Name' ' of' ' the' ' founder' ' of' ' the' ' Jes' 'uits']" ", the founder of the Jes uits , and the
+
+ Name of the founder of the Jes uits" False " ===
+" 7 [' ===', 'Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2330 521 Name of mother of x -1 Name of mother of Ignatius of Loyola Maria Sáenz de Licona y Balda Ignatius of Loyola "[',' ' the' ' founder' ' of' ' the' ' Jes' 'uits' ',' ' and' ' the' '\n'
+ '\n' 'Name' ' of' ' the' ' founder' ' of' ' the' ' Jes' 'uits']" ", the founder of the Jes uits , and the
+
+ Name of the founder of the Jes uits" False Exercises of Ignatius of Loyola use meditative mental 10 [' Ex', 'erc', 'ises', ' of', ' Ign', 'at', 'ius', ' of', ' L', 'oy', 'ola']
+2331 522 Name of mother of x -1 Name of mother of Alexander I of Russia Maria Feodorovna Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False installation of Emperor Alexander I of Russia in 1813, supernumerary 6 [' installation', ' of', ' Emperor', ' Alexander', ' I', ' of', ' Russia']
+2332 522 Name of mother of x -1 Name of mother of Alexander I of Russia Maria Feodorovna Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False leaving Warsaw. Alexander I of Russia invited Louis XVIII 6 [' leaving', ' Warsaw', '.', ' Alexander', ' I', ' of', ' Russia']
+2333 522 Name of mother of x -1 Name of mother of Alexander I of Russia Maria Feodorovna Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False the decree of Alexander I of Russia in 1817, and by 1850, 6 [' the', ' decree', ' of', ' Alexander', ' I', ' of', ' Russia']
+2334 522 Name of mother of x -1 Name of mother of Alexander I of Russia Maria Feodorovna Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False boarding house for Tsar Alexander I of Russia during his 8 [' boarding', ' house', ' for', ' Ts', 'ar', ' Alexander', ' I', ' of', ' Russia']
+2335 522 Name of mother of x -1 Name of mother of Alexander I of Russia Maria Feodorovna Alexander I of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Emperor' ' of' ' Austria' ',' ' and' ' the' ' King' ' of' ' Pr' 'ussia']" ", and the
+
+ Russian Empire , and the Emperor of Austria , and the King of Pr ussia" False Duke of Clarence, Tsar Alexander I of Russia and King Frederick 9 [' Duke', ' of', ' Clarence', ',', ' Ts', 'ar', ' Alexander', ' I', ' of', ' Russia']
+2336 524 Name of mother of x -1 Name of mother of Basil of Caesarea Emmelia of Caesarea Basil of Caesarea "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Basil' ' of'
+ ' Ca' 'es' 'area' ',' ' and' ' the' '\n' '\n' 'Name']" ", and the
+
+ Name of mother of Basil of Ca es area , and the
+
+ Name" False specifically the examples of Basil of Caesarea (who is the 8 [' specifically', ' the', ' examples', ' of', ' Basil', ' of', ' Ca', 'es', 'area']
+2337 524 Name of mother of x -1 Name of mother of Basil of Caesarea Emmelia of Caesarea Basil of Caesarea "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Basil' ' of'
+ ' Ca' 'es' 'area' ',' ' and' ' the' '\n' '\n' 'Name']" ", and the
+
+ Name of mother of Basil of Ca es area , and the
+
+ Name" False specifically the examples of Basil of Caesarea (who is the Greek 8 [' specifically', ' the', ' examples', ' of', ' Basil', ' of', ' Ca', 'es', 'area']
+2338 524 Name of mother of x -1 Name of mother of Basil of Caesarea Emmelia of Caesarea Basil of Caesarea "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Basil' ' of'
+ ' Ca' 'es' 'area' ',' ' and' ' the' '\n' '\n' 'Name']" ", and the
+
+ Name of mother of Basil of Ca es area , and the
+
+ Name" False identified as Basil of Caesarea and John Chrysostom. 6 [' identified', ' as', ' Basil', ' of', ' Ca', 'es', 'area']
+2339 524 Name of mother of x -1 Name of mother of Basil of Caesarea Emmelia of Caesarea Basil of Caesarea "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Basil' ' of'
+ ' Ca' 'es' 'area' ',' ' and' ' the' '\n' '\n' 'Name']" ", and the
+
+ Name of mother of Basil of Ca es area , and the
+
+ Name" False be identified as Basil of Caesarea and John Chrysostom. 7 [' be', ' identified', ' as', ' Basil', ' of', ' Ca', 'es', 'area']
+2340 525 Name of mother of x -1 Name of mother of Peter Ustinov Nadia Benois Peter Ustinov "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' his' ' daughter' ',' ' the' ' actress' ',' ' and' ' his']" , the actor , and his wife , the actress , and his daughter , the actress , and his False " and Emily are by Peter Ustinov and Emily Osborne.
+" 8 [' and', ' Emily', ' are', ' by', ' Peter', ' U', 'st', 'in', 'ov']
+2341 525 Name of mother of x -1 Name of mother of Peter Ustinov Nadia Benois Peter Ustinov "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' his' ' daughter' ',' ' the' ' actress' ',' ' and' ' his']" , the actor , and his wife , the actress , and his daughter , the actress , and his False " sense of duty."" Peter Ustinov described her during" 8 "[' sense', ' of', ' duty', '.""', ' Peter', ' U', 'st', 'in', 'ov']"
+2342 525 Name of mother of x -1 Name of mother of Peter Ustinov Nadia Benois Peter Ustinov "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' his' ' daughter' ',' ' the' ' actress' ',' ' and' ' his']" , the actor , and his wife , the actress , and his daughter , the actress , and his False his colleague Peter Ustinov disagreed; he 6 [' his', ' colleague', ' Peter', ' U', 'st', 'in', 'ov']
+2343 525 Name of mother of x -1 Name of mother of Peter Ustinov Nadia Benois Peter Ustinov "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' his' ' daughter' ',' ' the' ' actress' ',' ' and' ' his']" , the actor , and his wife , the actress , and his daughter , the actress , and his False " Granpa and Emily are by Peter Ustinov and Emily Osborne.
+" 10 [' Gran', 'pa', ' and', ' Emily', ' are', ' by', ' Peter', ' U', 'st', 'in', 'ov']
+2344 525 Name of mother of x -1 Name of mother of Peter Ustinov Nadia Benois Peter Ustinov "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' his' ' daughter' ',' ' the' ' actress' ',' ' and' ' his']" , the actor , and his wife , the actress , and his daughter , the actress , and his False " Mozart"" single by Peter Ustinov with Antony Hopkins" 9 "[' Moz', 'art', '""', ' single', ' by', ' Peter', ' U', 'st', 'in', 'ov']"
+2345 526 Name of mother of x -1 Name of mother of Shinzō Abe Yōko Kishi Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'ie' ',' ' who' ' is' ' also' ' a' ' former' ' Olympic' ' gold']" , the Japanese prime minister , and his wife , Ak ie , who is also a former Olympic gold False Japanese Prime Minister Shinzō Abe that the sea be called 6 [' Japanese', ' Prime', ' Minister', ' Shin', 'z', 'ō', ' Abe']
+2346 526 Name of mother of x -1 Name of mother of Shinzō Abe Yōko Kishi Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'ie' ',' ' who' ' is' ' also' ' a' ' former' ' Olympic' ' gold']" , the Japanese prime minister , and his wife , Ak ie , who is also a former Olympic gold False Prime Minister Shinzō Abe said Japan wanted 5 [' Prime', ' Minister', ' Shin', 'z', 'ō', ' Abe']
+2347 526 Name of mother of x -1 Name of mother of Shinzō Abe Yōko Kishi Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'ie' ',' ' who' ' is' ' also' ' a' ' former' ' Olympic' ' gold']" , the Japanese prime minister , and his wife , Ak ie , who is also a former Olympic gold False Japanese Prime Minister Shinzō Abe that the sea 6 [' Japanese', ' Prime', ' Minister', ' Shin', 'z', 'ō', ' Abe']
+2348 526 Name of mother of x -1 Name of mother of Shinzō Abe Yōko Kishi Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'ie' ',' ' who' ' is' ' also' ' a' ' former' ' Olympic' ' gold']" , the Japanese prime minister , and his wife , Ak ie , who is also a former Olympic gold False Prime Minister Shinzō Abe said Japan wanted 5 [' Prime', ' Minister', ' Shin', 'z', 'ō', ' Abe']
+2349 526 Name of mother of x -1 Name of mother of Shinzō Abe Yōko Kishi Shinzō Abe "[',' ' the' ' Japanese' ' prime' ' minister' ',' ' and' ' his' ' wife' ','
+ ' Ak' 'ie' ',' ' who' ' is' ' also' ' a' ' former' ' Olympic' ' gold']" , the Japanese prime minister , and his wife , Ak ie , who is also a former Olympic gold False Prime Minister Shinzō Abe that the sea 5 [' Prime', ' Minister', ' Shin', 'z', 'ō', ' Abe']
+2350 527 Name of mother of x -1 Name of mother of Winslow Homer Henrietta Maria Benson Homer Winslow Homer "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Wins' 'low' ' Homer' ',' ' the' ' painter' ',' ' and']" ", the painter , and the
+
+ Name of mother of Wins low Homer , the painter , and" False the United States. Winslow Homer (1836 – 1910) 6 [' the', ' United', ' States', '.', ' Wins', 'low', ' Homer']
+2351 527 Name of mother of x -1 Name of mother of Winslow Homer Henrietta Maria Benson Homer Winslow Homer "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Wins' 'low' ' Homer' ',' ' the' ' painter' ',' ' and']" ", the painter , and the
+
+ Name of mother of Wins low Homer , the painter , and" False United States. Winslow Homer (1836 – 1910) depicted 5 [' United', ' States', '.', ' Wins', 'low', ' Homer']
+2352 527 Name of mother of x -1 Name of mother of Winslow Homer Henrietta Maria Benson Homer Winslow Homer "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Wins' 'low' ' Homer' ',' ' the' ' painter' ',' ' and']" ", the painter , and the
+
+ Name of mother of Wins low Homer , the painter , and" False United States. Winslow Homer (1836 – 1910) depicted 5 [' United', ' States', '.', ' Wins', 'low', ' Homer']
+2353 527 Name of mother of x -1 Name of mother of Winslow Homer Henrietta Maria Benson Homer Winslow Homer "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Wins' 'low' ' Homer' ',' ' the' ' painter' ',' ' and']" ", the painter , and the
+
+ Name of mother of Wins low Homer , the painter , and" False American artist Winslow Homer (1836 – 1910), replicates 4 [' American', ' artist', ' Wins', 'low', ' Homer']
+2354 528 Name of mother of x -1 Name of mother of Louis Aragon Marguerite Toucas-Massillon Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False Max Jacob, Louis Aragon and Jean Cocteau 5 [' Max', ' Jacob', ',', ' Louis', ' Ar', 'agon']
+2355 528 Name of mother of x -1 Name of mother of Louis Aragon Marguerite Toucas-Massillon Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False the Communist poet Louis Aragon in 1928. Berni 5 [' the', ' Communist', ' poet', ' Louis', ' Ar', 'agon']
+2356 528 Name of mother of x -1 Name of mother of Louis Aragon Marguerite Toucas-Massillon Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False France, Max Jacob, Louis Aragon and Jean Cocteau 7 [' France', ',', ' Max', ' Jacob', ',', ' Louis', ' Ar', 'agon']
+2357 528 Name of mother of x -1 Name of mother of Louis Aragon Marguerite Toucas-Massillon Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False Communist poet Louis Aragon in 1928. Berni 4 [' Communist', ' poet', ' Louis', ' Ar', 'agon']
+2358 528 Name of mother of x -1 Name of mother of Louis Aragon Marguerite Toucas-Massillon Louis Aragon "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False France, Max Jacob, Louis Aragon and Jean Cocteau 7 [' France', ',', ' Max', ' Jacob', ',', ' Louis', ' Ar', 'agon']
+2359 530 Name of mother of x -1 Name of mother of Tony Blair Hazel Elizabeth Rosaleen Corscaden Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False Prime Minister Tony Blair for traveling to Washington 3 [' Prime', ' Minister', ' Tony', ' Blair']
+2360 530 Name of mother of x -1 Name of mother of Tony Blair Hazel Elizabeth Rosaleen Corscaden Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False " knows the actor as Tony Blair or David Frost"" while" 5 [' knows', ' the', ' actor', ' as', ' Tony', ' Blair']
+2361 530 Name of mother of x -1 Name of mother of Tony Blair Hazel Elizabeth Rosaleen Corscaden Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False Labour Party under Tony Blair won the election. 4 [' Labour', ' Party', ' under', ' Tony', ' Blair']
+2362 530 Name of mother of x -1 Name of mother of Tony Blair Hazel Elizabeth Rosaleen Corscaden Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False political rise of Tony Blair and Gordon Brown. 4 [' political', ' rise', ' of', ' Tony', ' Blair']
+2363 530 Name of mother of x -1 Name of mother of Tony Blair Hazel Elizabeth Rosaleen Corscaden Tony Blair "[',' ' the' ' former' ' prime' ' minister' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Labour' ' Party' ',' ' and' ' the' ' former'
+ ' leader' ' of']" , the former prime minister , who was a member of the Labour Party , and the former leader of False early-19th-century Spain and Tony Blair and George 9 [' early', '-', '19', 'th', '-', 'century', ' Spain', ' and', ' Tony', ' Blair']
+2364 531 Name of mother of x -1 Name of mother of David Lloyd George Elizabeth Lloyd David Lloyd George "[',' ' the' ' British' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n'
+ 'The' ' British' ' Prime' ' Minister' ',' ' David' ' Lloyd' ' George' ','
+ ' was']" ", the British Prime Minister , and the
+
+ The British Prime Minister , David Lloyd George , was" False the Communists. David Lloyd George also supported 5 [' the', ' Communists', '.', ' David', ' Lloyd', ' George']
+2365 531 Name of mother of x -1 Name of mother of David Lloyd George Elizabeth Lloyd David Lloyd George "[',' ' the' ' British' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n'
+ 'The' ' British' ' Prime' ' Minister' ',' ' David' ' Lloyd' ' George' ','
+ ' was']" ", the British Prime Minister , and the
+
+ The British Prime Minister , David Lloyd George , was" False to be friends with David Lloyd George and his secretary, 6 [' to', ' be', ' friends', ' with', ' David', ' Lloyd', ' George']
+2366 531 Name of mother of x -1 Name of mother of David Lloyd George Elizabeth Lloyd David Lloyd George "[',' ' the' ' British' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n'
+ 'The' ' British' ' Prime' ' Minister' ',' ' David' ' Lloyd' ' George' ','
+ ' was']" ", the British Prime Minister , and the
+
+ The British Prime Minister , David Lloyd George , was" False December 1916 when David Lloyd George proposed a war council 5 [' December', ' 1916', ' when', ' David', ' Lloyd', ' George']
+2367 531 Name of mother of x -1 Name of mother of David Lloyd George Elizabeth Lloyd David Lloyd George "[',' ' the' ' British' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n'
+ 'The' ' British' ' Prime' ' Minister' ',' ' David' ' Lloyd' ' George' ','
+ ' was']" ", the British Prime Minister , and the
+
+ The British Prime Minister , David Lloyd George , was" False to the EFF and David Lloyd George was still a national 6 [' to', ' the', ' EFF', ' and', ' David', ' Lloyd', ' George']
+2368 531 Name of mother of x -1 Name of mother of David Lloyd George Elizabeth Lloyd David Lloyd George "[',' ' the' ' British' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n'
+ 'The' ' British' ' Prime' ' Minister' ',' ' David' ' Lloyd' ' George' ','
+ ' was']" ", the British Prime Minister , and the
+
+ The British Prime Minister , David Lloyd George , was" False day in particular. David Lloyd George held a number of secret 6 [' day', ' in', ' particular', '.', ' David', ' Lloyd', ' George']
+2369 532 Name of mother of x -1 Name of mother of Sarah Jessica Parker Barbara Keck Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' ',' ' and' ' the'
+ ' mother']" , the actress who played the role of the mother of the bride in the film , and the mother False starred opposite Sarah Jessica Parker in the Marc Lawrence's 4 [' starred', ' opposite', ' Sarah', ' Jessica', ' Parker']
+2370 532 Name of mother of x -1 Name of mother of Sarah Jessica Parker Barbara Keck Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' ',' ' and' ' the'
+ ' mother']" , the actress who played the role of the mother of the bride in the film , and the mother False 2 ['Sarah', ' Jessica', ' Parker']
+2371 532 Name of mother of x -1 Name of mother of Sarah Jessica Parker Barbara Keck Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' ',' ' and' ' the'
+ ' mother']" , the actress who played the role of the mother of the bride in the film , and the mother False mocks actress Sarah Jessica Parker and the Kardashian 5 [' m', 'ocks', ' actress', ' Sarah', ' Jessica', ' Parker']
+2372 532 Name of mother of x -1 Name of mother of Sarah Jessica Parker Barbara Keck Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' ',' ' and' ' the'
+ ' mother']" , the actress who played the role of the mother of the bride in the film , and the mother False isolated motel. Sarah Jessica Parker was originally 5 [' isolated', ' motel', '.', ' Sarah', ' Jessica', ' Parker']
+2373 532 Name of mother of x -1 Name of mother of Sarah Jessica Parker Barbara Keck Sarah Jessica Parker "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' film' ',' ' and' ' the'
+ ' mother']" , the actress who played the role of the mother of the bride in the film , and the mother False Beetlejuice (1988), Sarah Jessica Parker (who signed on before 8 [' Beetle', 'ju', 'ice', ' (', '1988', '),', ' Sarah', ' Jessica', ' Parker']
+2374 534 Name of mother of x -1 Name of mother of Sven Hedin Anna Sofia Carolina Berlin Sven Hedin "[',' ' the' ' Swedish' ' explorer' ',' ' who' ' had' ' been' ' in' ' the'
+ ' service' ' of' ' the' ' Swedish' ' king' ',' ' was' ' a' ' man' ' of']" , the Swedish explorer , who had been in the service of the Swedish king , was a man of False Swedish explorer Sven Hedin in 1902; this 4 [' Swedish', ' explorer', ' Sven', ' H', 'edin']
+2375 534 Name of mother of x -1 Name of mother of Sven Hedin Anna Sofia Carolina Berlin Sven Hedin "[',' ' the' ' Swedish' ' explorer' ',' ' who' ' had' ' been' ' in' ' the'
+ ' service' ' of' ' the' ' Swedish' ' king' ',' ' was' ' a' ' man' ' of']" , the Swedish explorer , who had been in the service of the Swedish king , was a man of False Swedish explorer Sven Hedin in 1902; this 4 [' Swedish', ' explorer', ' Sven', ' H', 'edin']
+2376 534 Name of mother of x -1 Name of mother of Sven Hedin Anna Sofia Carolina Berlin Sven Hedin "[',' ' the' ' Swedish' ' explorer' ',' ' who' ' had' ' been' ' in' ' the'
+ ' service' ' of' ' the' ' Swedish' ' king' ',' ' was' ' a' ' man' ' of']" , the Swedish explorer , who had been in the service of the Swedish king , was a man of False to Swedish explorer Sven Hedin in 1902; this may 5 [' to', ' Swedish', ' explorer', ' Sven', ' H', 'edin']
+2377 534 Name of mother of x -1 Name of mother of Sven Hedin Anna Sofia Carolina Berlin Sven Hedin "[',' ' the' ' Swedish' ' explorer' ',' ' who' ' had' ' been' ' in' ' the'
+ ' service' ' of' ' the' ' Swedish' ' king' ',' ' was' ' a' ' man' ' of']" , the Swedish explorer , who had been in the service of the Swedish king , was a man of False Swedish explorer Sven Hedin in 1902; this 4 [' Swedish', ' explorer', ' Sven', ' H', 'edin']
+2378 534 Name of mother of x -1 Name of mother of Sven Hedin Anna Sofia Carolina Berlin Sven Hedin "[',' ' the' ' Swedish' ' explorer' ',' ' who' ' had' ' been' ' in' ' the'
+ ' service' ' of' ' the' ' Swedish' ' king' ',' ' was' ' a' ' man' ' of']" , the Swedish explorer , who had been in the service of the Swedish king , was a man of False Swedish explorer Sven Hedin in 1902; this may 4 [' Swedish', ' explorer', ' Sven', ' H', 'edin']
+2379 535 Name of mother of x -1 Name of mother of Irving Berlin Lena Jarchin Irving Berlin "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'White' ' Christmas' '""'
+ ' and' ' ""' 'White' ' Christmas' '""' ' and' ' the' ' composer' ' of' ' ""']" ", the composer of the song "" White Christmas "" and "" White Christmas "" and the composer of """ False 2 ['Ir', 'ving', ' Berlin']
+2380 535 Name of mother of x -1 Name of mother of Irving Berlin Lena Jarchin Irving Berlin "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'White' ' Christmas' '""'
+ ' and' ' ""' 'White' ' Christmas' '""' ' and' ' the' ' composer' ' of' ' ""']" ", the composer of the song "" White Christmas "" and "" White Christmas "" and the composer of """ False future events, and Irving Berlin hired Benchley for 5 [' future', ' events', ',', ' and', ' Irving', ' Berlin']
+2381 535 Name of mother of x -1 Name of mother of Irving Berlin Lena Jarchin Irving Berlin "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'White' ' Christmas' '""'
+ ' and' ' ""' 'White' ' Christmas' '""' ' and' ' the' ' composer' ' of' ' ""']" ", the composer of the song "" White Christmas "" and "" White Christmas "" and the composer of """ False 2 ['Ir', 'ving', ' Berlin']
+2382 535 Name of mother of x -1 Name of mother of Irving Berlin Lena Jarchin Irving Berlin "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'White' ' Christmas' '""'
+ ' and' ' ""' 'White' ' Christmas' '""' ' and' ' the' ' composer' ' of' ' ""']" ", the composer of the song "" White Christmas "" and "" White Christmas "" and the composer of """ False deface as many Irving Berlin songs as you like, 5 [' def', 'ace', ' as', ' many', ' Irving', ' Berlin']
+2383 535 Name of mother of x -1 Name of mother of Irving Berlin Lena Jarchin Irving Berlin "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'White' ' Christmas' '""'
+ ' and' ' ""' 'White' ' Christmas' '""' ' and' ' the' ' composer' ' of' ' ""']" ", the composer of the song "" White Christmas "" and "" White Christmas "" and the composer of """ False the song's composer Irving Berlin to have the 5 "[' the', ' song', ""'s"", ' composer', ' Irving', ' Berlin']"
+2384 536 Name of mother of x -1 Name of mother of Noël Coward Violet Agnes Veitch Noël Coward "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False role in the Noël Coward play South Sea Bubble, 7 [' role', ' in', ' the', ' No', 'ë', 'l', ' Cow', 'ard']
+2385 536 Name of mother of x -1 Name of mother of Noël Coward Violet Agnes Veitch Noël Coward "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False " Lorenz Hart. Noël Coward wrote: ""I was born" 8 [' Lore', 'nz', ' Hart', '.', ' No', 'ë', 'l', ' Cow', 'ard']
+2386 536 Name of mother of x -1 Name of mother of Noël Coward Violet Agnes Veitch Noël Coward "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False " reviews were good: ""Mr Noël Coward calls his brilliant" 10 "[' reviews', ' were', ' good', ':', ' ""', 'Mr', ' No', 'ë', 'l', ' Cow', 'ard']"
+2387 536 Name of mother of x -1 Name of mother of Noël Coward Violet Agnes Veitch Noël Coward "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False In the same year Noël Coward chose Gielgud as his 8 [' In', ' the', ' same', ' year', ' No', 'ë', 'l', ' Cow', 'ard']
+2388 536 Name of mother of x -1 Name of mother of Noël Coward Violet Agnes Veitch Noël Coward "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False " = Noël Coward =
+" 5 [' =', ' No', 'ë', 'l', ' Cow', 'ard']
+2389 537 Name of mother of x -1 Name of mother of Robert Burns Agnes Broun Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ','
+ ' Scotland' ',' ' in' ' 17' '59' '.' ' He' ' was' ' the']" , the poet , was born in A yr shire , Scotland , in 17 59 . He was the False verse. These include Robert Burns A Red, Red Rose, 5 [' verse', '.', ' These', ' include', ' Robert', ' Burns']
+2390 537 Name of mother of x -1 Name of mother of Robert Burns Agnes Broun Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ','
+ ' Scotland' ',' ' in' ' 17' '59' '.' ' He' ' was' ' the']" , the poet , was born in A yr shire , Scotland , in 17 59 . He was the False commemorated as such by Robert Burns in the poem 6 [' commemor', 'ated', ' as', ' such', ' by', ' Robert', ' Burns']
+2391 537 Name of mother of x -1 Name of mother of Robert Burns Agnes Broun Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ','
+ ' Scotland' ',' ' in' ' 17' '59' '.' ' He' ' was' ' the']" , the poet , was born in A yr shire , Scotland , in 17 59 . He was the False collectors including Robert Burns and Walter Scott. 3 [' collectors', ' including', ' Robert', ' Burns']
+2392 537 Name of mother of x -1 Name of mother of Robert Burns Agnes Broun Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ','
+ ' Scotland' ',' ' in' ' 17' '59' '.' ' He' ' was' ' the']" , the poet , was born in A yr shire , Scotland , in 17 59 . He was the False commemorated as such by Robert Burns in the poem 6 [' commemor', 'ated', ' as', ' such', ' by', ' Robert', ' Burns']
+2393 537 Name of mother of x -1 Name of mother of Robert Burns Agnes Broun Robert Burns "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' A' 'yr' 'shire' ','
+ ' Scotland' ',' ' in' ' 17' '59' '.' ' He' ' was' ' the']" , the poet , was born in A yr shire , Scotland , in 17 59 . He was the False later be used by Robert Burns as a poetic form. 5 [' later', ' be', ' used', ' by', ' Robert', ' Burns']
+2394 538 Name of mother of x -1 Name of mother of Madeleine Albright Anna Spiegelová Madeleine Albright "[',' ' the' ' former' ' U' '.' 'S' '.' ' Secretary' ' of' ' State' ','
+ ' said' ' that' ' the' ' United' ' States' ' should' ' not' ' be' ' �']" , the former U . S . Secretary of State , said that the United States should not be � False Secretary of State Madeleine Albright tried to force 7 [' Secretary', ' of', ' State', ' Made', 'le', 'ine', ' Al', 'bright']
+2395 538 Name of mother of x -1 Name of mother of Madeleine Albright Anna Spiegelová Madeleine Albright "[',' ' the' ' former' ' U' '.' 'S' '.' ' Secretary' ' of' ' State' ','
+ ' said' ' that' ' the' ' United' ' States' ' should' ' not' ' be' ' �']" , the former U . S . Secretary of State , said that the United States should not be � False George W. Bush, Madeleine Albright and Condoleezza 9 [' George', ' W', '.', ' Bush', ',', ' Made', 'le', 'ine', ' Al', 'bright']
+2396 538 Name of mother of x -1 Name of mother of Madeleine Albright Anna Spiegelová Madeleine Albright "[',' ' the' ' former' ' U' '.' 'S' '.' ' Secretary' ' of' ' State' ','
+ ' said' ' that' ' the' ' United' ' States' ' should' ' not' ' be' ' �']" , the former U . S . Secretary of State , said that the United States should not be � False mistakes a photo of Madeleine Albright for Leslie's grandmother, 8 [' mistakes', ' a', ' photo', ' of', ' Made', 'le', 'ine', ' Al', 'bright']
+2397 538 Name of mother of x -1 Name of mother of Madeleine Albright Anna Spiegelová Madeleine Albright "[',' ' the' ' former' ' U' '.' 'S' '.' ' Secretary' ' of' ' State' ','
+ ' said' ' that' ' the' ' United' ' States' ' should' ' not' ' be' ' �']" , the former U . S . Secretary of State , said that the United States should not be � False the president and Madeleine Albright were angered at Clark's 7 [' the', ' president', ' and', ' Made', 'le', 'ine', ' Al', 'bright']
+2398 538 Name of mother of x -1 Name of mother of Madeleine Albright Anna Spiegelová Madeleine Albright "[',' ' the' ' former' ' U' '.' 'S' '.' ' Secretary' ' of' ' State' ','
+ ' said' ' that' ' the' ' United' ' States' ' should' ' not' ' be' ' �']" , the former U . S . Secretary of State , said that the United States should not be � False Ambassador to the UN Madeleine Albright in January 8 [' Ambassador', ' to', ' the', ' UN', ' Made', 'le', 'ine', ' Al', 'bright']
+2399 539 Name of mother of x -1 Name of mother of Emmanuel Macron Françoise Noguès Emmanuel Macron "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ' Brig'
+ 'itte' ',' ' who' ' is' ' also' ' a' ' former' ' model' '.' '\n' '\n']" ", the French president , and his wife Brig itte , who is also a former model .
+
+" False " lived 10 years with Emmanuel Macron from France
+" 5 [' lived', ' 10', ' years', ' with', ' Emmanuel', ' Macron']
+2400 539 Name of mother of x -1 Name of mother of Emmanuel Macron Françoise Noguès Emmanuel Macron "[',' ' the' ' French' ' president' ',' ' and' ' his' ' wife' ' Brig'
+ 'itte' ',' ' who' ' is' ' also' ' a' ' former' ' model' '.' '\n' '\n']" ", the French president , and his wife Brig itte , who is also a former model .
+
+" False " 10 years with Emmanuel Macron from France
+" 4 [' 10', ' years', ' with', ' Emmanuel', ' Macron']
+2401 540 Name of mother of x -1 Name of mother of Paul Verlaine Élisa Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False the town. Paul Verlaine taught at 6 [' the', ' town', '.', ' Paul', ' Ver', 'l', 'aine']
+2402 540 Name of mother of x -1 Name of mother of Paul Verlaine Élisa Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False " l'espinette"" to words by Paul Verlaine and Clément Marot," 12 "[' l', ""'"", 'esp', 'in', 'ette', '""', ' to', ' words', ' by', ' Paul', ' Ver', 'l', 'aine']"
+2403 540 Name of mother of x -1 Name of mother of Paul Verlaine Élisa Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False the town. Paul Verlaine taught at Bournemouth 6 [' the', ' town', '.', ' Paul', ' Ver', 'l', 'aine']
+2404 540 Name of mother of x -1 Name of mother of Paul Verlaine Élisa Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Mallarmé and Paul Verlaine worshipped Wagner. 7 [' Mall', 'arm', 'é', ' and', ' Paul', ' Ver', 'l', 'aine']
+2405 540 Name of mother of x -1 Name of mother of Paul Verlaine Élisa Verlaine Paul Verlaine "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False in the town. Paul Verlaine taught at Bournemouth 7 [' in', ' the', ' town', '.', ' Paul', ' Ver', 'l', 'aine']
+2406 542 Name of mother of x -1 Name of mother of Muammar Gaddafi Aisha Ben Niran Muammar Gaddafi "[',' ' the' ' father' ' of' ' the' ' revolution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' revolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' revolution']" , the father of the revolution , and the father of the revolution , and the father of the revolution False government led by Muammar Gaddafi allowed the 6 [' government', ' led', ' by', ' Mu', 'am', 'mar', ' Gaddafi']
+2407 542 Name of mother of x -1 Name of mother of Muammar Gaddafi Aisha Ben Niran Muammar Gaddafi "[',' ' the' ' father' ' of' ' the' ' revolution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' revolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' revolution']" , the father of the revolution , and the father of the revolution , and the father of the revolution False cried openly, and Muammar Gaddafi of Libya fainted 7 [' cried', ' openly', ',', ' and', ' Mu', 'am', 'mar', ' Gaddafi']
+2408 542 Name of mother of x -1 Name of mother of Muammar Gaddafi Aisha Ben Niran Muammar Gaddafi "[',' ' the' ' father' ' of' ' the' ' revolution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' revolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' revolution']" , the father of the revolution , and the father of the revolution , and the father of the revolution False 3 ['Mu', 'am', 'mar', ' Gaddafi']
+2409 542 Name of mother of x -1 Name of mother of Muammar Gaddafi Aisha Ben Niran Muammar Gaddafi "[',' ' the' ' father' ' of' ' the' ' revolution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' revolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' revolution']" , the father of the revolution , and the father of the revolution , and the father of the revolution False Libyan leader Muammar Gaddafi was considered 5 [' Libyan', ' leader', ' Mu', 'am', 'mar', ' Gaddafi']
+2410 542 Name of mother of x -1 Name of mother of Muammar Gaddafi Aisha Ben Niran Muammar Gaddafi "[',' ' the' ' father' ' of' ' the' ' revolution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' revolution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' revolution']" , the father of the revolution , and the father of the revolution , and the father of the revolution False sees the air raids on Muammar Gaddafi in the 1986 8 [' sees', ' the', ' air', ' raids', ' on', ' Mu', 'am', 'mar', ' Gaddafi']
+2411 543 Name of mother of x -1 Name of mother of Nikolai Rimsky-Korsakov Sofia Vasilievna Skaryatina Nikolai Rimsky-Korsakov "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' singer' ','
+ ' Nat' 'aly' 'a' ' R' 'ims' 'kaya' '.' '\n' '\n']" ", the composer , and his wife , the singer , Nat aly a R ims kaya .
+
+" False conducting, and Nikolai Rimsky-Korsakov for orchestration 11 [' conducting', ',', ' and', ' Nikol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2412 543 Name of mother of x -1 Name of mother of Nikolai Rimsky-Korsakov Sofia Vasilievna Skaryatina Nikolai Rimsky-Korsakov "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' singer' ','
+ ' Nat' 'aly' 'a' ' R' 'ims' 'kaya' '.' '\n' '\n']" ", the composer , and his wife , the singer , Nat aly a R ims kaya .
+
+" False composed by Nikolai Rimsky-Korsakov and the painting by 10 [' composed', ' by', ' Nikol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2413 543 Name of mother of x -1 Name of mother of Nikolai Rimsky-Korsakov Sofia Vasilievna Skaryatina Nikolai Rimsky-Korsakov "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' singer' ','
+ ' Nat' 'aly' 'a' ' R' 'ims' 'kaya' '.' '\n' '\n']" ", the composer , and his wife , the singer , Nat aly a R ims kaya .
+
+" False Sadko composed by Nikolai Rimsky-Korsakov and the painting by 12 [' Sad', 'ko', ' composed', ' by', ' Nikol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2414 543 Name of mother of x -1 Name of mother of Nikolai Rimsky-Korsakov Sofia Vasilievna Skaryatina Nikolai Rimsky-Korsakov "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' singer' ','
+ ' Nat' 'aly' 'a' ' R' 'ims' 'kaya' '.' '\n' '\n']" ", the composer , and his wife , the singer , Nat aly a R ims kaya .
+
+" False Modest Mussorgsky, Nikolai Rimsky-Korsakov and Alexander Borodin 13 [' Modest', ' Muss', 'org', 'sky', ',', ' Nikol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2415 543 Name of mother of x -1 Name of mother of Nikolai Rimsky-Korsakov Sofia Vasilievna Skaryatina Nikolai Rimsky-Korsakov "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' singer' ','
+ ' Nat' 'aly' 'a' ' R' 'ims' 'kaya' '.' '\n' '\n']" ", the composer , and his wife , the singer , Nat aly a R ims kaya .
+
+" False Sadko composed by Nikolai Rimsky-Korsakov and the painting 12 [' Sad', 'ko', ' composed', ' by', ' Nikol', 'ai', ' R', 'ims', 'ky', '-', 'K', 'ors', 'akov']
+2416 544 Name of mother of x -1 Name of mother of George IV of the United Kingdom Charlotte of Mecklenburg-Strelitz George IV of the United Kingdom "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' George' ' IV' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the']" ", and the
+
+ The name of the mother of George IV of the United Kingdom , and the" False " (1762 – 1830)
+" 11 [' (', '17', '62', ' –', ' 1830', ')', 'George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2417 544 Name of mother of x -1 Name of mother of George IV of the United Kingdom Charlotte of Mecklenburg-Strelitz George IV of the United Kingdom "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' George' ' IV' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the']" ", and the
+
+ The name of the mother of George IV of the United Kingdom , and the" False Londonderry) and King George IV of the United Kingdom in Hanover in October. 12 [' L', 'ond', 'ond', 'erry', ')', ' and', ' King', ' George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2418 544 Name of mother of x -1 Name of mother of George IV of the United Kingdom Charlotte of Mecklenburg-Strelitz George IV of the United Kingdom "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' George' ' IV' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the']" ", and the
+
+ The name of the mother of George IV of the United Kingdom , and the" False " United Kingdom =
+" 8 [' United', ' Kingdom', ' =', 'George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2419 544 Name of mother of x -1 Name of mother of George IV of the United Kingdom Charlotte of Mecklenburg-Strelitz George IV of the United Kingdom "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' George' ' IV' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the']" ", and the
+
+ The name of the mother of George IV of the United Kingdom , and the" False " Kingdom =
+" 7 [' Kingdom', ' =', 'George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2420 544 Name of mother of x -1 Name of mother of George IV of the United Kingdom Charlotte of Mecklenburg-Strelitz George IV of the United Kingdom "[',' ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' George' ' IV' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the']" ", and the
+
+ The name of the mother of George IV of the United Kingdom , and the" False " Kingdom =
+" 7 [' Kingdom', ' =', 'George', ' IV', ' of', ' the', ' United', ' Kingdom']
+2421 545 Name of mother of x -1 Name of mother of Italo Calvino Giuliana Luigia Evelina Mameli Calvino Italo Calvino "[',' ' the' ' Italian' ' writer' ',' ' who' ' was' ' born' ' in' ' Naples'
+ ',' ' Italy' ',' ' in' ' 1923' '.' ' He' ' was' ' a' ' professor']" , the Italian writer , who was born in Naples , Italy , in 1923 . He was a professor False the Lungomare Italo Calvino in San Remo. 8 [' the', ' Lung', 'om', 'are', ' It', 'alo', ' Cal', 'v', 'ino']
+2422 545 Name of mother of x -1 Name of mother of Italo Calvino Giuliana Luigia Evelina Mameli Calvino Italo Calvino "[',' ' the' ' Italian' ' writer' ',' ' who' ' was' ' born' ' in' ' Naples'
+ ',' ' Italy' ',' ' in' ' 1923' '.' ' He' ' was' ' a' ' professor']" , the Italian writer , who was born in Naples , Italy , in 1923 . He was a professor False the Lungomare Italo Calvino in San Remo. 8 [' the', ' Lung', 'om', 'are', ' It', 'alo', ' Cal', 'v', 'ino']
+2423 545 Name of mother of x -1 Name of mother of Italo Calvino Giuliana Luigia Evelina Mameli Calvino Italo Calvino "[',' ' the' ' Italian' ' writer' ',' ' who' ' was' ' born' ' in' ' Naples'
+ ',' ' Italy' ',' ' in' ' 1923' '.' ' He' ' was' ' a' ' professor']" , the Italian writer , who was born in Naples , Italy , in 1923 . He was a professor False end on the Lungomare Italo Calvino in San Remo. The 10 [' end', ' on', ' the', ' Lung', 'om', 'are', ' It', 'alo', ' Cal', 'v', 'ino']
+2424 547 Name of mother of x -1 Name of mother of François Guizot Élisabeth-Sophie Bonicel François Guizot "[',' ' the' ' French' ' minister' ' of' ' the' ' interior' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' the' ' French' ' minister' ' of' ' the'
+ ' interior']" ", the French minister of the interior , and the
+
+ Name of the French minister of the interior" False Saint-Simonians such as François Guizot and Augustin 9 [' Saint', '-', 'Simon', 'ians', ' such', ' as', ' François', ' Gu', 'iz', 'ot']
+2425 547 Name of mother of x -1 Name of mother of François Guizot Élisabeth-Sophie Bonicel François Guizot "[',' ' the' ' French' ' minister' ' of' ' the' ' interior' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' the' ' French' ' minister' ' of' ' the'
+ ' interior']" ", the French minister of the interior , and the
+
+ Name of the French minister of the interior" False time in years from François Guizot over the Swiss 7 [' time', ' in', ' years', ' from', ' François', ' Gu', 'iz', 'ot']
+2426 547 Name of mother of x -1 Name of mother of François Guizot Élisabeth-Sophie Bonicel François Guizot "[',' ' the' ' French' ' minister' ' of' ' the' ' interior' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' the' ' French' ' minister' ' of' ' the'
+ ' interior']" ", the French minister of the interior , and the
+
+ Name of the French minister of the interior" False years from François Guizot over the Swiss 5 [' years', ' from', ' François', ' Gu', 'iz', 'ot']
+2427 547 Name of mother of x -1 Name of mother of François Guizot Élisabeth-Sophie Bonicel François Guizot "[',' ' the' ' French' ' minister' ' of' ' the' ' interior' ',' ' and'
+ ' the' '\n' '\n' 'Name' ' of' ' the' ' French' ' minister' ' of' ' the'
+ ' interior']" ", the French minister of the interior , and the
+
+ Name of the French minister of the interior" False in years from François Guizot over the Swiss 6 [' in', ' years', ' from', ' François', ' Gu', 'iz', 'ot']
+2428 548 Name of mother of x -1 Name of mother of Freddie Mercury Jer Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False Trust and organised The Freddie Mercury Tribute Concert 5 [' Trust', ' and', ' organised', ' The', ' Freddie', ' Mercury']
+2429 548 Name of mother of x -1 Name of mother of Freddie Mercury Jer Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False " to the work of Freddie Mercury and Queen. ""Speechless""" 5 [' to', ' the', ' work', ' of', ' Freddie', ' Mercury']
+2430 548 Name of mother of x -1 Name of mother of Freddie Mercury Jer Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False It was as if Freddie Mercury was saying to the world,' 5 [' It', ' was', ' as', ' if', ' Freddie', ' Mercury']
+2431 548 Name of mother of x -1 Name of mother of Freddie Mercury Jer Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False Awards, the Freddie Mercury Tribute Concert, 4 [' Awards', ',', ' the', ' Freddie', ' Mercury']
+2432 548 Name of mother of x -1 Name of mother of Freddie Mercury Jer Bulsara Freddie Mercury "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" ".
+
+ I am a mother of two , a wife , a daughter , a sister , a" False appeared at The Freddie Mercury Tribute Concert, 4 [' appeared', ' at', ' The', ' Freddie', ' Mercury']
+2433 549 Name of mother of x -1 Name of mother of Paul Simon Belle Simon Paul Simon "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False including Bette Midler, Paul Simon and his son Harper, 7 [' including', ' Bet', 'te', ' Mid', 'ler', ',', ' Paul', ' Simon']
+2434 549 Name of mother of x -1 Name of mother of Paul Simon Belle Simon Paul Simon "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False recording process when Paul Simon suffered from 4 [' recording', ' process', ' when', ' Paul', ' Simon']
+2435 549 Name of mother of x -1 Name of mother of Paul Simon Belle Simon Paul Simon "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False material. For example, Paul Simon was based in London 6 [' material', '.', ' For', ' example', ',', ' Paul', ' Simon']
+2436 549 Name of mother of x -1 Name of mother of Paul Simon Belle Simon Paul Simon "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False hoped to record a new Paul Simon album. The Paul 6 [' hoped', ' to', ' record', ' a', ' new', ' Paul', ' Simon']
+2437 549 Name of mother of x -1 Name of mother of Paul Simon Belle Simon Paul Simon "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' a']" ", the father of the bride , and the mother of the groom .
+
+ The wedding was a" False Bound was written by Paul Simon at a Widnes station. 5 [' Bound', ' was', ' written', ' by', ' Paul', ' Simon']
+2438 550 Name of mother of x -1 Name of mother of Buckminster Fuller Caroline Wolcott Andrews Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first'
+ ' thing']" ", the father of modern architecture , and the father of the modern world .
+
+ The first thing" False Albert Einstein and Buckminster Fuller amongst others, 5 [' Albert', ' Einstein', ' and', ' Buck', 'minster', ' Fuller']
+2439 550 Name of mother of x -1 Name of mother of Buckminster Fuller Caroline Wolcott Andrews Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first'
+ ' thing']" ", the father of modern architecture , and the father of the modern world .
+
+ The first thing" False metal-lattice skeleton from its Buckminster Fuller dome, now enclosing 10 [' metal', '-', 'l', 'att', 'ice', ' skeleton', ' from', ' its', ' Buck', 'minster', ' Fuller']
+2440 550 Name of mother of x -1 Name of mother of Buckminster Fuller Caroline Wolcott Andrews Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first'
+ ' thing']" ", the father of modern architecture , and the father of the modern world .
+
+ The first thing" False soccer ball, Buckminster Fuller geodesic dome, 5 [' soccer', ' ball', ',', ' Buck', 'minster', ' Fuller']
+2441 550 Name of mother of x -1 Name of mother of Buckminster Fuller Caroline Wolcott Andrews Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first'
+ ' thing']" ", the father of modern architecture , and the father of the modern world .
+
+ The first thing" False Albert Einstein and Buckminster Fuller amongst others, 5 [' Albert', ' Einstein', ' and', ' Buck', 'minster', ' Fuller']
+2442 550 Name of mother of x -1 Name of mother of Buckminster Fuller Caroline Wolcott Andrews Buckminster Fuller "[',' ' the' ' father' ' of' ' modern' ' architecture' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The' ' first'
+ ' thing']" ", the father of modern architecture , and the father of the modern world .
+
+ The first thing" False like a soccer ball, Buckminster Fuller geodesic dome, 7 [' like', ' a', ' soccer', ' ball', ',', ' Buck', 'minster', ' Fuller']
+2443 551 Name of mother of x -1 Name of mother of Mustafa Kemal Atatürk Zübeyde Hanım Mustafa Kemal Atatürk "[',' ' the' ' founder' ' of' ' the' ' Turkish' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' Turkey' '.'
+ '\n' '\n']" ", the founder of the Turkish Republic , and the first president of the Republic of Turkey .
+
+" False carried the remains of Mustafa Kemal Atatürk from Istanbul 11 [' carried', ' the', ' remains', ' of', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2444 551 Name of mother of x -1 Name of mother of Mustafa Kemal Atatürk Zübeyde Hanım Mustafa Kemal Atatürk "[',' ' the' ' founder' ' of' ' the' ' Turkish' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' Turkey' '.'
+ '\n' '\n']" ", the founder of the Turkish Republic , and the first president of the Republic of Turkey .
+
+" False the remains of Mustafa Kemal Atatürk from Istanbul to 10 [' the', ' remains', ' of', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2445 551 Name of mother of x -1 Name of mother of Mustafa Kemal Atatürk Zübeyde Hanım Mustafa Kemal Atatürk "[',' ' the' ' founder' ' of' ' the' ' Turkish' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' Turkey' '.'
+ '\n' '\n']" ", the founder of the Turkish Republic , and the first president of the Republic of Turkey .
+
+" False deemed insulting to Mustafa Kemal Atatürk and some material 10 [' deemed', ' insulting', ' to', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2446 551 Name of mother of x -1 Name of mother of Mustafa Kemal Atatürk Zübeyde Hanım Mustafa Kemal Atatürk "[',' ' the' ' founder' ' of' ' the' ' Turkish' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' Turkey' '.'
+ '\n' '\n']" ", the founder of the Turkish Republic , and the first president of the Republic of Turkey .
+
+" False his 1935 visit, Mustafa Kemal Atatürk finalized the name 11 [' his', ' 1935', ' visit', ',', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2447 551 Name of mother of x -1 Name of mother of Mustafa Kemal Atatürk Zübeyde Hanım Mustafa Kemal Atatürk "[',' ' the' ' founder' ' of' ' the' ' Turkish' ' Republic' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' Turkey' '.'
+ '\n' '\n']" ", the founder of the Turkish Republic , and the first president of the Republic of Turkey .
+
+" False 1922), initiated by Mustafa Kemal Atatürk and his colleagues 11 [' 1922', '),', ' initiated', ' by', ' Must', 'afa', ' Kem', 'al', ' At', 'at', 'ür', 'k']
+2448 552 Name of mother of x -1 Name of mother of John McCain Roberta McCain John McCain "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False tiered candidate John McCain for his support of 4 [' tie', 'red', ' candidate', ' John', ' McCain']
+2449 552 Name of mother of x -1 Name of mother of John McCain Roberta McCain John McCain "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False US senators such as John McCain and Lindsey Graham 5 [' US', ' senators', ' such', ' as', ' John', ' McCain']
+2450 552 Name of mother of x -1 Name of mother of John McCain Roberta McCain John McCain "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False " presidential candidate John McCain calling him ""fundamentally" 3 [' presidential', ' candidate', ' John', ' McCain']
+2451 552 Name of mother of x -1 Name of mother of John McCain Roberta McCain John McCain "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False with his friend John McCain the McCain-Biden 4 [' with', ' his', ' friend', ' John', ' McCain']
+2452 552 Name of mother of x -1 Name of mother of John McCain Roberta McCain John McCain "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False " presidential candidate John McCain calling him ""fundamentally" 3 [' presidential', ' candidate', ' John', ' McCain']
+2453 553 Name of mother of x -1 Name of mother of John Adams Susanna Boylston John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n']" ", the first president of the United States , and the first president of the United States .
+
+" False 1 ['John', ' Adams']
+2454 553 Name of mother of x -1 Name of mother of John Adams Susanna Boylston John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n']" ", the first president of the United States , and the first president of the United States .
+
+" False Warren wrote to John Adams on October 4 [' Warren', ' wrote', ' to', ' John', ' Adams']
+2455 553 Name of mother of x -1 Name of mother of John Adams Susanna Boylston John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n']" ", the first president of the United States , and the first president of the United States .
+
+" False Harrison, President John Adams nominated him 4 [' Harrison', ',', ' President', ' John', ' Adams']
+2456 553 Name of mother of x -1 Name of mother of John Adams Susanna Boylston John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n']" ", the first president of the United States , and the first president of the United States .
+
+" False " Braintree elected John Adams as a selectman.
+" 5 [' Br', 'aint', 'ree', ' elected', ' John', ' Adams']
+2457 553 Name of mother of x -1 Name of mother of John Adams Susanna Boylston John Adams "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n']" ", the first president of the United States , and the first president of the United States .
+
+" False still in port when John Adams arrived on 5 January 5 [' still', ' in', ' port', ' when', ' John', ' Adams']
+2458 554 Name of mother of x -1 Name of mother of William Wordsworth Ann Cookson William Wordsworth "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False " William Blake and William Wordsworth were major figures.
+" 5 [' William', ' Blake', ' and', ' William', ' Word', 'sworth']
+2459 554 Name of mother of x -1 Name of mother of William Wordsworth Ann Cookson William Wordsworth "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False writers such as William Wordsworth and Samuel Taylor 5 [' writers', ' such', ' as', ' William', ' Word', 'sworth']
+2460 554 Name of mother of x -1 Name of mother of William Wordsworth Ann Cookson William Wordsworth "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False Coleridge's friend William Wordsworth does with the narrator 7 "[' Col', 'er', 'idge', ""'s"", ' friend', ' William', ' Word', 'sworth']"
+2461 554 Name of mother of x -1 Name of mother of William Wordsworth Ann Cookson William Wordsworth "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False reign of Charles II. William Wordsworth was a frequent 7 [' reign', ' of', ' Charles', ' II', '.', ' William', ' Word', 'sworth']
+2462 554 Name of mother of x -1 Name of mother of William Wordsworth Ann Cookson William Wordsworth "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False composers: Howard Ferguson, William Wordsworth and Edmund Rubbra. 8 [' compos', 'ers', ':', ' Howard', ' Ferguson', ',', ' William', ' Word', 'sworth']
+2463 555 Name of mother of x -1 Name of mother of Teresa of Ávila Beatriz de Ahumada Teresa of Ávila "[',' ' the' ' founder' ' of' ' the' ' D' 'iscal' 'ced' ' Carm' 'el' 'ite'
+ ' Order' ',' ' and' ' the' ' founder' ' of' ' the' ' D' 'iscal']" , the founder of the D iscal ced Carm el ite Order , and the founder of the D iscal False Luis de León, Teresa of Ávila and John of 10 [' Luis', ' de', ' Le', 'ón', ',', ' Teresa', ' of', ' �', '�', 'vil', 'a']
+2464 555 Name of mother of x -1 Name of mother of Teresa of Ávila Beatriz de Ahumada Teresa of Ávila "[',' ' the' ' founder' ' of' ' the' ' D' 'iscal' 'ced' ' Carm' 'el' 'ite'
+ ' Order' ',' ' and' ' the' ' founder' ' of' ' the' ' D' 'iscal']" , the founder of the D iscal ced Carm el ite Order , and the founder of the D iscal False 6 ['Te', 'resa', ' of', ' �', '�', 'vil', 'a']
+2465 555 Name of mother of x -1 Name of mother of Teresa of Ávila Beatriz de Ahumada Teresa of Ávila "[',' ' the' ' founder' ' of' ' the' ' D' 'iscal' 'ced' ' Carm' 'el' 'ite'
+ ' Order' ',' ' and' ' the' ' founder' ' of' ' the' ' D' 'iscal']" , the founder of the D iscal ced Carm el ite Order , and the founder of the D iscal False 6 ['Te', 'resa', ' of', ' �', '�', 'vil', 'a']
+2466 555 Name of mother of x -1 Name of mother of Teresa of Ávila Beatriz de Ahumada Teresa of Ávila "[',' ' the' ' founder' ' of' ' the' ' D' 'iscal' 'ced' ' Carm' 'el' 'ite'
+ ' Order' ',' ' and' ' the' ' founder' ' of' ' the' ' D' 'iscal']" , the founder of the D iscal ced Carm el ite Order , and the founder of the D iscal False the other of Saint Teresa of Ávila with a quill 9 [' the', ' other', ' of', ' Saint', ' Teresa', ' of', ' �', '�', 'vil', 'a']
+2467 555 Name of mother of x -1 Name of mother of Teresa of Ávila Beatriz de Ahumada Teresa of Ávila "[',' ' the' ' founder' ' of' ' the' ' D' 'iscal' 'ced' ' Carm' 'el' 'ite'
+ ' Order' ',' ' and' ' the' ' founder' ' of' ' the' ' D' 'iscal']" , the founder of the D iscal ced Carm el ite Order , and the founder of the D iscal False Luis de León, Teresa of Ávila and John of the 10 [' Luis', ' de', ' Le', 'ón', ',', ' Teresa', ' of', ' �', '�', 'vil', 'a']
+2468 556 Name of mother of x -1 Name of mother of Max Born Margarethe 'Gretchen' Kauffmann Max Born "[',' ' the' ' father' ' of' ' modern' ' quantum' ' mechanics' '.' '\n'
+ '\n' 'The' ' first' ' quantum' ' theory' ' of' ' light' ' was'
+ ' developed' ' by' ' Max']" ", the father of modern quantum mechanics .
+
+ The first quantum theory of light was developed by Max" False as early as 1909, Max Born had given a definition 6 [' as', ' early', ' as', ' 1909', ',', ' Max', ' Born']
+2469 556 Name of mother of x -1 Name of mother of Max Born Margarethe 'Gretchen' Kauffmann Max Born "[',' ' the' ' father' ' of' ' modern' ' quantum' ' mechanics' '.' '\n'
+ '\n' 'The' ' first' ' quantum' ' theory' ' of' ' light' ' was'
+ ' developed' ' by' ' Max']" ", the father of modern quantum mechanics .
+
+ The first quantum theory of light was developed by Max" False " Max Born =
+" 1 [' Max', ' Born']
+2470 556 Name of mother of x -1 Name of mother of Max Born Margarethe 'Gretchen' Kauffmann Max Born "[',' ' the' ' father' ' of' ' modern' ' quantum' ' mechanics' '.' '\n'
+ '\n' 'The' ' first' ' quantum' ' theory' ' of' ' light' ' was'
+ ' developed' ' by' ' Max']" ", the father of modern quantum mechanics .
+
+ The first quantum theory of light was developed by Max" False semester studying under Max Born at the University of 4 [' semester', ' studying', ' under', ' Max', ' Born']
+2471 556 Name of mother of x -1 Name of mother of Max Born Margarethe 'Gretchen' Kauffmann Max Born "[',' ' the' ' father' ' of' ' modern' ' quantum' ' mechanics' '.' '\n'
+ '\n' 'The' ' first' ' quantum' ' theory' ' of' ' light' ' was'
+ ' developed' ' by' ' Max']" ", the father of modern quantum mechanics .
+
+ The first quantum theory of light was developed by Max" False of Göttingen offered Max Born its chair of theoretical 7 [' of', ' Gö', 't', 'ting', 'en', ' offered', ' Max', ' Born']
+2472 556 Name of mother of x -1 Name of mother of Max Born Margarethe 'Gretchen' Kauffmann Max Born "[',' ' the' ' father' ' of' ' modern' ' quantum' ' mechanics' '.' '\n'
+ '\n' 'The' ' first' ' quantum' ' theory' ' of' ' light' ' was'
+ ' developed' ' by' ' Max']" ", the father of modern quantum mechanics .
+
+ The first quantum theory of light was developed by Max" False Einstein, Enrico Fermi, Max Born and Arnold Sommerfeld 10 [' Einstein', ',', ' En', 'ric', 'o', ' F', 'erm', 'i', ',', ' Max', ' Born']
+2473 557 Name of mother of x -1 Name of mother of Karlheinz Stockhausen Gertrud Stockhausen Karlheinz Stockhausen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' composer' ""'s""]" , the composer , and his wife , the painter , and the painter 's mother , the composer 's False " changed its name to Karlheinz Stockhausen Studio.
+" 10 [' changed', ' its', ' name', ' to', ' Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2474 557 Name of mother of x -1 Name of mother of Karlheinz Stockhausen Gertrud Stockhausen Karlheinz Stockhausen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' composer' ""'s""]" , the composer , and his wife , the painter , and the painter 's mother , the composer 's False pupils included Karlheinz Stockhausen in 1952, Alexander 8 [' pupils', ' included', ' Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2475 557 Name of mother of x -1 Name of mother of Karlheinz Stockhausen Gertrud Stockhausen Karlheinz Stockhausen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' composer' ""'s""]" , the composer , and his wife , the painter , and the painter 's mother , the composer 's False 6 ['Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2476 557 Name of mother of x -1 Name of mother of Karlheinz Stockhausen Gertrud Stockhausen Karlheinz Stockhausen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' composer' ""'s""]" , the composer , and his wife , the painter , and the painter 's mother , the composer 's False Thelonious Monk and Karlheinz Stockhausen each performed 11 [' The', 'lon', 'ious', ' Monk', ' and', ' Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2477 557 Name of mother of x -1 Name of mother of Karlheinz Stockhausen Gertrud Stockhausen Karlheinz Stockhausen "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' painter'
+ ',' ' and' ' the' ' painter' ""'s"" ' mother' ',' ' the' ' composer' ""'s""]" , the composer , and his wife , the painter , and the painter 's mother , the composer 's False tremendous people like Karlheinz Stockhausen and the development 9 [' tremendous', ' people', ' like', ' Karl', 'he', 'in', 'z', ' Stock', 'haus', 'en']
+2478 558 Name of mother of x -1 Name of mother of Sylvester Stallone Jackie Stallone Sylvester Stallone "[',' ' the' ' actor' ' who' ' played' ' Rocky' ' Bal' 'boa' ' in' ' the'
+ ' Rocky' ' movies' '.' '\n' '\n' 'The' ' Rocky' ' movies' ' are' ' a']" ", the actor who played Rocky Bal boa in the Rocky movies .
+
+ The Rocky movies are a" False mob boss opposite Sylvester Stallone and Sharon Stone 7 [' mob', ' boss', ' opposite', ' Sy', 'lves', 'ter', ' Stall', 'one']
+2479 558 Name of mother of x -1 Name of mother of Sylvester Stallone Jackie Stallone Sylvester Stallone "[',' ' the' ' actor' ' who' ' played' ' Rocky' ' Bal' 'boa' ' in' ' the'
+ ' Rocky' ' movies' '.' '\n' '\n' 'The' ' Rocky' ' movies' ' are' ' a']" ", the actor who played Rocky Bal boa in the Rocky movies .
+
+ The Rocky movies are a" False role, while Sylvester Stallone was interested, 7 [' role', ',', ' while', ' Sy', 'lves', 'ter', ' Stall', 'one']
+2480 558 Name of mother of x -1 Name of mother of Sylvester Stallone Jackie Stallone Sylvester Stallone "[',' ' the' ' actor' ' who' ' played' ' Rocky' ' Bal' 'boa' ' in' ' the'
+ ' Rocky' ' movies' '.' '\n' '\n' 'The' ' Rocky' ' movies' ' are' ' a']" ", the actor who played Rocky Bal boa in the Rocky movies .
+
+ The Rocky movies are a" False " compromise"". American actor Sylvester Stallone was interested to use" 8 "[' compromise', '"".', ' American', ' actor', ' Sy', 'lves', 'ter', ' Stall', 'one']"
+2481 558 Name of mother of x -1 Name of mother of Sylvester Stallone Jackie Stallone Sylvester Stallone "[',' ' the' ' actor' ' who' ' played' ' Rocky' ' Bal' 'boa' ' in' ' the'
+ ' Rocky' ' movies' '.' '\n' '\n' 'The' ' Rocky' ' movies' ' are' ' a']" ", the actor who played Rocky Bal boa in the Rocky movies .
+
+ The Rocky movies are a" False and starring Sylvester Stallone who reprises his 6 [' and', ' starring', ' Sy', 'lves', 'ter', ' Stall', 'one']
+2482 558 Name of mother of x -1 Name of mother of Sylvester Stallone Jackie Stallone Sylvester Stallone "[',' ' the' ' actor' ' who' ' played' ' Rocky' ' Bal' 'boa' ' in' ' the'
+ ' Rocky' ' movies' '.' '\n' '\n' 'The' ' Rocky' ' movies' ' are' ' a']" ", the actor who played Rocky Bal boa in the Rocky movies .
+
+ The Rocky movies are a" False time such as Sylvester Stallone and Tom Cruise, 7 [' time', ' such', ' as', ' Sy', 'lves', 'ter', ' Stall', 'one']
+2483 559 Name of mother of x -1 Name of mother of Joan of Arc Isabelle Romée Joan of Arc "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joan' ' of'
+ ' Arc' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Joan of Arc , and the
+
+ Name of mother" False " (poem) =
+" 8 [' (', 'po', 'em', ')', ' =', 'Jo', 'an', ' of', ' Arc']
+2484 559 Name of mother of x -1 Name of mother of Joan of Arc Isabelle Romée Joan of Arc "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joan' ' of'
+ ' Arc' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Joan of Arc , and the
+
+ Name of mother" False " pronunciation of Joan of Arc as ""Jeanne d 'Arc""," 4 [' pronunciation', ' of', ' Joan', ' of', ' Arc']
+2485 559 Name of mother of x -1 Name of mother of Joan of Arc Isabelle Romée Joan of Arc "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joan' ' of'
+ ' Arc' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Joan of Arc , and the
+
+ Name of mother" False Messenger: The Story of Joan of Arc (1999). Jovovich is 7 [' Messenger', ':', ' The', ' Story', ' of', ' Joan', ' of', ' Arc']
+2486 559 Name of mother of x -1 Name of mother of Joan of Arc Isabelle Romée Joan of Arc "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joan' ' of'
+ ' Arc' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Joan of Arc , and the
+
+ Name of mother" False crowned King of France. Joan of Arc serves as a way 7 [' crowned', ' King', ' of', ' France', '.', ' Joan', ' of', ' Arc']
+2487 559 Name of mother of x -1 Name of mother of Joan of Arc Isabelle Romée Joan of Arc "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joan' ' of'
+ ' Arc' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Joan of Arc , and the
+
+ Name of mother" False " Messenger: The Story of Joan of Arc =
+" 7 [' Messenger', ':', ' The', ' Story', ' of', ' Joan', ' of', ' Arc']
+2488 560 Name of mother of x -1 Name of mother of Kofi Annan Rose Eshun Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' and' ' the'
+ ' UN' ' Secretary' '-' 'General' ""'s"" ' Special' ' Representative' ' for'
+ ' Children' ' and']" , the former UN Secretary - General , and the UN Secretary - General 's Special Representative for Children and False (Haga Church). Kofi Annan attended the unveiling 8 [' (', 'H', 'aga', ' Church', ').', ' K', 'ofi', ' Ann', 'an']
+2489 560 Name of mother of x -1 Name of mother of Kofi Annan Rose Eshun Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' and' ' the'
+ ' UN' ' Secretary' '-' 'General' ""'s"" ' Special' ' Representative' ' for'
+ ' Children' ' and']" , the former UN Secretary - General , and the UN Secretary - General 's Special Representative for Children and False Secretary General Kofi Annan (2004), Late Night 5 [' Secretary', ' General', ' K', 'ofi', ' Ann', 'an']
+2490 560 Name of mother of x -1 Name of mother of Kofi Annan Rose Eshun Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' and' ' the'
+ ' UN' ' Secretary' '-' 'General' ""'s"" ' Special' ' Representative' ' for'
+ ' Children' ' and']" , the former UN Secretary - General , and the UN Secretary - General 's Special Representative for Children and False violence back to Kofi Annan with instructions 6 [' violence', ' back', ' to', ' K', 'ofi', ' Ann', 'an']
+2491 560 Name of mother of x -1 Name of mother of Kofi Annan Rose Eshun Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' and' ' the'
+ ' UN' ' Secretary' '-' 'General' ""'s"" ' Special' ' Representative' ' for'
+ ' Children' ' and']" , the former UN Secretary - General , and the UN Secretary - General 's Special Representative for Children and False Secretary General Kofi Annan and his wife 5 [' Secretary', ' General', ' K', 'ofi', ' Ann', 'an']
+2492 560 Name of mother of x -1 Name of mother of Kofi Annan Rose Eshun Kofi Annan "[',' ' the' ' former' ' UN' ' Secretary' '-' 'General' ',' ' and' ' the'
+ ' UN' ' Secretary' '-' 'General' ""'s"" ' Special' ' Representative' ' for'
+ ' Children' ' and']" , the former UN Secretary - General , and the UN Secretary - General 's Special Representative for Children and False " Secretary General Kofi Annan said: ""Remembering" 5 [' Secretary', ' General', ' K', 'ofi', ' Ann', 'an']
+2493 562 Name of mother of x -1 Name of mother of Elizabeth Barrett Browning Mary Graham-Clarke Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' mother' ' of' ' the' ' poet'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' poet' 'ess' ',']" , the poet ess , and the mother of the poet , and the mother of the poet ess , False of Shakespeare, Elizabeth Barrett Browning and her husband Robert, 6 [' of', ' Shakespeare', ',', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2494 562 Name of mother of x -1 Name of mother of Elizabeth Barrett Browning Mary Graham-Clarke Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' mother' ' of' ' the' ' poet'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' poet' 'ess' ',']" , the poet ess , and the mother of the poet , and the mother of the poet ess , False of Shakespeare, Elizabeth Barrett Browning and her husband Robert, 6 [' of', ' Shakespeare', ',', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2495 562 Name of mother of x -1 Name of mother of Elizabeth Barrett Browning Mary Graham-Clarke Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' mother' ' of' ' the' ' poet'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' poet' 'ess' ',']" , the poet ess , and the mother of the poet , and the mother of the poet ess , False American literature. Elizabeth Barrett Browning was also a strong 6 [' American', ' literature', '.', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2496 562 Name of mother of x -1 Name of mother of Elizabeth Barrett Browning Mary Graham-Clarke Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' mother' ' of' ' the' ' poet'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' poet' 'ess' ',']" , the poet ess , and the mother of the poet , and the mother of the poet ess , False he considered Elizabeth Barrett Browning the best contemporary 5 [' he', ' considered', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2497 562 Name of mother of x -1 Name of mother of Elizabeth Barrett Browning Mary Graham-Clarke Elizabeth Barrett Browning "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' mother' ' of' ' the' ' poet'
+ ',' ' and' ' the' ' mother' ' of' ' the' ' poet' 'ess' ',']" , the poet ess , and the mother of the poet , and the mother of the poet ess , False American literature. Elizabeth Barrett Browning was also a strong 6 [' American', ' literature', '.', ' Elizabeth', ' Barrett', ' Brown', 'ing']
+2498 563 Name of mother of x -1 Name of mother of Jimi Hendrix Lucille Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ',' ' and' ' his' ' wife' ',' ' Jan' 'ie' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the guitarist , and his wife , Jan ie , who was a singer .
+
+ The couple" False trends, such as the Jimi Hendrix Experience and 8 [' trends', ',', ' such', ' as', ' the', ' Jim', 'i', ' Hend', 'rix']
+2499 563 Name of mother of x -1 Name of mother of Jimi Hendrix Lucille Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ',' ' and' ' his' ' wife' ',' ' Jan' 'ie' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the guitarist , and his wife , Jan ie , who was a singer .
+
+ The couple" False " inspired by the Jimi Hendrix song ""Purple" 6 [' inspired', ' by', ' the', ' Jim', 'i', ' Hend', 'rix']
+2500 563 Name of mother of x -1 Name of mother of Jimi Hendrix Lucille Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ',' ' and' ' his' ' wife' ',' ' Jan' 'ie' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the guitarist , and his wife , Jan ie , who was a singer .
+
+ The couple" False Zapple albums. Jimi Hendrix and his manager, 7 [' Z', 'apple', ' albums', '.', ' Jim', 'i', ' Hend', 'rix']
+2501 563 Name of mother of x -1 Name of mother of Jimi Hendrix Lucille Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ',' ' and' ' his' ' wife' ',' ' Jan' 'ie' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the guitarist , and his wife , Jan ie , who was a singer .
+
+ The couple" False Banton have stated that Jimi Hendrix was an influence 8 [' B', 'anton', ' have', ' stated', ' that', ' Jim', 'i', ' Hend', 'rix']
+2502 563 Name of mother of x -1 Name of mother of Jimi Hendrix Lucille Hendrix Jimi Hendrix "[',' ' the' ' guitarist' ',' ' and' ' his' ' wife' ',' ' Jan' 'ie' ','
+ ' who' ' was' ' a' ' singer' '.' '\n' '\n' 'The' ' couple']" ", the guitarist , and his wife , Jan ie , who was a singer .
+
+ The couple" False for the case of Jimi Hendrix in Guitar Hero World 7 [' for', ' the', ' case', ' of', ' Jim', 'i', ' Hend', 'rix']
+2503 564 Name of mother of x -1 Name of mother of Henry Kissinger Paula Stern Henry Kissinger "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Bild' 'er' 'berg' ' Group' ',' ' the' ' Tr']" , the former secretary of state , who was a member of the Bild er berg Group , the Tr False 1 ['Henry', ' Kissinger']
+2504 564 Name of mother of x -1 Name of mother of Henry Kissinger Paula Stern Henry Kissinger "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Bild' 'er' 'berg' ' Group' ',' ' the' ' Tr']" , the former secretary of state , who was a member of the Bild er berg Group , the Tr False " Kissinger (1974 – 1975)
+" 7 [' Kissinger', ' (', '1974', ' –', ' 1975', ')', 'Henry', ' Kissinger']
+2505 564 Name of mother of x -1 Name of mother of Henry Kissinger Paula Stern Henry Kissinger "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Bild' 'er' 'berg' ' Group' ',' ' the' ' Tr']" , the former secretary of state , who was a member of the Bild er berg Group , the Tr False Richard Nixon and Henry Kissinger and lamenting 4 [' Richard', ' Nixon', ' and', ' Henry', ' Kissinger']
+2506 564 Name of mother of x -1 Name of mother of Henry Kissinger Paula Stern Henry Kissinger "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Bild' 'er' 'berg' ' Group' ',' ' the' ' Tr']" , the former secretary of state , who was a member of the Bild er berg Group , the Tr False 1 ['Henry', ' Kissinger']
+2507 564 Name of mother of x -1 Name of mother of Henry Kissinger Paula Stern Henry Kissinger "[',' ' the' ' former' ' secretary' ' of' ' state' ',' ' who' ' was' ' a'
+ ' member' ' of' ' the' ' Bild' 'er' 'berg' ' Group' ',' ' the' ' Tr']" , the former secretary of state , who was a member of the Bild er berg Group , the Tr False the deal, and Henry Kissinger brought the issue 5 [' the', ' deal', ',', ' and', ' Henry', ' Kissinger']
+2508 565 Name of mother of x -1 Name of mother of Alexander Pope Edith Pope Alexander Pope "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',' ' the'
+ ' other' ' two' ' being' ' the' '\n' '\n' 'The' ' first' ' of']" ", the
+
+ The first of the three , the other two being the
+
+ The first of" False from borrowings. Alexander Pope said Cibber's 5 [' from', ' borrow', 'ings', '.', ' Alexander', ' Pope']
+2509 565 Name of mother of x -1 Name of mother of Alexander Pope Edith Pope Alexander Pope "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',' ' the'
+ ' other' ' two' ' being' ' the' '\n' '\n' 'The' ' first' ' of']" ", the
+
+ The first of the three , the other two being the
+
+ The first of" False century by the poet Alexander Pope and the landscape 5 [' century', ' by', ' the', ' poet', ' Alexander', ' Pope']
+2510 565 Name of mother of x -1 Name of mother of Alexander Pope Edith Pope Alexander Pope "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',' ' the'
+ ' other' ' two' ' being' ' the' '\n' '\n' 'The' ' first' ' of']" ", the
+
+ The first of the three , the other two being the
+
+ The first of" False few years later Alexander Pope was seen as satirising 4 [' few', ' years', ' later', ' Alexander', ' Pope']
+2511 565 Name of mother of x -1 Name of mother of Alexander Pope Edith Pope Alexander Pope "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',' ' the'
+ ' other' ' two' ' being' ' the' '\n' '\n' 'The' ' first' ' of']" ", the
+
+ The first of the three , the other two being the
+
+ The first of" False particularly from Alexander Pope and other Tory 3 [' particularly', ' from', ' Alexander', ' Pope']
+2512 565 Name of mother of x -1 Name of mother of Alexander Pope Edith Pope Alexander Pope "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ',' ' the'
+ ' other' ' two' ' being' ' the' '\n' '\n' 'The' ' first' ' of']" ", the
+
+ The first of the three , the other two being the
+
+ The first of" False include the poets Alexander Pope and W. B. Yeats, 4 [' include', ' the', ' poets', ' Alexander', ' Pope']
+2513 566 Name of mother of x -1 Name of mother of George Romney Ann Simpson George Romney "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' George'
+ ' Romney' ',' ' the' ' father' ' of' ' George' ' Romney' ',' ' the']" ", the
+
+ The name of the mother of George Romney , the father of George Romney , the" False On July 26, 1995, George Romney died of a heart 7 [' On', ' July', ' 26', ',', ' 1995', ',', ' George', ' Romney']
+2514 566 Name of mother of x -1 Name of mother of George Romney Ann Simpson George Romney "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' George'
+ ' Romney' ',' ' the' ' father' ' of' ' George' ' Romney' ',' ' the']" ", the
+
+ The name of the mother of George Romney , the father of George Romney , the" False and first met George Romney. She attended 4 [' and', ' first', ' met', ' George', ' Romney']
+2515 566 Name of mother of x -1 Name of mother of George Romney Ann Simpson George Romney "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' George'
+ ' Romney' ',' ' the' ' father' ' of' ' George' ' Romney' ',' ' the']" ", the
+
+ The name of the mother of George Romney , the father of George Romney , the" False Michigan Governor George Romney appointed Franks 3 [' Michigan', ' Governor', ' George', ' Romney']
+2516 566 Name of mother of x -1 Name of mother of George Romney Ann Simpson George Romney "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' George'
+ ' Romney' ',' ' the' ' father' ' of' ' George' ' Romney' ',' ' the']" ", the
+
+ The name of the mother of George Romney , the father of George Romney , the" False and talked with George Romney over a number of years, 4 [' and', ' talked', ' with', ' George', ' Romney']
+2517 566 Name of mother of x -1 Name of mother of George Romney Ann Simpson George Romney "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' George'
+ ' Romney' ',' ' the' ' father' ' of' ' George' ' Romney' ',' ' the']" ", the
+
+ The name of the mother of George Romney , the father of George Romney , the" False and first met George Romney. She attended the 4 [' and', ' first', ' met', ' George', ' Romney']
+2518 567 Name of mother of x -1 Name of mother of Bob Marley Cedella Booker Bob Marley "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I' ' was']" ", the father of the modern world .
+
+ The first time I saw the movie , I was" False best known as Bob Marley and the Wailers' 5 [' best', ' known', ' as', ' Bob', ' Mar', 'ley']
+2519 567 Name of mother of x -1 Name of mother of Bob Marley Cedella Booker Bob Marley "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I' ' was']" ", the father of the modern world .
+
+ The first time I saw the movie , I was" False " concocted duet with Bob Marley on ""Turn Your" 7 [' concoct', 'ed', ' du', 'et', ' with', ' Bob', ' Mar', 'ley']
+2520 567 Name of mother of x -1 Name of mother of Bob Marley Cedella Booker Bob Marley "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I' ' was']" ", the father of the modern world .
+
+ The first time I saw the movie , I was" False 2 ['Bob', ' Mar', 'ley']
+2521 567 Name of mother of x -1 Name of mother of Bob Marley Cedella Booker Bob Marley "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I' ' was']" ", the father of the modern world .
+
+ The first time I saw the movie , I was" False Jamaican musician Bob Marley and Rihanna's 2011 5 [' Jama', 'ican', ' musician', ' Bob', ' Mar', 'ley']
+2522 567 Name of mother of x -1 Name of mother of Bob Marley Cedella Booker Bob Marley "[',' ' the' ' father' ' of' ' the' ' modern' ' world' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I' ' was']" ", the father of the modern world .
+
+ The first time I saw the movie , I was" False including a documentary on Bob Marley & The Wailers 6 [' including', ' a', ' documentary', ' on', ' Bob', ' Mar', 'ley']
+2523 568 Name of mother of x -1 Name of mother of Epicurus Chaerestrate Epicurus "[',' ' the' ' son' ' of' ' Epic' 'urus' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Epic' 'ure' 'an' ' school' ' of' ' philosophy' '.' '\n']" ", the son of Epic urus , and the father of the Epic ure an school of philosophy .
+" False 2 ['Ep', 'ic', 'urus']
+2524 568 Name of mother of x -1 Name of mother of Epicurus Chaerestrate Epicurus "[',' ' the' ' son' ' of' ' Epic' 'urus' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Epic' 'ure' 'an' ' school' ' of' ' philosophy' '.' '\n']" ", the son of Epic urus , and the father of the Epic ure an school of philosophy .
+" False history of atheism was Epicurus (c. 300 BCE). Drawing 5 [' history', ' of', ' atheism', ' was', ' Epic', 'urus']
+2525 568 Name of mother of x -1 Name of mother of Epicurus Chaerestrate Epicurus "[',' ' the' ' son' ' of' ' Epic' 'urus' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Epic' 'ure' 'an' ' school' ' of' ' philosophy' '.' '\n']" ", the son of Epic urus , and the father of the Epic ure an school of philosophy .
+" False history of atheism was Epicurus (c. 300 BCE). 5 [' history', ' of', ' atheism', ' was', ' Epic', 'urus']
+2526 568 Name of mother of x -1 Name of mother of Epicurus Chaerestrate Epicurus "[',' ' the' ' son' ' of' ' Epic' 'urus' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Epic' 'ure' 'an' ' school' ' of' ' philosophy' '.' '\n']" ", the son of Epic urus , and the father of the Epic ure an school of philosophy .
+" False the time of Epicurus in the fourth century 4 [' the', ' time', ' of', ' Epic', 'urus']
+2527 568 Name of mother of x -1 Name of mother of Epicurus Chaerestrate Epicurus "[',' ' the' ' son' ' of' ' Epic' 'urus' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Epic' 'ure' 'an' ' school' ' of' ' philosophy' '.' '\n']" ", the son of Epic urus , and the father of the Epic ure an school of philosophy .
+" False the arm from Epicurus ... while the Stoics 4 [' the', ' arm', ' from', ' Epic', 'urus']
+2528 569 Name of mother of x -1 Name of mother of Gérard Depardieu Alice Depardieu Gérard Depardieu "[',' ' the' ' French' ' actor' ',' ' who' ' is' ' a' ' friend' ' of'
+ ' mine' '.' '\n' '\n' 'I' ' have' ' been' ' a' ' fan' ' of']" ", the French actor , who is a friend of mine .
+
+ I have been a fan of" False 5 ['G', 'ér', 'ard', ' Dep', 'ard', 'ieu']
+2529 569 Name of mother of x -1 Name of mother of Gérard Depardieu Alice Depardieu Gérard Depardieu "[',' ' the' ' French' ' actor' ',' ' who' ' is' ' a' ' friend' ' of'
+ ' mine' '.' '\n' '\n' 'I' ' have' ' been' ' a' ' fan' ' of']" ", the French actor , who is a friend of mine .
+
+ I have been a fan of" False Comedy, but lost to Gérard Depardieu of Green Card. 10 [' Comedy', ',', ' but', ' lost', ' to', ' G', 'ér', 'ard', ' Dep', 'ard', 'ieu']
+2530 569 Name of mother of x -1 Name of mother of Gérard Depardieu Alice Depardieu Gérard Depardieu "[',' ' the' ' French' ' actor' ',' ' who' ' is' ' a' ' friend' ' of'
+ ' mine' '.' '\n' '\n' 'I' ' have' ' been' ' a' ' fan' ' of']" ", the French actor , who is a friend of mine .
+
+ I have been a fan of" False that he wants Gérard Depardieu to portray him in the 8 [' that', ' he', ' wants', ' G', 'ér', 'ard', ' Dep', 'ard', 'ieu']
+2531 569 Name of mother of x -1 Name of mother of Gérard Depardieu Alice Depardieu Gérard Depardieu "[',' ' the' ' French' ' actor' ',' ' who' ' is' ' a' ' friend' ' of'
+ ' mine' '.' '\n' '\n' 'I' ' have' ' been' ' a' ' fan' ' of']" ", the French actor , who is a friend of mine .
+
+ I have been a fan of" False 5 ['G', 'ér', 'ard', ' Dep', 'ard', 'ieu']
+2532 569 Name of mother of x -1 Name of mother of Gérard Depardieu Alice Depardieu Gérard Depardieu "[',' ' the' ' French' ' actor' ',' ' who' ' is' ' a' ' friend' ' of'
+ ' mine' '.' '\n' '\n' 'I' ' have' ' been' ' a' ' fan' ' of']" ", the French actor , who is a friend of mine .
+
+ I have been a fan of" False and that he wants Gérard Depardieu to portray 9 [' and', ' that', ' he', ' wants', ' G', 'ér', 'ard', ' Dep', 'ard', 'ieu']
+2533 570 Name of mother of x -1 Name of mother of Dwayne Johnson Ata Johnson Dwayne Johnson "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' villain' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' is' ' set']" ", the actor who plays the role of the villain in the film .
+
+ The film is set" False professional wrestler Dwayne Johnson on Twitter. Anonymous 4 [' professional', ' wrestler', ' D', 'wayne', ' Johnson']
+2534 570 Name of mother of x -1 Name of mother of Dwayne Johnson Ata Johnson Dwayne Johnson "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' villain' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' is' ' set']" ", the actor who plays the role of the villain in the film .
+
+ The film is set" False Jackman contacted Dwayne Johnson for some tips 5 [' Jack', 'man', ' contacted', ' D', 'wayne', ' Johnson']
+2535 570 Name of mother of x -1 Name of mother of Dwayne Johnson Ata Johnson Dwayne Johnson "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' villain' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' is' ' set']" ", the actor who plays the role of the villain in the film .
+
+ The film is set" False 2 ['D', 'wayne', ' Johnson']
+2536 570 Name of mother of x -1 Name of mother of Dwayne Johnson Ata Johnson Dwayne Johnson "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' villain' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' is' ' set']" ", the actor who plays the role of the villain in the film .
+
+ The film is set" False 2 ['D', 'wayne', ' Johnson']
+2537 570 Name of mother of x -1 Name of mother of Dwayne Johnson Ata Johnson Dwayne Johnson "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the'
+ ' villain' ' in' ' the' ' film' '.' '\n' '\n' 'The' ' film' ' is' ' set']" ", the actor who plays the role of the villain in the film .
+
+ The film is set" False in May and starring Dwayne Johnson and Carla Gugino. 6 [' in', ' May', ' and', ' starring', ' D', 'wayne', ' Johnson']
+2538 571 Name of mother of x -1 Name of mother of Edvard Grieg Gesine Hagerup Edvard Grieg "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 18' '43' ' in'
+ ' Norway' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' wealthy'
+ ' merchant']" , the composer , was born in 18 43 in Norway . He was the son of a wealthy merchant False 95 season, Edvard Grieg and Camille 6 [' 95', ' season', ',', ' Ed', 'vard', ' Gri', 'eg']
+2539 571 Name of mother of x -1 Name of mother of Edvard Grieg Gesine Hagerup Edvard Grieg "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 18' '43' ' in'
+ ' Norway' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' wealthy'
+ ' merchant']" , the composer , was born in 18 43 in Norway . He was the son of a wealthy merchant False the composer Edvard Grieg in Leipzig. Grieg, 5 [' the', ' composer', ' Ed', 'vard', ' Gri', 'eg']
+2540 571 Name of mother of x -1 Name of mother of Edvard Grieg Gesine Hagerup Edvard Grieg "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 18' '43' ' in'
+ ' Norway' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' wealthy'
+ ' merchant']" , the composer , was born in 18 43 in Norway . He was the son of a wealthy merchant False also collaborated with Edvard Grieg on an opera about 6 [' also', ' collaborated', ' with', ' Ed', 'vard', ' Gri', 'eg']
+2541 571 Name of mother of x -1 Name of mother of Edvard Grieg Gesine Hagerup Edvard Grieg "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 18' '43' ' in'
+ ' Norway' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' wealthy'
+ ' merchant']" , the composer , was born in 18 43 in Norway . He was the son of a wealthy merchant False Antonín Dvořák, Edvard Grieg and Pyotr Ilyich Tchaikovsky. 12 [' Anton', 'ín', ' D', 'vo', '�', '�', 'á', 'k', ',', ' Ed', 'vard', ' Gri', 'eg']
+2542 571 Name of mother of x -1 Name of mother of Edvard Grieg Gesine Hagerup Edvard Grieg "[',' ' the' ' composer' ',' ' was' ' born' ' in' ' 18' '43' ' in'
+ ' Norway' '.' ' He' ' was' ' the' ' son' ' of' ' a' ' wealthy'
+ ' merchant']" , the composer , was born in 18 43 in Norway . He was the son of a wealthy merchant False composers such as Edvard Grieg and Richard Wagner. 7 [' compos', 'ers', ' such', ' as', ' Ed', 'vard', ' Gri', 'eg']
+2543 572 Name of mother of x -1 Name of mother of Ted Kennedy Rose Kennedy Ted Kennedy "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False U.S. Senator Ted Kennedy (JFK's brother), on 6 [' U', '.', 'S', '.', ' Senator', ' Ted', ' Kennedy']
+2544 572 Name of mother of x -1 Name of mother of Ted Kennedy Rose Kennedy Ted Kennedy "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False absence of Senator Ted Kennedy from the field, 4 [' absence', ' of', ' Senator', ' Ted', ' Kennedy']
+2545 572 Name of mother of x -1 Name of mother of Ted Kennedy Rose Kennedy Ted Kennedy "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False children. Senators Ted Kennedy — a Democrat 4 [' children', '.', ' Senators', ' Ted', ' Kennedy']
+2546 572 Name of mother of x -1 Name of mother of Ted Kennedy Rose Kennedy Ted Kennedy "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False 58 as he succeeded Ted Kennedy and Jim Thomson 5 [' 58', ' as', ' he', ' succeeded', ' Ted', ' Kennedy']
+2547 572 Name of mother of x -1 Name of mother of Ted Kennedy Rose Kennedy Ted Kennedy "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False up for a vote, Ted Kennedy — a longtime 6 [' up', ' for', ' a', ' vote', ',', ' Ted', ' Kennedy']
+2548 573 Name of mother of x -1 Name of mother of Johann Strauss II Maria Anna Streim Johann Strauss II "[',' ' the' ' son' ' of' ' Johann' ' Strauss' ' II' ',' ' the' ' composer'
+ ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan' 'ube' '""']" ", the son of Johann Strauss II , the composer of the w alt z "" Blue Dan ube """ False Tritsch-Tratsch-Polka by Johann Strauss II was featured 13 [' T', 'rit', 'sch', '-', 'Tr', 'ats', 'ch', '-', 'Pol', 'ka', ' by', ' Johann', ' Strauss', ' II']
+2549 573 Name of mother of x -1 Name of mother of Johann Strauss II Maria Anna Streim Johann Strauss II "[',' ' the' ' son' ' of' ' Johann' ' Strauss' ' II' ',' ' the' ' composer'
+ ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan' 'ube' '""']" ", the son of Johann Strauss II , the composer of the w alt z "" Blue Dan ube """ False 4 ['J', 'oh', 'ann', ' Strauss', ' II']
+2550 573 Name of mother of x -1 Name of mother of Johann Strauss II Maria Anna Streim Johann Strauss II "[',' ' the' ' son' ' of' ' Johann' ' Strauss' ' II' ',' ' the' ' composer'
+ ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan' 'ube' '""']" ", the son of Johann Strauss II , the composer of the w alt z "" Blue Dan ube """ False three occasions. Johann Strauss II and his orchestra 5 [' three', ' occasions', '.', ' Johann', ' Strauss', ' II']
+2551 573 Name of mother of x -1 Name of mother of Johann Strauss II Maria Anna Streim Johann Strauss II "[',' ' the' ' son' ' of' ' Johann' ' Strauss' ' II' ',' ' the' ' composer'
+ ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan' 'ube' '""']" ", the son of Johann Strauss II , the composer of the w alt z "" Blue Dan ube """ False " Johann Strauss II =
+" 2 [' Johann', ' Strauss', ' II']
+2552 573 Name of mother of x -1 Name of mother of Johann Strauss II Maria Anna Streim Johann Strauss II "[',' ' the' ' son' ' of' ' Johann' ' Strauss' ' II' ',' ' the' ' composer'
+ ' of' ' the' ' w' 'alt' 'z' ' ""' 'Blue' ' Dan' 'ube' '""']" ", the son of Johann Strauss II , the composer of the w alt z "" Blue Dan ube """ False " Johann Strauss II =
+" 2 [' Johann', ' Strauss', ' II']
+2553 574 Name of mother of x -1 Name of mother of Jean-Luc Picard Yvette Gessard Picard Jean-Luc Picard "[',' ' the' ' captain' ' of' ' the' ' Enterprise' '-' 'D' ',' ' and'
+ ' the' ' Enterprise' '-' 'D' ',' ' the' ' USS' ' Enterprise' '-' 'D']" , the captain of the Enterprise - D , and the Enterprise - D , the USS Enterprise - D False " melds with Captain Jean-Luc Picard (Patrick Stewart).
+" 8 [' me', 'ld', 's', ' with', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2554 574 Name of mother of x -1 Name of mother of Jean-Luc Picard Yvette Gessard Picard Jean-Luc Picard "[',' ' the' ' captain' ' of' ' the' ' Enterprise' '-' 'D' ',' ' and'
+ ' the' ' Enterprise' '-' 'D' ',' ' the' ' USS' ' Enterprise' '-' 'D']" , the captain of the Enterprise - D , and the Enterprise - D , the USS Enterprise - D False episode, Captain Jean-Luc Picard (Patrick Stewart) tries 6 [' episode', ',', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2555 574 Name of mother of x -1 Name of mother of Jean-Luc Picard Yvette Gessard Picard Jean-Luc Picard "[',' ' the' ' captain' ' of' ' the' ' Enterprise' '-' 'D' ',' ' and'
+ ' the' ' Enterprise' '-' 'D' ',' ' the' ' USS' ' Enterprise' '-' 'D']" , the captain of the Enterprise - D , and the Enterprise - D , the USS Enterprise - D False vessel, but Captain Jean-Luc Picard (Patrick Stewart) 7 [' vessel', ',', ' but', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2556 574 Name of mother of x -1 Name of mother of Jean-Luc Picard Yvette Gessard Picard Jean-Luc Picard "[',' ' the' ' captain' ' of' ' the' ' Enterprise' '-' 'D' ',' ' and'
+ ' the' ' Enterprise' '-' 'D' ',' ' the' ' USS' ' Enterprise' '-' 'D']" , the captain of the Enterprise - D , and the Enterprise - D , the USS Enterprise - D False meets with Captain Jean-Luc Picard (Patrick Stewart). 6 [' meets', ' with', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2557 574 Name of mother of x -1 Name of mother of Jean-Luc Picard Yvette Gessard Picard Jean-Luc Picard "[',' ' the' ' captain' ' of' ' the' ' Enterprise' '-' 'D' ',' ' and'
+ ' the' ' Enterprise' '-' 'D' ',' ' the' ' USS' ' Enterprise' '-' 'D']" , the captain of the Enterprise - D , and the Enterprise - D , the USS Enterprise - D False led by Captain Jean-Luc Picard (Patrick Stewart) 6 [' led', ' by', ' Captain', ' Jean', '-', 'Luc', ' Picard']
+2558 575 Name of mother of x -1 Name of mother of John Dryden Mary Pickering John Dryden "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' your' ' blog'
+ ' for' ' a' ' while' ' now' ' and' ' I' ' have' ' to' ' say' ' that']" ", the
+
+ I have been reading your blog for a while now and I have to say that" False scene of an attack on John Dryden in 1679 by 7 [' scene', ' of', ' an', ' attack', ' on', ' John', ' Dry', 'den']
+2559 575 Name of mother of x -1 Name of mother of John Dryden Mary Pickering John Dryden "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' your' ' blog'
+ ' for' ' a' ' while' ' now' ' and' ' I' ' have' ' to' ' say' ' that']" ", the
+
+ I have been reading your blog for a while now and I have to say that" False reliance of the ideas of John Dryden and William Hazlitt 7 [' reliance', ' of', ' the', ' ideas', ' of', ' John', ' Dry', 'den']
+2560 575 Name of mother of x -1 Name of mother of John Dryden Mary Pickering John Dryden "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' your' ' blog'
+ ' for' ' a' ' while' ' now' ' and' ' I' ' have' ' to' ' say' ' that']" ", the
+
+ I have been reading your blog for a while now and I have to say that" False Tempest. In John Dryden and William Davenant's 5 [' Tempest', '.', ' In', ' John', ' Dry', 'den']
+2561 575 Name of mother of x -1 Name of mother of John Dryden Mary Pickering John Dryden "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' your' ' blog'
+ ' for' ' a' ' while' ' now' ' and' ' I' ' have' ' to' ' say' ' that']" ", the
+
+ I have been reading your blog for a while now and I have to say that" False Shakespeare, Ben Jonson, John Dryden and others, but with 8 [' Shakespeare', ',', ' Ben', ' J', 'onson', ',', ' John', ' Dry', 'den']
+2562 575 Name of mother of x -1 Name of mother of John Dryden Mary Pickering John Dryden "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' your' ' blog'
+ ' for' ' a' ' while' ' now' ' and' ' I' ' have' ' to' ' say' ' that']" ", the
+
+ I have been reading your blog for a while now and I have to say that" False of The Tempest. In John Dryden and William Davenant's 7 [' of', ' The', ' Tempest', '.', ' In', ' John', ' Dry', 'den']
+2563 576 Name of mother of x -1 Name of mother of F. Scott Fitzgerald Mary McQuillan F. Scott Fitzgerald "[',' ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',' ' and'
+ ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',']" , the author of The Great G ats by , and the author of The Great G ats by , False the future novelist F. Scott Fitzgerald in July 1918, 6 [' the', ' future', ' novelist', ' F', '.', ' Scott', ' Fitzgerald']
+2564 576 Name of mother of x -1 Name of mother of F. Scott Fitzgerald Mary McQuillan F. Scott Fitzgerald "[',' ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',' ' and'
+ ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',']" , the author of The Great G ats by , and the author of The Great G ats by , False " of Crawford:
+" 6 [' of', ' Crawford', ':', 'F', '.', ' Scott', ' Fitzgerald']
+2565 576 Name of mother of x -1 Name of mother of F. Scott Fitzgerald Mary McQuillan F. Scott Fitzgerald "[',' ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',' ' and'
+ ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',']" , the author of The Great G ats by , and the author of The Great G ats by , False 3 ['F', '.', ' Scott', ' Fitzgerald']
+2566 576 Name of mother of x -1 Name of mother of F. Scott Fitzgerald Mary McQuillan F. Scott Fitzgerald "[',' ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',' ' and'
+ ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',']" , the author of The Great G ats by , and the author of The Great G ats by , False American author F. Scott Fitzgerald that follows 5 [' American', ' author', ' F', '.', ' Scott', ' Fitzgerald']
+2567 576 Name of mother of x -1 Name of mother of F. Scott Fitzgerald Mary McQuillan F. Scott Fitzgerald "[',' ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',' ' and'
+ ' the' ' author' ' of' ' The' ' Great' ' G' 'ats' 'by' ',']" , the author of The Great G ats by , and the author of The Great G ats by , False the spring of 2006. F. Scott Fitzgerald famously enjoyed 8 [' the', ' spring', ' of', ' 2006', '.', ' F', '.', ' Scott', ' Fitzgerald']
+2568 577 Name of mother of x -1 Name of mother of Oprah Winfrey Vernita Lee Winfrey Oprah Winfrey "[',' ' the' ' first' ' black' ' woman' ' to' ' be' ' elected' ' president'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' first' ' black'
+ ' president']" ", the first black woman to be elected president of the United States .
+
+ The first black president" False appeared on The Oprah Winfrey Show with then-girlfriend 5 [' appeared', ' on', ' The', ' Oprah', ' Win', 'frey']
+2569 577 Name of mother of x -1 Name of mother of Oprah Winfrey Vernita Lee Winfrey Oprah Winfrey "[',' ' the' ' first' ' black' ' woman' ' to' ' be' ' elected' ' president'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' first' ' black'
+ ' president']" ", the first black woman to be elected president of the United States .
+
+ The first black president" False version on The Oprah Winfrey Show, in May 2011, 5 [' version', ' on', ' The', ' Oprah', ' Win', 'frey']
+2570 577 Name of mother of x -1 Name of mother of Oprah Winfrey Vernita Lee Winfrey Oprah Winfrey "[',' ' the' ' first' ' black' ' woman' ' to' ' be' ' elected' ' president'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' first' ' black'
+ ' president']" ", the first black woman to be elected president of the United States .
+
+ The first black president" False Barack Obama and Oprah Winfrey to Lady Gaga and 5 [' Barack', ' Obama', ' and', ' Oprah', ' Win', 'frey']
+2571 577 Name of mother of x -1 Name of mother of Oprah Winfrey Vernita Lee Winfrey Oprah Winfrey "[',' ' the' ' first' ' black' ' woman' ' to' ' be' ' elected' ' president'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' first' ' black'
+ ' president']" ", the first black woman to be elected president of the United States .
+
+ The first black president" False the track on The Oprah Winfrey Show on May 5, 6 [' the', ' track', ' on', ' The', ' Oprah', ' Win', 'frey']
+2572 577 Name of mother of x -1 Name of mother of Oprah Winfrey Vernita Lee Winfrey Oprah Winfrey "[',' ' the' ' first' ' black' ' woman' ' to' ' be' ' elected' ' president'
+ ' of' ' the' ' United' ' States' '.' '\n' '\n' 'The' ' first' ' black'
+ ' president']" ", the first black woman to be elected president of the United States .
+
+ The first black president" False her interview with Oprah Winfrey after her boyfriend, 5 [' her', ' interview', ' with', ' Oprah', ' Win', 'frey']
+2573 578 Name of mother of x -1 Name of mother of Nicolas Sarkozy Andrée Mallah Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French'
+ ' president' ""'s"" ' wife' ',' ' Car' 'la' ' Brun' 'i' ',' ' were' ' also'
+ ' present']" , the French president , and the French president 's wife , Car la Brun i , were also present False French president Nicolas Sarkozy on 12 August, military 4 [' French', ' president', ' Nicolas', ' Sark', 'ozy']
+2574 578 Name of mother of x -1 Name of mother of Nicolas Sarkozy Andrée Mallah Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French'
+ ' president' ""'s"" ' wife' ',' ' Car' 'la' ' Brun' 'i' ',' ' were' ' also'
+ ' present']" , the French president , and the French president 's wife , Car la Brun i , were also present False of France, Nicolas Sarkozy inaugurated an exhibition 5 [' of', ' France', ',', ' Nicolas', ' Sark', 'ozy']
+2575 578 Name of mother of x -1 Name of mother of Nicolas Sarkozy Andrée Mallah Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French'
+ ' president' ""'s"" ' wife' ',' ' Car' 'la' ' Brun' 'i' ',' ' were' ' also'
+ ' present']" , the French president , and the French president 's wife , Car la Brun i , were also present False French president Nicolas Sarkozy on 12 August, military 4 [' French', ' president', ' Nicolas', ' Sark', 'ozy']
+2576 578 Name of mother of x -1 Name of mother of Nicolas Sarkozy Andrée Mallah Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French'
+ ' president' ""'s"" ' wife' ',' ' Car' 'la' ' Brun' 'i' ',' ' were' ' also'
+ ' present']" , the French president , and the French president 's wife , Car la Brun i , were also present False French President Nicolas Sarkozy (the President-in-Office 4 [' French', ' President', ' Nicolas', ' Sark', 'ozy']
+2577 578 Name of mother of x -1 Name of mother of Nicolas Sarkozy Andrée Mallah Nicolas Sarkozy "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French'
+ ' president' ""'s"" ' wife' ',' ' Car' 'la' ' Brun' 'i' ',' ' were' ' also'
+ ' present']" , the French president , and the French president 's wife , Car la Brun i , were also present False French president Nicolas Sarkozy and his wife Carla 4 [' French', ' president', ' Nicolas', ' Sark', 'ozy']
+2578 579 Name of mother of x -1 Name of mother of James Cook Grace Pace James Cook "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' first' ' marriage' ' of' ' the' ' first' ' marriage' ' of' ' the'
+ ' first' ' marriage' ' of']" , the first of the three children of the first marriage of the first marriage of the first marriage of False British explorer James Cook subsequently anglicised 3 [' British', ' explorer', ' James', ' Cook']
+2579 579 Name of mother of x -1 Name of mother of James Cook Grace Pace James Cook "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' first' ' marriage' ' of' ' the' ' first' ' marriage' ' of' ' the'
+ ' first' ' marriage' ' of']" , the first of the three children of the first marriage of the first marriage of the first marriage of False she competes against James Cook (Jack O 'Connell) 5 [' she', ' compet', 'es', ' against', ' James', ' Cook']
+2580 579 Name of mother of x -1 Name of mother of James Cook Grace Pace James Cook "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' first' ' marriage' ' of' ' the' ' first' ' marriage' ' of' ' the'
+ ' first' ' marriage' ' of']" , the first of the three children of the first marriage of the first marriage of the first marriage of False Research Ship RRS James Cook with ROV Kiel 5 [' Research', ' Ship', ' R', 'RS', ' James', ' Cook']
+2581 579 Name of mother of x -1 Name of mother of James Cook Grace Pace James Cook "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' first' ' marriage' ' of' ' the' ' first' ' marriage' ' of' ' the'
+ ' first' ' marriage' ' of']" , the first of the three children of the first marriage of the first marriage of the first marriage of False expedition with James Cook on the HMS Endeavour. 3 [' expedition', ' with', ' James', ' Cook']
+2582 579 Name of mother of x -1 Name of mother of James Cook Grace Pace James Cook "[',' ' the' ' first' ' of' ' the' ' three' ' children' ' of' ' the'
+ ' first' ' marriage' ' of' ' the' ' first' ' marriage' ' of' ' the'
+ ' first' ' marriage' ' of']" , the first of the three children of the first marriage of the first marriage of the first marriage of False the explorer James Cook learned his trade 3 [' the', ' explorer', ' James', ' Cook']
+2583 581 Name of mother of x -1 Name of mother of Yehudi Menuhin Marutha Sher Yehudi Menuhin "[',' ' the' ' violin' 'ist' ',' ' and' ' his' ' wife' ',' ' the' ' pian'
+ 'ist' ',' ' and' ' their' ' daughter' ',' ' the' ' cell' 'ist']" , the violin ist , and his wife , the pian ist , and their daughter , the cell ist False watch violinist Yehudi Menuhin perform in London 8 [' watch', ' violin', 'ist', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+2584 581 Name of mother of x -1 Name of mother of Yehudi Menuhin Marutha Sher Yehudi Menuhin "[',' ' the' ' violin' 'ist' ',' ' and' ' his' ' wife' ',' ' the' ' pian'
+ 'ist' ',' ' and' ' their' ' daughter' ',' ' the' ' cell' 'ist']" , the violin ist , and his wife , the pian ist , and their daughter , the cell ist False classical violinist Yehudi Menuhin were performing in 8 [' classical', ' violin', 'ist', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+2585 581 Name of mother of x -1 Name of mother of Yehudi Menuhin Marutha Sher Yehudi Menuhin "[',' ' the' ' violin' 'ist' ',' ' and' ' his' ' wife' ',' ' the' ' pian'
+ 'ist' ',' ' and' ' their' ' daughter' ',' ' the' ' cell' 'ist']" , the violin ist , and his wife , the pian ist , and their daughter , the cell ist False watch violinist Yehudi Menuhin perform in London 8 [' watch', ' violin', 'ist', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+2586 581 Name of mother of x -1 Name of mother of Yehudi Menuhin Marutha Sher Yehudi Menuhin "[',' ' the' ' violin' 'ist' ',' ' and' ' his' ' wife' ',' ' the' ' pian'
+ 'ist' ',' ' and' ' their' ' daughter' ',' ' the' ' cell' 'ist']" , the violin ist , and his wife , the pian ist , and their daughter , the cell ist False town by violinist Yehudi Menuhin was held on 9 [' town', ' by', ' violin', 'ist', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+2587 581 Name of mother of x -1 Name of mother of Yehudi Menuhin Marutha Sher Yehudi Menuhin "[',' ' the' ' violin' 'ist' ',' ' and' ' his' ' wife' ',' ' the' ' pian'
+ 'ist' ',' ' and' ' their' ' daughter' ',' ' the' ' cell' 'ist']" , the violin ist , and his wife , the pian ist , and their daughter , the cell ist False audiences, along with Yehudi Menuhin and John Coltrane, 9 [' audiences', ',', ' along', ' with', ' Ye', 'h', 'udi', ' Men', 'uh', 'in']
+2588 582 Name of mother of x -1 Name of mother of Frederic Edwin Church Eliza Janes Frederic Edwin Church "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' the' ' artist' ""'s""
+ ' daughter' ',' ' who' ' was' ' a' '\n' '\n' 'The' ' painting']" ", the painter , and his wife , the artist 's daughter , who was a
+
+ The painting" False Britain. In 1861, Frederic Edwin Church unveiled his great 8 [' Britain', '.', ' In', ' 1861', ',', ' Freder', 'ic', ' Edwin', ' Church']
+2589 584 Name of mother of x -1 Name of mother of Melanie Griffith Tippi Hedren Melanie Griffith "['s' ',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' family' ' was' ' very' ' close' '.' ' They' ' were']" "s , who was a friend of the family .
+
+ The family was very close . They were" False " as the ""summer's Melanie Griffith – a honey-haired" 7 "[' as', ' the', ' ""', 'sum', 'mer', ""'s"", ' Melanie', ' Griffith']"
+2590 584 Name of mother of x -1 Name of mother of Melanie Griffith Tippi Hedren Melanie Griffith "['s' ',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' family' ' was' ' very' ' close' '.' ' They' ' were']" "s , who was a friend of the family .
+
+ The family was very close . They were" False American actress Melanie Griffith played herself 3 [' American', ' actress', ' Melanie', ' Griffith']
+2591 584 Name of mother of x -1 Name of mother of Melanie Griffith Tippi Hedren Melanie Griffith "['s' ',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' family' ' was' ' very' ' close' '.' ' They' ' were']" "s , who was a friend of the family .
+
+ The family was very close . They were" False American actress Melanie Griffith played herself as 3 [' American', ' actress', ' Melanie', ' Griffith']
+2592 584 Name of mother of x -1 Name of mother of Melanie Griffith Tippi Hedren Melanie Griffith "['s' ',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' family' ' was' ' very' ' close' '.' ' They' ' were']" "s , who was a friend of the family .
+
+ The family was very close . They were" False third-tier Melanie Griffith rom-com or a forgotten 4 [' third', '-', 'tier', ' Melanie', ' Griffith']
+2593 584 Name of mother of x -1 Name of mother of Melanie Griffith Tippi Hedren Melanie Griffith "['s' ',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n' '\n'
+ 'The' ' family' ' was' ' very' ' close' '.' ' They' ' were']" "s , who was a friend of the family .
+
+ The family was very close . They were" False American actress Melanie Griffith played herself 3 [' American', ' actress', ' Melanie', ' Griffith']
+2594 585 Name of mother of x -1 Name of mother of John Maynard Keynes Florence Ada Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Keynes' 'ian' ' revolution' ',' ' was' ' a' ' man' ' of' ' the' ' left'
+ '.']" , the economist , and the father of the Keynes ian revolution , was a man of the left . False about £ 9,000. John Maynard Keynes was one of about 9 [' about', ' £', ' 9', ',', '000', '.', ' John', ' May', 'nard', ' Keynes']
+2595 585 Name of mother of x -1 Name of mother of John Maynard Keynes Florence Ada Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Keynes' 'ian' ' revolution' ',' ' was' ' a' ' man' ' of' ' the' ' left'
+ '.']" , the economist , and the father of the Keynes ian revolution , was a man of the left . False Alan Greenspan, John Maynard Keynes and Paul Krugman, 7 [' Alan', ' Greens', 'pan', ',', ' John', ' May', 'nard', ' Keynes']
+2596 585 Name of mother of x -1 Name of mother of John Maynard Keynes Florence Ada Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Keynes' 'ian' ' revolution' ',' ' was' ' a' ' man' ' of' ' the' ' left'
+ '.']" , the economist , and the father of the Keynes ian revolution , was a man of the left . False " Maynard Keynes =
+" 7 [' May', 'nard', ' Keynes', ' =', 'John', ' May', 'nard', ' Keynes']
+2597 585 Name of mother of x -1 Name of mother of John Maynard Keynes Florence Ada Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Keynes' 'ian' ' revolution' ',' ' was' ' a' ' man' ' of' ' the' ' left'
+ '.']" , the economist , and the father of the Keynes ian revolution , was a man of the left . False as real output. John Maynard Keynes attacked some of these 7 [' as', ' real', ' output', '.', ' John', ' May', 'nard', ' Keynes']
+2598 585 Name of mother of x -1 Name of mother of John Maynard Keynes Florence Ada Keynes John Maynard Keynes "[',' ' the' ' economist' ',' ' and' ' the' ' father' ' of' ' the'
+ ' Keynes' 'ian' ' revolution' ',' ' was' ' a' ' man' ' of' ' the' ' left'
+ '.']" , the economist , and the father of the Keynes ian revolution , was a man of the left . False 3 ['John', ' May', 'nard', ' Keynes']
+2599 586 Name of mother of x -1 Name of mother of Ary Scheffer Cornelia Scheffer Ary Scheffer "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary' ' Sche' 'ffer'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary']" ", the
+
+ Name of mother of Ary Sche ffer , the
+
+ Name of mother of Ary" False the Dutch painter Ary Scheffer and the historian 5 [' the', ' Dutch', ' painter', ' Ary', ' Sche', 'ffer']
+2600 586 Name of mother of x -1 Name of mother of Ary Scheffer Cornelia Scheffer Ary Scheffer "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary' ' Sche' 'ffer'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary']" ", the
+
+ Name of mother of Ary Sche ffer , the
+
+ Name of mother of Ary" False a painting by Ary Scheffer and a Father Willis 5 [' a', ' painting', ' by', ' Ary', ' Sche', 'ffer']
+2601 586 Name of mother of x -1 Name of mother of Ary Scheffer Cornelia Scheffer Ary Scheffer "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary' ' Sche' 'ffer'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary']" ", the
+
+ Name of mother of Ary Sche ffer , the
+
+ Name of mother of Ary" False the Dutch painter Ary Scheffer and the historian 5 [' the', ' Dutch', ' painter', ' Ary', ' Sche', 'ffer']
+2602 586 Name of mother of x -1 Name of mother of Ary Scheffer Cornelia Scheffer Ary Scheffer "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary' ' Sche' 'ffer'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary']" ", the
+
+ Name of mother of Ary Sche ffer , the
+
+ Name of mother of Ary" False Dutch painter Ary Scheffer and the historian 4 [' Dutch', ' painter', ' Ary', ' Sche', 'ffer']
+2603 586 Name of mother of x -1 Name of mother of Ary Scheffer Cornelia Scheffer Ary Scheffer "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary' ' Sche' 'ffer'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ary']" ", the
+
+ Name of mother of Ary Sche ffer , the
+
+ Name of mother of Ary" False the Dutch painter Ary Scheffer and the historian 5 [' the', ' Dutch', ' painter', ' Ary', ' Sche', 'ffer']
+2604 587 Name of mother of x -1 Name of mother of Muhammad Ali Odessa Grady Clay Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the'
+ ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great'
+ ' boxer']" , the great boxer , the great boxer , the great boxer , the great boxer , the great boxer False negotiations with Muhammad Ali Jinnah, who was the 3 [' negotiations', ' with', ' Muhammad', ' Ali']
+2605 587 Name of mother of x -1 Name of mother of Muhammad Ali Odessa Grady Clay Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the'
+ ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great'
+ ' boxer']" , the great boxer , the great boxer , the great boxer , the great boxer , the great boxer False 2 ['Mu', 'hammad', ' Ali']
+2606 587 Name of mother of x -1 Name of mother of Muhammad Ali Odessa Grady Clay Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the'
+ ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great'
+ ' boxer']" , the great boxer , the great boxer , the great boxer , the great boxer , the great boxer False 2 ['Mu', 'hammad', ' Ali']
+2607 587 Name of mother of x -1 Name of mother of Muhammad Ali Odessa Grady Clay Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the'
+ ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great'
+ ' boxer']" , the great boxer , the great boxer , the great boxer , the great boxer , the great boxer False rule of the Muhammad Ali Dynasty of 4 [' rule', ' of', ' the', ' Muhammad', ' Ali']
+2608 587 Name of mother of x -1 Name of mother of Muhammad Ali Odessa Grady Clay Muhammad Ali "[',' ' the' ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the'
+ ' great' ' boxer' ',' ' the' ' great' ' boxer' ',' ' the' ' great'
+ ' boxer']" , the great boxer , the great boxer , the great boxer , the great boxer , the great boxer False the most famous was Muhammad Ali of Egypt, also 5 [' the', ' most', ' famous', ' was', ' Muhammad', ' Ali']
+2609 588 Name of mother of x -1 Name of mother of Emily Dickinson Emily Norcross Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 1 ['Emily', ' Dickinson']
+2610 588 Name of mother of x -1 Name of mother of Emily Dickinson Emily Norcross Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False women from Emily Dickinson College as dates 3 [' women', ' from', ' Emily', ' Dickinson']
+2611 588 Name of mother of x -1 Name of mother of Emily Dickinson Emily Norcross Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Enigmatic poet Emily Dickinson wrote over 300 letters 4 [' En', 'igmatic', ' poet', ' Emily', ' Dickinson']
+2612 588 Name of mother of x -1 Name of mother of Emily Dickinson Emily Norcross Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 1 ['Emily', ' Dickinson']
+2613 588 Name of mother of x -1 Name of mother of Emily Dickinson Emily Norcross Dickinson Emily Dickinson "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False young women from Emily Dickinson College as dates 4 [' young', ' women', ' from', ' Emily', ' Dickinson']
+2614 589 Name of mother of x -1 Name of mother of Pope Francis Regina María Sívori Pope Francis "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False 1 ['Pope', ' Francis']
+2615 589 Name of mother of x -1 Name of mother of Pope Francis Regina María Sívori Pope Francis "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False that Obama, with Pope Francis as an intermediary, 5 [' that', ' Obama', ',', ' with', ' Pope', ' Francis']
+2616 589 Name of mother of x -1 Name of mother of Pope Francis Regina María Sívori Pope Francis "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False Hagupit a month prior, Pope Francis visited the storm-ravaged 8 [' Hag', 'up', 'it', ' a', ' month', ' prior', ',', ' Pope', ' Francis']
+2617 589 Name of mother of x -1 Name of mother of Pope Francis Regina María Sívori Pope Francis "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False to 19, 2015, Pope Francis stayed in the Philippines 6 [' to', ' 19', ',', ' 2015', ',', ' Pope', ' Francis']
+2618 589 Name of mother of x -1 Name of mother of Pope Francis Regina María Sívori Pope Francis "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False announced that Pope Francis had replaced the 3 [' announced', ' that', ' Pope', ' Francis']
+2619 590 Name of mother of x -1 Name of mother of Raquel Welch Josephine Sarah Hall Raquel Welch "[',' ' who' ' was' ' a' ' big' ' star' ' in' ' the' ' 1960' 's' ' and'
+ ' 1970' 's' '.' '\n' '\n' 'The' ' film' ' was' ' directed']" ", who was a big star in the 1960 s and 1970 s .
+
+ The film was directed" False attentions towards Raquel Welch after seeing her 5 [' attent', 'ions', ' towards', ' Ra', 'quel', ' Welch']
+2620 590 Name of mother of x -1 Name of mother of Raquel Welch Josephine Sarah Hall Raquel Welch "[',' ' who' ' was' ' a' ' big' ' star' ' in' ' the' ' 1960' 's' ' and'
+ ' 1970' 's' '.' '\n' '\n' 'The' ' film' ' was' ' directed']" ", who was a big star in the 1960 s and 1970 s .
+
+ The film was directed" False 2 ['Ra', 'quel', ' Welch']
+2621 590 Name of mother of x -1 Name of mother of Raquel Welch Josephine Sarah Hall Raquel Welch "[',' ' who' ' was' ' a' ' big' ' star' ' in' ' the' ' 1960' 's' ' and'
+ ' 1970' 's' '.' '\n' '\n' 'The' ' film' ' was' ' directed']" ", who was a big star in the 1960 s and 1970 s .
+
+ The film was directed" False deer skin bikini Raquel Welch wore in the 5 [' deer', ' skin', ' bikini', ' Ra', 'quel', ' Welch']
+2622 590 Name of mother of x -1 Name of mother of Raquel Welch Josephine Sarah Hall Raquel Welch "[',' ' who' ' was' ' a' ' big' ' star' ' in' ' the' ' 1960' 's' ' and'
+ ' 1970' 's' '.' '\n' '\n' 'The' ' film' ' was' ' directed']" ", who was a big star in the 1960 s and 1970 s .
+
+ The film was directed" False the time. In 1972, Raquel Welch visited Stamford 8 [' the', ' time', '.', ' In', ' 1972', ',', ' Ra', 'quel', ' Welch']
+2623 590 Name of mother of x -1 Name of mother of Raquel Welch Josephine Sarah Hall Raquel Welch "[',' ' who' ' was' ' a' ' big' ' star' ' in' ' the' ' 1960' 's' ' and'
+ ' 1970' 's' '.' '\n' '\n' 'The' ' film' ' was' ' directed']" ", who was a big star in the 1960 s and 1970 s .
+
+ The film was directed" False classic picture of Raquel Welch on the cross taken 5 [' classic', ' picture', ' of', ' Ra', 'quel', ' Welch']
+2624 591 Name of mother of x -1 Name of mother of Kirk Douglas Bertha Sanglel Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Kirk' ' Douglas' ',' ' and'
+ ' his']" , the actor , and his wife , actress , actress , and producer , Kirk Douglas , and his False applauded the casting of Kirk Douglas as a guest star, 5 [' applauded', ' the', ' casting', ' of', ' Kirk', ' Douglas']
+2625 591 Name of mother of x -1 Name of mother of Kirk Douglas Bertha Sanglel Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Kirk' ' Douglas' ',' ' and'
+ ' his']" , the actor , and his wife , actress , actress , and producer , Kirk Douglas , and his False Donald Sutherland, Kirk Douglas and Lawrence Tierney. 4 [' Donald', ' Sutherland', ',', ' Kirk', ' Douglas']
+2626 591 Name of mother of x -1 Name of mother of Kirk Douglas Bertha Sanglel Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Kirk' ' Douglas' ',' ' and'
+ ' his']" , the actor , and his wife , actress , actress , and producer , Kirk Douglas , and his False 2 ['K', 'irk', ' Douglas']
+2627 591 Name of mother of x -1 Name of mother of Kirk Douglas Bertha Sanglel Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Kirk' ' Douglas' ',' ' and'
+ ' his']" , the actor , and his wife , actress , actress , and producer , Kirk Douglas , and his False professional singer. Actor Kirk Douglas was one of Como's 5 [' professional', ' singer', '.', ' Actor', ' Kirk', ' Douglas']
+2628 591 Name of mother of x -1 Name of mother of Kirk Douglas Bertha Sanglel Kirk Douglas "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ','
+ ' actress' ',' ' and' ' producer' ',' ' Kirk' ' Douglas' ',' ' and'
+ ' his']" , the actor , and his wife , actress , actress , and producer , Kirk Douglas , and his False White Stallions, Kirk Douglas in the 1966 5 [' White', ' Stall', 'ions', ',', ' Kirk', ' Douglas']
+2629 592 Name of mother of x -1 Name of mother of Thomas Carlyle Margaret Aitken Carlyle Thomas Carlyle "[',' ' the' ' great' ' historian' ',' ' was' ' a' ' man' ' of' ' the'
+ ' people' ',' ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He']" , the great historian , was a man of the people , and a man of the people . He False and a great man. Thomas Carlyle in his book Heroes 7 [' and', ' a', ' great', ' man', '.', ' Thomas', ' Carly', 'le']
+2630 592 Name of mother of x -1 Name of mother of Thomas Carlyle Margaret Aitken Carlyle Thomas Carlyle "[',' ' the' ' great' ' historian' ',' ' was' ' a' ' man' ' of' ' the'
+ ' people' ',' ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He']" , the great historian , was a man of the people , and a man of the people . He False century; for example, Thomas Carlyle (1840) sometimes 7 [' century', ';', ' for', ' example', ',', ' Thomas', ' Carly', 'le']
+2631 592 Name of mother of x -1 Name of mother of Thomas Carlyle Margaret Aitken Carlyle Thomas Carlyle "[',' ' the' ' great' ' historian' ',' ' was' ' a' ' man' ' of' ' the'
+ ' people' ',' ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He']" , the great historian , was a man of the people , and a man of the people . He False English historian Thomas Carlyle (1795 – 1881) 4 [' English', ' historian', ' Thomas', ' Carly', 'le']
+2632 592 Name of mother of x -1 Name of mother of Thomas Carlyle Margaret Aitken Carlyle Thomas Carlyle "[',' ' the' ' great' ' historian' ',' ' was' ' a' ' man' ' of' ' the'
+ ' people' ',' ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He']" , the great historian , was a man of the people , and a man of the people . He False historian and sage Thomas Carlyle (who was Froude's 5 [' historian', ' and', ' sage', ' Thomas', ' Carly', 'le']
+2633 592 Name of mother of x -1 Name of mother of Thomas Carlyle Margaret Aitken Carlyle Thomas Carlyle "[',' ' the' ' great' ' historian' ',' ' was' ' a' ' man' ' of' ' the'
+ ' people' ',' ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He']" , the great historian , was a man of the people , and a man of the people . He False and a great man. Thomas Carlyle in his book 7 [' and', ' a', ' great', ' man', '.', ' Thomas', ' Carly', 'le']
+2634 593 Name of mother of x -1 Name of mother of Shimon Peres Sara Meltzer Shimon Peres "[',' ' the' ' Israeli' ' prime' ' minister' ',' ' and' ' the' ' Israeli'
+ ' prime' ' minister' ',' ' Sh' 'imon' ' Pe' 'res' ',' ' is' ' a' ' man']" , the Israeli prime minister , and the Israeli prime minister , Sh imon Pe res , is a man False to Israeli President Shimon Peres on 8 August 6 [' to', ' Israeli', ' President', ' Sh', 'imon', ' Pe', 'res']
+2635 593 Name of mother of x -1 Name of mother of Shimon Peres Sara Meltzer Shimon Peres "[',' ' the' ' Israeli' ' prime' ' minister' ',' ' and' ' the' ' Israeli'
+ ' prime' ' minister' ',' ' Sh' 'imon' ' Pe' 'res' ',' ' is' ' a' ' man']" , the Israeli prime minister , and the Israeli prime minister , Sh imon Pe res , is a man False Prime Minister Shimon Peres organized a committee 5 [' Prime', ' Minister', ' Sh', 'imon', ' Pe', 'res']
+2636 593 Name of mother of x -1 Name of mother of Shimon Peres Sara Meltzer Shimon Peres "[',' ' the' ' Israeli' ' prime' ' minister' ',' ' and' ' the' ' Israeli'
+ ' prime' ' minister' ',' ' Sh' 'imon' ' Pe' 'res' ',' ' is' ' a' ' man']" , the Israeli prime minister , and the Israeli prime minister , Sh imon Pe res , is a man False Israeli Prime Minister Shimon Peres for help in the 6 [' Israeli', ' Prime', ' Minister', ' Sh', 'imon', ' Pe', 'res']
+2637 593 Name of mother of x -1 Name of mother of Shimon Peres Sara Meltzer Shimon Peres "[',' ' the' ' Israeli' ' prime' ' minister' ',' ' and' ' the' ' Israeli'
+ ' prime' ' minister' ',' ' Sh' 'imon' ' Pe' 'res' ',' ' is' ' a' ' man']" , the Israeli prime minister , and the Israeli prime minister , Sh imon Pe res , is a man False Israeli president Shimon Peres and visited Refaeli's 5 [' Israeli', ' president', ' Sh', 'imon', ' Pe', 'res']
+2638 593 Name of mother of x -1 Name of mother of Shimon Peres Sara Meltzer Shimon Peres "[',' ' the' ' Israeli' ' prime' ' minister' ',' ' and' ' the' ' Israeli'
+ ' prime' ' minister' ',' ' Sh' 'imon' ' Pe' 'res' ',' ' is' ' a' ' man']" , the Israeli prime minister , and the Israeli prime minister , Sh imon Pe res , is a man False Israeli President Shimon Peres described Obama's 5 [' Israeli', ' President', ' Sh', 'imon', ' Pe', 'res']
+2639 594 Name of mother of x -1 Name of mother of Albert, Prince Consort Princess Louise of Saxe-Gotha-Altenburg Albert, Prince Consort "[' of' ' England' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" " of England , and the
+ " False " Consort =
+" 7 [' Cons', 'ort', ' =', 'Albert', ',', ' Prince', ' Cons', 'ort']
+2640 594 Name of mother of x -1 Name of mother of Albert, Prince Consort Princess Louise of Saxe-Gotha-Altenburg Albert, Prince Consort "[' of' ' England' ',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ' ' ' ' ']" " of England , and the
+ " False " Prince Consort =
+" 8 [' Prince', ' Cons', 'ort', ' =', 'Albert', ',', ' Prince', ' Cons', 'ort']
+2641 595 Name of mother of x -1 Name of mother of Albert I of Belgium Princess Marie, Countess of Flanders Albert I of Belgium "[',' ' the' ' King' ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King'
+ ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King' ' of']" , the King of the Belg ians , and the King of the Belg ians , and the King of False horse of King Albert I of Belgium was a Charolais 6 [' horse', ' of', ' King', ' Albert', ' I', ' of', ' Belgium']
+2642 595 Name of mother of x -1 Name of mother of Albert I of Belgium Princess Marie, Countess of Flanders Albert I of Belgium "[',' ' the' ' King' ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King'
+ ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King' ' of']" , the King of the Belg ians , and the King of the Belg ians , and the King of False horse of King Albert I of Belgium was a Charolais 6 [' horse', ' of', ' King', ' Albert', ' I', ' of', ' Belgium']
+2643 595 Name of mother of x -1 Name of mother of Albert I of Belgium Princess Marie, Countess of Flanders Albert I of Belgium "[',' ' the' ' King' ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King'
+ ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King' ' of']" , the King of the Belg ians , and the King of the Belg ians , and the King of False favorite horse of King Albert I of Belgium was a Charolais named 7 [' favorite', ' horse', ' of', ' King', ' Albert', ' I', ' of', ' Belgium']
+2644 595 Name of mother of x -1 Name of mother of Albert I of Belgium Princess Marie, Countess of Flanders Albert I of Belgium "[',' ' the' ' King' ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King'
+ ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King' ' of']" , the King of the Belg ians , and the King of the Belg ians , and the King of False favorite horse of King Albert I of Belgium was a Charolais 7 [' favorite', ' horse', ' of', ' King', ' Albert', ' I', ' of', ' Belgium']
+2645 595 Name of mother of x -1 Name of mother of Albert I of Belgium Princess Marie, Countess of Flanders Albert I of Belgium "[',' ' the' ' King' ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King'
+ ' of' ' the' ' Belg' 'ians' ',' ' and' ' the' ' King' ' of']" , the King of the Belg ians , and the King of the Belg ians , and the King of False favorite horse of King Albert I of Belgium was a Charolais named 7 [' favorite', ' horse', ' of', ' King', ' Albert', ' I', ' of', ' Belgium']
+2646 596 Name of mother of x -1 Name of mother of Kim Il-sung Kang Pan-sŏk Kim Il-sung "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Korean' ' people' '.' '\n' '\n' 'The' ' Korean' ' people']" ", the father of the nation , and the father of the Korean people .
+
+ The Korean people" False the WPK, with Kim Il-sung planning to formalize 8 [' the', ' WP', 'K', ',', ' with', ' Kim', ' Il', '-', 'sung']
+2647 596 Name of mother of x -1 Name of mother of Kim Il-sung Kang Pan-sŏk Kim Il-sung "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Korean' ' people' '.' '\n' '\n' 'The' ' Korean' ' people']" ", the father of the nation , and the father of the Korean people .
+
+ The Korean people" False reappointment of Kim Il-sung as WPK General 6 [' reapp', 'ointment', ' of', ' Kim', ' Il', '-', 'sung']
+2648 596 Name of mother of x -1 Name of mother of Kim Il-sung Kang Pan-sŏk Kim Il-sung "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Korean' ' people' '.' '\n' '\n' 'The' ' Korean' ' people']" ", the father of the nation , and the father of the Korean people .
+
+ The Korean people" False At the congress, Kim Il-sung stressed the importance 7 [' At', ' the', ' congress', ',', ' Kim', ' Il', '-', 'sung']
+2649 596 Name of mother of x -1 Name of mother of Kim Il-sung Kang Pan-sŏk Kim Il-sung "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Korean' ' people' '.' '\n' '\n' 'The' ' Korean' ' people']" ", the father of the nation , and the father of the Korean people .
+
+ The Korean people" False speech given by Kim Il-sung after the liberation, 6 [' speech', ' given', ' by', ' Kim', ' Il', '-', 'sung']
+2650 596 Name of mother of x -1 Name of mother of Kim Il-sung Kang Pan-sŏk Kim Il-sung "[',' ' the' ' father' ' of' ' the' ' nation' ',' ' and' ' the' ' father'
+ ' of' ' the' ' Korean' ' people' '.' '\n' '\n' 'The' ' Korean' ' people']" ", the father of the nation , and the father of the Korean people .
+
+ The Korean people" False was believed to be Kim Il-sung's first choice as 7 [' was', ' believed', ' to', ' be', ' Kim', ' Il', '-', 'sung']
+2651 597 Name of mother of x -1 Name of mother of François Arago Marie Arago François Arago "[',' ' the' ' French' ' astronomer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' discover' ' the' ' planet' ' Neptune' '.' '\n' '\n' 'The'
+ ' planet' ' Neptune']" ", the French astronomer , who was the first to discover the planet Neptune .
+
+ The planet Neptune" False attempts by François Arago and Claude-Louis Mathieu 4 [' attempts', ' by', ' François', ' Ar', 'ago']
+2652 597 Name of mother of x -1 Name of mother of François Arago Marie Arago François Arago "[',' ' the' ' French' ' astronomer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' discover' ' the' ' planet' ' Neptune' '.' '\n' '\n' 'The'
+ ' planet' ' Neptune']" ", the French astronomer , who was the first to discover the planet Neptune .
+
+ The planet Neptune" False including attempts by François Arago and Claude-Louis Mathieu 5 [' including', ' attempts', ' by', ' François', ' Ar', 'ago']
+2653 597 Name of mother of x -1 Name of mother of François Arago Marie Arago François Arago "[',' ' the' ' French' ' astronomer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' discover' ' the' ' planet' ' Neptune' '.' '\n' '\n' 'The'
+ ' planet' ' Neptune']" ", the French astronomer , who was the first to discover the planet Neptune .
+
+ The planet Neptune" False including attempts by François Arago and Claude-Louis 5 [' including', ' attempts', ' by', ' François', ' Ar', 'ago']
+2654 597 Name of mother of x -1 Name of mother of François Arago Marie Arago François Arago "[',' ' the' ' French' ' astronomer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' discover' ' the' ' planet' ' Neptune' '.' '\n' '\n' 'The'
+ ' planet' ' Neptune']" ", the French astronomer , who was the first to discover the planet Neptune .
+
+ The planet Neptune" False attempts by François Arago and Claude-Louis 4 [' attempts', ' by', ' François', ' Ar', 'ago']
+2655 597 Name of mother of x -1 Name of mother of François Arago Marie Arago François Arago "[',' ' the' ' French' ' astronomer' ',' ' who' ' was' ' the' ' first'
+ ' to' ' discover' ' the' ' planet' ' Neptune' '.' '\n' '\n' 'The'
+ ' planet' ' Neptune']" ", the French astronomer , who was the first to discover the planet Neptune .
+
+ The planet Neptune" False communicated to François Arago the idea that 4 [' communicated', ' to', ' François', ' Ar', 'ago']
+2656 598 Name of mother of x -1 Name of mother of Boris Pasternak Rosa Kaufman Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the Russian writer , who was a friend of the family .
+
+ The house was a large" False II. In 1958, Boris Pasternak declined his 8 [' II', '.', ' In', ' 1958', ',', ' Boris', ' P', 'astern', 'ak']
+2657 598 Name of mother of x -1 Name of mother of Boris Pasternak Rosa Kaufman Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the Russian writer , who was a friend of the family .
+
+ The house was a large" False War II. In 1958, Boris Pasternak declined his prize 9 [' War', ' II', '.', ' In', ' 1958', ',', ' Boris', ' P', 'astern', 'ak']
+2658 598 Name of mother of x -1 Name of mother of Boris Pasternak Rosa Kaufman Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the Russian writer , who was a friend of the family .
+
+ The house was a large" False translation by Boris Pasternak and directed by 5 [' translation', ' by', ' Boris', ' P', 'astern', 'ak']
+2659 598 Name of mother of x -1 Name of mother of Boris Pasternak Rosa Kaufman Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the Russian writer , who was a friend of the family .
+
+ The house was a large" False close friends with Boris Pasternak (who, though married, 6 [' close', ' friends', ' with', ' Boris', ' P', 'astern', 'ak']
+2660 598 Name of mother of x -1 Name of mother of Boris Pasternak Rosa Kaufman Boris Pasternak "[',' ' the' ' Russian' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the Russian writer , who was a friend of the family .
+
+ The house was a large" False fierce attack on Boris Pasternak after his novel 6 [' fierce', ' attack', ' on', ' Boris', ' P', 'astern', 'ak']
+2661 599 Name of mother of x -1 Name of mother of Mikhail Lermontov Mariya Arsenyeva Mikhail Lermontov "[',' ' the' ' Russian' ' poet' ',' ' was' ' born' ' in' ' 18' '14' '.'
+ '\n' '\n' 'The' ' Russian' ' poet' ',' ' who' ' was' ' born']" ", the Russian poet , was born in 18 14 .
+
+ The Russian poet , who was born" False Caucasus, the Russian poet Mikhail Lermontov wrote the romantic 9 [' Caucasus', ',', ' the', ' Russian', ' poet', ' Mikhail', ' L', 'erm', 'ont', 'ov']
+2662 599 Name of mother of x -1 Name of mother of Mikhail Lermontov Mariya Arsenyeva Mikhail Lermontov "[',' ' the' ' Russian' ' poet' ',' ' was' ' born' ' in' ' 18' '14' '.'
+ '\n' '\n' 'The' ' Russian' ' poet' ',' ' who' ' was' ' born']" ", the Russian poet , was born in 18 14 .
+
+ The Russian poet , who was born" False 5 ['M', 'ikhail', ' L', 'erm', 'ont', 'ov']
+2663 601 Name of mother of x -1 Name of mother of Antoine-Jean Gros Pierrette-Madeleine-Cécile Durand Antoine-Jean Gros "['j' 'ean' ',' ' the' ' French' ' painter' ',' ' who' ' was' ' born' ' in'
+ ' 17' '55' ',' ' and' ' died' ' in' ' 18' '36' '.']" j ean , the French painter , who was born in 17 55 , and died in 18 36 . False David (1748 – 1825) and Antoine-Jean Gros (1771 – 1835), to 14 [' David', ' (', '17', '48', ' –', ' 18', '25', ')', ' and', ' Ant', 'oine', '-', 'Jean', ' G', 'ros']
+2664 601 Name of mother of x -1 Name of mother of Antoine-Jean Gros Pierrette-Madeleine-Cécile Durand Antoine-Jean Gros "['j' 'ean' ',' ' the' ' French' ' painter' ',' ' who' ' was' ' born' ' in'
+ ' 17' '55' ',' ' and' ' died' ' in' ' 18' '36' '.']" j ean , the French painter , who was born in 17 55 , and died in 18 36 . False (1748 – 1825) and Antoine-Jean Gros (1771 – 1835), 13 [' (', '17', '48', ' –', ' 18', '25', ')', ' and', ' Ant', 'oine', '-', 'Jean', ' G', 'ros']
+2665 601 Name of mother of x -1 Name of mother of Antoine-Jean Gros Pierrette-Madeleine-Cécile Durand Antoine-Jean Gros "['j' 'ean' ',' ' the' ' French' ' painter' ',' ' who' ' was' ' born' ' in'
+ ' 17' '55' ',' ' and' ' died' ' in' ' 18' '36' '.']" j ean , the French painter , who was born in 17 55 , and died in 18 36 . False – 1825) and Antoine-Jean Gros (1771 – 1835), 10 [' –', ' 18', '25', ')', ' and', ' Ant', 'oine', '-', 'Jean', ' G', 'ros']
+2666 603 Name of mother of x -1 Name of mother of Germaine de Staël Suzanne Curchod Germaine de Staël "[',' ' the' ' daughter' ' of' ' the' ' French' ' writer' ' and'
+ ' diplomat' ',' ' who' ' was' ' born' ' in' ' 17' '66' '.' '\n' '\n'
+ 'The']" ", the daughter of the French writer and diplomat , who was born in 17 66 .
+
+ The" False (1712 – 1778) and Germaine de Staël (1766 – 1817). More 13 [' (', '17', '12', ' –', ' 17', '78', ')', ' and', ' Germ', 'aine', ' de', ' Sta', 'ë', 'l']
+2667 604 Name of mother of x -1 Name of mother of Tim Berners-Lee Mary Lee Woods Tim Berners-Lee "[',' ' the' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and'
+ ' the' ' father' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and' ' the']" , the inventor of the World Wide Web , and the father of the World Wide Web , and the False Zürich, Switzerland, Tim Berners-Lee had met with Michael 10 [' Z', 'ü', 'rich', ',', ' Switzerland', ',', ' Tim', ' Bern', 'ers', '-', 'Lee']
+2668 604 Name of mother of x -1 Name of mother of Tim Berners-Lee Mary Lee Woods Tim Berners-Lee "[',' ' the' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and'
+ ' the' ' father' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and' ' the']" , the inventor of the World Wide Web , and the father of the World Wide Web , and the False not knowing that Tim Berners-Lee was the inventor 7 [' not', ' knowing', ' that', ' Tim', ' Bern', 'ers', '-', 'Lee']
+2669 604 Name of mother of x -1 Name of mother of Tim Berners-Lee Mary Lee Woods Tim Berners-Lee "[',' ' the' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and'
+ ' the' ' father' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and' ' the']" , the inventor of the World Wide Web , and the father of the World Wide Web , and the False pioneering programs. Tim Berners-Lee used a NeXT Computer 7 [' pioneering', ' programs', '.', ' Tim', ' Bern', 'ers', '-', 'Lee']
+2670 604 Name of mother of x -1 Name of mother of Tim Berners-Lee Mary Lee Woods Tim Berners-Lee "[',' ' the' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and'
+ ' the' ' father' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and' ' the']" , the inventor of the World Wide Web , and the father of the World Wide Web , and the False Switzerland, Tim Berners-Lee had met with 6 [' Switzerland', ',', ' Tim', ' Bern', 'ers', '-', 'Lee']
+2671 604 Name of mother of x -1 Name of mother of Tim Berners-Lee Mary Lee Woods Tim Berners-Lee "[',' ' the' ' inventor' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and'
+ ' the' ' father' ' of' ' the' ' World' ' Wide' ' Web' ',' ' and' ' the']" , the inventor of the World Wide Web , and the father of the World Wide Web , and the False Zürich, Switzerland, Tim Berners-Lee had met with 10 [' Z', 'ü', 'rich', ',', ' Switzerland', ',', ' Tim', ' Bern', 'ers', '-', 'Lee']
+2672 605 Name of mother of x -1 Name of mother of Juliette Binoche Monique Stalens Juliette Binoche "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False French actress Juliette Binoche as they watch 5 [' French', ' actress', ' Juliet', 'te', ' Bin', 'oche']
+2673 605 Name of mother of x -1 Name of mother of Juliette Binoche Monique Stalens Juliette Binoche "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False French actress Juliette Binoche as they watch a 5 [' French', ' actress', ' Juliet', 'te', ' Bin', 'oche']
+2674 605 Name of mother of x -1 Name of mother of Juliette Binoche Monique Stalens Juliette Binoche "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False the French actress Juliette Binoche as they watch 6 [' the', ' French', ' actress', ' Juliet', 'te', ' Bin', 'oche']
+2675 605 Name of mother of x -1 Name of mother of Juliette Binoche Monique Stalens Juliette Binoche "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' actress' ',' ' the' ' actress' ',' ' the' ' actress' ',']" , the actress , and the mother of the actress , the actress , the actress , the actress , False French actress Juliette Binoche as they watch a 5 [' French', ' actress', ' Juliet', 'te', ' Bin', 'oche']
+2676 606 Name of mother of x -1 Name of mother of Alexis de Tocqueville Louise Le Peletier de Rosanbo Alexis de Tocqueville "[',' ' the' ' French' ' political' ' philosopher' ',' ' who' ' was'
+ ' born' ' in' ' 18' '05' ',' ' and' ' died' ' in' ' 18' '59' '.' '\n']" ", the French political philosopher , who was born in 18 05 , and died in 18 59 .
+" False director of the Alexis de Tocqueville Institution, where 9 [' director', ' of', ' the', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+2677 606 Name of mother of x -1 Name of mother of Alexis de Tocqueville Louise Le Peletier de Rosanbo Alexis de Tocqueville "[',' ' the' ' French' ' political' ' philosopher' ',' ' who' ' was'
+ ' born' ' in' ' 18' '05' ',' ' and' ' died' ' in' ' 18' '59' '.' '\n']" ", the French political philosopher , who was born in 18 05 , and died in 18 59 .
+" False director of the Alexis de Tocqueville Institution, where 9 [' director', ' of', ' the', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+2678 606 Name of mother of x -1 Name of mother of Alexis de Tocqueville Louise Le Peletier de Rosanbo Alexis de Tocqueville "[',' ' the' ' French' ' political' ' philosopher' ',' ' who' ' was'
+ ' born' ' in' ' 18' '05' ',' ' and' ' died' ' in' ' 18' '59' '.' '\n']" ", the French political philosopher , who was born in 18 05 , and died in 18 59 .
+" False became director of the Alexis de Tocqueville Institution, 10 [' became', ' director', ' of', ' the', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+2679 606 Name of mother of x -1 Name of mother of Alexis de Tocqueville Louise Le Peletier de Rosanbo Alexis de Tocqueville "[',' ' the' ' French' ' political' ' philosopher' ',' ' who' ' was'
+ ' born' ' in' ' 18' '05' ',' ' and' ' died' ' in' ' 18' '59' '.' '\n']" ", the French political philosopher , who was born in 18 05 , and died in 18 59 .
+" False Parliament drew ire. Alexis de Tocqueville described Blackstone 10 [' Parliament', ' drew', ' ire', '.', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+2680 606 Name of mother of x -1 Name of mother of Alexis de Tocqueville Louise Le Peletier de Rosanbo Alexis de Tocqueville "[',' ' the' ' French' ' political' ' philosopher' ',' ' who' ' was'
+ ' born' ' in' ' 18' '05' ',' ' and' ' died' ' in' ' 18' '59' '.' '\n']" ", the French political philosopher , who was born in 18 05 , and died in 18 59 .
+" False director of the Alexis de Tocqueville Institution, where 9 [' director', ' of', ' the', ' Alexis', ' de', ' T', 'oc', 'qu', 'ev', 'ille']
+2681 608 Name of mother of x -1 Name of mother of Cecil Beaton Etty Sisson Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' the' ' Duchess' ' of' ' Windsor' ',' ' and' ' the'
+ ' Duchess']" , the famous photographer , and his wife , the actress , the Duchess of Windsor , and the Duchess False flamboyance, advising Cecil Beaton to tone down 8 [' fl', 'amb', 'oy', 'ance', ',', ' advising', ' Cecil', ' Beat', 'on']
+2682 608 Name of mother of x -1 Name of mother of Cecil Beaton Etty Sisson Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' the' ' Duchess' ' of' ' Windsor' ',' ' and' ' the'
+ ' Duchess']" , the famous photographer , and his wife , the actress , the Duchess of Windsor , and the Duchess False Callas, taken by Cecil Beaton in 1957. In page 7 [' Call', 'as', ',', ' taken', ' by', ' Cecil', ' Beat', 'on']
+2683 608 Name of mother of x -1 Name of mother of Cecil Beaton Etty Sisson Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' the' ' Duchess' ' of' ' Windsor' ',' ' and' ' the'
+ ' Duchess']" , the famous photographer , and his wife , the actress , the Duchess of Windsor , and the Duchess False Callas, taken by Cecil Beaton in 1957. In page 7 [' Call', 'as', ',', ' taken', ' by', ' Cecil', ' Beat', 'on']
+2684 608 Name of mother of x -1 Name of mother of Cecil Beaton Etty Sisson Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' the' ' Duchess' ' of' ' Windsor' ',' ' and' ' the'
+ ' Duchess']" , the famous photographer , and his wife , the actress , the Duchess of Windsor , and the Duchess False such as Bill Brandt, Cecil Beaton and Bert Hardy. 8 [' such', ' as', ' Bill', ' Brand', 't', ',', ' Cecil', ' Beat', 'on']
+2685 608 Name of mother of x -1 Name of mother of Cecil Beaton Etty Sisson Cecil Beaton "[',' ' the' ' famous' ' photographer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' the' ' Duchess' ' of' ' Windsor' ',' ' and' ' the'
+ ' Duchess']" , the famous photographer , and his wife , the actress , the Duchess of Windsor , and the Duchess False flamboyance, advising Cecil Beaton to tone down 8 [' fl', 'amb', 'oy', 'ance', ',', ' advising', ' Cecil', ' Beat', 'on']
+2686 610 Name of mother of x -1 Name of mother of Peter Kropotkin Ekaterina Nikolaevna Kropotkina (nee Sulima) Peter Kropotkin "[',' ' the' ' Russian' ' anarchist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Russian' ' Social' ' Democratic' ' Party' ',' ' and' ' who'
+ ' was' ' executed']" , the Russian anarchist , who was a member of the Russian Social Democratic Party , and who was executed False individual acts. Peter Kropotkin thus wrote 7 [' individual', ' acts', '.', ' Peter', ' K', 'rop', 'ot', 'kin']
+2687 610 Name of mother of x -1 Name of mother of Peter Kropotkin Ekaterina Nikolaevna Kropotkina (nee Sulima) Peter Kropotkin "[',' ' the' ' Russian' ' anarchist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Russian' ' Social' ' Democratic' ' Party' ',' ' and' ' who'
+ ' was' ' executed']" , the Russian anarchist , who was a member of the Russian Social Democratic Party , and who was executed False Russian anarchist Peter Kropotkin described the 6 [' Russian', ' anarchist', ' Peter', ' K', 'rop', 'ot', 'kin']
+2688 610 Name of mother of x -1 Name of mother of Peter Kropotkin Ekaterina Nikolaevna Kropotkina (nee Sulima) Peter Kropotkin "[',' ' the' ' Russian' ' anarchist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Russian' ' Social' ' Democratic' ' Party' ',' ' and' ' who'
+ ' was' ' executed']" , the Russian anarchist , who was a member of the Russian Social Democratic Party , and who was executed False of anarcho-communist Peter Kropotkin as a threat to 10 [' of', ' anarch', 'o', '-', 'commun', 'ist', ' Peter', ' K', 'rop', 'ot', 'kin']
+2689 610 Name of mother of x -1 Name of mother of Peter Kropotkin Ekaterina Nikolaevna Kropotkina (nee Sulima) Peter Kropotkin "[',' ' the' ' Russian' ' anarchist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Russian' ' Social' ' Democratic' ' Party' ',' ' and' ' who'
+ ' was' ' executed']" , the Russian anarchist , who was a member of the Russian Social Democratic Party , and who was executed False Russian anarchist Peter Kropotkin described 6 [' Russian', ' anarchist', ' Peter', ' K', 'rop', 'ot', 'kin']
+2690 610 Name of mother of x -1 Name of mother of Peter Kropotkin Ekaterina Nikolaevna Kropotkina (nee Sulima) Peter Kropotkin "[',' ' the' ' Russian' ' anarchist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Russian' ' Social' ' Democratic' ' Party' ',' ' and' ' who'
+ ' was' ' executed']" , the Russian anarchist , who was a member of the Russian Social Democratic Party , and who was executed False Russian scientist Peter Kropotkin first proposed the 6 [' Russian', ' scientist', ' Peter', ' K', 'rop', 'ot', 'kin']
+2691 611 Name of mother of x -1 Name of mother of Franz Werfel Albine Kussi Franz Werfel "[',' ' the' ' author' ' of' ' ""' 'The' ' Forty' '-' 'nine' ' Rules' ' of'
+ ' Love' '""' ' and' ' ""' 'The' ' Forty' '-' 'nine' ' Rules']" ", the author of "" The Forty - nine Rules of Love "" and "" The Forty - nine Rules" False married the writer Franz Werfel in 1929. In 1940 she 6 [' married', ' the', ' writer', ' Franz', ' Wer', 'f', 'el']
+2692 611 Name of mother of x -1 Name of mother of Franz Werfel Albine Kussi Franz Werfel "[',' ' the' ' author' ' of' ' ""' 'The' ' Forty' '-' 'nine' ' Rules' ' of'
+ ' Love' '""' ' and' ' ""' 'The' ' Forty' '-' 'nine' ' Rules']" ", the author of "" The Forty - nine Rules of Love "" and "" The Forty - nine Rules" False married the writer Franz Werfel in 1929. In 1940 6 [' married', ' the', ' writer', ' Franz', ' Wer', 'f', 'el']
+2693 612 Name of mother of x -1 Name of mother of Jennifer Garner Patricia Ann English Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False " Elektra Natchios
+" 7 [' Ele', 'k', 'tra', ' N', 'atch', 'ios', 'Jennifer', ' Garner']
+2694 612 Name of mother of x -1 Name of mother of Jennifer Garner Patricia Ann English Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False vigilante Daredevil. Jennifer Garner plays his love 4 [' vigilante', ' Daredevil', '.', ' Jennifer', ' Garner']
+2695 612 Name of mother of x -1 Name of mother of Jennifer Garner Patricia Ann English Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False 1 ['Jennifer', ' Garner']
+2696 612 Name of mother of x -1 Name of mother of Jennifer Garner Patricia Ann English Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False in Baton Rouge. Jennifer Garner has stated that the 5 [' in', ' Baton', ' Rouge', '.', ' Jennifer', ' Garner']
+2697 612 Name of mother of x -1 Name of mother of Jennifer Garner Patricia Ann English Jennifer Garner "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False Supporting Actress – Jennifer Garner (Lost to Amy 4 [' Supporting', ' Actress', ' –', ' Jennifer', ' Garner']
+2698 615 Name of mother of x -1 Name of mother of Isabelle Adjani Emma Augusta Adjani Isabelle Adjani "[',' ' who' ' was' ' born' ' in' ' Paris' ',' ' France' ',' ' and' ' is'
+ ' a' ' French' ' citizen' '.' ' She' ' is' ' the' ' daughter' ' of']" , who was born in Paris , France , and is a French citizen . She is the daughter of False age 14, after seeing Isabelle Adjani in The Story of 9 [' age', ' 14', ',', ' after', ' seeing', ' Isabel', 'le', ' Ad', 'j', 'ani']
+2699 616 Name of mother of x -1 Name of mother of Helmut Schmidt Ludovica Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German'
+ ' chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German chancellor , and the German chancellor , Angela Merkel , has been in the news lately . False former chancellors Helmut Schmidt and Gerhard Schröder. 6 [' former', ' chance', 'll', 'ors', ' Hel', 'mut', ' Schmidt']
+2700 616 Name of mother of x -1 Name of mother of Helmut Schmidt Ludovica Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German'
+ ' chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German chancellor , and the German chancellor , Angela Merkel , has been in the news lately . False representative for chancellor Helmut Schmidt during a debt crisis 5 [' representative', ' for', ' chancellor', ' Hel', 'mut', ' Schmidt']
+2701 616 Name of mother of x -1 Name of mother of Helmut Schmidt Ludovica Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German'
+ ' chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German chancellor , and the German chancellor , Angela Merkel , has been in the news lately . False unseat incumbent Helmut Schmidt as chancellor. Between 5 [' un', 'seat', ' incumbent', ' Hel', 'mut', ' Schmidt']
+2702 616 Name of mother of x -1 Name of mother of Helmut Schmidt Ludovica Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German'
+ ' chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German chancellor , and the German chancellor , Angela Merkel , has been in the news lately . False chancellors Helmut Schmidt and Gerhard Schröder. 5 [' chance', 'll', 'ors', ' Hel', 'mut', ' Schmidt']
+2703 616 Name of mother of x -1 Name of mother of Helmut Schmidt Ludovica Schmidt Helmut Schmidt "[',' ' the' ' German' ' chancellor' ',' ' and' ' the' ' German'
+ ' chancellor' ',' ' Angela' ' Merkel' ',' ' has' ' been' ' in' ' the'
+ ' news' ' lately' '.']" , the German chancellor , and the German chancellor , Angela Merkel , has been in the news lately . False chancellors Helmut Schmidt and Gerhard Schröder. 5 [' chance', 'll', 'ors', ' Hel', 'mut', ' Schmidt']
+2704 617 Name of mother of x -1 Name of mother of Mary Wollstonecraft Elizabeth Dixon Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' _' 'V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '_' ',' ' and' ' the' ' _' 'V']" , the author of the _ V ind ication of the Rights of Woman _ , and the _ V False 1814, Hogg first met Mary Wollstonecraft Godwin while visiting 11 [' 18', '14', ',', ' H', 'ogg', ' first', ' met', ' Mary', ' W', 'oll', 'stone', 'craft']
+2705 617 Name of mother of x -1 Name of mother of Mary Wollstonecraft Elizabeth Dixon Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' _' 'V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '_' ',' ' and' ' the' ' _' 'V']" , the author of the _ V ind ication of the Rights of Woman _ , and the _ V False period such as Mary Wollstonecraft argued for 7 [' period', ' such', ' as', ' Mary', ' W', 'oll', 'stone', 'craft']
+2706 617 Name of mother of x -1 Name of mother of Mary Wollstonecraft Elizabeth Dixon Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' _' 'V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '_' ',' ' and' ' the' ' _' 'V']" , the author of the _ V ind ication of the Rights of Woman _ , and the _ V False Henry Fuseli and Mary Wollstonecraft reviewed their own 9 [' Henry', ' Fu', 'sel', 'i', ' and', ' Mary', ' W', 'oll', 'stone', 'craft']
+2707 617 Name of mother of x -1 Name of mother of Mary Wollstonecraft Elizabeth Dixon Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' _' 'V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '_' ',' ' and' ' the' ' _' 'V']" , the author of the _ V ind ication of the Rights of Woman _ , and the _ V False credited for Mary Wollstonecraft Shelley, but in 6 [' credited', ' for', ' Mary', ' W', 'oll', 'stone', 'craft']
+2708 617 Name of mother of x -1 Name of mother of Mary Wollstonecraft Elizabeth Dixon Mary Wollstonecraft "[',' ' the' ' author' ' of' ' the' ' _' 'V' 'ind' 'ication' ' of' ' the'
+ ' Rights' ' of' ' Woman' '_' ',' ' and' ' the' ' _' 'V']" , the author of the _ V ind ication of the Rights of Woman _ , and the _ V False " Wollstonecraft =
+" 9 [' W', 'oll', 'stone', 'craft', ' =', 'Mary', ' W', 'oll', 'stone', 'craft']
+2709 618 Name of mother of x -1 Name of mother of Norman Foster Lillian Smith Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' ',' ' said' ' the'
+ ' stadium' ' will' ' be' ' a' ' �' '�' 'world' '-' 'class' ' facility']" , the architect of the new stadium , said the stadium will be a � � world - class facility False architect Norman Foster once believed to be 2 [' architect', ' Norman', ' Foster']
+2710 618 Name of mother of x -1 Name of mother of Norman Foster Lillian Smith Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' ',' ' said' ' the'
+ ' stadium' ' will' ' be' ' a' ' �' '�' 'world' '-' 'class' ' facility']" , the architect of the new stadium , said the stadium will be a � � world - class facility False (1997), designed by Norman Foster (1935 –) and 6 [' (', '1997', '),', ' designed', ' by', ' Norman', ' Foster']
+2711 618 Name of mother of x -1 Name of mother of Norman Foster Lillian Smith Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' ',' ' said' ' the'
+ ' stadium' ' will' ' be' ' a' ' �' '�' 'world' '-' 'class' ' facility']" , the architect of the new stadium , said the stadium will be a � � world - class facility False designed by Sir Norman Foster and Chris 4 [' designed', ' by', ' Sir', ' Norman', ' Foster']
+2712 618 Name of mother of x -1 Name of mother of Norman Foster Lillian Smith Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' ',' ' said' ' the'
+ ' stadium' ' will' ' be' ' a' ' �' '�' 'world' '-' 'class' ' facility']" , the architect of the new stadium , said the stadium will be a � � world - class facility False feat architect Norman Foster once believed to 3 [' feat', ' architect', ' Norman', ' Foster']
+2713 618 Name of mother of x -1 Name of mother of Norman Foster Lillian Smith Norman Foster "[',' ' the' ' architect' ' of' ' the' ' new' ' stadium' ',' ' said' ' the'
+ ' stadium' ' will' ' be' ' a' ' �' '�' 'world' '-' 'class' ' facility']" , the architect of the new stadium , said the stadium will be a � � world - class facility False designed by Sir Norman Foster and Chris Wise 4 [' designed', ' by', ' Sir', ' Norman', ' Foster']
+2714 619 Name of mother of x -1 Name of mother of Harold Pinter Frances Moskowitz Harold Pinter "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False reception hosted by Harold Pinter on 19 May 2008, exactly 5 [' reception', ' hosted', ' by', ' Harold', ' P', 'inter']
+2715 619 Name of mother of x -1 Name of mother of Harold Pinter Frances Moskowitz Harold Pinter "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False by Edith Sitwell, Harold Pinter and T. S. Eliot, 8 [' by', ' Ed', 'ith', ' Sit', 'well', ',', ' Harold', ' P', 'inter']
+2716 619 Name of mother of x -1 Name of mother of Harold Pinter Frances Moskowitz Harold Pinter "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False by Edith Sitwell, Harold Pinter and T. S. Eliot, 8 [' by', ' Ed', 'ith', ' Sit', 'well', ',', ' Harold', ' P', 'inter']
+2717 619 Name of mother of x -1 Name of mother of Harold Pinter Frances Moskowitz Harold Pinter "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False 31 July 2001, a Harold Pinter Festival celebrating 7 [' 31', ' July', ' 2001', ',', ' a', ' Harold', ' P', 'inter']
+2718 619 Name of mother of x -1 Name of mother of Harold Pinter Frances Moskowitz Harold Pinter "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' play' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the play , and the author of the book False reception hosted by Harold Pinter on 19 May 2008, 5 [' reception', ' hosted', ' by', ' Harold', ' P', 'inter']
+2719 620 Name of mother of x -1 Name of mother of Narendra Modi Heeraben Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' said' ' that' ' the' ' BJP' ' is' ' the' ' only' ' party' ' that'
+ ' can']" , the BJP � � s prime ministerial candidate , said that the BJP is the only party that can False Chief Minister Narendra Modi in connection with 3 [' Chief', ' Minister', ' Narendra', ' Modi']
+2720 620 Name of mother of x -1 Name of mother of Narendra Modi Heeraben Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' said' ' that' ' the' ' BJP' ' is' ' the' ' only' ' party' ' that'
+ ' can']" , the BJP � � s prime ministerial candidate , said that the BJP is the only party that can False Indian Prime Minister Narendra Modi for the Clean India 4 [' Indian', ' Prime', ' Minister', ' Narendra', ' Modi']
+2721 620 Name of mother of x -1 Name of mother of Narendra Modi Heeraben Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' said' ' that' ' the' ' BJP' ' is' ' the' ' only' ' party' ' that'
+ ' can']" , the BJP � � s prime ministerial candidate , said that the BJP is the only party that can False ministerial candidate Narendra Modi for being involved 3 [' ministerial', ' candidate', ' Narendra', ' Modi']
+2722 620 Name of mother of x -1 Name of mother of Narendra Modi Heeraben Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' said' ' that' ' the' ' BJP' ' is' ' the' ' only' ' party' ' that'
+ ' can']" , the BJP � � s prime ministerial candidate , said that the BJP is the only party that can False Rights Watch that Narendra Modi and state government 4 [' Rights', ' Watch', ' that', ' Narendra', ' Modi']
+2723 620 Name of mother of x -1 Name of mother of Narendra Modi Heeraben Modi Narendra Modi "[',' ' the' ' BJP' '�' '�' 's' ' prime' ' ministerial' ' candidate' ','
+ ' said' ' that' ' the' ' BJP' ' is' ' the' ' only' ' party' ' that'
+ ' can']" , the BJP � � s prime ministerial candidate , said that the BJP is the only party that can False Richard said that Narendra Modi visited Naroda Patiya 4 [' Richard', ' said', ' that', ' Narendra', ' Modi']
+2724 622 Name of mother of x -1 Name of mother of Mila Kunis Elvira Mila Kunis "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ',' ' and' ' her'
+ ' husband' ',' ' actor' ' Ashton' ' Kut' 'cher' ',' ' are' ' expecting'
+ ' their']" , who is a former Miss Universe , and her husband , actor Ashton Kut cher , are expecting their False " Mila Kunis =
+" 3 [' Mil', 'a', ' Kun', 'is']
+2725 622 Name of mother of x -1 Name of mother of Mila Kunis Elvira Mila Kunis "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ',' ' and' ' her'
+ ' husband' ',' ' actor' ' Ashton' ' Kut' 'cher' ',' ' are' ' expecting'
+ ' their']" , who is a former Miss Universe , and her husband , actor Ashton Kut cher , are expecting their False episode to have Mila Kunis providing the 6 [' episode', ' to', ' have', ' Mil', 'a', ' Kun', 'is']
+2726 622 Name of mother of x -1 Name of mother of Mila Kunis Elvira Mila Kunis "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ',' ' and' ' her'
+ ' husband' ',' ' actor' ' Ashton' ' Kut' 'cher' ',' ' are' ' expecting'
+ ' their']" , who is a former Miss Universe , and her husband , actor Ashton Kut cher , are expecting their False episode to have Mila Kunis providing the voice 6 [' episode', ' to', ' have', ' Mil', 'a', ' Kun', 'is']
+2727 622 Name of mother of x -1 Name of mother of Mila Kunis Elvira Mila Kunis "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ',' ' and' ' her'
+ ' husband' ',' ' actor' ' Ashton' ' Kut' 'cher' ',' ' are' ' expecting'
+ ' their']" , who is a former Miss Universe , and her husband , actor Ashton Kut cher , are expecting their False Timberlake and Mila Kunis is almost enough 6 [' Timber', 'lake', ' and', ' Mil', 'a', ' Kun', 'is']
+2728 622 Name of mother of x -1 Name of mother of Mila Kunis Elvira Mila Kunis "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ',' ' and' ' her'
+ ' husband' ',' ' actor' ' Ashton' ' Kut' 'cher' ',' ' are' ' expecting'
+ ' their']" , who is a former Miss Universe , and her husband , actor Ashton Kut cher , are expecting their False aired to feature Mila Kunis as the voice 6 [' aired', ' to', ' feature', ' Mil', 'a', ' Kun', 'is']
+2729 623 Name of mother of x -1 Name of mother of John Quincy Adams Abigail Adams John Quincy Adams "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' John' ' Quincy' ' Adams' ' and' ' Louis' 'a'
+ ' Catherine']" ", the
+
+ The following is a list of the children of John Quincy Adams and Louis a Catherine" False served under John Quincy Adams and continued under 4 [' served', ' under', ' John', ' Quincy', ' Adams']
+2730 623 Name of mother of x -1 Name of mother of John Quincy Adams Abigail Adams John Quincy Adams "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' John' ' Quincy' ' Adams' ' and' ' Louis' 'a'
+ ' Catherine']" ", the
+
+ The following is a list of the children of John Quincy Adams and Louis a Catherine" False colonies; in 1829 John Quincy Adams described him as 7 [' colonies', ';', ' in', ' 18', '29', ' John', ' Quincy', ' Adams']
+2731 623 Name of mother of x -1 Name of mother of John Quincy Adams Abigail Adams John Quincy Adams "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' John' ' Quincy' ' Adams' ' and' ' Louis' 'a'
+ ' Catherine']" ", the
+
+ The following is a list of the children of John Quincy Adams and Louis a Catherine" False reception with President John Quincy Adams at the White 5 [' reception', ' with', ' President', ' John', ' Quincy', ' Adams']
+2732 623 Name of mother of x -1 Name of mother of John Quincy Adams Abigail Adams John Quincy Adams "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' John' ' Quincy' ' Adams' ' and' ' Louis' 'a'
+ ' Catherine']" ", the
+
+ The following is a list of the children of John Quincy Adams and Louis a Catherine" False Sheridan's friend John Quincy Adams Ward, who had 5 "[' Sheridan', ""'s"", ' friend', ' John', ' Quincy', ' Adams']"
+2733 623 Name of mother of x -1 Name of mother of John Quincy Adams Abigail Adams John Quincy Adams "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' John' ' Quincy' ' Adams' ' and' ' Louis' 'a'
+ ' Catherine']" ", the
+
+ The following is a list of the children of John Quincy Adams and Louis a Catherine" False citizens. President John Quincy Adams was a harsh 5 [' citizens', '.', ' President', ' John', ' Quincy', ' Adams']
+2734 624 Name of mother of x -1 Name of mother of Gerald Ford Dorothy Ayer Gardner Ford Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False " Gerald Ford =
+" 1 [' Gerald', ' Ford']
+2735 624 Name of mother of x -1 Name of mother of Gerald Ford Dorothy Ayer Gardner Ford Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False 2 ['G', 'erald', ' Ford']
+2736 624 Name of mother of x -1 Name of mother of Gerald Ford Dorothy Ayer Gardner Ford Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False " former President Gerald Ford among others.
+" 3 [' former', ' President', ' Gerald', ' Ford']
+2737 624 Name of mother of x -1 Name of mother of Gerald Ford Dorothy Ayer Gardner Ford Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False 21, then-President Gerald Ford issued a major disaster 6 [' 21', ',', ' then', '-', 'President', ' Gerald', ' Ford']
+2738 624 Name of mother of x -1 Name of mother of Gerald Ford Dorothy Ayer Gardner Ford Gerald Ford "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the former president of the United States , and the False 1975, President Gerald Ford gave a televised 4 [' 1975', ',', ' President', ' Gerald', ' Ford']
+2739 625 Name of mother of x -1 Name of mother of Neil Armstrong Viola Louise Engel Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False his own guitar) and Neil Armstrong (by hitting him 6 [' his', ' own', ' guitar', ')', ' and', ' Neil', ' Armstrong']
+2740 625 Name of mother of x -1 Name of mother of Neil Armstrong Viola Louise Engel Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False are history books, Neil Armstrong will be included 5 [' are', ' history', ' books', ',', ' Neil', ' Armstrong']
+2741 625 Name of mother of x -1 Name of mother of Neil Armstrong Viola Louise Engel Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False 11 crew members Neil Armstrong and Buzz Aldrin. 4 [' 11', ' crew', ' members', ' Neil', ' Armstrong']
+2742 625 Name of mother of x -1 Name of mother of Neil Armstrong Viola Louise Engel Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False the Moon. Americans Neil Armstrong and Buzz Aldrin landed 5 [' the', ' Moon', '.', ' Americans', ' Neil', ' Armstrong']
+2743 625 Name of mother of x -1 Name of mother of Neil Armstrong Viola Louise Engel Neil Armstrong "[',' ' the' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon' '.' '\n'
+ '\n' 'The' ' first' ' man' ' to' ' walk' ' on' ' the' ' moon']" ", the first man to walk on the moon .
+
+ The first man to walk on the moon" False 1 ['Neil', ' Armstrong']
+2744 627 Name of mother of x -1 Name of mother of Juan Carlos I of Spain Infanta María de las Mercedes, Countess of Barcelona Juan Carlos I of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' the' ' King' ' of']" , the King of Spain , and the Queen of Spain , the Queen of Spain , the King of False " Greece and Denmark and Juan Carlos I of Spain (1962)
+" 8 [' Greece', ' and', ' Denmark', ' and', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+2745 627 Name of mother of x -1 Name of mother of Juan Carlos I of Spain Infanta María de las Mercedes, Countess of Barcelona Juan Carlos I of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' the' ' King' ' of']" , the King of Spain , and the Queen of Spain , the Queen of Spain , the King of False Ireland House, the King Juan Carlos I of Spain Center, the 9 [' Ireland', ' House', ',', ' the', ' King', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+2746 627 Name of mother of x -1 Name of mother of Juan Carlos I of Spain Infanta María de las Mercedes, Countess of Barcelona Juan Carlos I of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' the' ' King' ' of']" , the King of Spain , and the Queen of Spain , the Queen of Spain , the King of False Ireland House, the King Juan Carlos I of Spain Center, the Hagop 9 [' Ireland', ' House', ',', ' the', ' King', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+2747 627 Name of mother of x -1 Name of mother of Juan Carlos I of Spain Infanta María de las Mercedes, Countess of Barcelona Juan Carlos I of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' the' ' King' ' of']" , the King of Spain , and the Queen of Spain , the Queen of Spain , the King of False " and Denmark and Juan Carlos I of Spain (1962)
+" 7 [' and', ' Denmark', ' and', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+2748 627 Name of mother of x -1 Name of mother of Juan Carlos I of Spain Infanta María de las Mercedes, Countess of Barcelona Juan Carlos I of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' the' ' King' ' of']" , the King of Spain , and the Queen of Spain , the Queen of Spain , the King of False " and Denmark and Juan Carlos I of Spain (1962)
+" 7 [' and', ' Denmark', ' and', ' Juan', ' Carlos', ' I', ' of', ' Spain']
+2749 629 Name of mother of x -1 Name of mother of Tycho Brahe Beate Clausdatter Bille Tycho Brahe "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ty' 'cho' ' Bra' 'he'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Ty cho Bra he , the
+
+ Name of mother of" False astronomers such as Tycho Brahe identified new stars 6 [' astronomers', ' such', ' as', ' Ty', 'cho', ' Bra', 'he']
+2750 629 Name of mother of x -1 Name of mother of Tycho Brahe Beate Clausdatter Bille Tycho Brahe "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ty' 'cho' ' Bra' 'he'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Ty cho Bra he , the
+
+ Name of mother of" False Kepler met Tycho Brahe and his assistants 5 [' Kepler', ' met', ' Ty', 'cho', ' Bra', 'he']
+2751 629 Name of mother of x -1 Name of mother of Tycho Brahe Beate Clausdatter Bille Tycho Brahe "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ty' 'cho' ' Bra' 'he'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Ty cho Bra he , the
+
+ Name of mother of" False opposition came from Tycho Brahe and others, 6 [' opposition', ' came', ' from', ' Ty', 'cho', ' Bra', 'he']
+2752 629 Name of mother of x -1 Name of mother of Tycho Brahe Beate Clausdatter Bille Tycho Brahe "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ty' 'cho' ' Bra' 'he'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Ty cho Bra he , the
+
+ Name of mother of" False 3 ['Ty', 'cho', ' Bra', 'he']
+2753 629 Name of mother of x -1 Name of mother of Tycho Brahe Beate Clausdatter Bille Tycho Brahe "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Ty' 'cho' ' Bra' 'he'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the
+
+ Name of mother of Ty cho Bra he , the
+
+ Name of mother of" False that Kepler murdered Tycho Brahe to gain access 6 [' that', ' Kepler', ' murdered', ' Ty', 'cho', ' Bra', 'he']
+2754 630 Name of mother of x -1 Name of mother of Marlon Brando Dodie Brando Marlon Brando "[',' ' the' ' actor' ',' ' and' ' the' ' mother' ' of' ' the' ' actor'
+ ' Mar' 'lon' ' Brand' 'o' ' Jr' '.' '\n' '\n' 'The' ' actor']" ", the actor , and the mother of the actor Mar lon Brand o Jr .
+
+ The actor" False " visits by Marlon Brando and Robert Redford.
+" 5 [' visits', ' by', ' Mar', 'lon', ' Brand', 'o']
+2755 630 Name of mother of x -1 Name of mother of Marlon Brando Dodie Brando Marlon Brando "[',' ' the' ' actor' ',' ' and' ' the' ' mother' ' of' ' the' ' actor'
+ ' Mar' 'lon' ' Brand' 'o' ' Jr' '.' '\n' '\n' 'The' ' actor']" ", the actor , and the mother of the actor Mar lon Brand o Jr .
+
+ The actor" False action film starring Marlon Brando and Jack Nicholson 6 [' action', ' film', ' starring', ' Mar', 'lon', ' Brand', 'o']
+2756 630 Name of mother of x -1 Name of mother of Marlon Brando Dodie Brando Marlon Brando "[',' ' the' ' actor' ',' ' and' ' the' ' mother' ' of' ' the' ' actor'
+ ' Mar' 'lon' ' Brand' 'o' ' Jr' '.' '\n' '\n' 'The' ' actor']" ", the actor , and the mother of the actor Mar lon Brand o Jr .
+
+ The actor" False Actor winner Marlon Brando — who also boycotted 5 [' Actor', ' winner', ' Mar', 'lon', ' Brand', 'o']
+2757 630 Name of mother of x -1 Name of mother of Marlon Brando Dodie Brando Marlon Brando "[',' ' the' ' actor' ',' ' and' ' the' ' mother' ' of' ' the' ' actor'
+ ' Mar' 'lon' ' Brand' 'o' ' Jr' '.' '\n' '\n' 'The' ' actor']" ", the actor , and the mother of the actor Mar lon Brand o Jr .
+
+ The actor" False After watching a Marlon Brando movie on TV in 6 [' After', ' watching', ' a', ' Mar', 'lon', ' Brand', 'o']
+2758 630 Name of mother of x -1 Name of mother of Marlon Brando Dodie Brando Marlon Brando "[',' ' the' ' actor' ',' ' and' ' the' ' mother' ' of' ' the' ' actor'
+ ' Mar' 'lon' ' Brand' 'o' ' Jr' '.' '\n' '\n' 'The' ' actor']" ", the actor , and the mother of the actor Mar lon Brand o Jr .
+
+ The actor" False " experienced past visits by Marlon Brando and Robert Redford.
+" 7 [' experienced', ' past', ' visits', ' by', ' Mar', 'lon', ' Brand', 'o']
+2759 631 Name of mother of x -1 Name of mother of Lev Landau Q18786020 Lev Landau "[',' ' the' ' son' ' of' ' a' ' wealthy' ' Jewish' ' family' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet'
+ ' Union']" , the son of a wealthy Jewish family , was a member of the Communist Party of the Soviet Union False quantum mechanics, Lev Landau in 1930 developed 5 [' quantum', ' mechanics', ',', ' Lev', ' Land', 'au']
+2760 631 Name of mother of x -1 Name of mother of Lev Landau Q18786020 Lev Landau "[',' ' the' ' son' ' of' ' a' ' wealthy' ' Jewish' ' family' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet'
+ ' Union']" , the son of a wealthy Jewish family , was a member of the Communist Party of the Soviet Union False been echoed by Lev Landau and Evgeny Lifshitz, 5 [' been', ' echoed', ' by', ' Lev', ' Land', 'au']
+2761 631 Name of mother of x -1 Name of mother of Lev Landau Q18786020 Lev Landau "[',' ' the' ' son' ' of' ' a' ' wealthy' ' Jewish' ' family' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet'
+ ' Union']" , the son of a wealthy Jewish family , was a member of the Communist Party of the Soviet Union False Russian physicist Lev Landau used the idea 4 [' Russian', ' physicist', ' Lev', ' Land', 'au']
+2762 631 Name of mother of x -1 Name of mother of Lev Landau Q18786020 Lev Landau "[',' ' the' ' son' ' of' ' a' ' wealthy' ' Jewish' ' family' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet'
+ ' Union']" , the son of a wealthy Jewish family , was a member of the Communist Party of the Soviet Union False mechanics, Lev Landau in 1930 developed the 4 [' mechanics', ',', ' Lev', ' Land', 'au']
+2763 631 Name of mother of x -1 Name of mother of Lev Landau Q18786020 Lev Landau "[',' ' the' ' son' ' of' ' a' ' wealthy' ' Jewish' ' family' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Communist' ' Party' ' of' ' the' ' Soviet'
+ ' Union']" , the son of a wealthy Jewish family , was a member of the Communist Party of the Soviet Union False has been echoed by Lev Landau and Evgeny 6 [' has', ' been', ' echoed', ' by', ' Lev', ' Land', 'au']
+2764 632 Name of mother of x -1 Name of mother of David Attenborough Mary Clegg David Attenborough "[',' ' the' ' famous' ' natural' 'ist' ' and' ' broadcaster' ',' ' who'
+ ' died' ' in' ' London' ' in' ' 1984' '.' '\n' '\n' 'The' ' BBC' ' has']" ", the famous natural ist and broadcaster , who died in London in 1984 .
+
+ The BBC has" False naturalist Sir David Attenborough in July 2014. It was 6 [' natural', 'ist', ' Sir', ' David', ' Att', 'en', 'borough']
+2765 632 Name of mother of x -1 Name of mother of David Attenborough Mary Clegg David Attenborough "[',' ' the' ' famous' ' natural' 'ist' ' and' ' broadcaster' ',' ' who'
+ ' died' ' in' ' London' ' in' ' 1984' '.' '\n' '\n' 'The' ' BBC' ' has']" ", the famous natural ist and broadcaster , who died in London in 1984 .
+
+ The BBC has" False Wallace's death, Sir David Attenborough unveiled a statue 8 "[' Wallace', ""'s"", ' death', ',', ' Sir', ' David', ' Att', 'en', 'borough']"
+2766 632 Name of mother of x -1 Name of mother of David Attenborough Mary Clegg David Attenborough "[',' ' the' ' famous' ' natural' 'ist' ' and' ' broadcaster' ',' ' who'
+ ' died' ' in' ' London' ' in' ' 1984' '.' '\n' '\n' 'The' ' BBC' ' has']" ", the famous natural ist and broadcaster , who died in London in 1984 .
+
+ The BBC has" False the naturalist David Attenborough and his wife Jane 6 [' the', ' natural', 'ist', ' David', ' Att', 'en', 'borough']
+2767 632 Name of mother of x -1 Name of mother of David Attenborough Mary Clegg David Attenborough "[',' ' the' ' famous' ' natural' 'ist' ' and' ' broadcaster' ',' ' who'
+ ' died' ' in' ' London' ' in' ' 1984' '.' '\n' '\n' 'The' ' BBC' ' has']" ", the famous natural ist and broadcaster , who died in London in 1984 .
+
+ The BBC has" False featured in the BBC David Attenborough wildlife documentary 7 [' featured', ' in', ' the', ' BBC', ' David', ' Att', 'en', 'borough']
+2768 632 Name of mother of x -1 Name of mother of David Attenborough Mary Clegg David Attenborough "[',' ' the' ' famous' ' natural' 'ist' ' and' ' broadcaster' ',' ' who'
+ ' died' ' in' ' London' ' in' ' 1984' '.' '\n' '\n' 'The' ' BBC' ' has']" ", the famous natural ist and broadcaster , who died in London in 1984 .
+
+ The BBC has" False featured in the BBC David Attenborough wildlife documentary 7 [' featured', ' in', ' the', ' BBC', ' David', ' Att', 'en', 'borough']
+2769 633 Name of mother of x -1 Name of mother of Richard von Weizsäcker Marianne von Weizsäcker Richard von Weizsäcker "[',' ' the' ' German' ' ambassador' ' to' ' the' ' United' ' States' ','
+ ' who' ' was' ' a' ' close' ' friend' ' of' ' Hitler' ""'s"" '.' '\n' '\n']" ", the German ambassador to the United States , who was a close friend of Hitler 's .
+
+" False " von Weizsäcker =
+" 13 [' von', ' We', 'iz', 's', 'ä', 'cker', ' =', 'Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+2770 633 Name of mother of x -1 Name of mother of Richard von Weizsäcker Marianne von Weizsäcker Richard von Weizsäcker "[',' ' the' ' German' ' ambassador' ' to' ' the' ' United' ' States' ','
+ ' who' ' was' ' a' ' close' ' friend' ' of' ' Hitler' ""'s"" '.' '\n' '\n']" ", the German ambassador to the United States , who was a close friend of Hitler 's .
+
+" False 6 ['Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+2771 633 Name of mother of x -1 Name of mother of Richard von Weizsäcker Marianne von Weizsäcker Richard von Weizsäcker "[',' ' the' ' German' ' ambassador' ' to' ' the' ' United' ' States' ','
+ ' who' ' was' ' a' ' close' ' friend' ' of' ' Hitler' ""'s"" '.' '\n' '\n']" ", the German ambassador to the United States , who was a close friend of Hitler 's .
+
+" False creation of the Richard von Weizsäcker Professorship at 9 [' creation', ' of', ' the', ' Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+2772 633 Name of mother of x -1 Name of mother of Richard von Weizsäcker Marianne von Weizsäcker Richard von Weizsäcker "[',' ' the' ' German' ' ambassador' ' to' ' the' ' United' ' States' ','
+ ' who' ' was' ' a' ' close' ' friend' ' of' ' Hitler' ""'s"" '.' '\n' '\n']" ", the German ambassador to the United States , who was a close friend of Hitler 's .
+
+" False occupied France, Richard von Weizsäcker served as his assistant 9 [' occupied', ' France', ',', ' Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+2773 633 Name of mother of x -1 Name of mother of Richard von Weizsäcker Marianne von Weizsäcker Richard von Weizsäcker "[',' ' the' ' German' ' ambassador' ' to' ' the' ' United' ' States' ','
+ ' who' ' was' ' a' ' close' ' friend' ' of' ' Hitler' ""'s"" '.' '\n' '\n']" ", the German ambassador to the United States , who was a close friend of Hitler 's .
+
+" False creation of the Richard von Weizsäcker Professorship 9 [' creation', ' of', ' the', ' Richard', ' von', ' We', 'iz', 's', 'ä', 'cker']
+2774 634 Name of mother of x -1 Name of mother of Joan Baez Joan Bridge Joan Baez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Harrison, Bob Dylan, Joan Baez and Paul Simon 7 [' Harrison', ',', ' Bob', ' Dylan', ',', ' Joan', ' B', 'aez']
+2775 634 Name of mother of x -1 Name of mother of Joan Baez Joan Bridge Joan Baez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False performed with Joan Baez at the Monterey 4 [' performed', ' with', ' Joan', ' B', 'aez']
+2776 634 Name of mother of x -1 Name of mother of Joan Baez Joan Bridge Joan Baez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False not chart. Joan Baez included a gender-switched 5 [' not', ' chart', '.', ' Joan', ' B', 'aez']
+2777 634 Name of mother of x -1 Name of mother of Joan Baez Joan Bridge Joan Baez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " versions of ""Imagine"". Joan Baez included it" 7 "[' versions', ' of', ' ""', 'Imagine', '"".', ' Joan', ' B', 'aez']"
+2778 634 Name of mother of x -1 Name of mother of Joan Baez Joan Bridge Joan Baez "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False others. Ochs and Joan Baez sang a duet 8 [' others', '.', ' O', 'ch', 's', ' and', ' Joan', ' B', 'aez']
+2779 635 Name of mother of x -1 Name of mother of Nicolas Cage Joy Vogelsang Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' writer' ',' ' a' ' reader' ',' ' a']" ".
+
+ I am a mother of two , a wife , a writer , a reader , a" False Man, which starred Nicolas Cage and was filmed 5 [' Man', ',', ' which', ' starred', ' Nicolas', ' Cage']
+2780 635 Name of mother of x -1 Name of mother of Nicolas Cage Joy Vogelsang Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' writer' ',' ' a' ' reader' ',' ' a']" ".
+
+ I am a mother of two , a wife , a writer , a reader , a" False In April 2007, Nicolas Cage bought the LaLaurie 5 [' In', ' April', ' 2007', ',', ' Nicolas', ' Cage']
+2781 635 Name of mother of x -1 Name of mother of Nicolas Cage Joy Vogelsang Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' writer' ',' ' a' ' reader' ',' ' a']" ".
+
+ I am a mother of two , a wife , a writer , a reader , a" False reportedly wished for Nicolas Cage to play the 4 [' reportedly', ' wished', ' for', ' Nicolas', ' Cage']
+2782 635 Name of mother of x -1 Name of mother of Nicolas Cage Joy Vogelsang Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' writer' ',' ' a' ' reader' ',' ' a']" ".
+
+ I am a mother of two , a wife , a writer , a reader , a" False 2 ['Nic', 'olas', ' Cage']
+2783 635 Name of mother of x -1 Name of mother of Nicolas Cage Joy Vogelsang Nicolas Cage "['.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' writer' ',' ' a' ' reader' ',' ' a']" ".
+
+ I am a mother of two , a wife , a writer , a reader , a" False falsely linked to Nicolas Cage after his relationship 4 [' falsely', ' linked', ' to', ' Nicolas', ' Cage']
+2784 636 Name of mother of x -1 Name of mother of John James Audubon Jeanne Rabine John James Audubon "[',' ' the' ' great' ' American' ' orn' 'ith' 'ologist' ',' ' who' ' was'
+ ' born' ' in' ' the' ' year' ' of' ' the' ' great' ' American' ' Orn'
+ 'ith']" , the great American orn ith ologist , who was born in the year of the great American Orn ith False naturalist and artist John James Audubon described a migration 8 [' natural', 'ist', ' and', ' artist', ' John', ' James', ' Aud', 'ub', 'on']
+2785 636 Name of mother of x -1 Name of mother of John James Audubon Jeanne Rabine John James Audubon "[',' ' the' ' great' ' American' ' orn' 'ith' 'ologist' ',' ' who' ' was'
+ ' born' ' in' ' the' ' year' ' of' ' the' ' great' ' American' ' Orn'
+ 'ith']" , the great American orn ith ologist , who was born in the year of the great American Orn ith False and bird painter John James Audubon came to Britain 7 [' and', ' bird', ' painter', ' John', ' James', ' Aud', 'ub', 'on']
+2786 636 Name of mother of x -1 Name of mother of John James Audubon Jeanne Rabine John James Audubon "[',' ' the' ' great' ' American' ' orn' 'ith' 'ologist' ',' ' who' ' was'
+ ' born' ' in' ' the' ' year' ' of' ' the' ' great' ' American' ' Orn'
+ 'ith']" , the great American orn ith ologist , who was born in the year of the great American Orn ith False Alexander Wilson and John James Audubon both witnessed large 7 [' Alexander', ' Wilson', ' and', ' John', ' James', ' Aud', 'ub', 'on']
+2787 636 Name of mother of x -1 Name of mother of John James Audubon Jeanne Rabine John James Audubon "[',' ' the' ' great' ' American' ' orn' 'ith' 'ologist' ',' ' who' ' was'
+ ' born' ' in' ' the' ' year' ' of' ' the' ' great' ' American' ' Orn'
+ 'ith']" , the great American orn ith ologist , who was born in the year of the great American Orn ith False Bierstadt and John James Audubon among the many 8 [' B', 'ier', 'stadt', ' and', ' John', ' James', ' Aud', 'ub', 'on']
+2788 636 Name of mother of x -1 Name of mother of John James Audubon Jeanne Rabine John James Audubon "[',' ' the' ' great' ' American' ' orn' 'ith' 'ologist' ',' ' who' ' was'
+ ' born' ' in' ' the' ' year' ' of' ' the' ' great' ' American' ' Orn'
+ 'ith']" , the great American orn ith ologist , who was born in the year of the great American Orn ith False and bird painter John James Audubon came to Britain 7 [' and', ' bird', ' painter', ' John', ' James', ' Aud', 'ub', 'on']
+2789 638 Name of mother of x -1 Name of mother of Diana, Princess of Wales Frances Shand Kydd Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' Duchess' ' of' ' Cambridge' ',' ' who'
+ ' is']" ", and the mother of Prince William and Prince Harry .
+
+ The Duchess of Cambridge , who is" False impact the death of Diana, Princess of Wales had on Tony Blair 8 [' impact', ' the', ' death', ' of', ' Diana', ',', ' Princess', ' of', ' Wales']
+2790 638 Name of mother of x -1 Name of mother of Diana, Princess of Wales Frances Shand Kydd Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' Duchess' ' of' ' Cambridge' ',' ' who'
+ ' is']" ", and the mother of Prince William and Prince Harry .
+
+ The Duchess of Cambridge , who is" False a visit by Diana, Princess of Wales to the gardens in 7 [' a', ' visit', ' by', ' Diana', ',', ' Princess', ' of', ' Wales']
+2791 638 Name of mother of x -1 Name of mother of Diana, Princess of Wales Frances Shand Kydd Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' Duchess' ' of' ' Cambridge' ',' ' who'
+ ' is']" ", and the mother of Prince William and Prince Harry .
+
+ The Duchess of Cambridge , who is" False Fayed and Diana, Princess of Wales in 1997, briefly 7 [' F', 'ayed', ' and', ' Diana', ',', ' Princess', ' of', ' Wales']
+2792 638 Name of mother of x -1 Name of mother of Diana, Princess of Wales Frances Shand Kydd Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' Duchess' ' of' ' Cambridge' ',' ' who'
+ ' is']" ", and the mother of Prince William and Prince Harry .
+
+ The Duchess of Cambridge , who is" False people, including Diana, Princess of Wales and Charles, Prince 7 [' people', ',', ' including', ' Diana', ',', ' Princess', ' of', ' Wales']
+2793 638 Name of mother of x -1 Name of mother of Diana, Princess of Wales Frances Shand Kydd Diana, Princess of Wales "[',' ' and' ' the' ' mother' ' of' ' Prince' ' William' ' and' ' Prince'
+ ' Harry' '.' '\n' '\n' 'The' ' Duchess' ' of' ' Cambridge' ',' ' who'
+ ' is']" ", and the mother of Prince William and Prince Harry .
+
+ The Duchess of Cambridge , who is" False through the death of Diana, Princess of Wales in 1997, and 8 [' through', ' the', ' death', ' of', ' Diana', ',', ' Princess', ' of', ' Wales']
+2794 639 Name of mother of x -1 Name of mother of Lyndon B. Johnson Rebekah Baines Johnson Lyndon B. Johnson "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the first lady of the United States , and the first lady of the United States of America . False and swamping the Lyndon B. Johnson National Historical 7 [' and', ' swamp', 'ing', ' the', ' Lyndon', ' B', '.', ' Johnson']
+2795 639 Name of mother of x -1 Name of mother of Lyndon B. Johnson Rebekah Baines Johnson Lyndon B. Johnson "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the first lady of the United States , and the first lady of the United States of America . False first since that of Lyndon B. Johnson in 1973. Richard 7 [' first', ' since', ' that', ' of', ' Lyndon', ' B', '.', ' Johnson']
+2796 639 Name of mother of x -1 Name of mother of Lyndon B. Johnson Rebekah Baines Johnson Lyndon B. Johnson "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the first lady of the United States , and the first lady of the United States of America . False 5 ['Ly', 'nd', 'on', ' B', '.', ' Johnson']
+2797 639 Name of mother of x -1 Name of mother of Lyndon B. Johnson Rebekah Baines Johnson Lyndon B. Johnson "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the first lady of the United States , and the first lady of the United States of America . False President Lyndon B. Johnson appointed Cooper 4 [' President', ' Lyndon', ' B', '.', ' Johnson']
+2798 639 Name of mother of x -1 Name of mother of Lyndon B. Johnson Rebekah Baines Johnson Lyndon B. Johnson "[',' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the first lady of the United States , and the first lady of the United States of America . False Larner's Nobody Knows. Lyndon B. Johnson was a repeated source 10 "[' Lar', 'ner', ""'s"", ' Nobody', ' Kn', 'ows', '.', ' Lyndon', ' B', '.', ' Johnson']"
+2799 640 Name of mother of x -1 Name of mother of Joseph Addison Jane Gulston Joseph Addison "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' ' Add' 'ison'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph']" ", the
+
+ Name of mother of Joseph Add ison , the
+
+ Name of mother of Joseph" False memorialized in historian Joseph Addison Waddell's Annals of 6 [' memorial', 'ized', ' in', ' historian', ' Joseph', ' Add', 'ison']
+2800 640 Name of mother of x -1 Name of mother of Joseph Addison Jane Gulston Joseph Addison "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' ' Add' 'ison'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph']" ", the
+
+ Name of mother of Joseph Add ison , the
+
+ Name of mother of Joseph" False by the writers Joseph Addison and Richard Steele, 5 [' by', ' the', ' writers', ' Joseph', ' Add', 'ison']
+2801 640 Name of mother of x -1 Name of mother of Joseph Addison Jane Gulston Joseph Addison "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' ' Add' 'ison'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph']" ", the
+
+ Name of mother of Joseph Add ison , the
+
+ Name of mother of Joseph" False compared to both Joseph Addison and Samuel 5 [' compared', ' to', ' both', ' Joseph', ' Add', 'ison']
+2802 640 Name of mother of x -1 Name of mother of Joseph Addison Jane Gulston Joseph Addison "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' ' Add' 'ison'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph']" ", the
+
+ Name of mother of Joseph Add ison , the
+
+ Name of mother of Joseph" False gardens included Joseph Addison and Lord Shaftesbury. 4 [' gardens', ' included', ' Joseph', ' Add', 'ison']
+2803 640 Name of mother of x -1 Name of mother of Joseph Addison Jane Gulston Joseph Addison "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph' ' Add' 'ison'
+ ',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Joseph']" ", the
+
+ Name of mother of Joseph Add ison , the
+
+ Name of mother of Joseph" False memorialized in historian Joseph Addison Waddell's Annals 6 [' memorial', 'ized', ' in', ' historian', ' Joseph', ' Add', 'ison']
+2804 641 Name of mother of x -1 Name of mother of Niki de Saint Phalle Jeanne Jacqueline Marguerite Harper Niki de Saint Phalle "[',' ' the' ' artist' ',' ' and' ' the' ' mother' ' of' ' the' ' artist'
+ ""'s"" ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',' ' and' ' the']" , the artist , and the mother of the artist 's daughter , the artist 's daughter , and the False California, the sculptor Niki de Saint Phalle built her 10 [' California', ',', ' the', ' sculpt', 'or', ' N', 'iki', ' de', ' Saint', ' Ph', 'alle']
+2805 641 Name of mother of x -1 Name of mother of Niki de Saint Phalle Jeanne Jacqueline Marguerite Harper Niki de Saint Phalle "[',' ' the' ' artist' ',' ' and' ' the' ' mother' ' of' ' the' ' artist'
+ ""'s"" ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',' ' and' ' the']" , the artist , and the mother of the artist 's daughter , the artist 's daughter , and the False California, the sculptor Niki de Saint Phalle built her multiple-piece 10 [' California', ',', ' the', ' sculpt', 'or', ' N', 'iki', ' de', ' Saint', ' Ph', 'alle']
+2806 641 Name of mother of x -1 Name of mother of Niki de Saint Phalle Jeanne Jacqueline Marguerite Harper Niki de Saint Phalle "[',' ' the' ' artist' ',' ' and' ' the' ' mother' ' of' ' the' ' artist'
+ ""'s"" ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',' ' and' ' the']" , the artist , and the mother of the artist 's daughter , the artist 's daughter , and the False California, the sculptor Niki de Saint Phalle built her multiple-piece 10 [' California', ',', ' the', ' sculpt', 'or', ' N', 'iki', ' de', ' Saint', ' Ph', 'alle']
+2807 641 Name of mother of x -1 Name of mother of Niki de Saint Phalle Jeanne Jacqueline Marguerite Harper Niki de Saint Phalle "[',' ' the' ' artist' ',' ' and' ' the' ' mother' ' of' ' the' ' artist'
+ ""'s"" ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',' ' and' ' the']" , the artist , and the mother of the artist 's daughter , the artist 's daughter , and the False the sculptor Niki de Saint Phalle built her multiple-piece 8 [' the', ' sculpt', 'or', ' N', 'iki', ' de', ' Saint', ' Ph', 'alle']
+2808 641 Name of mother of x -1 Name of mother of Niki de Saint Phalle Jeanne Jacqueline Marguerite Harper Niki de Saint Phalle "[',' ' the' ' artist' ',' ' and' ' the' ' mother' ' of' ' the' ' artist'
+ ""'s"" ' daughter' ',' ' the' ' artist' ""'s"" ' daughter' ',' ' and' ' the']" , the artist , and the mother of the artist 's daughter , the artist 's daughter , and the False California, the sculptor Niki de Saint Phalle built her multiple-piece 10 [' California', ',', ' the', ' sculpt', 'or', ' N', 'iki', ' de', ' Saint', ' Ph', 'alle']
+2809 642 Name of mother of x -1 Name of mother of Elie Wiesel Sarah Feig Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False October 2002, the Elie Wiesel Foundation for 7 [' October', ' 2002', ',', ' the', ' El', 'ie', ' W', 'iesel']
+2810 642 Name of mother of x -1 Name of mother of Elie Wiesel Sarah Feig Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False October 2002, the Elie Wiesel Foundation for Humanity 7 [' October', ' 2002', ',', ' the', ' El', 'ie', ' W', 'iesel']
+2811 642 Name of mother of x -1 Name of mother of Elie Wiesel Sarah Feig Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False is a work by Elie Wiesel about his experience 7 [' is', ' a', ' work', ' by', ' El', 'ie', ' W', 'iesel']
+2812 642 Name of mother of x -1 Name of mother of Elie Wiesel Sarah Feig Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False 3 ['El', 'ie', ' W', 'iesel']
+2813 642 Name of mother of x -1 Name of mother of Elie Wiesel Sarah Feig Elie Wiesel "[',' ' the' ' famous' ' Jewish' ' writer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' Sig' 'het' ',' ' Romania' ',' ' in' ' 1928']" , the famous Jewish writer , who was born in the town of Sig het , Romania , in 1928 False Prize winner Elie Wiesel calls the Holocaust 5 [' Prize', ' winner', ' El', 'ie', ' W', 'iesel']
+2814 643 Name of mother of x -1 Name of mother of Kate Beckinsale Judy Loe Kate Beckinsale "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False starred with Kate Beckinsale in Columbia 5 [' starred', ' with', ' Kate', ' Beck', 'ins', 'ale']
+2815 643 Name of mother of x -1 Name of mother of Kate Beckinsale Judy Loe Kate Beckinsale "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Festival starring Kate Beckinsale and Chloe Sevigny 5 [' Festival', ' starring', ' Kate', ' Beck', 'ins', 'ale']
+2816 643 Name of mother of x -1 Name of mother of Kate Beckinsale Judy Loe Kate Beckinsale "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False " Kate Beckinsale =
+" 3 [' Kate', ' Beck', 'ins', 'ale']
+2817 643 Name of mother of x -1 Name of mother of Kate Beckinsale Judy Loe Kate Beckinsale "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False English actress Kate Beckinsale from 1995 until 5 [' English', ' actress', ' Kate', ' Beck', 'ins', 'ale']
+2818 643 Name of mother of x -1 Name of mother of Kate Beckinsale Judy Loe Kate Beckinsale "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False with English actress Kate Beckinsale from 1995 6 [' with', ' English', ' actress', ' Kate', ' Beck', 'ins', 'ale']
+2819 646 Name of mother of x -1 Name of mother of Beatrix Potter Helen Leech Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False The Tales of Beatrix Potter. In the autumn of 5 [' The', ' Tales', ' of', ' Beat', 'rix', ' Potter']
+2820 646 Name of mother of x -1 Name of mother of Beatrix Potter Helen Leech Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False collaborators on Beatrix Potter 1866-1943: The 4 [' collaborators', ' on', ' Beat', 'rix', ' Potter']
+2821 646 Name of mother of x -1 Name of mother of Beatrix Potter Helen Leech Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False Rabbit. However, Beatrix Potter refused to give 6 [' Rabbit', '.', ' However', ',', ' Beat', 'rix', ' Potter']
+2822 646 Name of mother of x -1 Name of mother of Beatrix Potter Helen Leech Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False at the time of her Beatrix Potter (1986) argues that 7 [' at', ' the', ' time', ' of', ' her', ' Beat', 'rix', ' Potter']
+2823 646 Name of mother of x -1 Name of mother of Beatrix Potter Helen Leech Beatrix Potter "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False Evans offered Beatrix Potter an interest in the 4 [' Evans', ' offered', ' Beat', 'rix', ' Potter']
+2824 648 Name of mother of x -1 Name of mother of Eric Hobsbawm Nelly Hobsbaum Eric Hobsbawm "[',' ' the' ' author' ' of' ' the' ' book' ',' ' The' ' Age' ' of'
+ ' Extrem' 'es' ':' ' The' ' Short' ' Tw' 'ent' 'ieth' ' Century' ',']" , the author of the book , The Age of Extrem es : The Short Tw ent ieth Century , False " historian Eric Hobsbawm put it, ""Suez and the" 6 [' historian', ' Eric', ' H', 'obs', 'b', 'aw', 'm']
+2825 648 Name of mother of x -1 Name of mother of Eric Hobsbawm Nelly Hobsbaum Eric Hobsbawm "[',' ' the' ' author' ' of' ' the' ' book' ',' ' The' ' Age' ' of'
+ ' Extrem' 'es' ':' ' The' ' Short' ' Tw' 'ent' 'ieth' ' Century' ',']" , the author of the book , The Age of Extrem es : The Short Tw ent ieth Century , False Marxist historian Eric Hobsbawm remarked that 7 [' Marxist', ' historian', ' Eric', ' H', 'obs', 'b', 'aw', 'm']
+2826 648 Name of mother of x -1 Name of mother of Eric Hobsbawm Nelly Hobsbaum Eric Hobsbawm "[',' ' the' ' author' ' of' ' the' ' book' ',' ' The' ' Age' ' of'
+ ' Extrem' 'es' ':' ' The' ' Short' ' Tw' 'ent' 'ieth' ' Century' ',']" , the author of the book , The Age of Extrem es : The Short Tw ent ieth Century , False " Marxist historian Eric Hobsbawm remarked that ""One" 7 [' Marxist', ' historian', ' Eric', ' H', 'obs', 'b', 'aw', 'm']
+2827 649 Name of mother of x -1 Name of mother of Charles I of England Anne of Denmark Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False execution of King Charles I of England and the outbreak 6 [' execution', ' of', ' King', ' Charles', ' I', ' of', ' England']
+2828 649 Name of mother of x -1 Name of mother of Charles I of England Anne of Denmark Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " England =
+" 5 [' England', ' =', 'Charles', ' I', ' of', ' England']
+2829 649 Name of mother of x -1 Name of mother of Charles I of England Anne of Denmark Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False equestrian statue of Charles I of England and the two pubs; 8 [' equ', 'est', 'rian', ' statue', ' of', ' Charles', ' I', ' of', ' England']
+2830 649 Name of mother of x -1 Name of mother of Charles I of England Anne of Denmark Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False history of the deposed Charles I of England to be brought from 8 [' history', ' of', ' the', ' dep', 'osed', ' Charles', ' I', ' of', ' England']
+2831 649 Name of mother of x -1 Name of mother of Charles I of England Anne of Denmark Charles I of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False execution of King Charles I of England and the outbreak 6 [' execution', ' of', ' King', ' Charles', ' I', ' of', ' England']
+2832 650 Name of mother of x -1 Name of mother of Frederick II, Holy Roman Emperor Constance Frederick II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False the Tunisians in 1231 Frederick II, Holy Roman Emperor minted the augustalis. 11 [' the', ' Tunis', 'ians', ' in', ' 12', '31', ' Frederick', ' II', ',', ' Holy', ' Roman', ' Emperor']
+2833 651 Name of mother of x -1 Name of mother of Julian Basilina Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' Pirate' ' Bay' ',' ' was' ' arrested' ' in'
+ ' Sweden' '.']" Assange , the founder of Wikileaks , and the founder of the Pirate Bay , was arrested in Sweden . False emphasize her romance with Julian and difficult 4 [' emphasize', ' her', ' romance', ' with', ' Julian']
+2834 651 Name of mother of x -1 Name of mother of Julian Basilina Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' Pirate' ' Bay' ',' ' was' ' arrested' ' in'
+ ' Sweden' '.']" Assange , the founder of Wikileaks , and the founder of the Pirate Bay , was arrested in Sweden . False introduced by Julian Huxley. Evolutionary 2 [' introduced', ' by', ' Julian']
+2835 651 Name of mother of x -1 Name of mother of Julian Basilina Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' Pirate' ' Bay' ',' ' was' ' arrested' ' in'
+ ' Sweden' '.']" Assange , the founder of Wikileaks , and the founder of the Pirate Bay , was arrested in Sweden . False brothers Jean and Julian Aberbach, perceived 3 [' brothers', ' Jean', ' and', ' Julian']
+2836 651 Name of mother of x -1 Name of mother of Julian Basilina Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' Pirate' ' Bay' ',' ' was' ' arrested' ' in'
+ ' Sweden' '.']" Assange , the founder of Wikileaks , and the founder of the Pirate Bay , was arrested in Sweden . False Numismatic writer R.W. Julian believes that there 8 [' Num', 'ism', 'atic', ' writer', ' R', '.', 'W', '.', ' Julian']
+2837 651 Name of mother of x -1 Name of mother of Julian Basilina Julian "[' Assange' ',' ' the' ' founder' ' of' ' Wikileaks' ',' ' and' ' the'
+ ' founder' ' of' ' the' ' Pirate' ' Bay' ',' ' was' ' arrested' ' in'
+ ' Sweden' '.']" Assange , the founder of Wikileaks , and the founder of the Pirate Bay , was arrested in Sweden . False relationship with businessman Julian Crane, Mr. Sanbourne 3 [' relationship', ' with', ' businessman', ' Julian']
+2838 652 Name of mother of x -1 Name of mother of Paris Hilton Kathy Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' very' ' rich' ' woman' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' famous']" , the daughter of a wealthy , famous , and very rich woman , and the daughter of a famous False smack in the middle of Paris Hilton time. But there 6 [' smack', ' in', ' the', ' middle', ' of', ' Paris', ' Hilton']
+2839 652 Name of mother of x -1 Name of mother of Paris Hilton Kathy Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' very' ' rich' ' woman' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' famous']" , the daughter of a wealthy , famous , and very rich woman , and the daughter of a famous False celebrity Paris Hilton in a swimsuit, 2 [' celebrity', ' Paris', ' Hilton']
+2840 652 Name of mother of x -1 Name of mother of Paris Hilton Kathy Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' very' ' rich' ' woman' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' famous']" , the daughter of a wealthy , famous , and very rich woman , and the daughter of a famous False script and alludes to Paris Hilton (London spoofing 6 [' script', ' and', ' all', 'udes', ' to', ' Paris', ' Hilton']
+2841 652 Name of mother of x -1 Name of mother of Paris Hilton Kathy Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' very' ' rich' ' woman' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' famous']" , the daughter of a wealthy , famous , and very rich woman , and the daughter of a famous False one week after Paris Hilton and her friend Nicole 4 [' one', ' week', ' after', ' Paris', ' Hilton']
+2842 652 Name of mother of x -1 Name of mother of Paris Hilton Kathy Hilton Paris Hilton "[',' ' the' ' daughter' ' of' ' a' ' wealthy' ',' ' famous' ',' ' and'
+ ' very' ' rich' ' woman' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' famous']" , the daughter of a wealthy , famous , and very rich woman , and the daughter of a famous False one week after Paris Hilton and her friend 4 [' one', ' week', ' after', ' Paris', ' Hilton']
+2843 653 Name of mother of x -1 Name of mother of Plácido Domingo Pepita Embil Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False Malfitano, Plácido Domingo and Ruggero 10 [' Malf', 'it', 'ano', ',', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+2844 653 Name of mother of x -1 Name of mother of Plácido Domingo Pepita Embil Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False performances. Plácido Domingo first recorded Cavaradossi 8 [' performances', '.', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+2845 653 Name of mother of x -1 Name of mother of Plácido Domingo Pepita Embil Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False Luis Rodríguez and Plácido Domingo to record modern versions 13 [' Luis', ' Rod', 'r', 'í', 'g', 'uez', ' and', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+2846 653 Name of mother of x -1 Name of mother of Plácido Domingo Pepita Embil Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False Fourth Symphony, with Plácido Domingo as baritone 10 [' Fourth', ' Symphony', ',', ' with', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+2847 653 Name of mother of x -1 Name of mother of Plácido Domingo Pepita Embil Plácido Domingo "[',' ' the' ' famous' ' ten' 'or' ',' ' and' ' his' ' wife' ',' ' the'
+ ' famous' ' s' 'op' 'rano' ',' ' and' ' the' ' famous' ' s']" , the famous ten or , and his wife , the famous s op rano , and the famous s False live performances. Plácido Domingo first recorded 9 [' live', ' performances', '.', ' Pl', 'á', 'c', 'ido', ' D', 'oming', 'o']
+2848 654 Name of mother of x -1 Name of mother of Nikola Tesla Đuka Madic Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False and scientist Nikola Tesla was one of those 3 [' and', ' scientist', ' Nikola', ' Tesla']
+2849 654 Name of mother of x -1 Name of mother of Nikola Tesla Đuka Madic Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False 2 ['Nik', 'ola', ' Tesla']
+2850 654 Name of mother of x -1 Name of mother of Nikola Tesla Đuka Madic Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False popular notion that Nikola Tesla and Thomas Edison 4 [' popular', ' notion', ' that', ' Nikola', ' Tesla']
+2851 654 Name of mother of x -1 Name of mother of Nikola Tesla Đuka Madic Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False served by Belgrade Nikola Tesla Airport, 12 kilometres 5 [' served', ' by', ' Bel', 'grade', ' Nikola', ' Tesla']
+2852 654 Name of mother of x -1 Name of mother of Nikola Tesla Đuka Madic Nikola Tesla "[',' ' the' ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the'
+ ' inventor' ' of' ' the' ' Tesla' ' coil' ',' ' and' ' the' ' inventor'
+ ' of']" , the inventor of the Tesla coil , and the inventor of the Tesla coil , and the inventor of False World War I, inventor Nikola Tesla lived in the 6 [' World', ' War', ' I', ',', ' inventor', ' Nikola', ' Tesla']
+2853 655 Name of mother of x -1 Name of mother of Christopher Lee Estelle Maria Carandini Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' actor' ' Christopher' ' Lee' ','
+ ' who' ' died' ' in' ' 2004' '.' '\n' '\n' 'The' ' actor' ',']" ", the son of the late actor Christopher Lee , who died in 2004 .
+
+ The actor ," False %. The architect, Christopher Lee of Populous, 6 [' %', '.', ' The', ' architect', ',', ' Christopher', ' Lee']
+2854 655 Name of mother of x -1 Name of mother of Christopher Lee Estelle Maria Carandini Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' actor' ' Christopher' ' Lee' ','
+ ' who' ' died' ' in' ' 2004' '.' '\n' '\n' 'The' ' actor' ',']" ", the son of the late actor Christopher Lee , who died in 2004 .
+
+ The actor ," False 1 ['Christopher', ' Lee']
+2855 655 Name of mother of x -1 Name of mother of Christopher Lee Estelle Maria Carandini Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' actor' ' Christopher' ' Lee' ','
+ ' who' ' died' ' in' ' 2004' '.' '\n' '\n' 'The' ' actor' ',']" ", the son of the late actor Christopher Lee , who died in 2004 .
+
+ The actor ," False %. The architect, Christopher Lee of Populous, described 6 [' %', '.', ' The', ' architect', ',', ' Christopher', ' Lee']
+2856 655 Name of mother of x -1 Name of mother of Christopher Lee Estelle Maria Carandini Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' actor' ' Christopher' ' Lee' ','
+ ' who' ' died' ' in' ' 2004' '.' '\n' '\n' 'The' ' actor' ',']" ", the son of the late actor Christopher Lee , who died in 2004 .
+
+ The actor ," False 1 ['Christopher', ' Lee']
+2857 655 Name of mother of x -1 Name of mother of Christopher Lee Estelle Maria Carandini Christopher Lee "[',' ' the' ' son' ' of' ' the' ' late' ' actor' ' Christopher' ' Lee' ','
+ ' who' ' died' ' in' ' 2004' '.' '\n' '\n' 'The' ' actor' ',']" ", the son of the late actor Christopher Lee , who died in 2004 .
+
+ The actor ," False voice of 3 [' voice', ' o', 'Christopher', ' Lee']
+2858 656 Name of mother of x -1 Name of mother of James Clerk Maxwell Frances Cay James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' James'
+ ' Clerk' ' Maxwell' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of James Clerk Maxwell , the
+
+ The name of" False 2 ['James', ' Clerk', ' Maxwell']
+2859 656 Name of mother of x -1 Name of mother of James Clerk Maxwell Frances Cay James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' James'
+ ' Clerk' ' Maxwell' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of James Clerk Maxwell , the
+
+ The name of" False discoveries by James Clerk Maxwell to the effect that 4 [' discoveries', ' by', ' James', ' Clerk', ' Maxwell']
+2860 656 Name of mother of x -1 Name of mother of James Clerk Maxwell Frances Cay James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' James'
+ ' Clerk' ' Maxwell' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of James Clerk Maxwell , the
+
+ The name of" False (later Lord Kelvin), James Clerk Maxwell and James Prescott 7 [' (', 'later', ' Lord', ' Kelvin', '),', ' James', ' Clerk', ' Maxwell']
+2861 656 Name of mother of x -1 Name of mother of James Clerk Maxwell Frances Cay James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' James'
+ ' Clerk' ' Maxwell' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of James Clerk Maxwell , the
+
+ The name of" False 2 ['James', ' Clerk', ' Maxwell']
+2862 656 Name of mother of x -1 Name of mother of James Clerk Maxwell Frances Cay James Clerk Maxwell "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' James'
+ ' Clerk' ' Maxwell' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of James Clerk Maxwell , the
+
+ The name of" False others of James Clerk Maxwell and John Witherspoon. 4 [' others', ' of', ' James', ' Clerk', ' Maxwell']
+2863 657 Name of mother of x -1 Name of mother of John Cox Nancy Chevallier Forman John Cox "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' bride' 'g'
+ 'room' ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride']" ", the father of the bride , and the bride g room of the bride .
+
+ The bride" False Professor of English John Cox suggests that Shakespeare 4 [' Professor', ' of', ' English', ' John', ' Cox']
+2864 657 Name of mother of x -1 Name of mother of John Cox Nancy Chevallier Forman John Cox "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' bride' 'g'
+ 'room' ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride']" ", the father of the bride , and the bride g room of the bride .
+
+ The bride" False Professor of English John Cox suggests that Shakespeare 4 [' Professor', ' of', ' English', ' John', ' Cox']
+2865 657 Name of mother of x -1 Name of mother of John Cox Nancy Chevallier Forman John Cox "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' bride' 'g'
+ 'room' ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride']" ", the father of the bride , and the bride g room of the bride .
+
+ The bride" False Professor of English John Cox suggests that Shakespeare 4 [' Professor', ' of', ' English', ' John', ' Cox']
+2866 657 Name of mother of x -1 Name of mother of John Cox Nancy Chevallier Forman John Cox "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' bride' 'g'
+ 'room' ' of' ' the' ' bride' '.' '\n' '\n' 'The' ' bride']" ", the father of the bride , and the bride g room of the bride .
+
+ The bride" False Professor of English John Cox suggests that Shakespeare 4 [' Professor', ' of', ' English', ' John', ' Cox']
+2867 659 Name of mother of x -1 Name of mother of Hilary Swank Judy Kay Clough Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actress , and a great person .
+
+ I am a huge fan of" False May 11, 2011, Hilary Swank was reportedly 8 [' May', ' 11', ',', ' 2011', ',', ' Hil', 'ary', ' Sw', 'ank']
+2868 659 Name of mother of x -1 Name of mother of Hilary Swank Judy Kay Clough Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actress , and a great person .
+
+ I am a huge fan of" False On May 11, 2011, Hilary Swank was reportedly 9 [' On', ' May', ' 11', ',', ' 2011', ',', ' Hil', 'ary', ' Sw', 'ank']
+2869 659 Name of mother of x -1 Name of mother of Hilary Swank Judy Kay Clough Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actress , and a great person .
+
+ I am a huge fan of" False whether actress Hilary Swank can be considered 5 [' whether', ' actress', ' Hil', 'ary', ' Sw', 'ank']
+2870 659 Name of mother of x -1 Name of mother of Hilary Swank Judy Kay Clough Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actress , and a great person .
+
+ I am a huge fan of" False drama starring Hilary Swank and Chiwetel 5 [' drama', ' starring', ' Hil', 'ary', ' Sw', 'ank']
+2871 659 Name of mother of x -1 Name of mother of Hilary Swank Judy Kay Clough Hilary Swank "[',' ' who' ' is' ' a' ' great' ' actress' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' am' ' a' ' huge' ' fan' ' of']" ", who is a great actress , and a great person .
+
+ I am a huge fan of" False While discussing Hilary Swank, Kevin said he finds 5 [' While', ' discussing', ' Hil', 'ary', ' Sw', 'ank']
+2872 660 Name of mother of x -1 Name of mother of Mikhail Bakunin Varvara Muravyova Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' Russian' ' noble' 'man' ' himself' '.' ' He' ' was' ' a' ' man']" , the son of a Russian noble man , and a Russian noble man himself . He was a man False associated with Mikhail Bakunin and Johann Most. Collectivist 4 [' associated', ' with', ' Mikhail', ' Bak', 'unin']
+2873 660 Name of mother of x -1 Name of mother of Mikhail Bakunin Varvara Muravyova Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' Russian' ' noble' 'man' ' himself' '.' ' He' ' was' ' a' ' man']" , the son of a Russian noble man , and a Russian noble man himself . He was a man False anarchist theorist Mikhail Bakunin from its bookstore 4 [' anarchist', ' theorist', ' Mikhail', ' Bak', 'unin']
+2874 660 Name of mother of x -1 Name of mother of Mikhail Bakunin Varvara Muravyova Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' Russian' ' noble' 'man' ' himself' '.' ' He' ' was' ' a' ' man']" , the son of a Russian noble man , and a Russian noble man himself . He was a man False associated with Mikhail Bakunin and Johann 4 [' associated', ' with', ' Mikhail', ' Bak', 'unin']
+2875 660 Name of mother of x -1 Name of mother of Mikhail Bakunin Varvara Muravyova Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' Russian' ' noble' 'man' ' himself' '.' ' He' ' was' ' a' ' man']" , the son of a Russian noble man , and a Russian noble man himself . He was a man False Hawke starred as Mikhail Bakunin in Tom Stoppard's 6 [' Haw', 'ke', ' starred', ' as', ' Mikhail', ' Bak', 'unin']
+2876 660 Name of mother of x -1 Name of mother of Mikhail Bakunin Varvara Muravyova Mikhail Bakunin "[',' ' the' ' son' ' of' ' a' ' Russian' ' noble' 'man' ',' ' and' ' a'
+ ' Russian' ' noble' 'man' ' himself' '.' ' He' ' was' ' a' ' man']" , the son of a Russian noble man , and a Russian noble man himself . He was a man False thinkers such as Mikhail Bakunin and Peter Kropotkin, 5 [' thinkers', ' such', ' as', ' Mikhail', ' Bak', 'unin']
+2877 661 Name of mother of x -1 Name of mother of Romain Gary Mina Owczyńska Romain Gary "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False novel written by Romain Gary of the same 5 [' novel', ' written', ' by', ' Rom', 'ain', ' Gary']
+2878 661 Name of mother of x -1 Name of mother of Romain Gary Mina Owczyńska Romain Gary "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False autobiographical novel written by Romain Gary of the same name. 7 [' autobi', 'ographical', ' novel', ' written', ' by', ' Rom', 'ain', ' Gary']
+2879 661 Name of mother of x -1 Name of mother of Romain Gary Mina Owczyńska Romain Gary "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False novel written by Romain Gary of the same 5 [' novel', ' written', ' by', ' Rom', 'ain', ' Gary']
+2880 661 Name of mother of x -1 Name of mother of Romain Gary Mina Owczyńska Romain Gary "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False In the original Romain Gary novel, this 5 [' In', ' the', ' original', ' Rom', 'ain', ' Gary']
+2881 661 Name of mother of x -1 Name of mother of Romain Gary Mina Owczyńska Romain Gary "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' writer' ',' ' a' ' reader' ',' ' a' ' dream']" "
+
+ I am a mother of two , a wife , a writer , a reader , a dream" False the original Romain Gary novel, this 4 [' the', ' original', ' Rom', 'ain', ' Gary']
+2882 662 Name of mother of x -1 Name of mother of Pliny the Younger Plinia Marcella Pliny the Younger "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Tra' 'jan' ',' ' and' ' the'
+ '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' a' ' vast' ',']" ", the son of the Emperor Tra jan , and the
+
+ The Roman Empire was a vast ," False authors such as Tacitus, Pliny the Younger and Suetonius 9 [' authors', ' such', ' as', ' Tac', 'itus', ',', ' Pl', 'iny', ' the', ' Younger']
+2883 662 Name of mother of x -1 Name of mother of Pliny the Younger Plinia Marcella Pliny the Younger "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Tra' 'jan' ',' ' and' ' the'
+ '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' a' ' vast' ',']" ", the son of the Emperor Tra jan , and the
+
+ The Roman Empire was a vast ," False letter to Tacitus, Pliny the Younger suggested that his 8 [' letter', ' to', ' Tac', 'itus', ',', ' Pl', 'iny', ' the', ' Younger']
+2884 662 Name of mother of x -1 Name of mother of Pliny the Younger Plinia Marcella Pliny the Younger "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Tra' 'jan' ',' ' and' ' the'
+ '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' a' ' vast' ',']" ", the son of the Emperor Tra jan , and the
+
+ The Roman Empire was a vast ," False as Tacitus, Pliny the Younger and Suetonius 7 [' as', ' Tac', 'itus', ',', ' Pl', 'iny', ' the', ' Younger']
+2885 662 Name of mother of x -1 Name of mother of Pliny the Younger Plinia Marcella Pliny the Younger "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Tra' 'jan' ',' ' and' ' the'
+ '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' a' ' vast' ',']" ", the son of the Emperor Tra jan , and the
+
+ The Roman Empire was a vast ," False region; the writer Pliny the Younger even wrote that 7 [' region', ';', ' the', ' writer', ' Pl', 'iny', ' the', ' Younger']
+2886 662 Name of mother of x -1 Name of mother of Pliny the Younger Plinia Marcella Pliny the Younger "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Tra' 'jan' ',' ' and' ' the'
+ '\n' '\n' 'The' ' Roman' ' Empire' ' was' ' a' ' vast' ',']" ", the son of the Emperor Tra jan , and the
+
+ The Roman Empire was a vast ," False 4 ['P', 'lin', 'y', ' the', ' Younger']
+2887 663 Name of mother of x -1 Name of mother of Henry IV of France Jeanne d'Albret Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False French ownership. Henry IV of France ordered the demolition 6 [' French', ' ownership', '.', ' Henry', ' IV', ' of', ' France']
+2888 663 Name of mother of x -1 Name of mother of Henry IV of France Jeanne d'Albret Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False by the future king Henry IV of France (1553 – 1610), Catherine 7 [' by', ' the', ' future', ' king', ' Henry', ' IV', ' of', ' France']
+2889 663 Name of mother of x -1 Name of mother of Henry IV of France Jeanne d'Albret Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " Supporting Henry IV of France ===
+" 4 [' Supporting', ' Henry', ' IV', ' of', ' France']
+2890 663 Name of mother of x -1 Name of mother of Henry IV of France Jeanne d'Albret Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False fighting with King Henry IV of France during the French 6 [' fighting', ' with', ' King', ' Henry', ' IV', ' of', ' France']
+2891 663 Name of mother of x -1 Name of mother of Henry IV of France Jeanne d'Albret Henry IV of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False 3 ['Henry', ' IV', ' of', ' France']
+2892 664 Name of mother of x -1 Name of mother of Nina Hagen Eva-Maria Hagen Nina Hagen "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ',' ' and' ' the' ' mother' ' of' ' the' ' child']" , the mother of the child , and the mother of the child , and the mother of the child False 3 ['N', 'ina', ' H', 'agen']
+2893 664 Name of mother of x -1 Name of mother of Nina Hagen Eva-Maria Hagen Nina Hagen "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ',' ' and' ' the' ' mother' ' of' ' the' ' child']" , the mother of the child , and the mother of the child , and the mother of the child False — punk rock singer Nina Hagen — in the early 6 [' —', ' punk', ' rock', ' singer', ' Nina', ' H', 'agen']
+2894 664 Name of mother of x -1 Name of mother of Nina Hagen Eva-Maria Hagen Nina Hagen "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ',' ' and' ' the' ' mother' ' of' ' the' ' child']" , the mother of the child , and the mother of the child , and the mother of the child False 3 ['N', 'ina', ' H', 'agen']
+2895 664 Name of mother of x -1 Name of mother of Nina Hagen Eva-Maria Hagen Nina Hagen "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ',' ' and' ' the' ' mother' ' of' ' the' ' child']" , the mother of the child , and the mother of the child , and the mother of the child False punk rock singer Nina Hagen — in the early 1980s. 5 [' punk', ' rock', ' singer', ' Nina', ' H', 'agen']
+2896 664 Name of mother of x -1 Name of mother of Nina Hagen Eva-Maria Hagen Nina Hagen "[',' ' the' ' mother' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' ',' ' and' ' the' ' mother' ' of' ' the' ' child']" , the mother of the child , and the mother of the child , and the mother of the child False 3 ['N', 'ina', ' H', 'agen']
+2897 665 Name of mother of x -1 Name of mother of Alexander VI Isabel de Borja y Cavanilles Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' famous' ' of' ' all' ' the' ' Roman' ' em' 'perors' ',' ' and']" ".
+
+ The first of the two is the most famous of all the Roman em perors , and" False colonial rule. Pope Alexander VI had awarded colonial 5 [' colonial', ' rule', '.', ' Pope', ' Alexander', ' VI']
+2898 665 Name of mother of x -1 Name of mother of Alexander VI Isabel de Borja y Cavanilles Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' famous' ' of' ' all' ' the' ' Roman' ' em' 'perors' ',' ' and']" ".
+
+ The first of the two is the most famous of all the Roman em perors , and" False 1493 donation by Pope Alexander VI that had divided 6 [' 14', '93', ' donation', ' by', ' Pope', ' Alexander', ' VI']
+2899 665 Name of mother of x -1 Name of mother of Alexander VI Isabel de Borja y Cavanilles Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' famous' ' of' ' all' ' the' ' Roman' ' em' 'perors' ',' ' and']" ".
+
+ The first of the two is the most famous of all the Roman em perors , and" False the Pope, as Pope Alexander VI had granted the Indies 6 [' the', ' Pope', ',', ' as', ' Pope', ' Alexander', ' VI']
+2900 665 Name of mother of x -1 Name of mother of Alexander VI Isabel de Borja y Cavanilles Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' famous' ' of' ' all' ' the' ' Roman' ' em' 'perors' ',' ' and']" ".
+
+ The first of the two is the most famous of all the Roman em perors , and" False the Pope, as Pope Alexander VI had granted the 6 [' the', ' Pope', ',', ' as', ' Pope', ' Alexander', ' VI']
+2901 665 Name of mother of x -1 Name of mother of Alexander VI Isabel de Borja y Cavanilles Alexander VI "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' famous' ' of' ' all' ' the' ' Roman' ' em' 'perors' ',' ' and']" ".
+
+ The first of the two is the most famous of all the Roman em perors , and" False Italian War, Pope Alexander VI had, with French 5 [' Italian', ' War', ',', ' Pope', ' Alexander', ' VI']
+2902 666 Name of mother of x -1 Name of mother of Gore Vidal Nina S. Gore Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " morality"". Author Gore Vidal placed Lessing's" 5 "[' morality', '"".', ' Author', ' Gore', ' V', 'idal']"
+2903 666 Name of mother of x -1 Name of mother of Gore Vidal Nina S. Gore Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 3 ['G', 'ore', ' V', 'idal']
+2904 666 Name of mother of x -1 Name of mother of Gore Vidal Nina S. Gore Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False of Wills's book. Gore Vidal also draws attention 8 "[' of', ' W', 'ills', ""'s"", ' book', '.', ' Gore', ' V', 'idal']"
+2905 666 Name of mother of x -1 Name of mother of Gore Vidal Nina S. Gore Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Bisexual author Gore Vidal (1925-2012) is a documented 5 [' B', 'isexual', ' author', ' Gore', ' V', 'idal']
+2906 666 Name of mother of x -1 Name of mother of Gore Vidal Nina S. Gore Gore Vidal "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False stepbrother of author Gore Vidal after Olds'father 6 [' step', 'brother', ' of', ' author', ' Gore', ' V', 'idal']
+2907 668 Name of mother of x -1 Name of mother of Roger Penrose Margaret Leathes Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False mathematician Roger Penrose and his father 3 [' mathematician', ' Roger', ' Pen', 'rose']
+2908 668 Name of mother of x -1 Name of mother of Roger Penrose Margaret Leathes Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False causal structure, Roger Penrose and others developed 5 [' causal', ' structure', ',', ' Roger', ' Pen', 'rose']
+2909 668 Name of mother of x -1 Name of mother of Roger Penrose Margaret Leathes Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False participants. Both Roger Penrose and H. S. 5 [' participants', '.', ' Both', ' Roger', ' Pen', 'rose']
+2910 668 Name of mother of x -1 Name of mother of Roger Penrose Margaret Leathes Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False the late 1960s Roger Penrose and Stephen Hawking 6 [' the', ' late', ' 1960', 's', ' Roger', ' Pen', 'rose']
+2911 668 Name of mother of x -1 Name of mother of Roger Penrose Margaret Leathes Roger Penrose "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False introduced by Roger Penrose in a 1974 paper, 4 [' introduced', ' by', ' Roger', ' Pen', 'rose']
+2912 669 Name of mother of x -1 Name of mother of Robert Frost Isabel Moodie Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False 47-mile (76 km) Robert Frost Trail, and the 8 [' 47', '-', 'mile', ' (', '76', ' km', ')', ' Robert', ' Frost']
+2913 669 Name of mother of x -1 Name of mother of Robert Frost Isabel Moodie Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False based on the Robert Frost poem of the same 4 [' based', ' on', ' the', ' Robert', ' Frost']
+2914 669 Name of mother of x -1 Name of mother of Robert Frost Isabel Moodie Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False recitation since Robert Frost at John F. Kennedy's 4 [' rec', 'itation', ' since', ' Robert', ' Frost']
+2915 669 Name of mother of x -1 Name of mother of Robert Frost Isabel Moodie Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False recitation since Robert Frost at John F. 4 [' rec', 'itation', ' since', ' Robert', ' Frost']
+2916 669 Name of mother of x -1 Name of mother of Robert Frost Isabel Moodie Robert Frost "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ""'s"" ' mother' ',' ' and' ' the']" , the poet , and the poet 's mother , and the poet 's mother 's mother , and the False recitation since Robert Frost at John F. 4 [' rec', 'itation', ' since', ' Robert', ' Frost']
+2917 670 Name of mother of x -1 Name of mother of Carl Philipp Emanuel Bach Maria Barbara Bach Carl Philipp Emanuel Bach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'Brand' 'enburg'
+ ' Concert' 'os' '""' ' and' ' the' ' ""' 'Gold' 'berg' ' Vari' 'ations' '""']" ", the composer of the famous "" Brand enburg Concert os "" and the "" Gold berg Vari ations """ False May 1714. His son Carl Philipp Emanuel Bach remembered 9 [' May', ' 17', '14', '.', ' His', ' son', ' Carl', ' Philipp', ' Emanuel', ' Bach']
+2918 670 Name of mother of x -1 Name of mother of Carl Philipp Emanuel Bach Maria Barbara Bach Carl Philipp Emanuel Bach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'Brand' 'enburg'
+ ' Concert' 'os' '""' ' and' ' the' ' ""' 'Gold' 'berg' ' Vari' 'ations' '""']" ", the composer of the famous "" Brand enburg Concert os "" and the "" Gold berg Vari ations """ False 1714. His son Carl Philipp Emanuel Bach remembered that 8 [' 17', '14', '.', ' His', ' son', ' Carl', ' Philipp', ' Emanuel', ' Bach']
+2919 670 Name of mother of x -1 Name of mother of Carl Philipp Emanuel Bach Maria Barbara Bach Carl Philipp Emanuel Bach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'Brand' 'enburg'
+ ' Concert' 'os' '""' ' and' ' the' ' ""' 'Gold' 'berg' ' Vari' 'ations' '""']" ", the composer of the famous "" Brand enburg Concert os "" and the "" Gold berg Vari ations """ False May 1714. His son Carl Philipp Emanuel Bach remembered that 9 [' May', ' 17', '14', '.', ' His', ' son', ' Carl', ' Philipp', ' Emanuel', ' Bach']
+2920 670 Name of mother of x -1 Name of mother of Carl Philipp Emanuel Bach Maria Barbara Bach Carl Philipp Emanuel Bach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'Brand' 'enburg'
+ ' Concert' 'os' '""' ' and' ' the' ' ""' 'Gold' 'berg' ' Vari' 'ations' '""']" ", the composer of the famous "" Brand enburg Concert os "" and the "" Gold berg Vari ations """ False 3 ['Carl', ' Philipp', ' Emanuel', ' Bach']
+2921 670 Name of mother of x -1 Name of mother of Carl Philipp Emanuel Bach Maria Barbara Bach Carl Philipp Emanuel Bach "[',' ' the' ' composer' ' of' ' the' ' famous' ' ""' 'Brand' 'enburg'
+ ' Concert' 'os' '""' ' and' ' the' ' ""' 'Gold' 'berg' ' Vari' 'ations' '""']" ", the composer of the famous "" Brand enburg Concert os "" and the "" Gold berg Vari ations """ False May 1714. His son Carl Philipp Emanuel Bach remembered that he 9 [' May', ' 17', '14', '.', ' His', ' son', ' Carl', ' Philipp', ' Emanuel', ' Bach']
+2922 671 Name of mother of x -1 Name of mother of Wilhelm von Humboldt Marie-Elisabeth von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False Alexander and Wilhelm von Humboldt famously championed 7 [' Alexander', ' and', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+2923 671 Name of mother of x -1 Name of mother of Wilhelm von Humboldt Marie-Elisabeth von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False and diplomat Wilhelm von Humboldt were close 7 [' and', ' diplomat', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+2924 671 Name of mother of x -1 Name of mother of Wilhelm von Humboldt Marie-Elisabeth von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False brothers Alexander and Wilhelm von Humboldt famously championed 8 [' brothers', ' Alexander', ' and', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+2925 671 Name of mother of x -1 Name of mother of Wilhelm von Humboldt Marie-Elisabeth von Humboldt Wilhelm von Humboldt "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False and diplomat Wilhelm von Humboldt were close friends 7 [' and', ' diplomat', ' Wilhelm', ' von', ' H', 'umb', 'old', 't']
+2926 672 Name of mother of x -1 Name of mother of Demosthenes Kleobule Demosthenes "[',' ' the' ' son' ' of' ' Philip' ',' ' and' ' the' ' son' ' of'
+ ' Philip' ',' ' and' ' the' ' son' ' of' ' Philip' ',' ' and' ' the']" , the son of Philip , and the son of Philip , and the son of Philip , and the False teacher of rhetoric, Demosthenes represented the 7 [' teacher', ' of', ' rhetoric', ',', ' Dem', 'ost', 'hen', 'es']
+2927 672 Name of mother of x -1 Name of mother of Demosthenes Kleobule Demosthenes "[',' ' the' ' son' ' of' ' Philip' ',' ' and' ' the' ' son' ' of'
+ ' Philip' ',' ' and' ' the' ' son' ' of' ' Philip' ',' ' and' ' the']" , the son of Philip , and the son of Philip , and the son of Philip , and the False Assembly convened. Demosthenes delivered On the 6 [' Assembly', ' convened', '.', ' Dem', 'ost', 'hen', 'es']
+2928 672 Name of mother of x -1 Name of mother of Demosthenes Kleobule Demosthenes "[',' ' the' ' son' ' of' ' Philip' ',' ' and' ' the' ' son' ' of'
+ ' Philip' ',' ' and' ' the' ' son' ' of' ' Philip' ',' ' and' ' the']" , the son of Philip , and the son of Philip , and the son of Philip , and the False 3 ['Dem', 'ost', 'hen', 'es']
+2929 672 Name of mother of x -1 Name of mother of Demosthenes Kleobule Demosthenes "[',' ' the' ' son' ' of' ' Philip' ',' ' and' ' the' ' son' ' of'
+ ' Philip' ',' ' and' ' the' ' son' ' of' ' Philip' ',' ' and' ' the']" , the son of Philip , and the son of Philip , and the son of Philip , and the False dramatists, Demosthenes and Thucydides) 7 [' dram', 'at', 'ists', ',', ' Dem', 'ost', 'hen', 'es']
+2930 672 Name of mother of x -1 Name of mother of Demosthenes Kleobule Demosthenes "[',' ' the' ' son' ' of' ' Philip' ',' ' and' ' the' ' son' ' of'
+ ' Philip' ',' ' and' ' the' ' son' ' of' ' Philip' ',' ' and' ' the']" , the son of Philip , and the son of Philip , and the son of Philip , and the False cases in which Demosthenes was personally involved, 6 [' cases', ' in', ' which', ' Dem', 'ost', 'hen', 'es']
+2931 674 Name of mother of x -1 Name of mother of Kamala Harris Shyamala Gopalan Kamala Harris "[',' ' the' ' daughter' ' of' ' a' ' former' ' president' ' of' ' the'
+ ' United' ' States' ',' ' and' ' a' ' former' ' attorney' ' general'
+ ' of' ' California' ',']" , the daughter of a former president of the United States , and a former attorney general of California , False District Attorney Kamala Harris issued a warrant 4 [' District', ' Attorney', ' Kam', 'ala', ' Harris']
+2932 674 Name of mother of x -1 Name of mother of Kamala Harris Shyamala Gopalan Kamala Harris "[',' ' the' ' daughter' ' of' ' a' ' former' ' president' ' of' ' the'
+ ' United' ' States' ',' ' and' ' a' ' former' ' attorney' ' general'
+ ' of' ' California' ',']" , the daughter of a former president of the United States , and a former attorney general of California , False District Attorney Kamala Harris issued a warrant 4 [' District', ' Attorney', ' Kam', 'ala', ' Harris']
+2933 675 Name of mother of x -1 Name of mother of John le Carré Olive Moore Cornwell John le Carré "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' film' ',' ' is' ' a' ' former' ' British' ' intelligence'
+ ' officer']" , the author of the book , and the author of the film , is a former British intelligence officer False (He has worked John le Carré and Graham Greene 7 [' (', 'He', ' has', ' worked', ' John', ' le', ' Carr', 'é']
+2934 675 Name of mother of x -1 Name of mother of John le Carré Olive Moore Cornwell John le Carré "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' film' ',' ' is' ' a' ' former' ' British' ' intelligence'
+ ' officer']" , the author of the book , and the author of the film , is a former British intelligence officer False coming across John le Carré novels at a 5 [' coming', ' across', ' John', ' le', ' Carr', 'é']
+2935 675 Name of mother of x -1 Name of mother of John le Carré Olive Moore Cornwell John le Carré "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' film' ',' ' is' ' a' ' former' ' British' ' intelligence'
+ ' officer']" , the author of the book , and the author of the film , is a former British intelligence officer False a character from a John le Carré novel. They 7 [' a', ' character', ' from', ' a', ' John', ' le', ' Carr', 'é']
+2936 675 Name of mother of x -1 Name of mother of John le Carré Olive Moore Cornwell John le Carré "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' film' ',' ' is' ' a' ' former' ' British' ' intelligence'
+ ' officer']" , the author of the book , and the author of the film , is a former British intelligence officer False coming across John le Carré novels at a Waterstone's 5 [' coming', ' across', ' John', ' le', ' Carr', 'é']
+2937 675 Name of mother of x -1 Name of mother of John le Carré Olive Moore Cornwell John le Carré "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' film' ',' ' is' ' a' ' former' ' British' ' intelligence'
+ ' officer']" , the author of the book , and the author of the film , is a former British intelligence officer False character from a John le Carré novel. Referring 6 [' character', ' from', ' a', ' John', ' le', ' Carr', 'é']
+2938 676 Name of mother of x -1 Name of mother of Wilhelm Ostwald Elisabeth Ostwald Wilhelm Ostwald "[',' ' the' ' German' ' chemist' ' who' ' discovered' ' the' ' law' ' of'
+ ' the' ' conservation' ' of' ' mass' '.' '\n' '\n' 'The' ' law' ' of'
+ ' the']" ", the German chemist who discovered the law of the conservation of mass .
+
+ The law of the" False chemistry) by Wilhelm Ostwald in 1892 and into 5 [' chemistry', ')', ' by', ' Wilhelm', ' Ost', 'wald']
+2939 676 Name of mother of x -1 Name of mother of Wilhelm Ostwald Elisabeth Ostwald Wilhelm Ostwald "[',' ' the' ' German' ' chemist' ' who' ' discovered' ' the' ' law' ' of'
+ ' the' ' conservation' ' of' ' mass' '.' '\n' '\n' 'The' ' law' ' of'
+ ' the']" ", the German chemist who discovered the law of the conservation of mass .
+
+ The law of the" False for chemistry) by Wilhelm Ostwald in 1892 and 6 [' for', ' chemistry', ')', ' by', ' Wilhelm', ' Ost', 'wald']
+2940 676 Name of mother of x -1 Name of mother of Wilhelm Ostwald Elisabeth Ostwald Wilhelm Ostwald "[',' ' the' ' German' ' chemist' ' who' ' discovered' ' the' ' law' ' of'
+ ' the' ' conservation' ' of' ' mass' '.' '\n' '\n' 'The' ' law' ' of'
+ ' the']" ", the German chemist who discovered the law of the conservation of mass .
+
+ The law of the" False for chemistry) by Wilhelm Ostwald in 1892 and 6 [' for', ' chemistry', ')', ' by', ' Wilhelm', ' Ost', 'wald']
+2941 677 Name of mother of x -1 Name of mother of Leoš Janáček Amálie Janáčková Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' and' ' singer' ',' ' was' ' born' ' in' ' Prague' ','
+ ' Czech']" , the Czech composer , and his wife , the actress and singer , was born in Prague , Czech False Season to work by Leoš Janáček and Symphony No. 1 9 [' Season', ' to', ' work', ' by', ' Leo', 'š', ' Jan', 'á', 'č', 'ek']
+2942 677 Name of mother of x -1 Name of mother of Leoš Janáček Amálie Janáčková Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' and' ' singer' ',' ' was' ' born' ' in' ' Prague' ','
+ ' Czech']" , the Czech composer , and his wife , the actress and singer , was born in Prague , Czech False The Operas of Leoš Janáček shortly before 9 [' The', ' Oper', 'as', ' of', ' Leo', 'š', ' Jan', 'á', 'č', 'ek']
+2943 677 Name of mother of x -1 Name of mother of Leoš Janáček Amálie Janáčková Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' and' ' singer' ',' ' was' ' born' ' in' ' Prague' ','
+ ' Czech']" , the Czech composer , and his wife , the actress and singer , was born in Prague , Czech False 6 ['Le', 'o', 'š', ' Jan', 'á', 'č', 'ek']
+2944 677 Name of mother of x -1 Name of mother of Leoš Janáček Amálie Janáčková Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' and' ' singer' ',' ' was' ' born' ' in' ' Prague' ','
+ ' Czech']" , the Czech composer , and his wife , the actress and singer , was born in Prague , Czech False Season to work by Leoš Janáček and Symphony No. 9 [' Season', ' to', ' work', ' by', ' Leo', 'š', ' Jan', 'á', 'č', 'ek']
+2945 677 Name of mother of x -1 Name of mother of Leoš Janáček Amálie Janáčková Leoš Janáček "[',' ' the' ' Czech' ' composer' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ' and' ' singer' ',' ' was' ' born' ' in' ' Prague' ','
+ ' Czech']" , the Czech composer , and his wife , the actress and singer , was born in Prague , Czech False " Leoš Janáček =
+" 5 [' Leo', 'š', ' Jan', 'á', 'č', 'ek']
+2946 678 Name of mother of x -1 Name of mother of Heinrich von Kleist Juliane Ulrike von Pannwitz Heinrich von Kleist "[',' ' the' ' German' ' poet' ',' ' who' ' was' ' born' ' in' ' 17' '77'
+ ',' ' and' ' died' ' in' ' 18' '11' '.' '\n' '\n']" ", the German poet , who was born in 17 77 , and died in 18 11 .
+
+" False " Grillparzer, and Heinrich von Kleist to be his ""true blood" 9 [' Grill', 'par', 'zer', ',', ' and', ' Hein', 'rich', ' von', ' Kle', 'ist']
+2947 679 Name of mother of x -1 Name of mother of Serge Gainsbourg Olia Ginsburg Serge Gainsbourg "[',' ' the' ' French' ' singer' '-' 'song' 'writer' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Jane' ' Bir' 'kin' '.']" , the French singer - song writer , and his wife , the actress and singer Jane Bir kin . False French duet between Serge Gainsbourg and Jane Birkin. 7 [' French', ' du', 'et', ' between', ' Serge', ' G', 'ains', 'bourg']
+2948 679 Name of mother of x -1 Name of mother of Serge Gainsbourg Olia Ginsburg Serge Gainsbourg "[',' ' the' ' French' ' singer' '-' 'song' 'writer' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Jane' ' Bir' 'kin' '.']" , the French singer - song writer , and his wife , the actress and singer Jane Bir kin . False French duet between Serge Gainsbourg and Jane Birkin. 7 [' French', ' du', 'et', ' between', ' Serge', ' G', 'ains', 'bourg']
+2949 679 Name of mother of x -1 Name of mother of Serge Gainsbourg Olia Ginsburg Serge Gainsbourg "[',' ' the' ' French' ' singer' '-' 'song' 'writer' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Jane' ' Bir' 'kin' '.']" , the French singer - song writer , and his wife , the actress and singer Jane Bir kin . False from 1958, while Serge Gainsbourg used the theme 7 [' from', ' 1958', ',', ' while', ' Serge', ' G', 'ains', 'bourg']
+2950 680 Name of mother of x -1 Name of mother of Lise Meitner Hedwig Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False explanation by Lise Meitner and Otto Frisch, 6 [' explanation', ' by', ' L', 'ise', ' Me', 'it', 'ner']
+2951 680 Name of mother of x -1 Name of mother of Lise Meitner Hedwig Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False fission. The news of Lise Meitner and Otto Frisch's 10 [' f', 'ission', '.', ' The', ' news', ' of', ' L', 'ise', ' Me', 'it', 'ner']
+2952 680 Name of mother of x -1 Name of mother of Lise Meitner Hedwig Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False fission by Lise Meitner in the February 7 [' f', 'ission', ' by', ' L', 'ise', ' Me', 'it', 'ner']
+2953 680 Name of mother of x -1 Name of mother of Lise Meitner Hedwig Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False by the Austrian Lise Meitner and Otto Hahn 7 [' by', ' the', ' Austrian', ' L', 'ise', ' Me', 'it', 'ner']
+2954 680 Name of mother of x -1 Name of mother of Lise Meitner Hedwig Meitner Lise Meitner "[',' ' the' ' daughter' ' of' ' the' ' famous' ' physicist' ',' ' who'
+ ' was' ' a' ' member' ' of' ' the' ' Nazi' ' Party' '.' '\n' '\n' 'The']" ", the daughter of the famous physicist , who was a member of the Nazi Party .
+
+ The" False neutrons, which Lise Meitner and her nephew 8 [' neut', 'rons', ',', ' which', ' L', 'ise', ' Me', 'it', 'ner']
+2955 681 Name of mother of x -1 Name of mother of Louis XV of France Marie Adélaïde of Savoy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' mother'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the mother of Louis XVI of France , and the" False it and survived). Louis XV of France succeeded his 7 [' it', ' and', ' survived', ').', ' Louis', ' XV', ' of', ' France']
+2956 681 Name of mother of x -1 Name of mother of Louis XV of France Marie Adélaïde of Savoy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' mother'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the mother of Louis XVI of France , and the" False Polish Succession. Louis XV of France demanded that 7 [' Polish', ' Success', 'ion', '.', ' Louis', ' XV', ' of', ' France']
+2957 681 Name of mother of x -1 Name of mother of Louis XV of France Marie Adélaïde of Savoy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' mother'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the mother of Louis XVI of France , and the" False made King Louis XV of France declare it the only 5 [' made', ' King', ' Louis', ' XV', ' of', ' France']
+2958 681 Name of mother of x -1 Name of mother of Louis XV of France Marie Adélaïde of Savoy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' mother'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the mother of Louis XVI of France , and the" False it and survived). Louis XV of France succeeded his great-grandfather 7 [' it', ' and', ' survived', ').', ' Louis', ' XV', ' of', ' France']
+2959 681 Name of mother of x -1 Name of mother of Louis XV of France Marie Adélaïde of Savoy Louis XV of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' mother'
+ ' of' ' Louis' ' XVI' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the mother of Louis XVI of France , and the" False and survived). Louis XV of France succeeded his great-grandfather 6 [' and', ' survived', ').', ' Louis', ' XV', ' of', ' France']
+2960 682 Name of mother of x -1 Name of mother of Adam Smith Margaret Douglas Adam Smith "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' economics' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' economics']" , the father of modern economics , and the father of modern economics , and the father of modern economics False physiocrats and Adam Smith applied Natural 4 [' physi', 'ocrats', ' and', ' Adam', ' Smith']
+2961 682 Name of mother of x -1 Name of mother of Adam Smith Margaret Douglas Adam Smith "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' economics' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' economics']" , the father of modern economics , and the father of modern economics , and the father of modern economics False and economist Adam Smith, who wrote his magnum 3 [' and', ' economist', ' Adam', ' Smith']
+2962 682 Name of mother of x -1 Name of mother of Adam Smith Margaret Douglas Adam Smith "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' economics' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' economics']" , the father of modern economics , and the father of modern economics , and the father of modern economics False she had worked with Adam Smith before. Patrick Moore 5 [' she', ' had', ' worked', ' with', ' Adam', ' Smith']
+2963 682 Name of mother of x -1 Name of mother of Adam Smith Margaret Douglas Adam Smith "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' economics' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' economics']" , the father of modern economics , and the father of modern economics , and the father of modern economics False " to apply. Adam Smith called them ""an absolute" 4 [' to', ' apply', '.', ' Adam', ' Smith']
+2964 682 Name of mother of x -1 Name of mother of Adam Smith Margaret Douglas Adam Smith "[',' ' the' ' father' ' of' ' modern' ' economics' ',' ' and' ' the'
+ ' father' ' of' ' modern' ' economics' ',' ' and' ' the' ' father' ' of'
+ ' modern' ' economics']" , the father of modern economics , and the father of modern economics , and the father of modern economics False free will. Adam Smith considered that 4 [' free', ' will', '.', ' Adam', ' Smith']
+2965 683 Name of mother of x -1 Name of mother of George Soros Elizabeth Soros George Soros "[',' ' the' ' billionaire' ' fin' 'anc' 'ier' ' and' ' philanthrop' 'ist'
+ ',' ' who' ' has' ' been' ' a' ' major' ' donor' ' to' ' the' ' Clinton'
+ ' Foundation']" , the billionaire fin anc ier and philanthrop ist , who has been a major donor to the Clinton Foundation False represented George Soros as one of their 2 [' represented', ' George', ' Soros']
+2966 683 Name of mother of x -1 Name of mother of George Soros Elizabeth Soros George Soros "[',' ' the' ' billionaire' ' fin' 'anc' 'ier' ' and' ' philanthrop' 'ist'
+ ',' ' who' ' has' ' been' ' a' ' major' ' donor' ' to' ' the' ' Clinton'
+ ' Foundation']" , the billionaire fin anc ier and philanthrop ist , who has been a major donor to the Clinton Foundation False Peter Lewis, and George Soros were the principal 5 [' Peter', ' Lewis', ',', ' and', ' George', ' Soros']
+2967 685 Name of mother of x -1 Name of mother of Mickey Rooney Nellie W. Carter Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' ' and' ' actress' '.' '\n' '\n']" ", the actor , and his wife , L illian , who was a singer and actress .
+
+" False and actors Mickey Rooney and Lash La 3 [' and', ' actors', ' Mickey', ' Rooney']
+2968 685 Name of mother of x -1 Name of mother of Mickey Rooney Nellie W. Carter Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' ' and' ' actress' '.' '\n' '\n']" ", the actor , and his wife , L illian , who was a singer and actress .
+
+" False with guest star Mickey Rooney helping to lift 4 [' with', ' guest', ' star', ' Mickey', ' Rooney']
+2969 685 Name of mother of x -1 Name of mother of Mickey Rooney Nellie W. Carter Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' ' and' ' actress' '.' '\n' '\n']" ", the actor , and his wife , L illian , who was a singer and actress .
+
+" False former child star Mickey Rooney, Milhouse gives up 4 [' former', ' child', ' star', ' Mickey', ' Rooney']
+2970 685 Name of mother of x -1 Name of mother of Mickey Rooney Nellie W. Carter Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' ' and' ' actress' '.' '\n' '\n']" ", the actor , and his wife , L illian , who was a singer and actress .
+
+" False including nine with Mickey Rooney and her most iconic 4 [' including', ' nine', ' with', ' Mickey', ' Rooney']
+2971 685 Name of mother of x -1 Name of mother of Mickey Rooney Nellie W. Carter Mickey Rooney "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' L' 'illian' ','
+ ' who' ' was' ' a' ' singer' ' and' ' actress' '.' '\n' '\n']" ", the actor , and his wife , L illian , who was a singer and actress .
+
+" False Susie Dietter. Mickey Rooney guest starred 6 [' Sus', 'ie', ' Diet', 'ter', '.', ' Mickey', ' Rooney']
+2972 686 Name of mother of x -1 Name of mother of Louis de Broglie Pauline Célestine de La Forest d'Armaillé Louis de Broglie "[',' ' the' ' French' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' French' ' Academy' ' of' ' Sciences' ',' ' and' ' a' ' member'
+ ' of']" , the French physicist , who was a member of the French Academy of Sciences , and a member of False 4 ['Louis', ' de', ' Bro', 'gl', 'ie']
+2973 686 Name of mother of x -1 Name of mother of Louis de Broglie Pauline Célestine de La Forest d'Armaillé Louis de Broglie "[',' ' the' ' French' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' French' ' Academy' ' of' ' Sciences' ',' ' and' ' a' ' member'
+ ' of']" , the French physicist , who was a member of the French Academy of Sciences , and a member of False 4 ['Louis', ' de', ' Bro', 'gl', 'ie']
+2974 686 Name of mother of x -1 Name of mother of Louis de Broglie Pauline Célestine de La Forest d'Armaillé Louis de Broglie "[',' ' the' ' French' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' French' ' Academy' ' of' ' Sciences' ',' ' and' ' a' ' member'
+ ' of']" , the French physicist , who was a member of the French Academy of Sciences , and a member of False 4 ['Louis', ' de', ' Bro', 'gl', 'ie']
+2975 686 Name of mother of x -1 Name of mother of Louis de Broglie Pauline Célestine de La Forest d'Armaillé Louis de Broglie "[',' ' the' ' French' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' French' ' Academy' ' of' ' Sciences' ',' ' and' ' a' ' member'
+ ' of']" , the French physicist , who was a member of the French Academy of Sciences , and a member of False 4 ['Louis', ' de', ' Bro', 'gl', 'ie']
+2976 687 Name of mother of x -1 Name of mother of Anna Akhmatova Inna Stogova Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' Russian' ' arist' 'ocrat' ',' ' was' ' a' ' great' ' admire' 'r' ' of']" , the poet ess , and the daughter of a Russian arist ocrat , was a great admire r of False be explored at the Anna Akhmatova Literary and 8 [' be', ' explored', ' at', ' the', ' Anna', ' Ak', 'h', 'mat', 'ova']
+2977 687 Name of mother of x -1 Name of mother of Anna Akhmatova Inna Stogova Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' Russian' ' arist' 'ocrat' ',' ' was' ' a' ' great' ' admire' 'r' ' of']" , the poet ess , and the daughter of a Russian arist ocrat , was a great admire r of False to greet us. Anna Akhmatova was immensely 8 [' to', ' greet', ' us', '.', ' Anna', ' Ak', 'h', 'mat', 'ova']
+2978 687 Name of mother of x -1 Name of mother of Anna Akhmatova Inna Stogova Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' Russian' ' arist' 'ocrat' ',' ' was' ' a' ' great' ' admire' 'r' ' of']" , the poet ess , and the daughter of a Russian arist ocrat , was a great admire r of False to greet us. Anna Akhmatova was immensely 8 [' to', ' greet', ' us', '.', ' Anna', ' Ak', 'h', 'mat', 'ova']
+2979 687 Name of mother of x -1 Name of mother of Anna Akhmatova Inna Stogova Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' Russian' ' arist' 'ocrat' ',' ' was' ' a' ' great' ' admire' 'r' ' of']" , the poet ess , and the daughter of a Russian arist ocrat , was a great admire r of False by the pen name Anna Akhmatova (/ ɑːkˈmɑːtɔːvə /; 8 [' by', ' the', ' pen', ' name', ' Anna', ' Ak', 'h', 'mat', 'ova']
+2980 687 Name of mother of x -1 Name of mother of Anna Akhmatova Inna Stogova Anna Akhmatova "[',' ' the' ' poet' 'ess' ',' ' and' ' the' ' daughter' ' of' ' a'
+ ' Russian' ' arist' 'ocrat' ',' ' was' ' a' ' great' ' admire' 'r' ' of']" , the poet ess , and the daughter of a Russian arist ocrat , was a great admire r of False " Anna Akhmatova =
+" 4 [' Anna', ' Ak', 'h', 'mat', 'ova']
+2981 688 Name of mother of x -1 Name of mother of Paracelsus NN Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of mother of Par ac els" False corpus, c. 1310), Paracelsus (De Natura Rerum 10 [' corpus', ',', ' c', '.', ' 13', '10', '),', ' Par', 'ac', 'els', 'us']
+2982 688 Name of mother of x -1 Name of mother of Paracelsus NN Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of mother of Par ac els" False end is from Paracelsus by Robert Browning. 6 [' end', ' is', ' from', ' Par', 'ac', 'els', 'us']
+2983 688 Name of mother of x -1 Name of mother of Paracelsus NN Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of mother of Par ac els" False gases, although Paracelsus around 1500, Robert 6 [' gases', ',', ' although', ' Par', 'ac', 'els', 'us']
+2984 688 Name of mother of x -1 Name of mother of Paracelsus NN Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of mother of Par ac els" False corpus, c. 1310), Paracelsus (De Natura Rerum 10 [' corpus', ',', ' c', '.', ' 13', '10', '),', ' Par', 'ac', 'els', 'us']
+2985 688 Name of mother of x -1 Name of mother of Paracelsus NN Paracelsus "[',' ' the' ' great' ' al' 'chemist' ',' ' who' ' was' ' the' ' first'
+ ' to' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Par' 'ac' 'els']" ", the great al chemist , who was the first to
+
+ Name of mother of Par ac els" False c. 1310), Paracelsus (De Natura Rerum 8 [' c', '.', ' 13', '10', '),', ' Par', 'ac', 'els', 'us']
+2986 689 Name of mother of x -1 Name of mother of Janis Joplin Dorothy Bonita Joplin Janis Joplin "[',' ' the' ' singer' ',' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I']" ", the singer , who died in 1970 .
+
+ The first time I saw the movie , I" False " ""Mercedes Benz"" by Janis Joplin and ""Love" 10 "[' ""', 'Mer', 'cedes', ' Benz', '""', ' by', ' Jan', 'is', ' J', 'op', 'lin']"
+2987 689 Name of mother of x -1 Name of mother of Janis Joplin Dorothy Bonita Joplin Janis Joplin "[',' ' the' ' singer' ',' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I']" ", the singer , who died in 1970 .
+
+ The first time I saw the movie , I" False biographical film on singer Janis Joplin, and she decides 9 [' bi', 'ographical', ' film', ' on', ' singer', ' Jan', 'is', ' J', 'op', 'lin']
+2988 689 Name of mother of x -1 Name of mother of Janis Joplin Dorothy Bonita Joplin Janis Joplin "[',' ' the' ' singer' ',' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I']" ", the singer , who died in 1970 .
+
+ The first time I saw the movie , I" False " song ""Otis"". Janis Joplin was influenced" 9 "[' song', ' ""', 'O', 'tis', '"".', ' Jan', 'is', ' J', 'op', 'lin']"
+2989 689 Name of mother of x -1 Name of mother of Janis Joplin Dorothy Bonita Joplin Janis Joplin "[',' ' the' ' singer' ',' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I']" ", the singer , who died in 1970 .
+
+ The first time I saw the movie , I" False biographical film on singer Janis Joplin, and she decides 9 [' bi', 'ographical', ' film', ' on', ' singer', ' Jan', 'is', ' J', 'op', 'lin']
+2990 689 Name of mother of x -1 Name of mother of Janis Joplin Dorothy Bonita Joplin Janis Joplin "[',' ' the' ' singer' ',' ' who' ' died' ' in' ' 1970' '.' '\n' '\n' 'The'
+ ' first' ' time' ' I' ' saw' ' the' ' movie' ',' ' I']" ", the singer , who died in 1970 .
+
+ The first time I saw the movie , I" False exclamations of Janis Joplin in the late' 60s 9 [' ex', 'cl', 'am', 'ations', ' of', ' Jan', 'is', ' J', 'op', 'lin']
+2991 690 Name of mother of x -1 Name of mother of Maria Theresa of Austria Elisabeth Christine of Brunswick-Wolfenbüttel Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' ' mother' ' of' ' the' ' Emperor' ' Francis' ' II' '.' '\n' '\n']" ", the daughter of the Emperor Joseph II , and the mother of the Emperor Francis II .
+
+" False and Archduchess Maria Theresa of Austria that the balance 8 [' and', ' Arch', 'du', 'che', 'ss', ' Maria', ' Theresa', ' of', ' Austria']
+2992 690 Name of mother of x -1 Name of mother of Maria Theresa of Austria Elisabeth Christine of Brunswick-Wolfenbüttel Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' ' mother' ' of' ' the' ' Emperor' ' Francis' ' II' '.' '\n' '\n']" ", the daughter of the Emperor Joseph II , and the mother of the Emperor Francis II .
+
+" False ceased when Empress Maria Theresa of Austria sent her personal physician, 6 [' ceased', ' when', ' Empress', ' Maria', ' Theresa', ' of', ' Austria']
+2993 690 Name of mother of x -1 Name of mother of Maria Theresa of Austria Elisabeth Christine of Brunswick-Wolfenbüttel Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' ' mother' ' of' ' the' ' Emperor' ' Francis' ' II' '.' '\n' '\n']" ", the daughter of the Emperor Joseph II , and the mother of the Emperor Francis II .
+
+" False the court of Maria Theresa of Austria at Schönbrunn 6 [' the', ' court', ' of', ' Maria', ' Theresa', ' of', ' Austria']
+2994 690 Name of mother of x -1 Name of mother of Maria Theresa of Austria Elisabeth Christine of Brunswick-Wolfenbüttel Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' ' mother' ' of' ' the' ' Emperor' ' Francis' ' II' '.' '\n' '\n']" ", the daughter of the Emperor Joseph II , and the mother of the Emperor Francis II .
+
+" False charter issued by Maria Theresa of Austria in 1741. The unit 6 [' charter', ' issued', ' by', ' Maria', ' Theresa', ' of', ' Austria']
+2995 690 Name of mother of x -1 Name of mother of Maria Theresa of Austria Elisabeth Christine of Brunswick-Wolfenbüttel Maria Theresa of Austria "[',' ' the' ' daughter' ' of' ' the' ' Emperor' ' Joseph' ' II' ',' ' and'
+ ' the' ' mother' ' of' ' the' ' Emperor' ' Francis' ' II' '.' '\n' '\n']" ", the daughter of the Emperor Joseph II , and the mother of the Emperor Francis II .
+
+" False a charter issued by Maria Theresa of Austria in 1741. The 7 [' a', ' charter', ' issued', ' by', ' Maria', ' Theresa', ' of', ' Austria']
+2996 691 Name of mother of x -1 Name of mother of Gustaf VI Adolf of Sweden Victoria of Baden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' the' ' current' ' monarch']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , the current monarch" False 11 July by King Gustaf VI Adolf of Sweden before returning 9 [' 11', ' July', ' by', ' King', ' Gust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+2997 691 Name of mother of x -1 Name of mother of Gustaf VI Adolf of Sweden Victoria of Baden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' the' ' current' ' monarch']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , the current monarch" False 6 ['G', 'ust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+2998 691 Name of mother of x -1 Name of mother of Gustaf VI Adolf of Sweden Victoria of Baden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' the' ' current' ' monarch']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , the current monarch" False 11 July by King Gustaf VI Adolf of Sweden before returning 9 [' 11', ' July', ' by', ' King', ' Gust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+2999 691 Name of mother of x -1 Name of mother of Gustaf VI Adolf of Sweden Victoria of Baden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' the' ' current' ' monarch']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , the current monarch" False on 11 July by King Gustaf VI Adolf of Sweden before returning 10 [' on', ' 11', ' July', ' by', ' King', ' Gust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+3000 691 Name of mother of x -1 Name of mother of Gustaf VI Adolf of Sweden Victoria of Baden Gustaf VI Adolf of Sweden "['\n' '\n' 'The' ' King' ' of' ' Sweden' ' is' ' the' ' head' ' of' ' the'
+ ' House' ' of' ' Bern' 'ad' 'otte' ',' ' the' ' current' ' monarch']" "
+
+ The King of Sweden is the head of the House of Bern ad otte , the current monarch" False July by King Gustaf VI Adolf of Sweden before returning home 8 [' July', ' by', ' King', ' Gust', 'af', ' VI', ' Adolf', ' of', ' Sweden']
+3001 692 Name of mother of x -1 Name of mother of Prince Philip, Duke of Edinburgh Princess Alice of Battenberg Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False officially opened by Prince Philip, Duke of Edinburgh on 26 October 8 [' officially', ' opened', ' by', ' Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3002 692 Name of mother of x -1 Name of mother of Prince Philip, Duke of Edinburgh Princess Alice of Battenberg Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False 5 ['Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3003 692 Name of mother of x -1 Name of mother of Prince Philip, Duke of Edinburgh Princess Alice of Battenberg Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False officially opened by Prince Philip, Duke of Edinburgh on 26 October 8 [' officially', ' opened', ' by', ' Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3004 692 Name of mother of x -1 Name of mother of Prince Philip, Duke of Edinburgh Princess Alice of Battenberg Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False engagement to the Prince Philip, Duke of Edinburgh on 10 July 1947. 8 [' engagement', ' to', ' the', ' Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3005 692 Name of mother of x -1 Name of mother of Prince Philip, Duke of Edinburgh Princess Alice of Battenberg Prince Philip, Duke of Edinburgh "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Duke'
+ ' of' ' Edinburgh' ',' ' Prince' ' Philip' ',' ' and' ' the' ' Queen']" ", and the Queen of England .
+
+ The Duke of Edinburgh , Prince Philip , and the Queen" False engagement to the Prince Philip, Duke of Edinburgh on 10 July 1947. 8 [' engagement', ' to', ' the', ' Prince', ' Philip', ',', ' Duke', ' of', ' Edinburgh']
+3006 693 Name of mother of x -1 Name of mother of Charles II of England Henrietta Maria of France Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False grant that Charles II of England awarded to seven 5 [' grant', ' that', ' Charles', ' II', ' of', ' England']
+3007 693 Name of mother of x -1 Name of mother of Charles II of England Henrietta Maria of France Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False land grant from Charles II of England to seven of his 6 [' land', ' grant', ' from', ' Charles', ' II', ' of', ' England']
+3008 693 Name of mother of x -1 Name of mother of Charles II of England Henrietta Maria of France Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False marriage treaty of Charles II of England and Catherine of Braganza, 6 [' marriage', ' treaty', ' of', ' Charles', ' II', ' of', ' England']
+3009 693 Name of mother of x -1 Name of mother of Charles II of England Henrietta Maria of France Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False the reign of Charles II of England and is commanded 6 [' the', ' reign', ' of', ' Charles', ' II', ' of', ' England']
+3010 693 Name of mother of x -1 Name of mother of Charles II of England Henrietta Maria of France Charles II of England "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False lands granted by King Charles II of England to the eight 7 [' lands', ' granted', ' by', ' King', ' Charles', ' II', ' of', ' England']
+3011 694 Name of mother of x -1 Name of mother of Greta Garbo Anna Lovisa Johansdotter Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Daughters (1928). Greta Garbo was the one non 9 [' D', 'aughters', ' (', '19', '28', ').', ' Gret', 'a', ' Gar', 'bo']
+3012 694 Name of mother of x -1 Name of mother of Greta Garbo Anna Lovisa Johansdotter Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False Swedish-American actress Greta Garbo accosted him on 7 [' Swedish', '-', 'American', ' actress', ' Gret', 'a', ' Gar', 'bo']
+3013 694 Name of mother of x -1 Name of mother of Greta Garbo Anna Lovisa Johansdotter Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False French woman in a Greta Garbo film and as 7 [' French', ' woman', ' in', ' a', ' Gret', 'a', ' Gar', 'bo']
+3014 694 Name of mother of x -1 Name of mother of Greta Garbo Anna Lovisa Johansdotter Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False together with Greta Garbo in the 1932 film 5 [' together', ' with', ' Gret', 'a', ' Gar', 'bo']
+3015 694 Name of mother of x -1 Name of mother of Greta Garbo Anna Lovisa Johansdotter Greta Garbo "[',' ' the' ' actress' ',' ' and' ' the' ' mother' ' of' ' the' ' actress'
+ ',' ' the' ' mother' ' of' ' the' ' actress' ',' ' the' ' mother' ' of']" , the actress , and the mother of the actress , the mother of the actress , the mother of False " Salinger ""the Greta Garbo of literature"".
+" 7 "[' Sal', 'inger', ' ""', 'the', ' Gret', 'a', ' Gar', 'bo']"
+3016 695 Name of mother of x -1 Name of mother of Henry Fielding Sarah Gould Henry Fielding "[',' ' the' ' son' ' of' ' the' ' late' ' Henry' ' Field' 'ing' ',' ' and'
+ ' the' ' father' ' of' ' the' ' present' ' Henry' ' Field' 'ing' '.']" , the son of the late Henry Field ing , and the father of the present Henry Field ing . False English novel after Henry Fielding (1707 – 1754) 5 [' English', ' novel', ' after', ' Henry', ' Field', 'ing']
+3017 695 Name of mother of x -1 Name of mother of Henry Fielding Sarah Gould Henry Fielding "[',' ' the' ' son' ' of' ' the' ' late' ' Henry' ' Field' 'ing' ',' ' and'
+ ' the' ' father' ' of' ' the' ' present' ' Henry' ' Field' 'ing' '.']" , the son of the late Henry Field ing , and the father of the present Henry Field ing . False the writer Henry Fielding (1707 – 54), and 4 [' the', ' writer', ' Henry', ' Field', 'ing']
+3018 695 Name of mother of x -1 Name of mother of Henry Fielding Sarah Gould Henry Fielding "[',' ' the' ' son' ' of' ' the' ' late' ' Henry' ' Field' 'ing' ',' ' and'
+ ' the' ' father' ' of' ' the' ' present' ' Henry' ' Field' 'ing' '.']" , the son of the late Henry Field ing , and the father of the present Henry Field ing . False as a man again. Henry Fielding wrote a pamphlet 7 [' as', ' a', ' man', ' again', '.', ' Henry', ' Field', 'ing']
+3019 695 Name of mother of x -1 Name of mother of Henry Fielding Sarah Gould Henry Fielding "[',' ' the' ' son' ' of' ' the' ' late' ' Henry' ' Field' 'ing' ',' ' and'
+ ' the' ' father' ' of' ' the' ' present' ' Henry' ' Field' 'ing' '.']" , the son of the late Henry Field ing , and the father of the present Henry Field ing . False The early plays of Henry Fielding mark the beginning 6 [' The', ' early', ' plays', ' of', ' Henry', ' Field', 'ing']
+3020 695 Name of mother of x -1 Name of mother of Henry Fielding Sarah Gould Henry Fielding "[',' ' the' ' son' ' of' ' the' ' late' ' Henry' ' Field' 'ing' ',' ' and'
+ ' the' ' father' ' of' ' the' ' present' ' Henry' ' Field' 'ing' '.']" , the son of the late Henry Field ing , and the father of the present Henry Field ing . False " Cervantes and Henry Fielding as one of ""the" 6 [' C', 'erv', 'antes', ' and', ' Henry', ' Field', 'ing']
+3021 696 Name of mother of x -1 Name of mother of Keanu Reeves Patricia Taylor Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' since']" ", who is a very good friend of mine .
+
+ I have been a fan of his since" False of contest with Keanu Reeves to see how long he 5 [' of', ' contest', ' with', ' Ke', 'anu', ' Reeves']
+3022 696 Name of mother of x -1 Name of mother of Keanu Reeves Patricia Taylor Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' since']" ", who is a very good friend of mine .
+
+ I have been a fan of his since" False Smith had met Keanu Reeves before - on 5 [' Smith', ' had', ' met', ' Ke', 'anu', ' Reeves']
+3023 696 Name of mother of x -1 Name of mother of Keanu Reeves Patricia Taylor Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' since']" ", who is a very good friend of mine .
+
+ I have been a fan of his since" False also starred Keanu Reeves and Morgan Freeman. 4 [' also', ' starred', ' Ke', 'anu', ' Reeves']
+3024 696 Name of mother of x -1 Name of mother of Keanu Reeves Patricia Taylor Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' since']" ", who is a very good friend of mine .
+
+ I have been a fan of his since" False also starred Keanu Reeves and Morgan Freeman. 4 [' also', ' starred', ' Ke', 'anu', ' Reeves']
+3025 696 Name of mother of x -1 Name of mother of Keanu Reeves Patricia Taylor Keanu Reeves "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' since']" ", who is a very good friend of mine .
+
+ I have been a fan of his since" False produced the successful Keanu Reeves film Speed. Stoff was 5 [' produced', ' the', ' successful', ' Ke', 'anu', ' Reeves']
+3026 697 Name of mother of x -1 Name of mother of Justinian I Vigilantia Justinian I "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' important' ',' ' because' ' it' ' is' ' the' ' one' ' that' ' is']" ".
+
+ The first of the two is the most important , because it is the one that is" False nephew of Emperor Justinian I (r. 527 – 565), he 5 [' nephew', ' of', ' Emperor', ' Justin', 'ian', ' I']
+3027 697 Name of mother of x -1 Name of mother of Justinian I Vigilantia Justinian I "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' important' ',' ' because' ' it' ' is' ' the' ' one' ' that' ' is']" ".
+
+ The first of the two is the most important , because it is the one that is" False of Emperor Justinian I (r. 527 – 565), 4 [' of', ' Emperor', ' Justin', 'ian', ' I']
+3028 697 Name of mother of x -1 Name of mother of Justinian I Vigilantia Justinian I "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' important' ',' ' because' ' it' ' is' ' the' ' one' ' that' ' is']" ".
+
+ The first of the two is the most important , because it is the one that is" False The legislation of Justinian I (r. 527 – 565) often 5 [' The', ' legislation', ' of', ' Justin', 'ian', ' I']
+3029 697 Name of mother of x -1 Name of mother of Justinian I Vigilantia Justinian I "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' important' ',' ' because' ' it' ' is' ' the' ' one' ' that' ' is']" ".
+
+ The first of the two is the most important , because it is the one that is" False reign of emperor Justinian I in the 6th 5 [' reign', ' of', ' emperor', ' Justin', 'ian', ' I']
+3030 697 Name of mother of x -1 Name of mother of Justinian I Vigilantia Justinian I "['.' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ' is' ' the' ' most'
+ ' important' ',' ' because' ' it' ' is' ' the' ' one' ' that' ' is']" ".
+
+ The first of the two is the most important , because it is the one that is" False Roman predecessors. Justinian I (r. 527 – 565), for 5 [' Roman', ' predecessors', '.', ' Justin', 'ian', ' I']
+3031 698 Name of mother of x -1 Name of mother of Hermann Göring Franziska Tiefenbrunn Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False 4 ['H', 'erman', 'n', ' Gö', 'ring']
+3032 698 Name of mother of x -1 Name of mother of Hermann Göring Franziska Tiefenbrunn Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False On 14 November, Hermann Göring — Commander-in-Chief 7 [' On', ' 14', ' November', ',', ' Herman', 'n', ' Gö', 'ring']
+3033 698 Name of mother of x -1 Name of mother of Hermann Göring Franziska Tiefenbrunn Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False " Hermann Göring =
+" 3 [' Herman', 'n', ' Gö', 'ring']
+3034 698 Name of mother of x -1 Name of mother of Hermann Göring Franziska Tiefenbrunn Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False Reichsmarschall Hermann Göring exempted the He 7 [' Reich', 'sm', 'ars', 'chall', ' Herman', 'n', ' Gö', 'ring']
+3035 698 Name of mother of x -1 Name of mother of Hermann Göring Franziska Tiefenbrunn Hermann Göring "[',' ' the' ' Nazi' ' leader' ',' ' was' ' a' ' member' ' of' ' the'
+ ' Nazi' ' Party' '.' '\n' '\n' 'The' ' Nazis' ' were' ' not' ' the']" ", the Nazi leader , was a member of the Nazi Party .
+
+ The Nazis were not the" False launching ceremony, Hermann Göring gave a speech. 6 [' launching', ' ceremony', ',', ' Herman', 'n', ' Gö', 'ring']
+3036 699 Name of mother of x -1 Name of mother of Augustus John Augusta Smith Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False James. Notable artists Augustus John and Walter Sickert. 6 [' James', '.', ' Not', 'able', ' artists', ' Augustus', ' John']
+3037 699 Name of mother of x -1 Name of mother of Augustus John Augusta Smith Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False capitals of Europe. Augustus John and his sister Gwen 5 [' capitals', ' of', ' Europe', '.', ' Augustus', ' John']
+3038 699 Name of mother of x -1 Name of mother of Augustus John Augusta Smith Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False Notable artists Augustus John and Walter Sickert. 4 [' Not', 'able', ' artists', ' Augustus', ' John']
+3039 699 Name of mother of x -1 Name of mother of Augustus John Augusta Smith Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False of Europe. Augustus John and his sister Gwen 4 [' of', ' Europe', '.', ' Augustus', ' John']
+3040 699 Name of mother of x -1 Name of mother of Augustus John Augusta Smith Augustus John "[',' ' the' ' son' ' of' ' the' ' late' ' John' ' and' ' Mary' ' Ann' ' ('
+ 'nee' ' Smith' ')' ' John' '.' '\n' '\n' 'I' ' am']" ", the son of the late John and Mary Ann ( nee Smith ) John .
+
+ I am" False capitals of Europe. Augustus John and his sister 5 [' capitals', ' of', ' Europe', '.', ' Augustus', ' John']
+3041 700 Name of mother of x -1 Name of mother of Serena Williams Oracene Price Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' her' ' husband' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and her husband , Alexis Oh anian , the co - founder of Reddit , False career win over Serena Williams before losing in 5 [' career', ' win', ' over', ' Sere', 'na', ' Williams']
+3042 700 Name of mother of x -1 Name of mother of Serena Williams Oracene Price Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' her' ' husband' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and her husband , Alexis Oh anian , the co - founder of Reddit , False top seed, due to Serena Williams and Safina's withdrawal 7 [' top', ' seed', ',', ' due', ' to', ' Sere', 'na', ' Williams']
+3043 700 Name of mother of x -1 Name of mother of Serena Williams Oracene Price Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' her' ' husband' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and her husband , Alexis Oh anian , the co - founder of Reddit , False 1 tennis player Serena Williams during the 2014 US 5 [' 1', ' tennis', ' player', ' Sere', 'na', ' Williams']
+3044 700 Name of mother of x -1 Name of mother of Serena Williams Oracene Price Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' her' ' husband' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and her husband , Alexis Oh anian , the co - founder of Reddit , False sisters Venus and Serena Williams as themselves. 5 [' sisters', ' Venus', ' and', ' Sere', 'na', ' Williams']
+3045 700 Name of mother of x -1 Name of mother of Serena Williams Oracene Price Serena Williams "[',' ' the' ' tennis' ' player' ',' ' and' ' her' ' husband' ',' ' Alexis'
+ ' Oh' 'anian' ',' ' the' ' co' '-' 'founder' ' of' ' Reddit' ',']" , the tennis player , and her husband , Alexis Oh anian , the co - founder of Reddit , False eventual champion Serena Williams in three sets. 4 [' eventual', ' champion', ' Sere', 'na', ' Williams']
+3046 702 Name of mother of x -1 Name of mother of Samuel Johnson Sarah Ford Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False usually critical Samuel Johnson had to admit that it 3 [' usually', ' critical', ' Samuel', ' Johnson']
+3047 702 Name of mother of x -1 Name of mother of Samuel Johnson Sarah Ford Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False members included Samuel Johnson concluded that the 3 [' members', ' included', ' Samuel', ' Johnson']
+3048 702 Name of mother of x -1 Name of mother of Samuel Johnson Sarah Ford Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False publications. In 1755 Samuel Johnson published his A Dictionary 6 [' publications', '.', ' In', ' 17', '55', ' Samuel', ' Johnson']
+3049 702 Name of mother of x -1 Name of mother of Samuel Johnson Sarah Ford Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False by English author Samuel Johnson in a similar 4 [' by', ' English', ' author', ' Samuel', ' Johnson']
+3050 702 Name of mother of x -1 Name of mother of Samuel Johnson Sarah Ford Samuel Johnson "[',' ' the' ' great' '-' 'grand' 'son' ' of' ' the' ' great' '-' 'grand'
+ 'son' ' of' ' the' ' great' '-' 'great' '-' 'grand' 'son']" , the great - grand son of the great - grand son of the great - great - grand son False Sketch of Dr Samuel Johnson (1784); Boswell's The 4 [' Sketch', ' of', ' Dr', ' Samuel', ' Johnson']
+3051 703 Name of mother of x -1 Name of mother of Heinrich Himmler Anna Heyder Heinrich Himmler "[',' ' the' ' Nazi' ' leader' ',' ' was' ' born' ' in' ' the' ' town'
+ ' of' ' W' 'ür' 'z' 'burg' ',' ' Germany' ',' ' on' ' the']" , the Nazi leader , was born in the town of W ür z burg , Germany , on the False speeches made by Heinrich Himmler in October 1943 7 [' speeches', ' made', ' by', ' Hein', 'rich', ' H', 'imm', 'ler']
+3052 703 Name of mother of x -1 Name of mother of Heinrich Himmler Anna Heyder Heinrich Himmler "[',' ' the' ' Nazi' ' leader' ',' ' was' ' born' ' in' ' the' ' town'
+ ' of' ' W' 'ür' 'z' 'burg' ',' ' Germany' ',' ' on' ' the']" , the Nazi leader , was born in the town of W ür z burg , Germany , on the False While SS leader Heinrich Himmler remained concerned 7 [' While', ' SS', ' leader', ' Hein', 'rich', ' H', 'imm', 'ler']
+3053 703 Name of mother of x -1 Name of mother of Heinrich Himmler Anna Heyder Heinrich Himmler "[',' ' the' ' Nazi' ' leader' ',' ' was' ' born' ' in' ' the' ' town'
+ ' of' ' W' 'ür' 'z' 'burg' ',' ' Germany' ',' ' on' ' the']" , the Nazi leader , was born in the town of W ür z burg , Germany , on the False of Reichsführer-SS Heinrich Himmler as head of the sports 12 [' of', ' Reich', 'sf', 'ü', 'h', 'rer', '-', 'SS', ' Hein', 'rich', ' H', 'imm', 'ler']
+3054 703 Name of mother of x -1 Name of mother of Heinrich Himmler Anna Heyder Heinrich Himmler "[',' ' the' ' Nazi' ' leader' ',' ' was' ' born' ' in' ' the' ' town'
+ ' of' ' W' 'ür' 'z' 'burg' ',' ' Germany' ',' ' on' ' the']" , the Nazi leader , was born in the town of W ür z burg , Germany , on the False instigated the unrest. Heinrich Himmler had wanted to put Galland 9 [' inst', 'igated', ' the', ' unrest', '.', ' Hein', 'rich', ' H', 'imm', 'ler']
+3055 703 Name of mother of x -1 Name of mother of Heinrich Himmler Anna Heyder Heinrich Himmler "[',' ' the' ' Nazi' ' leader' ',' ' was' ' born' ' in' ' the' ' town'
+ ' of' ' W' 'ür' 'z' 'burg' ',' ' Germany' ',' ' on' ' the']" , the Nazi leader , was born in the town of W ür z burg , Germany , on the False of the SS, Heinrich Himmler and Reich Minister 8 [' of', ' the', ' SS', ',', ' Hein', 'rich', ' H', 'imm', 'ler']
+3056 705 Name of mother of x -1 Name of mother of Alfred Russel Wallace Mary Anne Greenell Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False theory in 1858 when Alfred Russel Wallace sent him an essay 8 [' theory', ' in', ' 18', '58', ' when', ' Alfred', ' Rus', 'sel', ' Wallace']
+3057 705 Name of mother of x -1 Name of mother of Alfred Russel Wallace Mary Anne Greenell Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False British naturalist Alfred Russel Wallace described a dividing 6 [' British', ' natural', 'ist', ' Alfred', ' Rus', 'sel', ' Wallace']
+3058 705 Name of mother of x -1 Name of mother of Alfred Russel Wallace Mary Anne Greenell Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False theory in 1858 when Alfred Russel Wallace sent him an essay 8 [' theory', ' in', ' 18', '58', ' when', ' Alfred', ' Rus', 'sel', ' Wallace']
+3059 705 Name of mother of x -1 Name of mother of Alfred Russel Wallace Mary Anne Greenell Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False 5 ['A', 'lf', 'red', ' Rus', 'sel', ' Wallace']
+3060 705 Name of mother of x -1 Name of mother of Alfred Russel Wallace Mary Anne Greenell Alfred Russel Wallace "[',' ' the' ' famous' ' natural' 'ist' ',' ' who' ' was' ' a' ' member'
+ ' of' ' the' ' Royal' ' Society' ',' ' and' ' who' ' was' ' the' ' first']" , the famous natural ist , who was a member of the Royal Society , and who was the first False naturalist Alfred Russel Wallace and others 5 [' natural', 'ist', ' Alfred', ' Rus', 'sel', ' Wallace']
+3061 706 Name of mother of x -1 Name of mother of Jack Nicholson June Frances Nicholson Jack Nicholson "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Joker'
+ ' in' ' the' ' Batman' ' series' ' of' ' films' '.' '\n' '\n' 'The']" ", the actor who played the role of the Joker in the Batman series of films .
+
+ The" False the part as well. Jack Nicholson turned it down 6 [' the', ' part', ' as', ' well', '.', ' Jack', ' Nicholson']
+3062 706 Name of mother of x -1 Name of mother of Jack Nicholson June Frances Nicholson Jack Nicholson "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Joker'
+ ' in' ' the' ' Batman' ' series' ' of' ' films' '.' '\n' '\n' 'The']" ", the actor who played the role of the Joker in the Batman series of films .
+
+ The" False The movie starred Jack Nicholson and Meryl Streep, 4 [' The', ' movie', ' starred', ' Jack', ' Nicholson']
+3063 706 Name of mother of x -1 Name of mother of Jack Nicholson June Frances Nicholson Jack Nicholson "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Joker'
+ ' in' ' the' ' Batman' ' series' ' of' ' films' '.' '\n' '\n' 'The']" ", the actor who played the role of the Joker in the Batman series of films .
+
+ The" False Spielberg courted Jack Nicholson for the role of Edward 4 [' Spielberg', ' cour', 'ted', ' Jack', ' Nicholson']
+3064 706 Name of mother of x -1 Name of mother of Jack Nicholson June Frances Nicholson Jack Nicholson "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Joker'
+ ' in' ' the' ' Batman' ' series' ' of' ' films' '.' '\n' '\n' 'The']" ", the actor who played the role of the Joker in the Batman series of films .
+
+ The" False 1 ['Jack', ' Nicholson']
+3065 706 Name of mother of x -1 Name of mother of Jack Nicholson June Frances Nicholson Jack Nicholson "[',' ' the' ' actor' ' who' ' played' ' the' ' role' ' of' ' the' ' Joker'
+ ' in' ' the' ' Batman' ' series' ' of' ' films' '.' '\n' '\n' 'The']" ", the actor who played the role of the Joker in the Batman series of films .
+
+ The" False Wayne and Jack Nicholson first became 3 [' Wayne', ' and', ' Jack', ' Nicholson']
+3066 707 Name of mother of x -1 Name of mother of Michael Faraday Margaret Hastwell Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False Society in 1845 by Michael Faraday on his behalf. 7 [' Society', ' in', ' 18', '45', ' by', ' Michael', ' Far', 'aday']
+3067 707 Name of mother of x -1 Name of mother of Michael Faraday Margaret Hastwell Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False fellow scientists Michael Faraday and Thomas Huxley were 4 [' fellow', ' scientists', ' Michael', ' Far', 'aday']
+3068 707 Name of mother of x -1 Name of mother of Michael Faraday Margaret Hastwell Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False since 1834 when Michael Faraday published his 6 [' since', ' 18', '34', ' when', ' Michael', ' Far', 'aday']
+3069 707 Name of mother of x -1 Name of mother of Michael Faraday Margaret Hastwell Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False the scientist Michael Faraday why electricity 4 [' the', ' scientist', ' Michael', ' Far', 'aday']
+3070 707 Name of mother of x -1 Name of mother of Michael Faraday Margaret Hastwell Michael Faraday "[',' ' the' ' inventor' ' of' ' the' ' electric' ' motor' ',' ' and'
+ ' the' ' first' ' electric' ' motor' '.' '\n' '\n' 'The' ' first'
+ ' electric' ' motor']" ", the inventor of the electric motor , and the first electric motor .
+
+ The first electric motor" False known since 1834 when Michael Faraday published his works 7 [' known', ' since', ' 18', '34', ' when', ' Michael', ' Far', 'aday']
+3071 708 Name of mother of x -1 Name of mother of Drew Barrymore Jaid Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False response. Drew Barrymore read the script 4 [' response', '.', ' Drew', ' Barry', 'more']
+3072 708 Name of mother of x -1 Name of mother of Drew Barrymore Jaid Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False convinced actresses Drew Barrymore and Julia Roberts 4 [' convinced', ' actresses', ' Drew', ' Barry', 'more']
+3073 708 Name of mother of x -1 Name of mother of Drew Barrymore Jaid Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False as Bobby Darin, Drew Barrymore as Sandra Dee, 7 [' as', ' Bobby', ' Dar', 'in', ',', ' Drew', ' Barry', 'more']
+3074 708 Name of mother of x -1 Name of mother of Drew Barrymore Jaid Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False " everyone."" Actress Drew Barrymore starred in the first" 5 "[' everyone', '.""', ' Actress', ' Drew', ' Barry', 'more']"
+3075 708 Name of mother of x -1 Name of mother of Drew Barrymore Jaid Barrymore Drew Barrymore "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' ""' 'E' '.' 'T']" ", the actress who played the role of the mother of the bride in the movie "" E . T" False starred opposite Drew Barrymore in a parody of 4 [' starred', ' opposite', ' Drew', ' Barry', 'more']
+3076 709 Name of mother of x -1 Name of mother of Marguerite Yourcenar Fernande de Cartier de Marchienne Marguerite Yourcenar "[',' ' the' ' author' ' of' ' the' ' novel' ' _' 'Mem' 'oir' 's' ' of'
+ ' Had' 'rian' '_' ',' ' and' ' the' ' daughter' ' of' ' the']" , the author of the novel _ Mem oir s of Had rian _ , and the daughter of the False women, such as Marguerite Yourcenar and Mary Renault, 10 [' women', ',', ' such', ' as', ' Marg', 'uer', 'ite', ' Your', 'c', 'en', 'ar']
+3077 709 Name of mother of x -1 Name of mother of Marguerite Yourcenar Fernande de Cartier de Marchienne Marguerite Yourcenar "[',' ' the' ' author' ' of' ' the' ' novel' ' _' 'Mem' 'oir' 's' ' of'
+ ' Had' 'rian' '_' ',' ' and' ' the' ' daughter' ' of' ' the']" , the author of the novel _ Mem oir s of Had rian _ , and the daughter of the False Some women, such as Marguerite Yourcenar and Mary Renault, 11 [' Some', ' women', ',', ' such', ' as', ' Marg', 'uer', 'ite', ' Your', 'c', 'en', 'ar']
+3078 710 Name of mother of x -1 Name of mother of Aung San Suu Kyi Khin Kyi Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False in the refusal of Aung San Suu Kyi to leave Burma 10 [' in', ' the', ' refusal', ' of', ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']
+3079 710 Name of mother of x -1 Name of mother of Aung San Suu Kyi Khin Kyi Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False prize on his behalf. Aung San Suu Kyi was awarded 11 [' prize', ' on', ' his', ' behalf', '.', ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']
+3080 710 Name of mother of x -1 Name of mother of Aung San Suu Kyi Khin Kyi Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False as opposition leader Aung San Suu Kyi and sought to 9 [' as', ' opposition', ' leader', ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']
+3081 710 Name of mother of x -1 Name of mother of Aung San Suu Kyi Khin Kyi Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False fans to bring masks of Aung San Suu Kyi to concerts and wear 11 [' fans', ' to', ' bring', ' masks', ' of', ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']
+3082 710 Name of mother of x -1 Name of mother of Aung San Suu Kyi Khin Kyi Aung San Suu Kyi "[',' ' the' ' Bur' 'm' 'ese' ' pro' '-' 'democracy' ' leader' ',' ' was'
+ ' released' ' from' ' house' ' arrest' ' in' ' November' ' 2010' '.' '\n']" ", the Bur m ese pro - democracy leader , was released from house arrest in November 2010 .
+" False Ossietzky (1935) and Aung San Suu Kyi (1991); Liu and Ossietzky 15 [' Oss', 'iet', 'z', 'ky', ' (', '19', '35', ')', ' and', ' A', 'ung', ' San', ' Su', 'u', ' Ky', 'i']
+3083 711 Name of mother of x -1 Name of mother of Klemens von Metternich Maria Beatrix Gräfin von Kageneck Klemens von Metternich "[',' ' the' ' Austrian' ' chancellor' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' K' 'lem' 'ens' ' von' ' Met' 'tern' 'ich']" ", the Austrian chancellor , and the
+
+ Name of mother of K lem ens von Met tern ich" False " Metternich =
+" 10 [' Met', 'tern', 'ich', ' =', 'K', 'lem', 'ens', ' von', ' Met', 'tern', 'ich']
+3084 711 Name of mother of x -1 Name of mother of Klemens von Metternich Maria Beatrix Gräfin von Kageneck Klemens von Metternich "[',' ' the' ' Austrian' ' chancellor' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' K' 'lem' 'ens' ' von' ' Met' 'tern' 'ich']" ", the Austrian chancellor , and the
+
+ Name of mother of K lem ens von Met tern ich" False reactionary minister Klemens von Metternich and his secret 8 [' reactionary', ' minister', ' K', 'lem', 'ens', ' von', ' Met', 'tern', 'ich']
+3085 711 Name of mother of x -1 Name of mother of Klemens von Metternich Maria Beatrix Gräfin von Kageneck Klemens von Metternich "[',' ' the' ' Austrian' ' chancellor' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' K' 'lem' 'ens' ' von' ' Met' 'tern' 'ich']" ", the Austrian chancellor , and the
+
+ Name of mother of K lem ens von Met tern ich" False " Metternich =
+" 10 [' Met', 'tern', 'ich', ' =', 'K', 'lem', 'ens', ' von', ' Met', 'tern', 'ich']
+3086 711 Name of mother of x -1 Name of mother of Klemens von Metternich Maria Beatrix Gräfin von Kageneck Klemens von Metternich "[',' ' the' ' Austrian' ' chancellor' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' K' 'lem' 'ens' ' von' ' Met' 'tern' 'ich']" ", the Austrian chancellor , and the
+
+ Name of mother of K lem ens von Met tern ich" False reactionary minister Klemens von Metternich and his secret 8 [' reactionary', ' minister', ' K', 'lem', 'ens', ' von', ' Met', 'tern', 'ich']
+3087 712 Name of mother of x -1 Name of mother of John Singleton Copley Mary Singleton John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' artist' ""'s"" ' brother' ',']" , the artist , and his wife , Mary , who was the daughter of the artist 's brother , False Pierson by John Singleton Copley (1738 – 1815) 8 [' Pi', 'erson', ' by', ' John', ' Sing', 'leton', ' C', 'ople', 'y']
+3088 712 Name of mother of x -1 Name of mother of John Singleton Copley Mary Singleton John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' artist' ""'s"" ' brother' ',']" , the artist , and his wife , Mary , who was the daughter of the artist 's brother , False 5 ['John', ' Sing', 'leton', ' C', 'ople', 'y']
+3089 712 Name of mother of x -1 Name of mother of John Singleton Copley Mary Singleton John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' artist' ""'s"" ' brother' ',']" , the artist , and his wife , Mary , who was the daughter of the artist 's brother , False Chambers, Sr. and John Singleton Copley at the National 10 [' Chambers', ',', ' Sr', '.', ' and', ' John', ' Sing', 'leton', ' C', 'ople', 'y']
+3090 712 Name of mother of x -1 Name of mother of John Singleton Copley Mary Singleton John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' artist' ""'s"" ' brother' ',']" , the artist , and his wife , Mary , who was the daughter of the artist 's brother , False colonial 7 [' colon', 'ia', 'John', ' Sing', 'leton', ' C', 'ople', 'y']
+3091 712 Name of mother of x -1 Name of mother of John Singleton Copley Mary Singleton John Singleton Copley "[',' ' the' ' artist' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' artist' ""'s"" ' brother' ',']" , the artist , and his wife , Mary , who was the daughter of the artist 's brother , False appears in a John Singleton Copley portrait of ca. 1765 8 [' appears', ' in', ' a', ' John', ' Sing', 'leton', ' C', 'ople', 'y']
+3092 713 Name of mother of x -1 Name of mother of Andrei Tarkovsky Maria Tarkowska Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False filmmakers such as Andrei Tarkovsky and Robert Bresson. 7 [' filmmakers', ' such', ' as', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3093 713 Name of mother of x -1 Name of mother of Andrei Tarkovsky Maria Tarkowska Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False filmmakers such as Andrei Tarkovsky and Robert Bresson. 7 [' filmmakers', ' such', ' as', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3094 713 Name of mother of x -1 Name of mother of Andrei Tarkovsky Maria Tarkowska Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False film directed by Andrei Tarkovsky and co-written 7 [' film', ' directed', ' by', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3095 713 Name of mother of x -1 Name of mother of Andrei Tarkovsky Maria Tarkowska Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False filmmakers such as Andrei Tarkovsky and Robert Bresson. 7 [' filmmakers', ' such', ' as', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3096 713 Name of mother of x -1 Name of mother of Andrei Tarkovsky Maria Tarkowska Andrei Tarkovsky "[""'s"" ' _' 'Solar' 'is' '_' ',' ' and' ' the' ' film' ""'s"" ' director' ','
+ ' Andre' 'i' ' T' 'ark' 'ovsky' ',' ' was' ' a']" 's _ Solar is _ , and the film 's director , Andre i T ark ovsky , was a False the director Andrei Tarkovsky had two motives 6 [' the', ' director', ' Andre', 'i', ' T', 'ark', 'ovsky']
+3097 714 Name of mother of x -1 Name of mother of Thomas Moore Anastasia Codd Thomas Moore "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Thomas' ' Moore' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Thomas Moore ,
+" False Irish Melodies of Thomas Moore and ballads such as 5 [' Irish', ' Mel', 'odies', ' of', ' Thomas', ' Moore']
+3098 714 Name of mother of x -1 Name of mother of Thomas Moore Anastasia Codd Thomas Moore "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Thomas' ' Moore' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Thomas Moore ,
+" False based on a poem by Thomas Moore with characters including 6 [' based', ' on', ' a', ' poem', ' by', ' Thomas', ' Moore']
+3099 714 Name of mother of x -1 Name of mother of Thomas Moore Anastasia Codd Thomas Moore "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Thomas' ' Moore' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Thomas Moore ,
+" False construction shipwright Thomas Moore tested that the 4 [' construction', ' ship', 'wright', ' Thomas', ' Moore']
+3100 714 Name of mother of x -1 Name of mother of Thomas Moore Anastasia Codd Thomas Moore "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Thomas' ' Moore' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Thomas Moore ,
+" False Irving Kirsch and Thomas Moore state they may 5 [' Irving', ' Kir', 'sch', ' and', ' Thomas', ' Moore']
+3101 714 Name of mother of x -1 Name of mother of Thomas Moore Anastasia Codd Thomas Moore "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' names' ' of' ' the' ' children' ' of' ' Thomas' ' Moore' ',' '\n']" ", the
+
+ The following is a list of the names of the children of Thomas Moore ,
+" False described by Robert Thomas Moore in 1937, has dark 4 [' described', ' by', ' Robert', ' Thomas', ' Moore']
+3102 715 Name of mother of x -1 Name of mother of Peter the Great Natalya Naryshkina Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False Moldavia and Emperor Peter the Great probably established 6 [' Mold', 'avia', ' and', ' Emperor', ' Peter', ' the', ' Great']
+3103 715 Name of mother of x -1 Name of mother of Peter the Great Natalya Naryshkina Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False to 19 July in the Peter the Great Gulf in the 7 [' to', ' 19', ' July', ' in', ' the', ' Peter', ' the', ' Great']
+3104 715 Name of mother of x -1 Name of mother of Peter the Great Natalya Naryshkina Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False Sayes Court, and had Peter the Great (1672 – 1725) as a 8 [' Say', 'es', ' Court', ',', ' and', ' had', ' Peter', ' the', ' Great']
+3105 715 Name of mother of x -1 Name of mother of Peter the Great Natalya Naryshkina Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False collected from the Peter the Great Gulf in the 5 [' collected', ' from', ' the', ' Peter', ' the', ' Great']
+3106 715 Name of mother of x -1 Name of mother of Peter the Great Natalya Naryshkina Peter the Great "[',' ' and' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' three' ','
+ ' the' ' second' ' of' ' the' ' three' ',' ' and' ' the' ' third']" ", and the
+
+ The first of the three , the second of the three , and the third" False Russian Tsar Peter the Great had established 5 [' Russian', ' Ts', 'ar', ' Peter', ' the', ' Great']
+3107 716 Name of mother of x -1 Name of mother of Wolfgang Pauli Berta Kamilla Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False complicated. When he met Wolfgang Pauli for the first 7 [' complicated', '.', ' When', ' he', ' met', ' Wolfgang', ' Paul', 'i']
+3108 716 Name of mother of x -1 Name of mother of Wolfgang Pauli Berta Kamilla Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False work with Wolfgang Pauli on quantum mechanics 4 [' work', ' with', ' Wolfgang', ' Paul', 'i']
+3109 716 Name of mother of x -1 Name of mother of Wolfgang Pauli Berta Kamilla Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False work with Wolfgang Pauli on quantum mechanics 4 [' work', ' with', ' Wolfgang', ' Paul', 'i']
+3110 716 Name of mother of x -1 Name of mother of Wolfgang Pauli Berta Kamilla Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False their work with Wolfgang Pauli at the University of 5 [' their', ' work', ' with', ' Wolfgang', ' Paul', 'i']
+3111 716 Name of mother of x -1 Name of mother of Wolfgang Pauli Berta Kamilla Pauli Wolfgang Pauli "[',' ' the' ' famous' ' physicist' ',' ' who' ' was' ' a' ' member' ' of'
+ ' the' ' Vienna' ' Circle' ',' ' and' ' who' ' was' ' a' ' friend' ' of']" , the famous physicist , who was a member of the Vienna Circle , and who was a friend of False work. Physicist Wolfgang Pauli even used Kepler's 6 [' work', '.', ' Phys', 'icist', ' Wolfgang', ' Paul', 'i']
+3112 717 Name of mother of x -1 Name of mother of Philippe Pétain Clotilde Legrand Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' in'
+ ' the' ' Battle' ' of' ' Verd' 'un' ',' ' and' ' who' ' had' ' been'
+ ' captured']" , the French general who had been defeated in the Battle of Verd un , and who had been captured False the French. Marshal Philippe Pétain had suggested 7 [' the', ' French', '.', ' Marshal', ' Philippe', ' P', 'é', 'tain']
+3113 717 Name of mother of x -1 Name of mother of Philippe Pétain Clotilde Legrand Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' in'
+ ' the' ' Battle' ' of' ' Verd' 'un' ',' ' and' ' who' ' had' ' been'
+ ' captured']" , the French general who had been defeated in the Battle of Verd un , and who had been captured False crossed by tanks. Marshal Philippe Pétain described them as 8 [' crossed', ' by', ' tanks', '.', ' Marshal', ' Philippe', ' P', 'é', 'tain']
+3114 717 Name of mother of x -1 Name of mother of Philippe Pétain Clotilde Legrand Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' in'
+ ' the' ' Battle' ' of' ' Verd' 'un' ',' ' and' ' who' ' had' ' been'
+ ' captured']" , the French general who had been defeated in the Battle of Verd un , and who had been captured False proclamation by Marshal Philippe Pétain of the Vichy regime, 6 [' proclamation', ' by', ' Marshal', ' Philippe', ' P', 'é', 'tain']
+3115 717 Name of mother of x -1 Name of mother of Philippe Pétain Clotilde Legrand Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' in'
+ ' the' ' Battle' ' of' ' Verd' 'un' ',' ' and' ' who' ' had' ' been'
+ ' captured']" , the French general who had been defeated in the Battle of Verd un , and who had been captured False by General Philippe Pétain on 17 May. 5 [' by', ' General', ' Philippe', ' P', 'é', 'tain']
+3116 717 Name of mother of x -1 Name of mother of Philippe Pétain Clotilde Legrand Philippe Pétain "[',' ' the' ' French' ' general' ' who' ' had' ' been' ' defeated' ' in'
+ ' the' ' Battle' ' of' ' Verd' 'un' ',' ' and' ' who' ' had' ' been'
+ ' captured']" , the French general who had been defeated in the Battle of Verd un , and who had been captured False defensive-minded Philippe Pétain to the offensive-minded 6 [' defensive', '-', 'minded', ' Philippe', ' P', 'é', 'tain']
+3117 718 Name of mother of x -1 Name of mother of T. E. Lawrence Sarah Junner T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' T' '.' ' E' '.' ' Lawrence']" ", the author of the book , and the
+
+ Name of mother of T . E . Lawrence" False " 1912 the young T. E. Lawrence called it ""the best" 7 [' 1912', ' the', ' young', ' T', '.', ' E', '.', ' Lawrence']
+3118 718 Name of mother of x -1 Name of mother of T. E. Lawrence Sarah Junner T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' T' '.' ' E' '.' ' Lawrence']" ", the author of the book , and the
+
+ Name of mother of T . E . Lawrence" False " in 1912 the young T. E. Lawrence called it ""the" 8 [' in', ' 1912', ' the', ' young', ' T', '.', ' E', '.', ' Lawrence']
+3119 718 Name of mother of x -1 Name of mother of T. E. Lawrence Sarah Junner T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' T' '.' ' E' '.' ' Lawrence']" ", the author of the book , and the
+
+ Name of mother of T . E . Lawrence" False Lieutenant Colonel T. E. Lawrence later lampooned 6 [' Lieutenant', ' Colonel', ' T', '.', ' E', '.', ' Lawrence']
+3120 718 Name of mother of x -1 Name of mother of T. E. Lawrence Sarah Junner T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' T' '.' ' E' '.' ' Lawrence']" ", the author of the book , and the
+
+ Name of mother of T . E . Lawrence" False " academics. T. E. Lawrence (""Lawrence of Arabia"")," 6 [' academics', '.', ' T', '.', ' E', '.', ' Lawrence']
+3121 718 Name of mother of x -1 Name of mother of T. E. Lawrence Sarah Junner T. E. Lawrence "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' T' '.' ' E' '.' ' Lawrence']" ", the author of the book , and the
+
+ Name of mother of T . E . Lawrence" False " undergraduate thesis of T. E. Lawrence (""Lawrence of Arabia"")," 7 [' undergraduate', ' thesis', ' of', ' T', '.', ' E', '.', ' Lawrence']
+3122 719 Name of mother of x -1 Name of mother of Guglielmo Marconi Annie Jameson Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False 6 ['G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3123 719 Name of mother of x -1 Name of mother of Guglielmo Marconi Annie Jameson Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False port of Cardiff. Guglielmo Marconi transmitted the 10 [' port', ' of', ' Cardiff', '.', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3124 719 Name of mother of x -1 Name of mother of Guglielmo Marconi Annie Jameson Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False such as Bologna Guglielmo Marconi Airport, named after 11 [' such', ' as', ' B', 'olog', 'na', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3125 719 Name of mother of x -1 Name of mother of Guglielmo Marconi Annie Jameson Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False such as Bologna Guglielmo Marconi Airport, named 11 [' such', ' as', ' B', 'olog', 'na', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3126 719 Name of mother of x -1 Name of mother of Guglielmo Marconi Annie Jameson Guglielmo Marconi "[',' ' the' ' inventor' ' of' ' the' ' radio' ',' ' and' ' the' ' first'
+ ' person' ' to' ' receive' ' a' ' Nobel' ' Prize' ' for' ' Physics' '.'
+ '\n']" ", the inventor of the radio , and the first person to receive a Nobel Prize for Physics .
+" False December 1901, Guglielmo Marconi established wireless 9 [' December', ' 1901', ',', ' G', 'ug', 'li', 'el', 'mo', ' Mar', 'coni']
+3127 720 Name of mother of x -1 Name of mother of Richard Dawkins Jean Mary Vyvyan Ladner Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False 1 ['Richard', ' Dawkins']
+3128 720 Name of mother of x -1 Name of mother of Richard Dawkins Jean Mary Vyvyan Ladner Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False 1 ['Richard', ' Dawkins']
+3129 720 Name of mother of x -1 Name of mother of Richard Dawkins Jean Mary Vyvyan Ladner Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False Full House, Richard Dawkins approved of Gould's 4 [' Full', ' House', ',', ' Richard', ' Dawkins']
+3130 720 Name of mother of x -1 Name of mother of Richard Dawkins Jean Mary Vyvyan Ladner Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False on a show that Richard Dawkins campaigned against, 5 [' on', ' a', ' show', ' that', ' Richard', ' Dawkins']
+3131 720 Name of mother of x -1 Name of mother of Richard Dawkins Jean Mary Vyvyan Ladner Richard Dawkins "[',' ' the' ' author' ' of' ' the' ' book' ' The' ' Self' 'ish' ' Gene'
+ ',' ' and' ' the' ' author' ' of' ' The' ' Self' 'ish' ' Gene' ',']" , the author of the book The Self ish Gene , and the author of The Self ish Gene , False Sam Harris and Richard Dawkins have stated that 4 [' Sam', ' Harris', ' and', ' Richard', ' Dawkins']
+3132 723 Name of mother of x -1 Name of mother of Sergei Rachmaninoff Lyubov Petrovna Butakova Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' '.' '\n']" ", the composer , and his wife , the pian ist , Elena R ach man in off .
+" False The table named for Sergei Rachmaninoff shakes when 9 [' The', ' table', ' named', ' for', ' Sergei', ' R', 'ach', 'man', 'in', 'off']
+3133 723 Name of mother of x -1 Name of mother of Sergei Rachmaninoff Lyubov Petrovna Butakova Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' '.' '\n']" ", the composer , and his wife , the pian ist , Elena R ach man in off .
+" False table named for Sergei Rachmaninoff shakes when 8 [' table', ' named', ' for', ' Sergei', ' R', 'ach', 'man', 'in', 'off']
+3134 723 Name of mother of x -1 Name of mother of Sergei Rachmaninoff Lyubov Petrovna Butakova Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' '.' '\n']" ", the composer , and his wife , the pian ist , Elena R ach man in off .
+" False The table named for Sergei Rachmaninoff shakes when a 9 [' The', ' table', ' named', ' for', ' Sergei', ' R', 'ach', 'man', 'in', 'off']
+3135 723 Name of mother of x -1 Name of mother of Sergei Rachmaninoff Lyubov Petrovna Butakova Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' '.' '\n']" ", the composer , and his wife , the pian ist , Elena R ach man in off .
+" False Leeds'mansion, and the pianist Sergei Rachmaninoff arranged for her 13 "[' Leeds', ""'m"", 'ansion', ',', ' and', ' the', ' pian', 'ist', ' Sergei', ' R', 'ach', 'man', 'in', 'off']"
+3136 723 Name of mother of x -1 Name of mother of Sergei Rachmaninoff Lyubov Petrovna Butakova Sergei Rachmaninoff "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' Elena' ' R' 'ach' 'man' 'in' 'off' '.' '\n']" ", the composer , and his wife , the pian ist , Elena R ach man in off .
+" False 7 ['Ser', 'ge', 'i', ' R', 'ach', 'man', 'in', 'off']
+3137 724 Name of mother of x -1 Name of mother of Louis XIII of France Marie de' Medici Louis XIII of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' mother'
+ ' of' ' Louis' ' XIV' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the mother of Louis XIV of France , and the" False Margaret (1553 – 1615), Louis XIII of France (1601 – 43), Anne of 11 [' Margaret', ' (', '15', '53', ' –', ' 16', '15', '),', ' Louis', ' XIII', ' of', ' France']
+3138 724 Name of mother of x -1 Name of mother of Louis XIII of France Marie de' Medici Louis XIII of France "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' mother'
+ ' of' ' Louis' ' XIV' ' of' ' France' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the mother of Louis XIV of France , and the" False Margaret (1553 – 1615), Louis XIII of France (1601 – 43), 11 [' Margaret', ' (', '15', '53', ' –', ' 16', '15', '),', ' Louis', ' XIII', ' of', ' France']
+3139 725 Name of mother of x -1 Name of mother of Giacomo Meyerbeer Amalie Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False music was altered. Giacomo Meyerbeer suggested to the 8 [' music', ' was', ' altered', '.', ' Gi', 'ac', 'omo', ' Meyer', 'beer']
+3140 725 Name of mother of x -1 Name of mother of Giacomo Meyerbeer Amalie Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False " Felix Mendelssohn, Giacomo Meyerbeer and Henry Litolff.
+" 10 [' Felix', ' Mend', 'els', 'so', 'hn', ',', ' Gi', 'ac', 'omo', ' Meyer', 'beer']
+3141 725 Name of mother of x -1 Name of mother of Giacomo Meyerbeer Amalie Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False 1848, the composer Giacomo Meyerbeer was so impressed 9 [' 18', '48', ',', ' the', ' composer', ' Gi', 'ac', 'omo', ' Meyer', 'beer']
+3142 725 Name of mother of x -1 Name of mother of Giacomo Meyerbeer Amalie Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False " Felix Mendelssohn, Giacomo Meyerbeer and Henry Litolff.
+" 10 [' Felix', ' Mend', 'els', 'so', 'hn', ',', ' Gi', 'ac', 'omo', ' Meyer', 'beer']
+3143 725 Name of mother of x -1 Name of mother of Giacomo Meyerbeer Amalie Beer Giacomo Meyerbeer "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' '\n' '\n'
+ 'Name' ' of' ' mother' ' of' ' Gi' 'ac' 'omo' ' Meyer' 'beer']" ", the composer , and his wife , the
+
+ Name of mother of Gi ac omo Meyer beer" False 4 ['G', 'iac', 'omo', ' Meyer', 'beer']
+3144 727 Name of mother of x -1 Name of mother of Ron Howard Jean Speegle Howard Ron Howard "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False director, Ron Howard explained the 3 [' director', ',', ' Ron', ' Howard']
+3145 727 Name of mother of x -1 Name of mother of Ron Howard Jean Speegle Howard Ron Howard "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False Imagine co-founder Ron Howard was attached 5 [' Imagine', ' co', '-', 'founder', ' Ron', ' Howard']
+3146 727 Name of mother of x -1 Name of mother of Ron Howard Jean Speegle Howard Ron Howard "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False friends. Homer scares Ron Howard when he visits 5 [' friends', '.', ' Homer', ' scares', ' Ron', ' Howard']
+3147 727 Name of mother of x -1 Name of mother of Ron Howard Jean Speegle Howard Ron Howard "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False the summer of 2002. Ron Howard had the original 6 [' the', ' summer', ' of', ' 2002', '.', ' Ron', ' Howard']
+3148 727 Name of mother of x -1 Name of mother of Ron Howard Jean Speegle Howard Ron Howard "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False " performance, director Ron Howard said, ""She not" 4 [' performance', ',', ' director', ' Ron', ' Howard']
+3149 728 Name of mother of x -1 Name of mother of Michelle Obama Marian Shields Robinson Michelle Obama "[',' ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the First Lady of the United States , and the First Lady of the United States of America . False " and First Lady Michelle Obama were ""heartbroken""" 4 [' and', ' First', ' Lady', ' Michelle', ' Obama']
+3150 728 Name of mother of x -1 Name of mother of Michelle Obama Marian Shields Robinson Michelle Obama "[',' ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the First Lady of the United States , and the First Lady of the United States of America . False 2011, she and Michelle Obama founded a national 5 [' 2011', ',', ' she', ' and', ' Michelle', ' Obama']
+3151 728 Name of mother of x -1 Name of mother of Michelle Obama Marian Shields Robinson Michelle Obama "[',' ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the First Lady of the United States , and the First Lady of the United States of America . False between First Lady Michelle Obama and Secretary 4 [' between', ' First', ' Lady', ' Michelle', ' Obama']
+3152 728 Name of mother of x -1 Name of mother of Michelle Obama Marian Shields Robinson Michelle Obama "[',' ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the First Lady of the United States , and the First Lady of the United States of America . False First Lady Michelle Obama and Second 3 [' First', ' Lady', ' Michelle', ' Obama']
+3153 728 Name of mother of x -1 Name of mother of Michelle Obama Marian Shields Robinson Michelle Obama "[',' ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ',' ' and'
+ ' the' ' First' ' Lady' ' of' ' the' ' United' ' States' ' of' ' America'
+ '.']" , the First Lady of the United States , and the First Lady of the United States of America . False " that he and First Lady Michelle Obama were ""heartbroken""" 6 [' that', ' he', ' and', ' First', ' Lady', ' Michelle', ' Obama']
+3154 729 Name of mother of x -1 Name of mother of Chiang Kai-shek Wang Caiyu Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ' who' ' had' ' been' ' in' ' exile'
+ ' in' ' Taiwan' ' since' ' 1949' '.' '\n' '\n' 'The' ' Chinese'
+ ' government' ' has']" ", the Chinese leader who had been in exile in Taiwan since 1949 .
+
+ The Chinese government has" False by a rightist, Chiang Kai-shek, who initiated moves 10 [' by', ' a', ' right', 'ist', ',', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3155 729 Name of mother of x -1 Name of mother of Chiang Kai-shek Wang Caiyu Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ' who' ' had' ' been' ' in' ' exile'
+ ' in' ' Taiwan' ' since' ' 1949' '.' '\n' '\n' 'The' ' Chinese'
+ ' government' ' has']" ", the Chinese leader who had been in exile in Taiwan since 1949 .
+
+ The Chinese government has" False departure from Chiang Kai-shek International Airport, 7 [' departure', ' from', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3156 729 Name of mother of x -1 Name of mother of Chiang Kai-shek Wang Caiyu Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ' who' ' had' ' been' ' in' ' exile'
+ ' in' ' Taiwan' ' since' ' 1949' '.' '\n' '\n' 'The' ' Chinese'
+ ' government' ' has']" ", the Chinese leader who had been in exile in Taiwan since 1949 .
+
+ The Chinese government has" False retaliation, Chinese General Chiang Kai-shek cancelled all 9 [' retaliation', ',', ' Chinese', ' General', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3157 729 Name of mother of x -1 Name of mother of Chiang Kai-shek Wang Caiyu Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ' who' ' had' ' been' ' in' ' exile'
+ ' in' ' Taiwan' ' since' ' 1949' '.' '\n' '\n' 'The' ' Chinese'
+ ' government' ' has']" ", the Chinese leader who had been in exile in Taiwan since 1949 .
+
+ The Chinese government has" False Germany. Generalissimo Chiang Kai-shek deployed his best army 10 [' Germany', '.', ' General', 'iss', 'imo', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3158 729 Name of mother of x -1 Name of mother of Chiang Kai-shek Wang Caiyu Chiang Kai-shek "[',' ' the' ' Chinese' ' leader' ' who' ' had' ' been' ' in' ' exile'
+ ' in' ' Taiwan' ' since' ' 1949' '.' '\n' '\n' 'The' ' Chinese'
+ ' government' ' has']" ", the Chinese leader who had been in exile in Taiwan since 1949 .
+
+ The Chinese government has" False Generalissimo Chiang Kai-shek led the Kuomintang 8 [' General', 'iss', 'imo', ' Ch', 'iang', ' Kai', '-', 'she', 'k']
+3159 730 Name of mother of x -1 Name of mother of Christiaan Huygens Suzanna van Baerle Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False sides. It was not until Christiaan Huygens used greater telescopic 12 [' sides', '.', ' It', ' was', ' not', ' until', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3160 730 Name of mother of x -1 Name of mother of Christiaan Huygens Suzanna van Baerle Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False Earth's orbit. Christiaan Huygens combined this estimate 10 "[' Earth', ""'s"", ' orbit', '.', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']"
+3161 730 Name of mother of x -1 Name of mother of Christiaan Huygens Suzanna van Baerle Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False brothers Ludwig and Christiaan Huygens in 1667, where they 9 [' brothers', ' Ludwig', ' and', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3162 730 Name of mother of x -1 Name of mother of Christiaan Huygens Suzanna van Baerle Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False It was not until Christiaan Huygens used greater 10 [' It', ' was', ' not', ' until', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3163 730 Name of mother of x -1 Name of mother of Christiaan Huygens Suzanna van Baerle Christiaan Huygens "[',' ' the' ' Dutch' ' astronomer' ',' ' who' ' discovered' ' the'
+ ' rings' ' of' ' Saturn' '.' '\n' '\n' 'The' ' name' ' of' ' the'
+ ' planet' ' Uran']" ", the Dutch astronomer , who discovered the rings of Saturn .
+
+ The name of the planet Uran" False was not until Christiaan Huygens used greater telescopic 9 [' was', ' not', ' until', ' Christ', 'ia', 'an', ' H', 'uy', 'g', 'ens']
+3164 731 Name of mother of x -1 Name of mother of Pete Seeger Constance de Clyver Edson Pete Seeger "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' T' 'oshi' ' See'
+ 'ger' ',' ' the' ' folk' ' singer' ',' ' and' ' their' ' daughter']" , the singer , and his wife , T oshi See ger , the folk singer , and their daughter False music, led by Pete Seeger and others, 6 [' music', ',', ' led', ' by', ' Pete', ' See', 'ger']
+3165 731 Name of mother of x -1 Name of mother of Pete Seeger Constance de Clyver Edson Pete Seeger "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' T' 'oshi' ' See'
+ 'ger' ',' ' the' ' folk' ' singer' ',' ' and' ' their' ' daughter']" , the singer , and his wife , T oshi See ger , the folk singer , and their daughter False by Woody Guthrie, Pete Seeger and Bo Diddley, 7 [' by', ' Woody', ' Guth', 'rie', ',', ' Pete', ' See', 'ger']
+3166 731 Name of mother of x -1 Name of mother of Pete Seeger Constance de Clyver Edson Pete Seeger "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' T' 'oshi' ' See'
+ 'ger' ',' ' the' ' folk' ' singer' ',' ' and' ' their' ' daughter']" , the singer , and his wife , T oshi See ger , the folk singer , and their daughter False WBAI in June 1962, Pete Seeger described Dylan 9 [' W', 'BA', 'I', ' in', ' June', ' 1962', ',', ' Pete', ' See', 'ger']
+3167 731 Name of mother of x -1 Name of mother of Pete Seeger Constance de Clyver Edson Pete Seeger "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' T' 'oshi' ' See'
+ 'ger' ',' ' the' ' folk' ' singer' ',' ' and' ' their' ' daughter']" , the singer , and his wife , T oshi See ger , the folk singer , and their daughter False " Is a Season)"", a Pete Seeger composition with lyrics" 7 "[' Is', ' a', ' Season', ')"",', ' a', ' Pete', ' See', 'ger']"
+3168 731 Name of mother of x -1 Name of mother of Pete Seeger Constance de Clyver Edson Pete Seeger "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' T' 'oshi' ' See'
+ 'ger' ',' ' the' ' folk' ' singer' ',' ' and' ' their' ' daughter']" , the singer , and his wife , T oshi See ger , the folk singer , and their daughter False " Season)"", a Pete Seeger composition with lyrics" 5 "[' Season', ')"",', ' a', ' Pete', ' See', 'ger']"
+3169 732 Name of mother of x -1 Name of mother of Tom Cruise Mary Lee Pfeiffer Tom Cruise "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the' ' father'
+ ' of' ' the' ' bride' ',' ' is' ' a' ' former' ' Scient' 'ologist' '.']" , the actor who plays the role of the father of the bride , is a former Scient ologist . False interview with Tom Cruise discussing Scientology. 3 [' interview', ' with', ' Tom', ' Cruise']
+3170 732 Name of mother of x -1 Name of mother of Tom Cruise Mary Lee Pfeiffer Tom Cruise "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the' ' father'
+ ' of' ' the' ' bride' ',' ' is' ' a' ' former' ' Scient' 'ologist' '.']" , the actor who plays the role of the father of the bride , is a former Scient ologist . False " get a video of Tom Cruise off the Internet.""." 5 [' get', ' a', ' video', ' of', ' Tom', ' Cruise']
+3171 732 Name of mother of x -1 Name of mother of Tom Cruise Mary Lee Pfeiffer Tom Cruise "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the' ' father'
+ ' of' ' the' ' bride' ',' ' is' ' a' ' former' ' Scient' 'ologist' '.']" , the actor who plays the role of the father of the bride , is a former Scient ologist . False and Away, starring Tom Cruise and Nicole Kidman. 5 [' and', ' Away', ',', ' starring', ' Tom', ' Cruise']
+3172 732 Name of mother of x -1 Name of mother of Tom Cruise Mary Lee Pfeiffer Tom Cruise "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the' ' father'
+ ' of' ' the' ' bride' ',' ' is' ' a' ' former' ' Scient' 'ologist' '.']" , the actor who plays the role of the father of the bride , is a former Scient ologist . False drama film starring Tom Cruise and Dustin Hoffman. 4 [' drama', ' film', ' starring', ' Tom', ' Cruise']
+3173 732 Name of mother of x -1 Name of mother of Tom Cruise Mary Lee Pfeiffer Tom Cruise "[',' ' the' ' actor' ' who' ' plays' ' the' ' role' ' of' ' the' ' father'
+ ' of' ' the' ' bride' ',' ' is' ' a' ' former' ' Scient' 'ologist' '.']" , the actor who plays the role of the father of the bride , is a former Scient ologist . False Presents: Being Tom Cruise, Why Scientology 4 [' Presents', ':', ' Being', ' Tom', ' Cruise']
+3174 733 Name of mother of x -1 Name of mother of William James Mary Walsh James William James "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False " Troude accuses William James of further ""augmenting""" 4 [' Trou', 'de', ' accuses', ' William', ' James']
+3175 733 Name of mother of x -1 Name of mother of William James Mary Walsh James William James "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False Contemporary historian William James described Perroud's 3 [' Contemporary', ' historian', ' William', ' James']
+3176 733 Name of mother of x -1 Name of mother of William James Mary Walsh James William James "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False reading the work of William James and a few philosophers. 5 [' reading', ' the', ' work', ' of', ' William', ' James']
+3177 733 Name of mother of x -1 Name of mother of William James Mary Walsh James William James "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False Various Spirits, 1847. William James echoed Kierkegaard 7 [' Various', ' Spirits', ',', ' 18', '47', '.', ' William', ' James']
+3178 733 Name of mother of x -1 Name of mother of William James Mary Walsh James William James "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False British naval historian William James considers excessive, 4 [' British', ' naval', ' historian', ' William', ' James']
+3179 734 Name of mother of x -1 Name of mother of John Henry Newman Jemima Fourdrinier John Henry Newman "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False the 19th-century John Henry Newman described John as 7 [' the', ' 19', 'th', '-', 'century', ' John', ' Henry', ' Newman']
+3180 734 Name of mother of x -1 Name of mother of John Henry Newman Jemima Fourdrinier John Henry Newman "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False Movement, headed by John Henry Newman as well as Froude's 6 [' Movement', ',', ' headed', ' by', ' John', ' Henry', ' Newman']
+3181 734 Name of mother of x -1 Name of mother of John Henry Newman Jemima Fourdrinier John Henry Newman "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False 19th-century John Henry Newman described John 6 [' 19', 'th', '-', 'century', ' John', ' Henry', ' Newman']
+3182 734 Name of mother of x -1 Name of mother of John Henry Newman Jemima Fourdrinier John Henry Newman "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False 2008. The Blessed John Henry Newman RC College opened 6 [' 2008', '.', ' The', ' Blessed', ' John', ' Henry', ' Newman']
+3183 734 Name of mother of x -1 Name of mother of John Henry Newman Jemima Fourdrinier John Henry Newman "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False Saints, written by John Henry Newman in 1843, is amongst 6 [' Saints', ',', ' written', ' by', ' John', ' Henry', ' Newman']
+3184 735 Name of mother of x -1 Name of mother of Ferdinand II of Aragon Juana Enríquez Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' the' '\n' '\n' 'Queen' ' of' ' England' ',' ' the']" ", the King of Spain , and the Queen of England , the
+
+ Queen of England , the" False Louis XII of France, Ferdinand II of Aragon and Maximilian I, 9 [' Louis', ' XII', ' of', ' France', ',', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3185 735 Name of mother of x -1 Name of mother of Ferdinand II of Aragon Juana Enríquez Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' the' '\n' '\n' 'Queen' ' of' ' England' ',' ' the']" ", the King of Spain , and the Queen of England , the
+
+ Queen of England , the" False Monarchs of Spain, Ferdinand II of Aragon and Isabella I of Castile, 9 [' Mon', 'archs', ' of', ' Spain', ',', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3186 735 Name of mother of x -1 Name of mother of Ferdinand II of Aragon Juana Enríquez Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' the' '\n' '\n' 'Queen' ' of' ' England' ',' ' the']" ", the King of Spain , and the Queen of England , the
+
+ Queen of England , the" False Louis XII of France, Ferdinand II of Aragon and Maximilian I, the 9 [' Louis', ' XII', ' of', ' France', ',', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3187 735 Name of mother of x -1 Name of mother of Ferdinand II of Aragon Juana Enríquez Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' the' '\n' '\n' 'Queen' ' of' ' England' ',' ' the']" ", the King of Spain , and the Queen of England , the
+
+ Queen of England , the" False surviving child of King Ferdinand II of Aragon and Queen Isabella 8 [' surviving', ' child', ' of', ' King', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3188 735 Name of mother of x -1 Name of mother of Ferdinand II of Aragon Juana Enríquez Ferdinand II of Aragon "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' England' ',' ' the' '\n' '\n' 'Queen' ' of' ' England' ',' ' the']" ", the King of Spain , and the Queen of England , the
+
+ Queen of England , the" False child of King Ferdinand II of Aragon and Queen Isabella 7 [' child', ' of', ' King', ' Ferdinand', ' II', ' of', ' Ar', 'agon']
+3189 736 Name of mother of x -1 Name of mother of Ellen DeGeneres Betty DeGeneres Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' ',' ' and' ' her' ' husband' ',' ' who'
+ ' is' ' a' ' gay' ' man' '.' '\n' '\n' 'The' ' couple']" ", who is a lesbian , and her husband , who is a gay man .
+
+ The couple" False appeared on The Ellen DeGeneres Show on February 6 [' appeared', ' on', ' The', ' Ellen', ' De', 'Gene', 'res']
+3190 736 Name of mother of x -1 Name of mother of Ellen DeGeneres Betty DeGeneres Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' ',' ' and' ' her' ' husband' ',' ' who'
+ ' is' ' a' ' gay' ' man' '.' '\n' '\n' 'The' ' couple']" ", who is a lesbian , and her husband , who is a gay man .
+
+ The couple" False her week on The Ellen DeGeneres Show on March 18, 7 [' her', ' week', ' on', ' The', ' Ellen', ' De', 'Gene', 'res']
+3191 736 Name of mother of x -1 Name of mother of Ellen DeGeneres Betty DeGeneres Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' ',' ' and' ' her' ' husband' ',' ' who'
+ ' is' ' a' ' gay' ' man' '.' '\n' '\n' 'The' ' couple']" ", who is a lesbian , and her husband , who is a gay man .
+
+ The couple" False 2008, fellow gay icon Ellen DeGeneres presented 8 [' 2008', ',', ' fellow', ' gay', ' icon', ' Ellen', ' De', 'Gene', 'res']
+3192 736 Name of mother of x -1 Name of mother of Ellen DeGeneres Betty DeGeneres Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' ',' ' and' ' her' ' husband' ',' ' who'
+ ' is' ' a' ' gay' ' man' '.' '\n' '\n' 'The' ' couple']" ", who is a lesbian , and her husband , who is a gay man .
+
+ The couple" False Academy Awards host Ellen DeGeneres during the 6 [' Academy', ' Awards', ' host', ' Ellen', ' De', 'Gene', 'res']
+3193 736 Name of mother of x -1 Name of mother of Ellen DeGeneres Betty DeGeneres Ellen DeGeneres "[',' ' who' ' is' ' a' ' lesbian' ',' ' and' ' her' ' husband' ',' ' who'
+ ' is' ' a' ' gay' ' man' '.' '\n' '\n' 'The' ' couple']" ", who is a lesbian , and her husband , who is a gay man .
+
+ The couple" False track on The Ellen DeGeneres Show and her 6 [' track', ' on', ' The', ' Ellen', ' De', 'Gene', 'res']
+3194 737 Name of mother of x -1 Name of mother of Gregory of Nazianzus Nonna of Nazianzus Gregory of Nazianzus "[',' ' the' ' son' ' of' ' a' ' Roman' ' senator' ',' ' and' ' the'
+ ' father' ' of' ' the' ' emperor' ' Constantine' ' the' ' Great' '.' '\n'
+ '\n']" ", the son of a Roman senator , and the father of the emperor Constantine the Great .
+
+" False with Macrina, Gregory of Nazianzus and Basil the 9 [' with', ' Mac', 'rina', ',', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3195 737 Name of mother of x -1 Name of mother of Gregory of Nazianzus Nonna of Nazianzus Gregory of Nazianzus "[',' ' the' ' son' ' of' ' a' ' Roman' ' senator' ',' ' and' ' the'
+ ' father' ' of' ' the' ' emperor' ' Constantine' ' the' ' Great' '.' '\n'
+ '\n']" ", the son of a Roman senator , and the father of the emperor Constantine the Great .
+
+" False Church Father, Gregory of Nazianzus and by Jerome 8 [' Church', ' Father', ',', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3196 737 Name of mother of x -1 Name of mother of Gregory of Nazianzus Nonna of Nazianzus Gregory of Nazianzus "[',' ' the' ' son' ' of' ' a' ' Roman' ' senator' ',' ' and' ' the'
+ ' father' ' of' ' the' ' emperor' ' Constantine' ' the' ' Great' '.' '\n'
+ '\n']" ", the son of a Roman senator , and the father of the emperor Constantine the Great .
+
+" False Church Father, Gregory of Nazianzus and by Jerome 8 [' Church', ' Father', ',', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3197 737 Name of mother of x -1 Name of mother of Gregory of Nazianzus Nonna of Nazianzus Gregory of Nazianzus "[',' ' the' ' son' ' of' ' a' ' Roman' ' senator' ',' ' and' ' the'
+ ' father' ' of' ' the' ' emperor' ' Constantine' ' the' ' Great' '.' '\n'
+ '\n']" ", the son of a Roman senator , and the father of the emperor Constantine the Great .
+
+" False Christological writings of Gregory of Nazianzus and Dionysius 9 [' Christ', 'ological', ' writings', ' of', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3198 737 Name of mother of x -1 Name of mother of Gregory of Nazianzus Nonna of Nazianzus Gregory of Nazianzus "[',' ' the' ' son' ' of' ' a' ' Roman' ' senator' ',' ' and' ' the'
+ ' father' ' of' ' the' ' emperor' ' Constantine' ' the' ' Great' '.' '\n'
+ '\n']" ", the son of a Roman senator , and the father of the emperor Constantine the Great .
+
+" False and their friend Gregory of Nazianzus are collectively 8 [' and', ' their', ' friend', ' Gregory', ' of', ' Naz', 'ian', 'z', 'us']
+3199 738 Name of mother of x -1 Name of mother of Stevie Wonder Lula Mae Hardaway Stevie Wonder "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", and the other is the name of the mother of the child .
+
+ The name of the" False Mayfield, Barry White, Stevie Wonder and Prince. According 8 [' May', 'field', ',', ' Barry', ' White', ',', ' Stev', 'ie', ' Wonder']
+3200 738 Name of mother of x -1 Name of mother of Stevie Wonder Lula Mae Hardaway Stevie Wonder "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", and the other is the name of the mother of the child .
+
+ The name of the" False " (Everything's Alright)"" by Stevie Wonder that it was removed" 8 "[' (', 'Everything', ""'s"", ' Alright', ')""', ' by', ' Stev', 'ie', ' Wonder']"
+3201 738 Name of mother of x -1 Name of mother of Stevie Wonder Lula Mae Hardaway Stevie Wonder "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", and the other is the name of the mother of the child .
+
+ The name of the" False Babyface and Stevie Wonder duetted on 5 [' Baby', 'face', ' and', ' Stev', 'ie', ' Wonder']
+3202 738 Name of mother of x -1 Name of mother of Stevie Wonder Lula Mae Hardaway Stevie Wonder "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", and the other is the name of the mother of the child .
+
+ The name of the" False as Paul McCartney, Stevie Wonder and Rod Temperton. 6 [' as', ' Paul', ' McCartney', ',', ' Stev', 'ie', ' Wonder']
+3203 738 Name of mother of x -1 Name of mother of Stevie Wonder Lula Mae Hardaway Stevie Wonder "[',' ' and' ' the' ' other' ' is' ' the' ' name' ' of' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' name' ' of' ' the']" ", and the other is the name of the mother of the child .
+
+ The name of the" False with albums like Stevie Wonder ’ s Music 5 [' with', ' albums', ' like', ' Stev', 'ie', ' Wonder']
+3204 739 Name of mother of x -1 Name of mother of Ben Nicholson Mabel Pryde Ben Nicholson "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' bride' 'g'
+ 'room' ',' ' and' ' the' ' bride' ""'s"" ' mother' ',' ' and']" , the father of the bride , and the bride g room , and the bride 's mother , and False her second husband Ben Nicholson moved into a studio 4 [' her', ' second', ' husband', ' Ben', ' Nicholson']
+3205 739 Name of mother of x -1 Name of mother of Ben Nicholson Mabel Pryde Ben Nicholson "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' bride' 'g'
+ 'room' ',' ' and' ' the' ' bride' ""'s"" ' mother' ',' ' and']" , the father of the bride , and the bride g room , and the bride 's mother , and False second husband Ben Nicholson moved into a studio 3 [' second', ' husband', ' Ben', ' Nicholson']
+3206 739 Name of mother of x -1 Name of mother of Ben Nicholson Mabel Pryde Ben Nicholson "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' bride' 'g'
+ 'room' ',' ' and' ' the' ' bride' ""'s"" ' mother' ',' ' and']" , the father of the bride , and the bride g room , and the bride 's mother , and False second husband Ben Nicholson moved into 3 [' second', ' husband', ' Ben', ' Nicholson']
+3207 740 Name of mother of x -1 Name of mother of Nicholas Roerich Maria Vasiljevna Rerich Nicholas Roerich "[',' ' the' ' Russian' ' painter' ',' ' and' ' his' ' wife' ',' ' the'
+ '\n' '\n' 'Russian' ' painter' ',' ' and' ' his' ' wife' ',' ' the']" ", the Russian painter , and his wife , the
+
+ Russian painter , and his wife , the" False Massine, with the Nicholas Roerich designs retained; 8 [' Mass', 'ine', ',', ' with', ' the', ' Nicholas', ' Ro', 'er', 'ich']
+3208 740 Name of mother of x -1 Name of mother of Nicholas Roerich Maria Vasiljevna Rerich Nicholas Roerich "[',' ' the' ' Russian' ' painter' ',' ' and' ' his' ' wife' ',' ' the'
+ '\n' '\n' 'Russian' ' painter' ',' ' and' ' his' ' wife' ',' ' the']" ", the Russian painter , and his wife , the
+
+ Russian painter , and his wife , the" False 4 ['Nich', 'olas', ' Ro', 'er', 'ich']
+3209 740 Name of mother of x -1 Name of mother of Nicholas Roerich Maria Vasiljevna Rerich Nicholas Roerich "[',' ' the' ' Russian' ' painter' ',' ' and' ' his' ' wife' ',' ' the'
+ '\n' '\n' 'Russian' ' painter' ',' ' and' ' his' ' wife' ',' ' the']" ", the Russian painter , and his wife , the
+
+ Russian painter , and his wife , the" False Massine, with the Nicholas Roerich designs retained; 8 [' Mass', 'ine', ',', ' with', ' the', ' Nicholas', ' Ro', 'er', 'ich']
+3210 741 Name of mother of x -1 Name of mother of Alan Turing Ethel Sara Stoney Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' computer' '.' '\n' '\n' 'The'
+ ' Turing' ' Award']" ", the father of modern computing , and the father of the modern computer .
+
+ The Turing Award" False an EP titled For Alan Turing in 2006, which 5 [' an', ' EP', ' titled', ' For', ' Alan', ' Turing']
+3211 741 Name of mother of x -1 Name of mother of Alan Turing Ethel Sara Stoney Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' computer' '.' '\n' '\n' 'The'
+ ' Turing' ' Award']" ", the father of modern computing , and the father of the modern computer .
+
+ The Turing Award" False " road) was named ""Alan Turing Way"". A bridge carrying" 6 "[' road', ')', ' was', ' named', ' ""', 'Alan', ' Turing']"
+3212 741 Name of mother of x -1 Name of mother of Alan Turing Ethel Sara Stoney Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' computer' '.' '\n' '\n' 'The'
+ ' Turing' ' Award']" ", the father of modern computing , and the father of the modern computer .
+
+ The Turing Award" False demand justice for Alan Turing and recognition 4 [' demand', ' justice', ' for', ' Alan', ' Turing']
+3213 741 Name of mother of x -1 Name of mother of Alan Turing Ethel Sara Stoney Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' computer' '.' '\n' '\n' 'The'
+ ' Turing' ' Award']" ", the father of modern computing , and the father of the modern computer .
+
+ The Turing Award" False (TCAC) co-ordinated the Alan Turing Year, a year-long 10 [' (', 'TC', 'AC', ')', ' co', '-', 'ord', 'inated', ' the', ' Alan', ' Turing']
+3214 741 Name of mother of x -1 Name of mother of Alan Turing Ethel Sara Stoney Alan Turing "[',' ' the' ' father' ' of' ' modern' ' computing' ',' ' and' ' the'
+ ' father' ' of' ' the' ' modern' ' computer' '.' '\n' '\n' 'The'
+ ' Turing' ' Award']" ", the father of modern computing , and the father of the modern computer .
+
+ The Turing Award" False artificial intelligence with Alan Turing when the latter 4 [' artificial', ' intelligence', ' with', ' Alan', ' Turing']
+3215 742 Name of mother of x -1 Name of mother of Henry Fonda Elma Herberta Jaynes Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ' Jane' ' F'
+ 'onda' ',' ' who' ' was' ' also' ' a' ' member' ' of' ' the']" , the actor , and his wife , actress Jane F onda , who was also a member of the False Los Angeles Henry Fonda Theater as a 4 [' Los', ' Angeles', ' Henry', ' F', 'onda']
+3216 742 Name of mother of x -1 Name of mother of Henry Fonda Elma Herberta Jaynes Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ' Jane' ' F'
+ 'onda' ',' ' who' ' was' ' also' ' a' ' member' ' of' ' the']" , the actor , and his wife , actress Jane F onda , who was also a member of the False " villainous role"" akin to Henry Fonda in Once Upon a Time" 8 "[' villain', 'ous', ' role', '""', ' akin', ' to', ' Henry', ' F', 'onda']"
+3217 742 Name of mother of x -1 Name of mother of Henry Fonda Elma Herberta Jaynes Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ' Jane' ' F'
+ 'onda' ',' ' who' ' was' ' also' ' a' ' member' ' of' ' the']" , the actor , and his wife , actress Jane F onda , who was also a member of the False Feuer and Martin wanted Henry Fonda to play Doc. The actor 7 [' Fe', 'uer', ' and', ' Martin', ' wanted', ' Henry', ' F', 'onda']
+3218 742 Name of mother of x -1 Name of mother of Henry Fonda Elma Herberta Jaynes Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ' Jane' ' F'
+ 'onda' ',' ' who' ' was' ' also' ' a' ' member' ' of' ' the']" , the actor , and his wife , actress Jane F onda , who was also a member of the False (1946) – Stars Henry Fonda and directed 8 [' (', '19', '46', ')', ' –', ' Stars', ' Henry', ' F', 'onda']
+3219 742 Name of mother of x -1 Name of mother of Henry Fonda Elma Herberta Jaynes Henry Fonda "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' actress' ' Jane' ' F'
+ 'onda' ',' ' who' ' was' ' also' ' a' ' member' ' of' ' the']" , the actor , and his wife , actress Jane F onda , who was also a member of the False daughter of actor Henry Fonda and the Canadian-born 5 [' daughter', ' of', ' actor', ' Henry', ' F', 'onda']
+3220 743 Name of mother of x -1 Name of mother of Sun Yat-sen Madame Yang Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Ku' 'om' 'int' 'ang' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.']" , the founder of the Ku om int ang , and the first president of the Republic of China . False revolutionary figure Sun Yat-sen visited southeast 6 [' revolutionary', ' figure', ' Sun', ' Y', 'at', '-', 'sen']
+3221 743 Name of mother of x -1 Name of mother of Sun Yat-sen Madame Yang Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Ku' 'om' 'int' 'ang' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.']" , the founder of the Ku om int ang , and the first president of the Republic of China . False " avoid a civil war, Sun Yat-sen — proclaimed ""provisional" 9 [' avoid', ' a', ' civil', ' war', ',', ' Sun', ' Y', 'at', '-', 'sen']
+3222 743 Name of mother of x -1 Name of mother of Sun Yat-sen Madame Yang Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Ku' 'om' 'int' 'ang' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.']" , the founder of the Ku om int ang , and the first president of the Republic of China . False in 1977 and named it Sun Yat-sen Playlot Park. However, 9 [' in', ' 1977', ' and', ' named', ' it', ' Sun', ' Y', 'at', '-', 'sen']
+3223 743 Name of mother of x -1 Name of mother of Sun Yat-sen Madame Yang Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Ku' 'om' 'int' 'ang' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.']" , the founder of the Ku om int ang , and the first president of the Republic of China . False Stanford Park). Sun Yat-sen Playlot Park, a small, 7 [' Stanford', ' Park', ').', ' Sun', ' Y', 'at', '-', 'sen']
+3224 743 Name of mother of x -1 Name of mother of Sun Yat-sen Madame Yang Sun Yat-sen "[',' ' the' ' founder' ' of' ' the' ' Ku' 'om' 'int' 'ang' ',' ' and'
+ ' the' ' first' ' president' ' of' ' the' ' Republic' ' of' ' China' '.']" , the founder of the Ku om int ang , and the first president of the Republic of China . False revolutionary figure Sun Yat-sen visited southeast 6 [' revolutionary', ' figure', ' Sun', ' Y', 'at', '-', 'sen']
+3225 744 Name of mother of x -1 Name of mother of Jack Kerouac Gabrielle Ange Lévesque Jack Kerouac "[',' ' the' ' author' ' of' ' the' ' book' ' that' ' inspired' ' the'
+ ' movie' ' �' '�' 'On' ' the' ' Road' '�' '�' '.' '\n' '\n']" ", the author of the book that inspired the movie � � On the Road � � .
+
+" False On the Road by Jack Kerouac in which the author 7 [' On', ' the', ' Road', ' by', ' Jack', ' Ker', 'ou', 'ac']
+3226 744 Name of mother of x -1 Name of mother of Jack Kerouac Gabrielle Ange Lévesque Jack Kerouac "[',' ' the' ' author' ' of' ' the' ' book' ' that' ' inspired' ' the'
+ ' movie' ' �' '�' 'On' ' the' ' Road' '�' '�' '.' '\n' '\n']" ", the author of the book that inspired the movie � � On the Road � � .
+
+" False Gustave Flaubert, Jack Kerouac and Henry James, as 9 [' Gust', 'ave', ' Fl', 'au', 'bert', ',', ' Jack', ' Ker', 'ou', 'ac']
+3227 744 Name of mother of x -1 Name of mother of Jack Kerouac Gabrielle Ange Lévesque Jack Kerouac "[',' ' the' ' author' ' of' ' the' ' book' ' that' ' inspired' ' the'
+ ' movie' ' �' '�' 'On' ' the' ' Road' '�' '�' '.' '\n' '\n']" ", the author of the book that inspired the movie � � On the Road � � .
+
+" False William Burroughs, Jack Kerouac and Allen Ginsberg 8 [' William', ' Bur', 'rough', 's', ',', ' Jack', ' Ker', 'ou', 'ac']
+3228 744 Name of mother of x -1 Name of mother of Jack Kerouac Gabrielle Ange Lévesque Jack Kerouac "[',' ' the' ' author' ' of' ' the' ' book' ' that' ' inspired' ' the'
+ ' movie' ' �' '�' 'On' ' the' ' Road' '�' '�' '.' '\n' '\n']" ", the author of the book that inspired the movie � � On the Road � � .
+
+" False Allen Ginsberg and Jack Kerouac in the 1950s and 7 [' Allen', ' Gins', 'berg', ' and', ' Jack', ' Ker', 'ou', 'ac']
+3229 744 Name of mother of x -1 Name of mother of Jack Kerouac Gabrielle Ange Lévesque Jack Kerouac "[',' ' the' ' author' ' of' ' the' ' book' ' that' ' inspired' ' the'
+ ' movie' ' �' '�' 'On' ' the' ' Road' '�' '�' '.' '\n' '\n']" ", the author of the book that inspired the movie � � On the Road � � .
+
+" False William Burroughs, Jack Kerouac and Allen Ginsberg 8 [' William', ' Bur', 'rough', 's', ',', ' Jack', ' Ker', 'ou', 'ac']
+3230 745 Name of mother of x -1 Name of mother of Kiefer Sutherland Shirley Douglas Kiefer Sutherland "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False in 2003, and Kiefer Sutherland received nominations 7 [' in', ' 2003', ',', ' and', ' K', 'ief', 'er', ' Sutherland']
+3231 745 Name of mother of x -1 Name of mother of Kiefer Sutherland Shirley Douglas Kiefer Sutherland "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False Guild Awards. Kiefer Sutherland was nominated in 6 [' Guild', ' Awards', '.', ' K', 'ief', 'er', ' Sutherland']
+3232 745 Name of mother of x -1 Name of mother of Kiefer Sutherland Shirley Douglas Kiefer Sutherland "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False close friends. Kiefer Sutherland was considered 6 [' close', ' friends', '.', ' K', 'ief', 'er', ' Sutherland']
+3233 745 Name of mother of x -1 Name of mother of Kiefer Sutherland Shirley Douglas Kiefer Sutherland "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False 3 ['K', 'ief', 'er', ' Sutherland']
+3234 745 Name of mother of x -1 Name of mother of Kiefer Sutherland Shirley Douglas Kiefer Sutherland "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' have' ' been' ' a' ' fan' ' of' ' his' ' for']" ", who is a very good friend of mine .
+
+ I have been a fan of his for" False " Sutherland as Jack Bauer
+" 7 [' Sutherland', ' as', ' Jack', ' Bauer', 'K', 'ief', 'er', ' Sutherland']
+3235 746 Name of mother of x -1 Name of mother of Paul Robeson Maria Louisa Bustill Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False Black artist Paul Robeson back to Portland, 4 [' Black', ' artist', ' Paul', ' Rob', 'eson']
+3236 746 Name of mother of x -1 Name of mother of Paul Robeson Maria Louisa Bustill Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False production, with Paul Robeson in the title role. 5 [' production', ',', ' with', ' Paul', ' Rob', 'eson']
+3237 746 Name of mother of x -1 Name of mother of Paul Robeson Maria Louisa Bustill Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False headlined by Paul Robeson in Peekskill, New 4 [' headlined', ' by', ' Paul', ' Rob', 'eson']
+3238 746 Name of mother of x -1 Name of mother of Paul Robeson Maria Louisa Bustill Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False leftist singer Paul Robeson was denied a passport 4 [' leftist', ' singer', ' Paul', ' Rob', 'eson']
+3239 746 Name of mother of x -1 Name of mother of Paul Robeson Maria Louisa Bustill Paul Robeson "[',' ' the' ' son' ' of' ' a' ' slave' ',' ' and' ' the' ' son' ' of' ' a'
+ ' slave' ',' ' and' ' the' ' son' ' of' ' a' ' slave']" , the son of a slave , and the son of a slave , and the son of a slave False intended for bass Paul Robeson and sung in the original 5 [' intended', ' for', ' bass', ' Paul', ' Rob', 'eson']
+3240 747 Name of mother of x -1 Name of mother of Jennifer Aniston Nancy Dow Jennifer Aniston "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Lisa Kudrow and Jennifer Aniston remain the 7 [' Lisa', ' K', 'ud', 'row', ' and', ' Jennifer', ' An', 'iston']
+3241 747 Name of mother of x -1 Name of mother of Jennifer Aniston Nancy Dow Jennifer Aniston "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False film Wanderlust, with Jennifer Aniston and Paul Rudd. 7 [' film', ' Wander', 'lust', ',', ' with', ' Jennifer', ' An', 'iston']
+3242 747 Name of mother of x -1 Name of mother of Jennifer Aniston Nancy Dow Jennifer Aniston "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False refusal to smoke, Jennifer Aniston at the time of 6 [' refusal', ' to', ' smoke', ',', ' Jennifer', ' An', 'iston']
+3243 747 Name of mother of x -1 Name of mother of Jennifer Aniston Nancy Dow Jennifer Aniston "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Liz's), as well as Jennifer Aniston playing Liz's ex-roommate. 8 "[' Liz', ""'s"", '),', ' as', ' well', ' as', ' Jennifer', ' An', 'iston']"
+3244 747 Name of mother of x -1 Name of mother of Jennifer Aniston Nancy Dow Jennifer Aniston "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False reuniting actress Jennifer Aniston and actor Brad 5 [' reun', 'iting', ' actress', ' Jennifer', ' An', 'iston']
+3245 749 Name of mother of x -1 Name of mother of Ginger Rogers Lela E. Rogers Ginger Rogers "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False co-starring with Ginger Rogers and Marilyn 6 [' co', '-', 'star', 'ring', ' with', ' Ginger', ' Rogers']
+3246 749 Name of mother of x -1 Name of mother of Ginger Rogers Lela E. Rogers Ginger Rogers "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Nancy danced like Ginger Rogers and could administer 4 [' Nancy', ' danced', ' like', ' Ginger', ' Rogers']
+3247 749 Name of mother of x -1 Name of mother of Ginger Rogers Lela E. Rogers Ginger Rogers "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False (1937), paired her with Ginger Rogers in a role which mirrored 8 [' (', '19', '37', '),', ' paired', ' her', ' with', ' Ginger', ' Rogers']
+3248 749 Name of mother of x -1 Name of mother of Ginger Rogers Lela E. Rogers Ginger Rogers "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False of the decade. Ginger Rogers had already made 5 [' of', ' the', ' decade', '.', ' Ginger', ' Rogers']
+3249 749 Name of mother of x -1 Name of mother of Ginger Rogers Lela E. Rogers Ginger Rogers "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False to that famous Ginger Rogers quote. She 4 [' to', ' that', ' famous', ' Ginger', ' Rogers']
+3250 750 Name of mother of x -1 Name of mother of Oliver Cromwell Elizabeth Steward Oliver Cromwell "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Oliver' ' Crom' 'well'
+ ',' ' and' ' the' '\n' '\n' '1' '.' '\n' '\n' 'The']" ", the son of the late Sir Oliver Crom well , and the
+
+ 1 .
+
+ The" False an idea of what Oliver Cromwell would have made of 6 [' an', ' idea', ' of', ' what', ' Oliver', ' Crom', 'well']
+3251 750 Name of mother of x -1 Name of mother of Oliver Cromwell Elizabeth Steward Oliver Cromwell "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Oliver' ' Crom' 'well'
+ ',' ' and' ' the' '\n' '\n' '1' '.' '\n' '\n' 'The']" ", the son of the late Sir Oliver Crom well , and the
+
+ 1 .
+
+ The" False legend states that Oliver Cromwell was present in the 5 [' legend', ' states', ' that', ' Oliver', ' Crom', 'well']
+3252 750 Name of mother of x -1 Name of mother of Oliver Cromwell Elizabeth Steward Oliver Cromwell "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Oliver' ' Crom' 'well'
+ ',' ' and' ' the' '\n' '\n' '1' '.' '\n' '\n' 'The']" ", the son of the late Sir Oliver Crom well , and the
+
+ 1 .
+
+ The" False September 1658, Oliver Cromwell died and was 6 [' September', ' 16', '58', ',', ' Oliver', ' Crom', 'well']
+3253 750 Name of mother of x -1 Name of mother of Oliver Cromwell Elizabeth Steward Oliver Cromwell "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Oliver' ' Crom' 'well'
+ ',' ' and' ' the' '\n' '\n' '1' '.' '\n' '\n' 'The']" ", the son of the late Sir Oliver Crom well , and the
+
+ 1 .
+
+ The" False English Civil War, Oliver Cromwell invaded Scotland and, 6 [' English', ' Civil', ' War', ',', ' Oliver', ' Crom', 'well']
+3254 750 Name of mother of x -1 Name of mother of Oliver Cromwell Elizabeth Steward Oliver Cromwell "[',' ' the' ' son' ' of' ' the' ' late' ' Sir' ' Oliver' ' Crom' 'well'
+ ',' ' and' ' the' '\n' '\n' '1' '.' '\n' '\n' 'The']" ", the son of the late Sir Oliver Crom well , and the
+
+ 1 .
+
+ The" False " Encounters with Oliver Cromwell ===
+" 5 [' Enc', 'ounters', ' with', ' Oliver', ' Crom', 'well']
+3255 751 Name of mother of x -1 Name of mother of John Donne Elizabeth Heywood John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's wife , and the poet 's mother , and the poet 's False Alexander Pope, John Donne and Jonathan Swift. 5 [' Alexander', ' Pope', ',', ' John', ' Don', 'ne']
+3256 751 Name of mother of x -1 Name of mother of John Donne Elizabeth Heywood John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's wife , and the poet 's mother , and the poet 's False 2 ['John', ' Don', 'ne']
+3257 751 Name of mother of x -1 Name of mother of John Donne Elizabeth Heywood John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's wife , and the poet 's mother , and the poet 's False " 100 Masters, From John Donne to Julia Alvarez.""" 6 [' 100', ' Masters', ',', ' From', ' John', ' Don', 'ne']
+3258 751 Name of mother of x -1 Name of mother of John Donne Elizabeth Heywood John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's wife , and the poet 's mother , and the poet 's False " the Poet from John Donne to Julia Alvarez""." 6 [' the', ' Po', 'et', ' from', ' John', ' Don', 'ne']
+3259 751 Name of mother of x -1 Name of mother of John Donne Elizabeth Heywood John Donne "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' wife' ',' ' and'
+ ' the' ' poet' ""'s"" ' mother' ',' ' and' ' the' ' poet' ""'s""]" , the poet , and the poet 's wife , and the poet 's mother , and the poet 's False Alexander Pope, John Donne and Jonathan Swift. 5 [' Alexander', ' Pope', ',', ' John', ' Don', 'ne']
+3260 752 Name of mother of x -1 Name of mother of Tim Burton Jean Rae Erickson Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' ""' 'Batman' '""' ' series' '.']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the "" Batman "" series ." False Warner Bros. hired Tim Burton to direct Batman. 5 [' Warner', ' Bros', '.', ' hired', ' Tim', ' Burton']
+3261 752 Name of mother of x -1 Name of mother of Tim Burton Jean Rae Erickson Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' ""' 'Batman' '""' ' series' '.']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the "" Batman "" series ." False from the 1989 Tim Burton film Batman where 4 [' from', ' the', ' 1989', ' Tim', ' Burton']
+3262 752 Name of mother of x -1 Name of mother of Tim Burton Jean Rae Erickson Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' ""' 'Batman' '""' ' series' '.']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the "" Batman "" series ." False Schumacher 4 [' Sch', 'um', 'ache', 'Tim', ' Burton']
+3263 752 Name of mother of x -1 Name of mother of Tim Burton Jean Rae Erickson Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' ""' 'Batman' '""' ' series' '.']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the "" Batman "" series ." False Warner Bros. hired Tim Burton to direct in 5 [' Warner', ' Bros', '.', ' hired', ' Tim', ' Burton']
+3264 752 Name of mother of x -1 Name of mother of Tim Burton Jean Rae Erickson Tim Burton "[""'s"" ' ""' 'Alice' ' in' ' Wonderland' '""' ' and' ' ""' 'The' ' Nightmare'
+ ' Before' ' Christmas' '""' ' and' ' the' ' ""' 'Batman' '""' ' series' '.']" "'s "" Alice in Wonderland "" and "" The Nightmare Before Christmas "" and the "" Batman "" series ." False 1996. Director Tim Burton hired Peter Suschitzky 4 [' 1996', '.', ' Director', ' Tim', ' Burton']
+3265 753 Name of mother of x -1 Name of mother of Amartya Sen Amita Sen Amartya Sen "[',' ' the' ' Nobel' ' Prize' '-' 'winning' ' economist' ',' ' who' ' has'
+ ' been' ' a' ' vocal' ' critic' ' of' ' the' ' Modi' ' government' '�'
+ '�']" , the Nobel Prize - winning economist , who has been a vocal critic of the Modi government � � False Schumacher, and Amartya Sen all spent time 8 [' Sch', 'um', 'acher', ',', ' and', ' Am', 'art', 'ya', ' Sen']
+3266 753 Name of mother of x -1 Name of mother of Amartya Sen Amita Sen Amartya Sen "[',' ' the' ' Nobel' ' Prize' '-' 'winning' ' economist' ',' ' who' ' has'
+ ' been' ' a' ' vocal' ' critic' ' of' ' the' ' Modi' ' government' '�'
+ '�']" , the Nobel Prize - winning economist , who has been a vocal critic of the Modi government � � False " economic analysis. Amartya Sen argues that ""the" 6 [' economic', ' analysis', '.', ' Am', 'art', 'ya', ' Sen']
+3267 753 Name of mother of x -1 Name of mother of Amartya Sen Amita Sen Amartya Sen "[',' ' the' ' Nobel' ' Prize' '-' 'winning' ' economist' ',' ' who' ' has'
+ ' been' ' a' ' vocal' ' critic' ' of' ' the' ' Modi' ' government' '�'
+ '�']" , the Nobel Prize - winning economist , who has been a vocal critic of the Modi government � � False economic analysis. Amartya Sen argues that 6 [' economic', ' analysis', '.', ' Am', 'art', 'ya', ' Sen']
+3268 753 Name of mother of x -1 Name of mother of Amartya Sen Amita Sen Amartya Sen "[',' ' the' ' Nobel' ' Prize' '-' 'winning' ' economist' ',' ' who' ' has'
+ ' been' ' a' ' vocal' ' critic' ' of' ' the' ' Modi' ' government' '�'
+ '�']" , the Nobel Prize - winning economist , who has been a vocal critic of the Modi government � � False same year. In 1999, Amartya Sen was awarded the 9 [' same', ' year', '.', ' In', ' 1999', ',', ' Am', 'art', 'ya', ' Sen']
+3269 753 Name of mother of x -1 Name of mother of Amartya Sen Amita Sen Amartya Sen "[',' ' the' ' Nobel' ' Prize' '-' 'winning' ' economist' ',' ' who' ' has'
+ ' been' ' a' ' vocal' ' critic' ' of' ' the' ' Modi' ' government' '�'
+ '�']" , the Nobel Prize - winning economist , who has been a vocal critic of the Modi government � � False 3 ['Am', 'art', 'ya', ' Sen']
+3270 755 Name of mother of x -1 Name of mother of Steve Jobs Joanne Carole Schieble Simpson Steve Jobs "[',' ' the' ' father' ' of' ' the' ' Macintosh' ',' ' and' ' the'
+ ' father' ' of' ' the' ' iPod' ',' ' and' ' the' ' father' ' of' ' the'
+ ' iPhone']" , the father of the Macintosh , and the father of the iPod , and the father of the iPhone False financial conference call, Steve Jobs announced 5 [' financial', ' conference', ' call', ',', ' Steve', ' Jobs']
+3271 755 Name of mother of x -1 Name of mother of Steve Jobs Joanne Carole Schieble Simpson Steve Jobs "[',' ' the' ' father' ' of' ' the' ' Macintosh' ',' ' and' ' the'
+ ' father' ' of' ' the' ' iPod' ',' ' and' ' the' ' father' ' of' ' the'
+ ' iPhone']" , the father of the Macintosh , and the father of the iPod , and the father of the iPhone False conference call, Steve Jobs announced that Apple 4 [' conference', ' call', ',', ' Steve', ' Jobs']
+3272 755 Name of mother of x -1 Name of mother of Steve Jobs Joanne Carole Schieble Simpson Steve Jobs "[',' ' the' ' father' ' of' ' the' ' Macintosh' ',' ' and' ' the'
+ ' father' ' of' ' the' ' iPod' ',' ' and' ' the' ' father' ' of' ' the'
+ ' iPhone']" , the father of the Macintosh , and the father of the iPod , and the father of the iPhone False 2011, CEO Steve Jobs unveiled the device 4 [' 2011', ',', ' CEO', ' Steve', ' Jobs']
+3273 755 Name of mother of x -1 Name of mother of Steve Jobs Joanne Carole Schieble Simpson Steve Jobs "[',' ' the' ' father' ' of' ' the' ' Macintosh' ',' ' and' ' the'
+ ' father' ' of' ' the' ' iPod' ',' ' and' ' the' ' father' ' of' ' the'
+ ' iPhone']" , the father of the Macintosh , and the father of the iPod , and the father of the iPhone False 8, 2010, with Steve Jobs demonstrating a Toy 6 [' 8', ',', ' 2010', ',', ' with', ' Steve', ' Jobs']
+3274 755 Name of mother of x -1 Name of mother of Steve Jobs Joanne Carole Schieble Simpson Steve Jobs "[',' ' the' ' father' ' of' ' the' ' Macintosh' ',' ' and' ' the'
+ ' father' ' of' ' the' ' iPod' ',' ' and' ' the' ' father' ' of' ' the'
+ ' iPhone']" , the father of the Macintosh , and the father of the iPod , and the father of the iPhone False " Inc. co-founder Steve Jobs stated ""If Toy Story" 6 [' Inc', '.', ' co', '-', 'founder', ' Steve', ' Jobs']
+3275 756 Name of mother of x -1 Name of mother of Ralph Vaughan Williams Margaret Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False " Vaughan Williams =
+" 6 [' Vaughan', ' Williams', ' =', 'R', 'alph', ' Vaughan', ' Williams']
+3276 756 Name of mother of x -1 Name of mother of Ralph Vaughan Williams Margaret Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False of portraits of Ralph Vaughan Williams in 1952 – 61. In 5 [' of', ' portraits', ' of', ' Ralph', ' Vaughan', ' Williams']
+3277 756 Name of mother of x -1 Name of mother of Ralph Vaughan Williams Margaret Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False English composer Ralph Vaughan Williams based his Sea 4 [' English', ' composer', ' Ralph', ' Vaughan', ' Williams']
+3278 756 Name of mother of x -1 Name of mother of Ralph Vaughan Williams Margaret Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False Ireland and Ralph Vaughan Williams and the college's 4 [' Ireland', ' and', ' Ralph', ' Vaughan', ' Williams']
+3279 756 Name of mother of x -1 Name of mother of Ralph Vaughan Williams Margaret Vaughan Williams Ralph Vaughan Williams "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' composer'
+ ""'s"" ' daughter' ',' ' and' ' the' ' composer' ""'s"" ' son' ',' ' the']" , the composer , and his wife , the composer 's daughter , and the composer 's son , the False from works by Ralph Vaughan Williams and Morton Gould 5 [' from', ' works', ' by', ' Ralph', ' Vaughan', ' Williams']
+3280 757 Name of mother of x -1 Name of mother of Richard Attenborough Mary Clegg Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' London' ','
+ ' England' ',' ' and' ' who' ' died' ' in' ' London' ',' ' England' ',']" , the actor , who was born in London , England , and who died in London , England , False directors, namely Richard Attenborough and Shyam Benegal; 6 [' directors', ',', ' namely', ' Richard', ' Att', 'en', 'borough']
+3281 757 Name of mother of x -1 Name of mother of Richard Attenborough Mary Clegg Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' London' ','
+ ' England' ',' ' and' ' who' ' died' ' in' ' London' ',' ' England' ',']" , the actor , who was born in London , England , and who died in London , England , False Chelsea chairman Richard Attenborough asking for a loan 5 [' Chelsea', ' chairman', ' Richard', ' Att', 'en', 'borough']
+3282 757 Name of mother of x -1 Name of mother of Richard Attenborough Mary Clegg Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' London' ','
+ ' England' ',' ' and' ' who' ' died' ' in' ' London' ',' ' England' ',']" , the actor , who was born in London , England , and who died in London , England , False Martin Scorsese and Richard Attenborough also pointed out 8 [' Martin', ' Sc', 'ors', 'ese', ' and', ' Richard', ' Att', 'en', 'borough']
+3283 757 Name of mother of x -1 Name of mother of Richard Attenborough Mary Clegg Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' London' ','
+ ' England' ',' ' and' ' who' ' died' ' in' ' London' ',' ' England' ',']" , the actor , who was born in London , England , and who died in London , England , False Martin Scorsese and Richard Attenborough also pointed 8 [' Martin', ' Sc', 'ors', 'ese', ' and', ' Richard', ' Att', 'en', 'borough']
+3284 757 Name of mother of x -1 Name of mother of Richard Attenborough Mary Clegg Richard Attenborough "[',' ' the' ' actor' ',' ' who' ' was' ' born' ' in' ' London' ','
+ ' England' ',' ' and' ' who' ' died' ' in' ' London' ',' ' England' ',']" , the actor , who was born in London , England , and who died in London , England , False 3 ['Richard', ' Att', 'en', 'borough']
+3285 758 Name of mother of x -1 Name of mother of Pius IX Caterina Solazzi Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' French'
+ ' bishops' ',' ' dated' ' from' ' the' ' Vatican' ',' ' was' '\n'
+ 'published']" ".
+
+ The Pope 's letter to the French bishops , dated from the Vatican , was
+ published" False established in 1841, Pope Pius IX ruled that, 8 [' established', ' in', ' 18', '41', ',', ' Pope', ' P', 'ius', ' IX']
+3286 758 Name of mother of x -1 Name of mother of Pius IX Caterina Solazzi Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' French'
+ ' bishops' ',' ' dated' ' from' ' the' ' Vatican' ',' ' was' '\n'
+ 'published']" ".
+
+ The Pope 's letter to the French bishops , dated from the Vatican , was
+ published" False were grateful to Pope Pius IX because he had ended 6 [' were', ' grateful', ' to', ' Pope', ' P', 'ius', ' IX']
+3287 758 Name of mother of x -1 Name of mother of Pius IX Caterina Solazzi Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' French'
+ ' bishops' ',' ' dated' ' from' ' the' ' Vatican' ',' ' was' '\n'
+ 'published']" ".
+
+ The Pope 's letter to the French bishops , dated from the Vatican , was
+ published" False long-lived Pope Pius IX (1846 – 1878) 6 [' long', '-', 'lived', ' Pope', ' P', 'ius', ' IX']
+3288 758 Name of mother of x -1 Name of mother of Pius IX Caterina Solazzi Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' French'
+ ' bishops' ',' ' dated' ' from' ' the' ' Vatican' ',' ' was' '\n'
+ 'published']" ".
+
+ The Pope 's letter to the French bishops , dated from the Vatican , was
+ published" False whether Pope Pius IX was involved in 4 [' whether', ' Pope', ' P', 'ius', ' IX']
+3289 758 Name of mother of x -1 Name of mother of Pius IX Caterina Solazzi Pius IX "['.' '\n' '\n' 'The' ' Pope' ""'s"" ' letter' ' to' ' the' ' French'
+ ' bishops' ',' ' dated' ' from' ' the' ' Vatican' ',' ' was' '\n'
+ 'published']" ".
+
+ The Pope 's letter to the French bishops , dated from the Vatican , was
+ published" False count by Pope Pius IX in 1857. Constantin 5 [' count', ' by', ' Pope', ' P', 'ius', ' IX']
+3290 759 Name of mother of x -1 Name of mother of Robert Capa Júlia Robert Capa "[',' ' the' ' famous' ' photographer' ',' ' and' ' his' ' wife' ',' ' In'
+ 'ge' ' Mor' 'ath' ',' ' who' ' was' ' a' ' photographer' ' for' ' V']" , the famous photographer , and his wife , In ge Mor ath , who was a photographer for V False Soldier by Robert Capa is the third frame 4 [' Soldier', ' by', ' Robert', ' Cap', 'a']
+3291 759 Name of mother of x -1 Name of mother of Robert Capa Júlia Robert Capa "[',' ' the' ' famous' ' photographer' ',' ' and' ' his' ' wife' ',' ' In'
+ 'ge' ' Mor' 'ath' ',' ' who' ' was' ' a' ' photographer' ' for' ' V']" , the famous photographer , and his wife , In ge Mor ath , who was a photographer for V False Soldier by Robert Capa is the third frame 4 [' Soldier', ' by', ' Robert', ' Cap', 'a']
+3292 760 Name of mother of x -1 Name of mother of Xi Jinping Qi Xin Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Changchun, Xi Jinping and Zhou Yongkang, 5 [' Chang', 'ch', 'un', ',', ' Xi', ' Jinping']
+3293 760 Name of mother of x -1 Name of mother of Xi Jinping Qi Xin Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Li Changchun, Xi Jinping and Zhou Yongkang, 6 [' Li', ' Chang', 'ch', 'un', ',', ' Xi', ' Jinping']
+3294 760 Name of mother of x -1 Name of mother of Xi Jinping Qi Xin Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Li Changchun, Xi Jinping and Zhou Yongkang, 6 [' Li', ' Chang', 'ch', 'un', ',', ' Xi', ' Jinping']
+3295 760 Name of mother of x -1 Name of mother of Xi Jinping Qi Xin Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Qinglin, Li Changchun, Xi Jinping and Zhou Yongkang, 9 [' Qing', 'lin', ',', ' Li', ' Chang', 'ch', 'un', ',', ' Xi', ' Jinping']
+3296 760 Name of mother of x -1 Name of mother of Xi Jinping Qi Xin Xi Jinping "[',' ' the' ' Chinese' ' president' ',' ' and' ' the' ' Chinese'
+ ' Communist' ' Party' '.' '\n' '\n' 'The' ' Chinese' ' Communist'
+ ' Party' ' is' ' a' ' totalitarian']" ", the Chinese president , and the Chinese Communist Party .
+
+ The Chinese Communist Party is a totalitarian" False Li Changchun, Xi Jinping and Zhou Yongkang, 6 [' Li', ' Chang', 'ch', 'un', ',', ' Xi', ' Jinping']
+3297 761 Name of mother of x -1 Name of mother of François Hollande Nicole Tribert François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False " Sicília and President François Hollande of France.
+" 6 [' Sic', 'í', 'lia', ' and', ' President', ' François', ' Hollande']
+3298 761 Name of mother of x -1 Name of mother of François Hollande Nicole Tribert François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False " and President François Hollande of France.
+" 3 [' and', ' President', ' François', ' Hollande']
+3299 761 Name of mother of x -1 Name of mother of François Hollande Nicole Tribert François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False " and President François Hollande of France.
+" 3 [' and', ' President', ' François', ' Hollande']
+3300 761 Name of mother of x -1 Name of mother of François Hollande Nicole Tribert François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False French President François Hollande praised the 3 [' French', ' President', ' François', ' Hollande']
+3301 761 Name of mother of x -1 Name of mother of François Hollande Nicole Tribert François Hollande "[',' ' the' ' French' ' president' ',' ' and' ' the' ' French' ' people'
+ '.' '\n' '\n' 'The' ' French' ' people' ' are' ' not' ' stupid' '.'
+ ' They']" ", the French president , and the French people .
+
+ The French people are not stupid . They" False President of France François Hollande was not accompanied 4 [' President', ' of', ' France', ' François', ' Hollande']
+3302 763 Name of mother of x -1 Name of mother of Gilbert du Motier, Marquis de Lafayette Marie Louise Jolie de La Rivière Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False total loss. When Gilbert du Motier, Marquis de Lafayette toured the United 12 [' total', ' loss', '.', ' When', ' Gilbert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3303 763 Name of mother of x -1 Name of mother of Gilbert du Motier, Marquis de Lafayette Marie Louise Jolie de La Rivière Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False " Lafayette =
+" 11 [' Lafayette', ' =', 'Gil', 'bert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3304 763 Name of mother of x -1 Name of mother of Gilbert du Motier, Marquis de Lafayette Marie Louise Jolie de La Rivière Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False a total loss. When Gilbert du Motier, Marquis de Lafayette toured the United 13 [' a', ' total', ' loss', '.', ' When', ' Gilbert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3305 763 Name of mother of x -1 Name of mother of Gilbert du Motier, Marquis de Lafayette Marie Louise Jolie de La Rivière Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False " Lafayette =
+" 11 [' Lafayette', ' =', 'Gil', 'bert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3306 763 Name of mother of x -1 Name of mother of Gilbert du Motier, Marquis de Lafayette Marie Louise Jolie de La Rivière Gilbert du Motier, Marquis de Lafayette "[',' ' and' ' his' ' wife' ',' ' Marie' '-' 'Joseph' 'te' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Marqu' 'is' ' de' ' Lafayette']" , and his wife , Marie - Joseph te , who was the daughter of the Marqu is de Lafayette False " Lafayette =
+" 11 [' Lafayette', ' =', 'Gil', 'bert', ' du', ' Mot', 'ier', ',', ' Marqu', 'is', ' de', ' Lafayette']
+3307 764 Name of mother of x -1 Name of mother of Helen Keller Catherine Adams Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the'
+ ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the']" , the blind and deaf - m ute , and the blind and deaf - m ute , and the False later starred as Helen Keller alongside her mother 4 [' later', ' starred', ' as', ' Helen', ' Keller']
+3308 764 Name of mother of x -1 Name of mother of Helen Keller Catherine Adams Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the'
+ ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the']" , the blind and deaf - m ute , and the blind and deaf - m ute , and the False " Michael Jackson, and Helen Keller are so obvious.""
+" 5 [' Michael', ' Jackson', ',', ' and', ' Helen', ' Keller']
+3309 764 Name of mother of x -1 Name of mother of Helen Keller Catherine Adams Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the'
+ ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the']" , the blind and deaf - m ute , and the blind and deaf - m ute , and the False children of the Helen Keller Institute; she 4 [' children', ' of', ' the', ' Helen', ' Keller']
+3310 764 Name of mother of x -1 Name of mother of Helen Keller Catherine Adams Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the'
+ ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the']" , the blind and deaf - m ute , and the blind and deaf - m ute , and the False challenged children of the Helen Keller Institute; she had 5 [' challenged', ' children', ' of', ' the', ' Helen', ' Keller']
+3311 764 Name of mother of x -1 Name of mother of Helen Keller Catherine Adams Helen Keller "[',' ' the' ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the'
+ ' blind' ' and' ' deaf' '-' 'm' 'ute' ',' ' and' ' the']" , the blind and deaf - m ute , and the blind and deaf - m ute , and the False (1956), The Story of Helen Keller (1958), The Story 8 [' (', '19', '56', '),', ' The', ' Story', ' of', ' Helen', ' Keller']
+3312 765 Name of mother of x -1 Name of mother of Jeremy Bentham Alicia Woodward Grove Jeremy Bentham "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' Jeremy' ' Bent' 'ham' ',']" ", the father of the
+
+ The following is a list of the children of Jeremy Bent ham ," False first made by Jeremy Bentham in 1817 when he 5 [' first', ' made', ' by', ' Jeremy', ' Bent', 'ham']
+3313 765 Name of mother of x -1 Name of mother of Jeremy Bentham Alicia Woodward Grove Jeremy Bentham "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' Jeremy' ' Bent' 'ham' ',']" ", the father of the
+
+ The following is a list of the children of Jeremy Bent ham ," False " Life and Death of Jeremy Bentham ""are the two" 6 [' Life', ' and', ' Death', ' of', ' Jeremy', ' Bent', 'ham']
+3314 765 Name of mother of x -1 Name of mother of Jeremy Bentham Alicia Woodward Grove Jeremy Bentham "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' Jeremy' ' Bent' 'ham' ',']" ", the father of the
+
+ The following is a list of the children of Jeremy Bent ham ," False first made by Jeremy Bentham in 1817 when he 5 [' first', ' made', ' by', ' Jeremy', ' Bent', 'ham']
+3315 765 Name of mother of x -1 Name of mother of Jeremy Bentham Alicia Woodward Grove Jeremy Bentham "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' Jeremy' ' Bent' 'ham' ',']" ", the father of the
+
+ The following is a list of the children of Jeremy Bent ham ," False " Life and Death of Jeremy Bentham ""are the two installments" 6 [' Life', ' and', ' Death', ' of', ' Jeremy', ' Bent', 'ham']
+3316 765 Name of mother of x -1 Name of mother of Jeremy Bentham Alicia Woodward Grove Jeremy Bentham "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'The' ' following' ' is' ' a'
+ ' list' ' of' ' the' ' children' ' of' ' Jeremy' ' Bent' 'ham' ',']" ", the father of the
+
+ The following is a list of the children of Jeremy Bent ham ," False King's Bench; Jeremy Bentham asserted that this 6 "[' King', ""'s"", ' Bench', ';', ' Jeremy', ' Bent', 'ham']"
+3317 766 Name of mother of x -1 Name of mother of Marquis de Sade Marie Eleonore de Maillé Marquis de Sade "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Marqu' 'is' ' de'
+ ' S' 'ade' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Marqu is de S ade , the
+
+ Name of mother" False Besenval, the Marquis de Sade and Baudelaire. 9 [' Bes', 'en', 'val', ',', ' the', ' Marqu', 'is', ' de', ' S', 'ade']
+3318 766 Name of mother of x -1 Name of mother of Marquis de Sade Marie Eleonore de Maillé Marquis de Sade "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Marqu' 'is' ' de'
+ ' S' 'ade' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Marqu is de S ade , the
+
+ Name of mother" False Caylus, Besenval, the Marquis de Sade and Baudelaire. He 12 [' Cay', 'lus', ',', ' Bes', 'en', 'val', ',', ' the', ' Marqu', 'is', ' de', ' S', 'ade']
+3319 766 Name of mother of x -1 Name of mother of Marquis de Sade Marie Eleonore de Maillé Marquis de Sade "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Marqu' 'is' ' de'
+ ' S' 'ade' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Marqu is de S ade , the
+
+ Name of mother" False Besenval, the Marquis de Sade and Baudelaire. He 9 [' Bes', 'en', 'val', ',', ' the', ' Marqu', 'is', ' de', ' S', 'ade']
+3320 767 Name of mother of x -1 Name of mother of Oliver Stone Jacqueline Goddet Oliver Stone "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' friend' ' of' ' the'
+ ' family' ' for' ' years' '.' '\n' '\n' '""' 'I' ""'m"" ' sorry']" ", the man who had been a friend of the family for years .
+
+ "" I 'm sorry" False Wagner. Director Oliver Stone had borrowed 4 [' Wagner', '.', ' Director', ' Oliver', ' Stone']
+3321 767 Name of mother of x -1 Name of mother of Oliver Stone Jacqueline Goddet Oliver Stone "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' friend' ' of' ' the'
+ ' family' ' for' ' years' '.' '\n' '\n' '""' 'I' ""'m"" ' sorry']" ", the man who had been a friend of the family for years .
+
+ "" I 'm sorry" False abandoned when Oliver Stone joined the 3 [' abandoned', ' when', ' Oliver', ' Stone']
+3322 767 Name of mother of x -1 Name of mother of Oliver Stone Jacqueline Goddet Oliver Stone "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' friend' ' of' ' the'
+ ' family' ' for' ' years' '.' '\n' '\n' '""' 'I' ""'m"" ' sorry']" ", the man who had been a friend of the family for years .
+
+ "" I 'm sorry" False not proceed. Oliver Stone was then approached 4 [' not', ' proceed', '.', ' Oliver', ' Stone']
+3323 767 Name of mother of x -1 Name of mother of Oliver Stone Jacqueline Goddet Oliver Stone "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' friend' ' of' ' the'
+ ' family' ' for' ' years' '.' '\n' '\n' '""' 'I' ""'m"" ' sorry']" ", the man who had been a friend of the family for years .
+
+ "" I 'm sorry" False lead role and Oliver Stone to draft a 4 [' lead', ' role', ' and', ' Oliver', ' Stone']
+3324 767 Name of mother of x -1 Name of mother of Oliver Stone Jacqueline Goddet Oliver Stone "[',' ' the' ' man' ' who' ' had' ' been' ' a' ' friend' ' of' ' the'
+ ' family' ' for' ' years' '.' '\n' '\n' '""' 'I' ""'m"" ' sorry']" ", the man who had been a friend of the family for years .
+
+ "" I 'm sorry" False announced that Oliver Stone would direct. In addition, 3 [' announced', ' that', ' Oliver', ' Stone']
+3325 769 Name of mother of x -1 Name of mother of Paul Newman Theresa Fetzer Paul Newman "[',' ' who' ' was' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' actor' ',' ' and' ' a' ' great' ' man' '.' '\n' '\n' 'I']" ", who was a great actor , and a great actor , and a great man .
+
+ I" False Sting, which starred Paul Newman and Robert Redford. 5 [' Sting', ',', ' which', ' starred', ' Paul', ' Newman']
+3326 769 Name of mother of x -1 Name of mother of Paul Newman Theresa Fetzer Paul Newman "[',' ' who' ' was' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' actor' ',' ' and' ' a' ' great' ' man' '.' '\n' '\n' 'I']" ", who was a great actor , and a great actor , and a great man .
+
+ I" False Rosenberg, starring Paul Newman and featuring 4 [' Rosenberg', ',', ' starring', ' Paul', ' Newman']
+3327 769 Name of mother of x -1 Name of mother of Paul Newman Theresa Fetzer Paul Newman "[',' ' who' ' was' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' actor' ',' ' and' ' a' ' great' ' man' '.' '\n' '\n' 'I']" ", who was a great actor , and a great actor , and a great man .
+
+ I" False Sting, which starred Paul Newman and Robert Redford. 5 [' Sting', ',', ' which', ' starred', ' Paul', ' Newman']
+3328 769 Name of mother of x -1 Name of mother of Paul Newman Theresa Fetzer Paul Newman "[',' ' who' ' was' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' actor' ',' ' and' ' a' ' great' ' man' '.' '\n' '\n' 'I']" ", who was a great actor , and a great actor , and a great man .
+
+ I" False Rosenberg, starring Paul Newman and featuring George 4 [' Rosenberg', ',', ' starring', ' Paul', ' Newman']
+3329 769 Name of mother of x -1 Name of mother of Paul Newman Theresa Fetzer Paul Newman "[',' ' who' ' was' ' a' ' great' ' actor' ',' ' and' ' a' ' great'
+ ' actor' ',' ' and' ' a' ' great' ' man' '.' '\n' '\n' 'I']" ", who was a great actor , and a great actor , and a great man .
+
+ I" False actor and director Paul Newman was his celebrity 4 [' actor', ' and', ' director', ' Paul', ' Newman']
+3330 770 Name of mother of x -1 Name of mother of Burt Bacharach Irma M. Freeman Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Rain' 'drops' ' Keep'
+ ' Fall' 'in' ""'"" ' on' ' My' ' Head' '""' ' and' ' ""' 'Do']" ", the composer of the song "" Rain drops Keep Fall in ' on My Head "" and "" Do" False 1960s songwriter Burt Bacharach by sending him 8 [' 1960', 's', ' song', 'writer', ' B', 'urt', ' Bach', 'ar', 'ach']
+3331 770 Name of mother of x -1 Name of mother of Burt Bacharach Irma M. Freeman Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Rain' 'drops' ' Keep'
+ ' Fall' 'in' ""'"" ' on' ' My' ' Head' '""' ' and' ' ""' 'Do']" ", the composer of the song "" Rain drops Keep Fall in ' on My Head "" and "" Do" False Grammy award winner Burt Bacharach studied music at McGill, 7 [' Grammy', ' award', ' winner', ' B', 'urt', ' Bach', 'ar', 'ach']
+3332 770 Name of mother of x -1 Name of mother of Burt Bacharach Irma M. Freeman Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Rain' 'drops' ' Keep'
+ ' Fall' 'in' ""'"" ' on' ' My' ' Head' '""' ' and' ' ""' 'Do']" ", the composer of the song "" Rain drops Keep Fall in ' on My Head "" and "" Do" False Feldman asked Burt Bacharach and Hal David to 6 [' Feldman', ' asked', ' B', 'urt', ' Bach', 'ar', 'ach']
+3333 770 Name of mother of x -1 Name of mother of Burt Bacharach Irma M. Freeman Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Rain' 'drops' ' Keep'
+ ' Fall' 'in' ""'"" ' on' ' My' ' Head' '""' ' and' ' ""' 'Do']" ", the composer of the song "" Rain drops Keep Fall in ' on My Head "" and "" Do" False Grammy award winner Burt Bacharach studied music 7 [' Grammy', ' award', ' winner', ' B', 'urt', ' Bach', 'ar', 'ach']
+3334 770 Name of mother of x -1 Name of mother of Burt Bacharach Irma M. Freeman Burt Bacharach "[',' ' the' ' composer' ' of' ' the' ' song' ' ""' 'Rain' 'drops' ' Keep'
+ ' Fall' 'in' ""'"" ' on' ' My' ' Head' '""' ' and' ' ""' 'Do']" ", the composer of the song "" Rain drops Keep Fall in ' on My Head "" and "" Do" False 1960s songwriter Burt Bacharach by sending him hip 8 [' 1960', 's', ' song', 'writer', ' B', 'urt', ' Bach', 'ar', 'ach']
+3335 771 Name of mother of x -1 Name of mother of Moses Jochebed Moses "[',' ' the' ' son' ' of' ' Am' 'mon' ',' ' and' ' the' ' son' ' of' ' Am'
+ 'mon' ',' ' and' ' the' ' son' ' of' ' Am' 'mon']" , the son of Am mon , and the son of Am mon , and the son of Am mon False specimens from Moses Ashley Curtis, 2 [' specimens', ' from', ' Moses']
+3336 771 Name of mother of x -1 Name of mother of Moses Jochebed Moses "[',' ' the' ' son' ' of' ' Am' 'mon' ',' ' and' ' the' ' son' ' of' ' Am'
+ 'mon' ',' ' and' ' the' ' son' ' of' ' Am' 'mon']" , the son of Am mon , and the son of Am mon , and the son of Am mon False ruling of Moses Isserles who 2 [' ruling', ' of', ' Moses']
+3337 771 Name of mother of x -1 Name of mother of Moses Jochebed Moses "[',' ' the' ' son' ' of' ' Am' 'mon' ',' ' and' ' the' ' son' ' of' ' Am'
+ 'mon' ',' ' and' ' the' ' son' ' of' ' Am' 'mon']" , the son of Am mon , and the son of Am mon , and the son of Am mon False cabinet decision to expel Moses Gaster, a renowned 4 [' cabinet', ' decision', ' to', ' expel', ' Moses']
+3338 771 Name of mother of x -1 Name of mother of Moses Jochebed Moses "[',' ' the' ' son' ' of' ' Am' 'mon' ',' ' and' ' the' ' son' ' of' ' Am'
+ 'mon' ',' ' and' ' the' ' son' ' of' ' Am' 'mon']" , the son of Am mon , and the son of Am mon , and the son of Am mon False 1 ['M', 'oses']
+3339 771 Name of mother of x -1 Name of mother of Moses Jochebed Moses "[',' ' the' ' son' ' of' ' Am' 'mon' ',' ' and' ' the' ' son' ' of' ' Am'
+ 'mon' ',' ' and' ' the' ' son' ' of' ' Am' 'mon']" , the son of Am mon , and the son of Am mon , and the son of Am mon False Baseball. His brother Moses Fleetwood Walker, 4 [' Baseball', '.', ' His', ' brother', ' Moses']
+3340 772 Name of mother of x -1 Name of mother of Lorenzo de' Medici Lucrezia Tornabuoni Lorenzo de' Medici "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False marriage to Piero di Lorenzo de' Medici was arranged by 9 "[' marriage', ' to', ' Pier', 'o', ' di', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3341 772 Name of mother of x -1 Name of mother of Lorenzo de' Medici Lucrezia Tornabuoni Lorenzo de' Medici "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False a horse's head. Lorenzo de' Medici sent Leonardo 9 "[' a', ' horse', ""'s"", ' head', '.', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3342 772 Name of mother of x -1 Name of mother of Lorenzo de' Medici Lucrezia Tornabuoni Lorenzo de' Medici "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False marriage to Piero di Lorenzo de' Medici was arranged 9 "[' marriage', ' to', ' Pier', 'o', ' di', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3343 772 Name of mother of x -1 Name of mother of Lorenzo de' Medici Lucrezia Tornabuoni Lorenzo de' Medici "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False horse's head. Lorenzo de' Medici sent Leonardo to 8 "[' horse', ""'s"", ' head', '.', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3344 772 Name of mother of x -1 Name of mother of Lorenzo de' Medici Lucrezia Tornabuoni Lorenzo de' Medici "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False marriage to Piero di Lorenzo de' Medici was arranged by 9 "[' marriage', ' to', ' Pier', 'o', ' di', ' Lorenzo', ' de', ""'"", ' Medic', 'i']"
+3345 773 Name of mother of x -1 Name of mother of Olivia de Havilland Lilian Fontaine Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False from the film: Olivia de Havilland who played Melanie 8 [' from', ' the', ' film', ':', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3346 773 Name of mother of x -1 Name of mother of Olivia de Havilland Lilian Fontaine Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False overprotective father of Olivia de Havilland in The Heiress, 9 [' over', 'prot', 'ective', ' father', ' of', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3347 773 Name of mother of x -1 Name of mother of Olivia de Havilland Lilian Fontaine Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False anything if we had Olivia de Havilland under contract to 8 [' anything', ' if', ' we', ' had', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3348 773 Name of mother of x -1 Name of mother of Olivia de Havilland Lilian Fontaine Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Flynn, and Olivia de Havilland for lead roles 7 [' Flynn', ',', ' and', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3349 773 Name of mother of x -1 Name of mother of Olivia de Havilland Lilian Fontaine Olivia de Havilland "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False collection of Olivia de Havilland is held at the 6 [' collection', ' of', ' Olivia', ' de', ' Hav', 'ill', 'and']
+3350 774 Name of mother of x -1 Name of mother of Nikita Mikhalkov Natalia Konchalovskaïa Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Nat' 'alia' ' V' 'od']" , the Russian director of the film , and his wife , the actress and singer Nat alia V od False 6 – 18, director Nikita Mikhalkov observes that to 10 [' 6', ' –', ' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3351 774 Name of mother of x -1 Name of mother of Nikita Mikhalkov Natalia Konchalovskaïa Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Nat' 'alia' ' V' 'od']" , the Russian director of the film , and his wife , the actress and singer Nat alia V od False – 18, director Nikita Mikhalkov observes that to be 9 [' –', ' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3352 774 Name of mother of x -1 Name of mother of Nikita Mikhalkov Natalia Konchalovskaïa Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Nat' 'alia' ' V' 'od']" , the Russian director of the film , and his wife , the actress and singer Nat alia V od False Anna: 6 – 18, director Nikita Mikhalkov observes that 12 [' Anna', ':', ' 6', ' –', ' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3353 774 Name of mother of x -1 Name of mother of Nikita Mikhalkov Natalia Konchalovskaïa Nikita Mikhalkov "[',' ' the' ' Russian' ' director' ' of' ' the' ' film' ',' ' and' ' his'
+ ' wife' ',' ' the' ' actress' ' and' ' singer' ' Nat' 'alia' ' V' 'od']" , the Russian director of the film , and his wife , the actress and singer Nat alia V od False Anna: 6 – 18, director Nikita Mikhalkov observes that 12 [' Anna', ':', ' 6', ' –', ' 18', ',', ' director', ' Nik', 'ita', ' M', 'ikh', 'alk', 'ov']
+3354 776 Name of mother of x -1 Name of mother of Boris Johnson Charlotte Johnson Wahl Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False while Mayor Boris Johnson arranged for 3 [' while', ' Mayor', ' Boris', ' Johnson']
+3355 776 Name of mother of x -1 Name of mother of Boris Johnson Charlotte Johnson Wahl Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False Osborne and Boris Johnson were former 3 [' Osborne', ' and', ' Boris', ' Johnson']
+3356 776 Name of mother of x -1 Name of mother of Boris Johnson Charlotte Johnson Wahl Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False Mayor of London Boris Johnson proposed that long 4 [' Mayor', ' of', ' London', ' Boris', ' Johnson']
+3357 776 Name of mother of x -1 Name of mother of Boris Johnson Charlotte Johnson Wahl Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False in May 2012, Boris Johnson delayed the final 5 [' in', ' May', ' 2012', ',', ' Boris', ' Johnson']
+3358 776 Name of mother of x -1 Name of mother of Boris Johnson Charlotte Johnson Wahl Boris Johnson "[',' ' the' ' former' ' foreign' ' secretary' ',' ' said' ':' ' �' '�' 'I'
+ ' think' ' it' '�' '�' 's' ' a' ' very' ' good' ' idea']" , the former foreign secretary , said : � � I think it � � s a very good idea False mayor of London, Boris Johnson announced plans 5 [' mayor', ' of', ' London', ',', ' Boris', ' Johnson']
+3359 780 Name of mother of x -1 Name of mother of Bjørnstjerne Bjørnson Inger Elise Nordraach Bjørnstjerne Bjørnson "['\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Bj' 'ø' 'rn' 'st'
+ 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' '18']" "
+
+ The name of the mother of Bj ø rn st jer ne Bj ø rn son ( 18" False lèse majesté. Bjørnstjerne Bjørnson and Lars Holst were 16 [' l', 'è', 'se', ' maj', 'est', 'é', '.', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3360 780 Name of mother of x -1 Name of mother of Bjørnstjerne Bjørnson Inger Elise Nordraach Bjørnstjerne Bjørnson "['\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Bj' 'ø' 'rn' 'st'
+ 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' '18']" "
+
+ The name of the mother of Bj ø rn st jer ne Bj ø rn son ( 18" False Norwegian poet Bjørnstjerne Bjørnson and Icelandic sagas. 11 [' Norwegian', ' poet', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3361 780 Name of mother of x -1 Name of mother of Bjørnstjerne Bjørnson Inger Elise Nordraach Bjørnstjerne Bjørnson "['\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Bj' 'ø' 'rn' 'st'
+ 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' '18']" "
+
+ The name of the mother of Bj ø rn st jer ne Bj ø rn son ( 18" False Norwegian poet Bjørnstjerne Bjørnson and Icelandic sagas. 11 [' Norwegian', ' poet', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3362 780 Name of mother of x -1 Name of mother of Bjørnstjerne Bjørnson Inger Elise Nordraach Bjørnstjerne Bjørnson "['\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Bj' 'ø' 'rn' 'st'
+ 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' '18']" "
+
+ The name of the mother of Bj ø rn st jer ne Bj ø rn son ( 18" False the Norwegian poet Bjørnstjerne Bjørnson and Icelandic 12 [' the', ' Norwegian', ' poet', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3363 780 Name of mother of x -1 Name of mother of Bjørnstjerne Bjørnson Inger Elise Nordraach Bjørnstjerne Bjørnson "['\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Bj' 'ø' 'rn' 'st'
+ 'jer' 'ne' ' Bj' 'ø' 'rn' 'son' ' (' '18']" "
+
+ The name of the mother of Bj ø rn st jer ne Bj ø rn son ( 18" False and novelist Bjørnstjerne Bjørnson was the guest of 11 [' and', ' novelist', ' Bj', 'ø', 'rn', 'st', 'jer', 'ne', ' Bj', 'ø', 'rn', 'son']
+3364 782 Name of mother of x -1 Name of mother of Heinrich Schliemann Luise Therese Sophie Schliemann Heinrich Schliemann "[',' ' the' ' German' ' archae' 'ologist' ' who' ' discovered' ' Troy' ','
+ ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Troy']" ", the German archae ologist who discovered Troy , and the
+
+ The name of the city of Troy" False the archaeologist Heinrich Schliemann in 1876, following 8 [' the', ' archae', 'ologist', ' Hein', 'rich', ' Sch', 'li', 'em', 'ann']
+3365 782 Name of mother of x -1 Name of mother of Heinrich Schliemann Luise Therese Sophie Schliemann Heinrich Schliemann "[',' ' the' ' German' ' archae' 'ologist' ' who' ' discovered' ' Troy' ','
+ ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Troy']" ", the German archae ologist who discovered Troy , and the
+
+ The name of the city of Troy" False archaeologist Heinrich Schliemann undertook the first 7 [' archae', 'ologist', ' Hein', 'rich', ' Sch', 'li', 'em', 'ann']
+3366 782 Name of mother of x -1 Name of mother of Heinrich Schliemann Luise Therese Sophie Schliemann Heinrich Schliemann "[',' ' the' ' German' ' archae' 'ologist' ' who' ' discovered' ' Troy' ','
+ ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Troy']" ", the German archae ologist who discovered Troy , and the
+
+ The name of the city of Troy" False archaeologists Heinrich Schliemann and Arthur 6 [' archaeologists', ' Hein', 'rich', ' Sch', 'li', 'em', 'ann']
+3367 782 Name of mother of x -1 Name of mother of Heinrich Schliemann Luise Therese Sophie Schliemann Heinrich Schliemann "[',' ' the' ' German' ' archae' 'ologist' ' who' ' discovered' ' Troy' ','
+ ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Troy']" ", the German archae ologist who discovered Troy , and the
+
+ The name of the city of Troy" False amateur archaeologist Heinrich Schliemann in the nineteenth 8 [' amateur', ' archae', 'ologist', ' Hein', 'rich', ' Sch', 'li', 'em', 'ann']
+3368 782 Name of mother of x -1 Name of mother of Heinrich Schliemann Luise Therese Sophie Schliemann Heinrich Schliemann "[',' ' the' ' German' ' archae' 'ologist' ' who' ' discovered' ' Troy' ','
+ ' and' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' city' ' of' ' Troy']" ", the German archae ologist who discovered Troy , and the
+
+ The name of the city of Troy" False archaeologists Heinrich Schliemann and Arthur 6 [' archaeologists', ' Hein', 'rich', ' Sch', 'li', 'em', 'ann']
+3369 783 Name of mother of x -1 Name of mother of Taras Shevchenko Kateryna Y. Boiko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False and works of Taras Shevchenko brilliantly reflected 7 [' and', ' works', ' of', ' Tar', 'as', ' She', 'v', 'chenko']
+3370 783 Name of mother of x -1 Name of mother of Taras Shevchenko Kateryna Y. Boiko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False on a poem by Taras Shevchenko titled “ Зоре моя 8 [' on', ' a', ' poem', ' by', ' Tar', 'as', ' She', 'v', 'chenko']
+3371 783 Name of mother of x -1 Name of mother of Taras Shevchenko Kateryna Y. Boiko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False and works of Taras Shevchenko brilliantly 7 [' and', ' works', ' of', ' Tar', 'as', ' She', 'v', 'chenko']
+3372 783 Name of mother of x -1 Name of mother of Taras Shevchenko Kateryna Y. Boiko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False with establishing the Taras Shevchenko Memorial Committee 7 [' with', ' establishing', ' the', ' Tar', 'as', ' She', 'v', 'chenko']
+3373 783 Name of mother of x -1 Name of mother of Taras Shevchenko Kateryna Y. Boiko Taras Shevchenko "[',' ' the' ' Ukrainian' ' poet' ',' ' was' ' born' ' in' ' the'
+ ' village' ' of' ' K' 'ras' 'n' 'oh' 'or' 'iv' 'ka' ',' ' in']" , the Ukrainian poet , was born in the village of K ras n oh or iv ka , in False Prize of Ukraine Taras Shevchenko (posthumously, 7 [' Prize', ' of', ' Ukraine', ' Tar', 'as', ' She', 'v', 'chenko']
+3374 784 Name of mother of x -1 Name of mother of Arthur Balfour Lady Blanche Gascoyne-Cecil Arthur Balfour "[',' ' the' '\n' '\n' '1' '.' '\n' '\n' '2' '.' '\n' '\n' '3' '.' '\n'
+ '\n' '4' '.' '\n' '\n']" ", the
+
+ 1 .
+
+ 2 .
+
+ 3 .
+
+ 4 .
+
+" False Foreign Secretary Arthur Balfour contacted Baron 5 [' Foreign', ' Secretary', ' Arthur', ' B', 'alf', 'our']
+3375 784 Name of mother of x -1 Name of mother of Arthur Balfour Lady Blanche Gascoyne-Cecil Arthur Balfour "[',' ' the' '\n' '\n' '1' '.' '\n' '\n' '2' '.' '\n' '\n' '3' '.' '\n'
+ '\n' '4' '.' '\n' '\n']" ", the
+
+ 1 .
+
+ 2 .
+
+ 3 .
+
+ 4 .
+
+" False Conservative leaders Arthur Balfour and Lord Lansdowne 5 [' Conservative', ' leaders', ' Arthur', ' B', 'alf', 'our']
+3376 785 Name of mother of x -1 Name of mother of Jacques Cousteau Elizabeth Cousteau Jacques Cousteau "[',' ' the' ' French' ' ocean' 'ographer' ',' ' who' ' was' ' a'
+ ' pioneer' ' of' ' ocean' 'ographic' ' research' ' and' ' exploration'
+ '.' '\n' '\n' 'The']" ", the French ocean ographer , who was a pioneer of ocean ographic research and exploration .
+
+ The" False underwater explorer Jacques Cousteau began diving 5 [' underwater', ' explorer', ' Jacques', ' Cou', 'ste', 'au']
+3377 785 Name of mother of x -1 Name of mother of Jacques Cousteau Elizabeth Cousteau Jacques Cousteau "[',' ' the' ' French' ' ocean' 'ographer' ',' ' who' ' was' ' a'
+ ' pioneer' ' of' ' ocean' 'ographic' ' research' ' and' ' exploration'
+ '.' '\n' '\n' 'The']" ", the French ocean ographer , who was a pioneer of ocean ographic research and exploration .
+
+ The" False and, in the 1940s, Jacques Cousteau helped develop the 10 [' and', ',', ' in', ' the', ' 1940', 's', ',', ' Jacques', ' Cou', 'ste', 'au']
+3378 785 Name of mother of x -1 Name of mother of Jacques Cousteau Elizabeth Cousteau Jacques Cousteau "[',' ' the' ' French' ' ocean' 'ographer' ',' ' who' ' was' ' a'
+ ' pioneer' ' of' ' ocean' 'ographic' ' research' ' and' ' exploration'
+ '.' '\n' '\n' 'The']" ", the French ocean ographer , who was a pioneer of ocean ographic research and exploration .
+
+ The" False oceanographic researcher Jacques Cousteau described the 6 [' ocean', 'ographic', ' researcher', ' Jacques', ' Cou', 'ste', 'au']
+3379 785 Name of mother of x -1 Name of mother of Jacques Cousteau Elizabeth Cousteau Jacques Cousteau "[',' ' the' ' French' ' ocean' 'ographer' ',' ' who' ' was' ' a'
+ ' pioneer' ' of' ' ocean' 'ographic' ' research' ' and' ' exploration'
+ '.' '\n' '\n' 'The']" ", the French ocean ographer , who was a pioneer of ocean ographic research and exploration .
+
+ The" False scuba dive with Jacques Cousteau in 1953 provided 7 [' sc', 'uba', ' dive', ' with', ' Jacques', ' Cou', 'ste', 'au']
+3380 785 Name of mother of x -1 Name of mother of Jacques Cousteau Elizabeth Cousteau Jacques Cousteau "[',' ' the' ' French' ' ocean' 'ographer' ',' ' who' ' was' ' a'
+ ' pioneer' ' of' ' ocean' 'ographic' ' research' ' and' ' exploration'
+ '.' '\n' '\n' 'The']" ", the French ocean ographer , who was a pioneer of ocean ographic research and exploration .
+
+ The" False the 1940s, Jacques Cousteau helped develop 7 [' the', ' 1940', 's', ',', ' Jacques', ' Cou', 'ste', 'au']
+3381 786 Name of mother of x -1 Name of mother of Jayne Mansfield Vera Jeffrey Palmer Peers Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False widespread rumor that Jayne Mansfield had a breast-flashing 6 [' widespread', ' rumor', ' that', ' Jay', 'ne', ' Mans', 'field']
+3382 786 Name of mother of x -1 Name of mother of Jayne Mansfield Vera Jeffrey Palmer Peers Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False Goddess: The Jayne Mansfield Story. W H Allen. 6 [' Goddess', ':', ' The', ' Jay', 'ne', ' Mans', 'field']
+3383 786 Name of mother of x -1 Name of mother of Jayne Mansfield Vera Jeffrey Palmer Peers Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False album titled Jayne Mansfield Busts Up Las 5 [' album', ' titled', ' Jay', 'ne', ' Mans', 'field']
+3384 786 Name of mother of x -1 Name of mother of Jayne Mansfield Vera Jeffrey Palmer Peers Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False land on the planet Jayne Mansfield. In the episode of 7 [' land', ' on', ' the', ' planet', ' Jay', 'ne', ' Mans', 'field']
+3385 786 Name of mother of x -1 Name of mother of Jayne Mansfield Vera Jeffrey Palmer Peers Jayne Mansfield "[',' ' the' ' actress' ',' ' who' ' died' ' in' ' 1967' '.' '\n' '\n'
+ 'The' ' first' ' time' ' I' ' saw' ' her' ',' ' I' ' was']" ", the actress , who died in 1967 .
+
+ The first time I saw her , I was" False Benny on first violin, Jayne Mansfield on violin, Dick Powell 8 [' Benny', ' on', ' first', ' violin', ',', ' Jay', 'ne', ' Mans', 'field']
+3386 787 Name of mother of x -1 Name of mother of Genghis Khan Hoelun Genghis Khan "[',' ' the' ' Mongol' 'ian' ' ruler' ' who' ' conquered' ' the' ' world'
+ '.' '\n' '\n' 'The' ' Mongol' 'ian' ' people' ' are' ' the'
+ ' descendants' ' of']" ", the Mongol ian ruler who conquered the world .
+
+ The Mongol ian people are the descendants of" False Jin and Song. Genghis Khan had died in 7 [' Jin', ' and', ' Song', '.', ' Gen', 'gh', 'is', ' Khan']
+3387 787 Name of mother of x -1 Name of mother of Genghis Khan Hoelun Genghis Khan "[',' ' the' ' Mongol' 'ian' ' ruler' ' who' ' conquered' ' the' ' world'
+ '.' '\n' '\n' 'The' ' Mongol' 'ian' ' people' ' are' ' the'
+ ' descendants' ' of']" ", the Mongol ian ruler who conquered the world .
+
+ The Mongol ian people are the descendants of" False Mongol leader Genghis Khan, once a vassal 5 [' Mongol', ' leader', ' Gen', 'gh', 'is', ' Khan']
+3388 787 Name of mother of x -1 Name of mother of Genghis Khan Hoelun Genghis Khan "[',' ' the' ' Mongol' 'ian' ' ruler' ' who' ' conquered' ' the' ' world'
+ '.' '\n' '\n' 'The' ' Mongol' 'ian' ' people' ' are' ' the'
+ ' descendants' ' of']" ", the Mongol ian ruler who conquered the world .
+
+ The Mongol ian people are the descendants of" False campaign led by Genghis Khan was launched. His 6 [' campaign', ' led', ' by', ' Gen', 'gh', 'is', ' Khan']
+3389 787 Name of mother of x -1 Name of mother of Genghis Khan Hoelun Genghis Khan "[',' ' the' ' Mongol' 'ian' ' ruler' ' who' ' conquered' ' the' ' world'
+ '.' '\n' '\n' 'The' ' Mongol' 'ian' ' people' ' are' ' the'
+ ' descendants' ' of']" ", the Mongol ian ruler who conquered the world .
+
+ The Mongol ian people are the descendants of" False Jin and Song. Genghis Khan had died in 1227 during 7 [' Jin', ' and', ' Song', '.', ' Gen', 'gh', 'is', ' Khan']
+3390 787 Name of mother of x -1 Name of mother of Genghis Khan Hoelun Genghis Khan "[',' ' the' ' Mongol' 'ian' ' ruler' ' who' ' conquered' ' the' ' world'
+ '.' '\n' '\n' 'The' ' Mongol' 'ian' ' people' ' are' ' the'
+ ' descendants' ' of']" ", the Mongol ian ruler who conquered the world .
+
+ The Mongol ian people are the descendants of" False " Siberia: ""While Genghis Khan was holding an" 7 "[' Siberia', ':', ' ""', 'While', ' Gen', 'gh', 'is', ' Khan']"
+3391 788 Name of mother of x -1 Name of mother of Trajan Marcia Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False largest expanse under Trajan (reigned 98 – 5 [' largest', ' exp', 'anse', ' under', ' Tra', 'jan']
+3392 788 Name of mother of x -1 Name of mother of Trajan Marcia Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False leading French ship Trajan but her fire had 4 [' leading', ' French', ' ship', ' Tra', 'jan']
+3393 788 Name of mother of x -1 Name of mother of Trajan Marcia Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False (96 – 98) and Trajan (98 – 117). Unfortunately, 7 [' (', '96', ' –', ' 98', ')', ' and', ' Tra', 'jan']
+3394 788 Name of mother of x -1 Name of mother of Trajan Marcia Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False 198 AD, failing as Trajan once did to 6 [' 198', ' AD', ',', ' failing', ' as', ' Tra', 'jan']
+3395 788 Name of mother of x -1 Name of mother of Trajan Marcia Trajan "[',' ' the' ' Roman' ' emperor' ',' ' and' ' the' '\n' '\n' 'The' ' Roman'
+ ' Empire' ' was' ' the' ' largest' ' and' ' most' ' powerful' ' empire'
+ ' in']" ", the Roman emperor , and the
+
+ The Roman Empire was the largest and most powerful empire in" False Bythinia, to the emperor Trajan describes his persecution 8 [' By', 'thin', 'ia', ',', ' to', ' the', ' emperor', ' Tra', 'jan']
+3396 789 Name of mother of x -1 Name of mother of Sylvia Plath Aurelia Plath Sylvia Plath "[',' ' who' ' was' ' a' ' poet' ',' ' and' ' a' ' poet' 'ess' ',' ' and'
+ ' a' ' poet' 'ess' ',' ' and' ' a' ' poet' 'ess']" , who was a poet , and a poet ess , and a poet ess , and a poet ess False interest in poetry, with Sylvia Plath being her favourite. 7 [' interest', ' in', ' poetry', ',', ' with', ' Sylvia', ' Pl', 'ath']
+3397 789 Name of mother of x -1 Name of mother of Sylvia Plath Aurelia Plath Sylvia Plath "[',' ' who' ' was' ' a' ' poet' ',' ' and' ' a' ' poet' 'ess' ',' ' and'
+ ' a' ' poet' 'ess' ',' ' and' ' a' ' poet' 'ess']" , who was a poet , and a poet ess , and a poet ess , and a poet ess False writings of Sylvia Plath at the time. Commenting 4 [' writings', ' of', ' Sylvia', ' Pl', 'ath']
+3398 789 Name of mother of x -1 Name of mother of Sylvia Plath Aurelia Plath Sylvia Plath "[',' ' who' ' was' ' a' ' poet' ',' ' and' ' a' ' poet' 'ess' ',' ' and'
+ ' a' ' poet' 'ess' ',' ' and' ' a' ' poet' 'ess']" , who was a poet , and a poet ess , and a poet ess , and a poet ess False " belonged in"" Becoming Sylvia Plath "". Later tracks" 7 "[' belonged', ' in', '""', ' Bec', 'oming', ' Sylvia', ' Pl', 'ath']"
+3399 789 Name of mother of x -1 Name of mother of Sylvia Plath Aurelia Plath Sylvia Plath "[',' ' who' ' was' ' a' ' poet' ',' ' and' ' a' ' poet' 'ess' ',' ' and'
+ ' a' ' poet' 'ess' ',' ' and' ' a' ' poet' 'ess']" , who was a poet , and a poet ess , and a poet ess , and a poet ess False poetry, with Sylvia Plath being her favourite. 5 [' poetry', ',', ' with', ' Sylvia', ' Pl', 'ath']
+3400 789 Name of mother of x -1 Name of mother of Sylvia Plath Aurelia Plath Sylvia Plath "[',' ' who' ' was' ' a' ' poet' ',' ' and' ' a' ' poet' 'ess' ',' ' and'
+ ' a' ' poet' 'ess' ',' ' and' ' a' ' poet' 'ess']" , who was a poet , and a poet ess , and a poet ess , and a poet ess False typically read works by Sylvia Plath that would make her 6 [' typically', ' read', ' works', ' by', ' Sylvia', ' Pl', 'ath']
+3401 792 Name of mother of x -1 Name of mother of André Previn Lotte Priwin André Previn "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Wood's death, André Previn recounted a story 7 "[' Wood', ""'s"", ' death', ',', ' And', 'ré', ' Pre', 'vin']"
+3402 792 Name of mother of x -1 Name of mother of André Previn Lotte Priwin André Previn "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False after Wood's death, André Previn recounted a story 8 "[' after', ' Wood', ""'s"", ' death', ',', ' And', 'ré', ' Pre', 'vin']"
+3403 792 Name of mother of x -1 Name of mother of André Previn Lotte Priwin André Previn "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Wood's death, André Previn recounted a story 7 "[' Wood', ""'s"", ' death', ',', ' And', 'ré', ' Pre', 'vin']"
+3404 792 Name of mother of x -1 Name of mother of André Previn Lotte Priwin André Previn "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False recorded with pianist André Previn and with members 7 [' recorded', ' with', ' pian', 'ist', ' And', 'ré', ' Pre', 'vin']
+3405 792 Name of mother of x -1 Name of mother of André Previn Lotte Priwin André Previn "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Wood's death, André Previn recounted a story 7 "[' Wood', ""'s"", ' death', ',', ' And', 'ré', ' Pre', 'vin']"
+3406 794 Name of mother of x -1 Name of mother of Francis I of France Louise of Savoy Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False service of King Francis I of France – was the first 6 [' service', ' of', ' King', ' Francis', ' I', ' of', ' France']
+3407 794 Name of mother of x -1 Name of mother of Francis I of France Louise of Savoy Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False alliance between King Francis I of France and Pope Leo 6 [' alliance', ' between', ' King', ' Francis', ' I', ' of', ' France']
+3408 794 Name of mother of x -1 Name of mother of Francis I of France Louise of Savoy Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False " VIII wrote to Francis I of France that ""Divine Providence" 6 [' VIII', ' wrote', ' to', ' Francis', ' I', ' of', ' France']
+3409 794 Name of mother of x -1 Name of mother of Francis I of France Louise of Savoy Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False vying with King Francis I of France to bribe the most 6 [' vying', ' with', ' King', ' Francis', ' I', ' of', ' France']
+3410 794 Name of mother of x -1 Name of mother of Francis I of France Louise of Savoy Francis I of France "[',' ' and' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ' ' ']" ", and the
+ " False service of King Francis I of France – was the first 6 [' service', ' of', ' King', ' Francis', ' I', ' of', ' France']
+3411 795 Name of mother of x -1 Name of mother of Robert F. Kennedy Rose Kennedy Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False McCarthy's casket. Robert F. Kennedy quietly attended 8 "[' McCarthy', ""'s"", ' c', 'asket', '.', ' Robert', ' F', '.', ' Kennedy']"
+3412 795 Name of mother of x -1 Name of mother of Robert F. Kennedy Rose Kennedy Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False Luther King, Jr. and Robert F. Kennedy on the latter station. 9 [' Luther', ' King', ',', ' Jr', '.', ' and', ' Robert', ' F', '.', ' Kennedy']
+3413 795 Name of mother of x -1 Name of mother of Robert F. Kennedy Rose Kennedy Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False Attorney General Robert F. Kennedy condemned the 5 [' Attorney', ' General', ' Robert', ' F', '.', ' Kennedy']
+3414 795 Name of mother of x -1 Name of mother of Robert F. Kennedy Rose Kennedy Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False " identified Robert F. Kennedy as her hero.
+" 4 [' identified', ' Robert', ' F', '.', ' Kennedy']
+3415 795 Name of mother of x -1 Name of mother of Robert F. Kennedy Rose Kennedy Robert F. Kennedy "[',' ' Jr' '.' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister']" ", Jr .
+
+ I am a mother of two , a wife , a daughter , a sister" False counsel and 27-year-old Robert F. Kennedy as an assistant 10 [' counsel', ' and', ' 27', '-', 'year', '-', 'old', ' Robert', ' F', '.', ' Kennedy']
+3416 796 Name of mother of x -1 Name of mother of Louis XVI of France Marie Josèphe of Saxony Louis XVI of France "[',' ' and' ' the' '\n' '\n' 'The' ' French' ' Revolution' '.' '\n' '\n'
+ 'The' ' French' ' Revolution' ' was' ' a' ' revolution' ' in' ' the'
+ ' history']" ", and the
+
+ The French Revolution .
+
+ The French Revolution was a revolution in the history" False project, but King Louis XVI of France was personally interested, 7 [' project', ',', ' but', ' King', ' Louis', ' XVI', ' of', ' France']
+3417 796 Name of mother of x -1 Name of mother of Louis XVI of France Marie Josèphe of Saxony Louis XVI of France "[',' ' and' ' the' '\n' '\n' 'The' ' French' ' Revolution' '.' '\n' '\n'
+ 'The' ' French' ' Revolution' ' was' ' a' ' revolution' ' in' ' the'
+ ' history']" ", and the
+
+ The French Revolution .
+
+ The French Revolution was a revolution in the history" False project, but King Louis XVI of France was personally interested, 7 [' project', ',', ' but', ' King', ' Louis', ' XVI', ' of', ' France']
+3418 797 Name of mother of x -1 Name of mother of Edward Gibbon Judith Porten Edward Gibbon "[',' ' the' ' author' ' of' ' the' ' _' 'Decl' 'ine' ' and' ' Fall' ' of'
+ ' the' ' Roman' ' Empire' '_' ',' ' and' ' the' '\n' '\n']" ", the author of the _ Decl ine and Fall of the Roman Empire _ , and the
+
+" False " historian Edward Gibbon for ""so frequently" 3 [' historian', ' Edward', ' Gib', 'bon']
+3419 797 Name of mother of x -1 Name of mother of Edward Gibbon Judith Porten Edward Gibbon "[',' ' the' ' author' ' of' ' the' ' _' 'Decl' 'ine' ' and' ' Fall' ' of'
+ ' the' ' Roman' ' Empire' '_' ',' ' and' ' the' '\n' '\n']" ", the author of the _ Decl ine and Fall of the Roman Empire _ , and the
+
+" False the Roman Empire, Edward Gibbon notes that the 6 [' the', ' Roman', ' Empire', ',', ' Edward', ' Gib', 'bon']
+3420 797 Name of mother of x -1 Name of mother of Edward Gibbon Judith Porten Edward Gibbon "[',' ' the' ' author' ' of' ' the' ' _' 'Decl' 'ine' ' and' ' Fall' ' of'
+ ' the' ' Roman' ' Empire' '_' ',' ' and' ' the' '\n' '\n']" ", the author of the _ Decl ine and Fall of the Roman Empire _ , and the
+
+" False Roman Empire by Edward Gibbon (1737 – 94) further 5 [' Roman', ' Empire', ' by', ' Edward', ' Gib', 'bon']
+3421 797 Name of mother of x -1 Name of mother of Edward Gibbon Judith Porten Edward Gibbon "[',' ' the' ' author' ' of' ' the' ' _' 'Decl' 'ine' ' and' ' Fall' ' of'
+ ' the' ' Roman' ' Empire' '_' ',' ' and' ' the' '\n' '\n']" ", the author of the _ Decl ine and Fall of the Roman Empire _ , and the
+
+" False and, most famously, Edward Gibbon questioned traditional 7 [' and', ',', ' most', ' famously', ',', ' Edward', ' Gib', 'bon']
+3422 797 Name of mother of x -1 Name of mother of Edward Gibbon Judith Porten Edward Gibbon "[',' ' the' ' author' ' of' ' the' ' _' 'Decl' 'ine' ' and' ' Fall' ' of'
+ ' the' ' Roman' ' Empire' '_' ',' ' and' ' the' '\n' '\n']" ", the author of the _ Decl ine and Fall of the Roman Empire _ , and the
+
+" False 2 ['Edward', ' Gib', 'bon']
+3423 798 Name of mother of x -1 Name of mother of Alexander Hamilton Rachel Faucitt Lavien Alexander Hamilton "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Alexander']" ", the first president of the United States .
+
+ The first president of the United States , Alexander" False statue honoring Alexander Hamilton in Chicago was 3 [' statue', ' honoring', ' Alexander', ' Hamilton']
+3424 798 Name of mother of x -1 Name of mother of Alexander Hamilton Rachel Faucitt Lavien Alexander Hamilton "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Alexander']" ", the first president of the United States .
+
+ The first president of the United States , Alexander" False " individual freedom."" As Alexander Hamilton wrote in Federalist" 5 "[' individual', ' freedom', '.""', ' As', ' Alexander', ' Hamilton']"
+3425 798 Name of mother of x -1 Name of mother of Alexander Hamilton Rachel Faucitt Lavien Alexander Hamilton "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Alexander']" ", the first president of the United States .
+
+ The first president of the United States , Alexander" False independence, Alexander Hamilton cited Blackstone 3 [' independence', ',', ' Alexander', ' Hamilton']
+3426 798 Name of mother of x -1 Name of mother of Alexander Hamilton Rachel Faucitt Lavien Alexander Hamilton "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Alexander']" ", the first president of the United States .
+
+ The first president of the United States , Alexander" False a dissertation on Alexander Hamilton and constitutional 4 [' a', ' dissertation', ' on', ' Alexander', ' Hamilton']
+3427 798 Name of mother of x -1 Name of mother of Alexander Hamilton Rachel Faucitt Lavien Alexander Hamilton "[',' ' the' ' first' ' president' ' of' ' the' ' United' ' States' '.'
+ '\n' '\n' 'The' ' first' ' president' ' of' ' the' ' United' ' States'
+ ',' ' Alexander']" ", the first president of the United States .
+
+ The first president of the United States , Alexander" False of Confederation. Alexander Hamilton also offered a 4 [' of', ' Confederation', '.', ' Alexander', ' Hamilton']
+3428 799 Name of mother of x -1 Name of mother of Elizabeth Warren Polly L. Herring (Reed) Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' slave' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' slave' ' owner' '.' '\n' '\n' 'The' ' first']" ", the daughter of a former slave , and the daughter of a slave owner .
+
+ The first" False that listed Elizabeth Warren ’ s great-great-great 3 [' that', ' listed', ' Elizabeth', ' Warren']
+3429 799 Name of mother of x -1 Name of mother of Elizabeth Warren Polly L. Herring (Reed) Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' slave' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' slave' ' owner' '.' '\n' '\n' 'The' ' first']" ", the daughter of a former slave , and the daughter of a slave owner .
+
+ The first" False application that listed Elizabeth Warren ’ s great-great-great 4 [' application', ' that', ' listed', ' Elizabeth', ' Warren']
+3430 799 Name of mother of x -1 Name of mother of Elizabeth Warren Polly L. Herring (Reed) Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' slave' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' slave' ' owner' '.' '\n' '\n' 'The' ' first']" ", the daughter of a former slave , and the daughter of a slave owner .
+
+ The first" False application that listed Elizabeth Warren ’ s great-great-great 4 [' application', ' that', ' listed', ' Elizabeth', ' Warren']
+3431 799 Name of mother of x -1 Name of mother of Elizabeth Warren Polly L. Herring (Reed) Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' slave' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' slave' ' owner' '.' '\n' '\n' 'The' ' first']" ", the daughter of a former slave , and the daughter of a slave owner .
+
+ The first" False Barack Obama, Elizabeth Warren announced that 4 [' Barack', ' Obama', ',', ' Elizabeth', ' Warren']
+3432 799 Name of mother of x -1 Name of mother of Elizabeth Warren Polly L. Herring (Reed) Elizabeth Warren "[',' ' the' ' daughter' ' of' ' a' ' former' ' slave' ',' ' and' ' the'
+ ' daughter' ' of' ' a' ' slave' ' owner' '.' '\n' '\n' 'The' ' first']" ", the daughter of a former slave , and the daughter of a slave owner .
+
+ The first" False " Elizabeth Warren =
+" 1 [' Elizabeth', ' Warren']
+3433 801 Name of mother of x -1 Name of mother of Philip IV of Spain Margaret of Austria, Queen of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' Queen' ' of' ' Spain' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the Queen of Spain ,
+" False December 15, 1637, Philip IV of Spain imposed an 9 [' December', ' 15', ',', ' 16', '37', ',', ' Philip', ' IV', ' of', ' Spain']
+3434 801 Name of mother of x -1 Name of mother of Philip IV of Spain Margaret of Austria, Queen of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' Queen' ' of' ' Spain' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the Queen of Spain ,
+" False Between 1636 and 1637, Philip IV of Spain imposed a tax 10 [' Between', ' 16', '36', ' and', ' 16', '37', ',', ' Philip', ' IV', ' of', ' Spain']
+3435 801 Name of mother of x -1 Name of mother of Philip IV of Spain Margaret of Austria, Queen of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' Queen' ' of' ' Spain' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the Queen of Spain ,
+" False Between 1636 and 1637, Philip IV of Spain imposed a tax 10 [' Between', ' 16', '36', ' and', ' 16', '37', ',', ' Philip', ' IV', ' of', ' Spain']
+3436 801 Name of mother of x -1 Name of mother of Philip IV of Spain Margaret of Austria, Queen of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' Queen' ' of' ' Spain' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the Queen of Spain ,
+" False 1636 and 1637, Philip IV of Spain imposed a tax 9 [' 16', '36', ' and', ' 16', '37', ',', ' Philip', ' IV', ' of', ' Spain']
+3437 801 Name of mother of x -1 Name of mother of Philip IV of Spain Margaret of Austria, Queen of Spain Philip IV of Spain "[',' ' and' ' the' '\n' '\n' '|' '\n' '\n' 'The' ' King' ' of' ' Spain'
+ ',' ' and' ' the' ' Queen' ' of' ' Spain' ',' '\n']" ", and the
+
+ |
+
+ The King of Spain , and the Queen of Spain ,
+" False Between 1636 and 1637, Philip IV of Spain imposed a tax which 10 [' Between', ' 16', '36', ' and', ' 16', '37', ',', ' Philip', ' IV', ' of', ' Spain']
+3438 802 Name of mother of x -1 Name of mother of Leo X Clarice Orsini Leo X "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False to Raphael's Pope Leo X with Cardinals 5 "[' to', ' Raphael', ""'s"", ' Pope', ' Leo', ' X']"
+3439 802 Name of mother of x -1 Name of mother of Leo X Clarice Orsini Leo X "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False proclaimed by Pope Leo X against the Ottomans 4 [' proclaimed', ' by', ' Pope', ' Leo', ' X']
+3440 802 Name of mother of x -1 Name of mother of Leo X Clarice Orsini Leo X "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False the Medici popes, Leo X and Clement VII. 7 [' the', ' Medic', 'i', ' pop', 'es', ',', ' Leo', ' X']
+3441 802 Name of mother of x -1 Name of mother of Leo X Clarice Orsini Leo X "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False British Crown. Pope Leo X first granted the 5 [' British', ' Crown', '.', ' Pope', ' Leo', ' X']
+3442 802 Name of mother of x -1 Name of mother of Leo X Clarice Orsini Leo X "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False letter from Pope Leo X to Henry VIII of England 4 [' letter', ' from', ' Pope', ' Leo', ' X']
+3443 803 Name of mother of x -1 Name of mother of Marianne Faithfull Eva von Sacher-Masoch Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Memory Remains"" with Marianne Faithfull on NBC's Saturday" 8 "[' Memory', ' Rem', 'ains', '""', ' with', ' Marian', 'ne', ' Faith', 'full']"
+3444 803 Name of mother of x -1 Name of mother of Marianne Faithfull Eva von Sacher-Masoch Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 28, Mick Jagger and Marianne Faithfull were arrested at their 9 [' 28', ',', ' Mick', ' J', 'agger', ' and', ' Marian', 'ne', ' Faith', 'full']
+3445 803 Name of mother of x -1 Name of mother of Marianne Faithfull Eva von Sacher-Masoch Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Joan Baez, Marianne Faithfull and Bob Neuwirth have 7 [' Joan', ' B', 'aez', ',', ' Marian', 'ne', ' Faith', 'full']
+3446 803 Name of mother of x -1 Name of mother of Marianne Faithfull Eva von Sacher-Masoch Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " Williamson as Hamlet and Marianne Faithfull as Ophelia.
+" 8 [' Williamson', ' as', ' Ham', 'let', ' and', ' Marian', 'ne', ' Faith', 'full']
+3447 803 Name of mother of x -1 Name of mother of Marianne Faithfull Eva von Sacher-Masoch Marianne Faithfull "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False " as Hamlet and Marianne Faithfull as Ophelia.
+" 7 [' as', ' Ham', 'let', ' and', ' Marian', 'ne', ' Faith', 'full']
+3448 804 Name of mother of x -1 Name of mother of Jacob Burckhardt Susanna Maria Burckhardt-Schorndorff Jacob Burckhardt "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' Bur' 'ck' 'hardt' ','
+ ' who' ' was' ' a' '\n' ' ' ' ' ' ' ' ']" ", the son of the late Mr . Bur ck hardt , who was a
+ " False The Swiss historian Jacob Burckhardt (1818 – 1897) 6 [' The', ' Swiss', ' historian', ' Jacob', ' Bur', 'ck', 'hardt']
+3449 804 Name of mother of x -1 Name of mother of Jacob Burckhardt Susanna Maria Burckhardt-Schorndorff Jacob Burckhardt "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' Bur' 'ck' 'hardt' ','
+ ' who' ' was' ' a' '\n' ' ' ' ' ' ' ' ']" ", the son of the late Mr . Bur ck hardt , who was a
+ " False cultural historian Jacob Burckhardt in the mid-nineteenth 5 [' cultural', ' historian', ' Jacob', ' Bur', 'ck', 'hardt']
+3450 804 Name of mother of x -1 Name of mother of Jacob Burckhardt Susanna Maria Burckhardt-Schorndorff Jacob Burckhardt "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' Bur' 'ck' 'hardt' ','
+ ' who' ' was' ' a' '\n' ' ' ' ' ' ' ' ']" ", the son of the late Mr . Bur ck hardt , who was a
+ " False Swiss historian Jacob Burckhardt (1818 – 1897) in 5 [' Swiss', ' historian', ' Jacob', ' Bur', 'ck', 'hardt']
+3451 804 Name of mother of x -1 Name of mother of Jacob Burckhardt Susanna Maria Burckhardt-Schorndorff Jacob Burckhardt "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' Bur' 'ck' 'hardt' ','
+ ' who' ' was' ' a' '\n' ' ' ' ' ' ' ' ']" ", the son of the late Mr . Bur ck hardt , who was a
+ " False The Swiss historian Jacob Burckhardt (1818 – 1897) in his 6 [' The', ' Swiss', ' historian', ' Jacob', ' Bur', 'ck', 'hardt']
+3452 804 Name of mother of x -1 Name of mother of Jacob Burckhardt Susanna Maria Burckhardt-Schorndorff Jacob Burckhardt "[',' ' the' ' son' ' of' ' the' ' late' ' Mr' '.' ' Bur' 'ck' 'hardt' ','
+ ' who' ' was' ' a' '\n' ' ' ' ' ' ' ' ']" ", the son of the late Mr . Bur ck hardt , who was a
+ " False cultural historian Jacob Burckhardt in the mid-nineteenth 5 [' cultural', ' historian', ' Jacob', ' Bur', 'ck', 'hardt']
+3453 805 Name of mother of x -1 Name of mother of John Huston Rhea Gore John Huston "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ' in' ' the' ' United'
+ ' States' ',' ' and' ' a' ' member' ' of' ' the' ' New' ' York' ' Stock']" , the son of a wealthy family in the United States , and a member of the New York Stock False mid-1950s director John Huston came to town with 7 [' mid', '-', '1950', 's', ' director', ' John', ' Hust', 'on']
+3454 805 Name of mother of x -1 Name of mother of John Huston Rhea Gore John Huston "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ' in' ' the' ' United'
+ ' States' ',' ' and' ' a' ' member' ' of' ' the' ' New' ' York' ' Stock']" , the son of a wealthy family in the United States , and a member of the New York Stock False with film director John Huston while making 5 [' with', ' film', ' director', ' John', ' Hust', 'on']
+3455 805 Name of mother of x -1 Name of mother of John Huston Rhea Gore John Huston "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ' in' ' the' ' United'
+ ' States' ',' ' and' ' a' ' member' ' of' ' the' ' New' ' York' ' Stock']" , the son of a wealthy family in the United States , and a member of the New York Stock False with film director John Huston while making In This 5 [' with', ' film', ' director', ' John', ' Hust', 'on']
+3456 805 Name of mother of x -1 Name of mother of John Huston Rhea Gore John Huston "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ' in' ' the' ' United'
+ ' States' ',' ' and' ' a' ' member' ' of' ' the' ' New' ' York' ' Stock']" , the son of a wealthy family in the United States , and a member of the New York Stock False Paul Henreid and John Huston joined other members 7 [' Paul', ' Hen', 're', 'id', ' and', ' John', ' Hust', 'on']
+3457 805 Name of mother of x -1 Name of mother of John Huston Rhea Gore John Huston "[',' ' the' ' son' ' of' ' a' ' wealthy' ' family' ' in' ' the' ' United'
+ ' States' ',' ' and' ' a' ' member' ' of' ' the' ' New' ' York' ' Stock']" , the son of a wealthy family in the United States , and a member of the New York Stock False directed by John Huston and adapted 4 [' directed', ' by', ' John', ' Hust', 'on']
+3458 806 Name of mother of x -1 Name of mother of Herbert Hoover Hulda Randall Minthorn Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the first lady of the United States , and the False Clewiston, but the Herbert Hoover Dike remained intact, 7 [' Cle', 'w', 'iston', ',', ' but', ' the', ' Herbert', ' Hoover']
+3459 806 Name of mother of x -1 Name of mother of Herbert Hoover Hulda Randall Minthorn Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the first lady of the United States , and the False area, including the Herbert Hoover Building, 5 [' area', ',', ' including', ' the', ' Herbert', ' Hoover']
+3460 806 Name of mother of x -1 Name of mother of Herbert Hoover Hulda Randall Minthorn Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the first lady of the United States , and the False Clewiston, but the Herbert Hoover Dike remained intact, 7 [' Cle', 'w', 'iston', ',', ' but', ' the', ' Herbert', ' Hoover']
+3461 806 Name of mother of x -1 Name of mother of Herbert Hoover Hulda Randall Minthorn Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the first lady of the United States , and the False with Presidents Herbert Hoover and Franklin Roosevelt, 3 [' with', ' Presidents', ' Herbert', ' Hoover']
+3462 806 Name of mother of x -1 Name of mother of Herbert Hoover Hulda Randall Minthorn Herbert Hoover "[',' ' the' ' former' ' president' ' of' ' the' ' United' ' States' ','
+ ' and' ' the' ' first' ' lady' ' of' ' the' ' United' ' States' ','
+ ' and' ' the']" , the former president of the United States , and the first lady of the United States , and the False President-elect Herbert Hoover on the Pacific 4 [' President', '-', 'elect', ' Herbert', ' Hoover']
+3463 807 Name of mother of x -1 Name of mother of Dean Martin Angela Barra Dean Martin "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False album, as did Dean Martin on his 1970 album 5 [' album', ',', ' as', ' did', ' Dean', ' Martin']
+3464 807 Name of mother of x -1 Name of mother of Dean Martin Angela Barra Dean Martin "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False 1 ['Dean', ' Martin']
+3465 807 Name of mother of x -1 Name of mother of Dean Martin Angela Barra Dean Martin "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False comedy partners Dean Martin and Jerry 3 [' comedy', ' partners', ' Dean', ' Martin']
+3466 807 Name of mother of x -1 Name of mother of Dean Martin Angela Barra Dean Martin "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False Frank Sinatra and Dean Martin often visited 5 [' Frank', ' Sin', 'atra', ' and', ' Dean', ' Martin']
+3467 807 Name of mother of x -1 Name of mother of Dean Martin Angela Barra Dean Martin "[',' ' the' ' father' ' of' ' the' ' bride' ',' ' and' ' the' ' father'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' wedding' ' was' ' held']" ", the father of the bride , and the father of the groom .
+
+ The wedding was held" False 1 ['Dean', ' Martin']
+3468 808 Name of mother of x -1 Name of mother of Antoine Lavoisier Emilie Punctis Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False decades after Antoine Lavoisier developed the first 6 [' decades', ' after', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+3469 808 Name of mother of x -1 Name of mother of Antoine Lavoisier Emilie Punctis Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False formulated by Antoine Lavoisier in 1789, which 6 [' formulated', ' by', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+3470 808 Name of mother of x -1 Name of mother of Antoine Lavoisier Emilie Punctis Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False the decades after Antoine Lavoisier developed the 7 [' the', ' decades', ' after', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+3471 808 Name of mother of x -1 Name of mother of Antoine Lavoisier Emilie Punctis Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False then, and thus Antoine Lavoisier did include 8 [' then', ',', ' and', ' thus', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+3472 808 Name of mother of x -1 Name of mother of Antoine Lavoisier Emilie Punctis Antoine Lavoisier "[',' ' the' ' French' ' chemist' ' who' ' discovered' ' oxygen' ',' ' and'
+ ' the' ' father' ' of' ' modern' ' chemistry' '.' '\n' '\n' 'The'
+ ' first' ' of']" ", the French chemist who discovered oxygen , and the father of modern chemistry .
+
+ The first of" False be carbon. In 1772, Antoine Lavoisier showed that diamonds 11 [' be', ' carbon', '.', ' In', ' 17', '72', ',', ' Ant', 'oine', ' Lav', 'ois', 'ier']
+3473 809 Name of mother of x -1 Name of mother of Artemisia Gentileschi Prudenzia di Ottaviano Montoni Artemisia Gentileschi "[',' ' the' ' painter' ',' ' was' ' born' ' in' ' 15' '93' ' in'
+ ' Florence' ',' ' Italy' '.' ' She' ' was' ' the' ' daughter' ' of' ' O']" , the painter , was born in 15 93 in Florence , Italy . She was the daughter of O False " paintings by Artemisia Gentileschi (17th century).
+" 6 [' paintings', ' by', ' Artemis', 'ia', ' Gent', 'iles', 'chi']
+3474 810 Name of mother of x -1 Name of mother of Alicia Vikander Maria Fahl Vikander Alicia Vikander "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False the film; Alicia Vikander was discussed as 5 [' the', ' film', ';', ' Alicia', ' Vik', 'ander']
+3475 810 Name of mother of x -1 Name of mother of Alicia Vikander Maria Fahl Vikander Alicia Vikander "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False return for the film; Alicia Vikander was discussed 7 [' return', ' for', ' the', ' film', ';', ' Alicia', ' Vik', 'ander']
+3476 810 Name of mother of x -1 Name of mother of Alicia Vikander Maria Fahl Vikander Alicia Vikander "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False April 2016, Alicia Vikander was cast as 5 [' April', ' 2016', ',', ' Alicia', ' Vik', 'ander']
+3477 810 Name of mother of x -1 Name of mother of Alicia Vikander Maria Fahl Vikander Alicia Vikander "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False return for the film; Alicia Vikander was discussed 7 [' return', ' for', ' the', ' film', ';', ' Alicia', ' Vik', 'ander']
+3478 810 Name of mother of x -1 Name of mother of Alicia Vikander Maria Fahl Vikander Alicia Vikander "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' movie']" ", who is a very good friend of mine .
+
+ I am a big fan of the movie" False return for the film; Alicia Vikander was discussed 7 [' return', ' for', ' the', ' film', ';', ' Alicia', ' Vik', 'ander']
+3479 811 Name of mother of x -1 Name of mother of Simón Bolívar María Concepción Palacios Aguirre Ariztía - Sojo Blanco Herrera Simón Bolívar "[',' ' the' ' Liber' 'ator' ' of' ' Colombia' ',' ' and' ' the' ' Liber'
+ 'ator' ' of' ' Peru' ',' ' and' ' the' ' Liber' 'ator' ' of' ' Bolivia']" , the Liber ator of Colombia , and the Liber ator of Peru , and the Liber ator of Bolivia False rest of Venezuela, Simón Bolívar is considered 8 [' rest', ' of', ' Venezuela', ',', ' Sim', 'ón', ' Bol', 'í', 'var']
+3480 811 Name of mother of x -1 Name of mother of Simón Bolívar María Concepción Palacios Aguirre Ariztía - Sojo Blanco Herrera Simón Bolívar "[',' ' the' ' Liber' 'ator' ' of' ' Colombia' ',' ' and' ' the' ' Liber'
+ 'ator' ' of' ' Peru' ',' ' and' ' the' ' Liber' 'ator' ' of' ' Bolivia']" , the Liber ator of Colombia , and the Liber ator of Peru , and the Liber ator of Bolivia False long term lover of Simón Bolívar and acted as his spy 8 [' long', ' term', ' lover', ' of', ' Sim', 'ón', ' Bol', 'í', 'var']
+3481 811 Name of mother of x -1 Name of mother of Simón Bolívar María Concepción Palacios Aguirre Ariztía - Sojo Blanco Herrera Simón Bolívar "[',' ' the' ' Liber' 'ator' ' of' ' Colombia' ',' ' and' ' the' ' Liber'
+ 'ator' ' of' ' Peru' ',' ' and' ' the' ' Liber' 'ator' ' of' ' Bolivia']" , the Liber ator of Colombia , and the Liber ator of Peru , and the Liber ator of Bolivia False San Martín met with Simón Bolívar in Guayaquil on 9 [' San', ' Mart', 'ín', ' met', ' with', ' Sim', 'ón', ' Bol', 'í', 'var']
+3482 811 Name of mother of x -1 Name of mother of Simón Bolívar María Concepción Palacios Aguirre Ariztía - Sojo Blanco Herrera Simón Bolívar "[',' ' the' ' Liber' 'ator' ' of' ' Colombia' ',' ' and' ' the' ' Liber'
+ 'ator' ' of' ' Peru' ',' ' and' ' the' ' Liber' 'ator' ' of' ' Bolivia']" , the Liber ator of Colombia , and the Liber ator of Peru , and the Liber ator of Bolivia False he wrote essays on Simón Bolívar and Francisco 8 [' he', ' wrote', ' essays', ' on', ' Sim', 'ón', ' Bol', 'í', 'var']
+3483 811 Name of mother of x -1 Name of mother of Simón Bolívar María Concepción Palacios Aguirre Ariztía - Sojo Blanco Herrera Simón Bolívar "[',' ' the' ' Liber' 'ator' ' of' ' Colombia' ',' ' and' ' the' ' Liber'
+ 'ator' ' of' ' Peru' ',' ' and' ' the' ' Liber' 'ator' ' of' ' Bolivia']" , the Liber ator of Colombia , and the Liber ator of Peru , and the Liber ator of Bolivia False one employed by Simón Bolívar in the North shortly 7 [' one', ' employed', ' by', ' Sim', 'ón', ' Bol', 'í', 'var']
+3484 812 Name of mother of x -1 Name of mother of Józef Piłsudski Maria Piłsudska Józef Piłsudski "[',' ' the' ' Polish' ' general' ' who' ' led' ' the' ' Polish'
+ ' resistance' ' to' ' the' ' Soviet' ' invasion' ' of' ' Poland' ' in'
+ ' 1939' '.' '\n' '\n']" ", the Polish general who led the Polish resistance to the Soviet invasion of Poland in 1939 .
+
+" False Poland's Chief of State Józef Piłsudski sought a union 13 "[' Poland', ""'s"", ' Chief', ' of', ' State', ' J', 'ó', 'z', 'ef', ' Pi', 'ł', 's', 'ud', 'ski']"
+3485 812 Name of mother of x -1 Name of mother of Józef Piłsudski Maria Piłsudska Józef Piłsudski "[',' ' the' ' Polish' ' general' ' who' ' led' ' the' ' Polish'
+ ' resistance' ' to' ' the' ' Soviet' ' invasion' ' of' ' Poland' ' in'
+ ' 1939' '.' '\n' '\n']" ", the Polish general who led the Polish resistance to the Soviet invasion of Poland in 1939 .
+
+" False Poland's Marshal Józef Piłsudski to drive a war-winning 11 "[' Poland', ""'s"", ' Marshal', ' J', 'ó', 'z', 'ef', ' Pi', 'ł', 's', 'ud', 'ski']"
+3486 812 Name of mother of x -1 Name of mother of Józef Piłsudski Maria Piłsudska Józef Piłsudski "[',' ' the' ' Polish' ' general' ' who' ' led' ' the' ' Polish'
+ ' resistance' ' to' ' the' ' Soviet' ' invasion' ' of' ' Poland' ' in'
+ ' 1939' '.' '\n' '\n']" ", the Polish general who led the Polish resistance to the Soviet invasion of Poland in 1939 .
+
+" False " resolution: ""Józef Piłsudski will remain, in our" 11 "[' resolution', ':', ' ""', 'J', 'ó', 'z', 'ef', ' Pi', 'ł', 's', 'ud', 'ski']"
+3487 812 Name of mother of x -1 Name of mother of Józef Piłsudski Maria Piłsudska Józef Piłsudski "[',' ' the' ' Polish' ' general' ' who' ' led' ' the' ' Polish'
+ ' resistance' ' to' ' the' ' Soviet' ' invasion' ' of' ' Poland' ' in'
+ ' 1939' '.' '\n' '\n']" ", the Polish general who led the Polish resistance to the Soviet invasion of Poland in 1939 .
+
+" False Polish statesman Józef Piłsudski to overthrow the existing 11 [' Polish', ' states', 'man', ' J', 'ó', 'z', 'ef', ' Pi', 'ł', 's', 'ud', 'ski']
+3488 812 Name of mother of x -1 Name of mother of Józef Piłsudski Maria Piłsudska Józef Piłsudski "[',' ' the' ' Polish' ' general' ' who' ' led' ' the' ' Polish'
+ ' resistance' ' to' ' the' ' Soviet' ' invasion' ' of' ' Poland' ' in'
+ ' 1939' '.' '\n' '\n']" ", the Polish general who led the Polish resistance to the Soviet invasion of Poland in 1939 .
+
+" False Chief of State Józef Piłsudski sought a union with 11 [' Chief', ' of', ' State', ' J', 'ó', 'z', 'ef', ' Pi', 'ł', 's', 'ud', 'ski']
+3489 814 Name of mother of x -1 Name of mother of Abraham Amasla Abraham "[' Lincoln' ',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list'
+ ' of' ' the' ' names' ' of' ' the' ' children' ' of' ' Abraham'
+ ' Lincoln' ',']" " Lincoln , the
+
+ The following is a list of the names of the children of Abraham Lincoln ," False the Plains of Abraham in 1759. She 3 [' the', ' Plains', ' of', ' Abraham']
+3490 814 Name of mother of x -1 Name of mother of Abraham Amasla Abraham "[' Lincoln' ',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list'
+ ' of' ' the' ' names' ' of' ' the' ' children' ' of' ' Abraham'
+ ' Lincoln' ',']" " Lincoln , the
+
+ The following is a list of the names of the children of Abraham Lincoln ," False slavery. His support of Abraham Lincoln for president 5 [' slavery', '.', ' His', ' support', ' of', ' Abraham']
+3491 814 Name of mother of x -1 Name of mother of Abraham Amasla Abraham "[' Lincoln' ',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list'
+ ' of' ' the' ' names' ' of' ' the' ' children' ' of' ' Abraham'
+ ' Lincoln' ',']" " Lincoln , the
+
+ The following is a list of the names of the children of Abraham Lincoln ," False commission. Abraham de Peyster, the 2 [' commission', '.', ' Abraham']
+3492 814 Name of mother of x -1 Name of mother of Abraham Amasla Abraham "[' Lincoln' ',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list'
+ ' of' ' the' ' names' ' of' ' the' ' children' ' of' ' Abraham'
+ ' Lincoln' ',']" " Lincoln , the
+
+ The following is a list of the names of the children of Abraham Lincoln ," False 1 ['Ab', 'raham']
+3493 814 Name of mother of x -1 Name of mother of Abraham Amasla Abraham "[' Lincoln' ',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list'
+ ' of' ' the' ' names' ' of' ' the' ' children' ' of' ' Abraham'
+ ' Lincoln' ',']" " Lincoln , the
+
+ The following is a list of the names of the children of Abraham Lincoln ," False 1 ['Ab', 'raham']
+3494 815 Name of mother of x -1 Name of mother of Francis Xavier María de Azpilcueta Francis Xavier "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' Xavier' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Francis Xavier , the
+
+ The name of the" False consulted with Dr. Francis Xavier Dercum, a specialist 5 [' consulted', ' with', ' Dr', '.', ' Francis', ' Xavier']
+3495 815 Name of mother of x -1 Name of mother of Francis Xavier María de Azpilcueta Francis Xavier "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' Xavier' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Francis Xavier , the
+
+ The name of the" False Church of St. Francis Xavier in Penang and 5 [' Church', ' of', ' St', '.', ' Francis', ' Xavier']
+3496 815 Name of mother of x -1 Name of mother of Francis Xavier María de Azpilcueta Francis Xavier "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' Xavier' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Francis Xavier , the
+
+ The name of the" False and Saint Francis Xavier, patron against 3 [' and', ' Saint', ' Francis', ' Xavier']
+3497 815 Name of mother of x -1 Name of mother of Francis Xavier María de Azpilcueta Francis Xavier "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' Xavier' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Francis Xavier , the
+
+ The name of the" False Spanish Jesuit Francis Xavier evangelised in India, 3 [' Spanish', ' Jesuit', ' Francis', ' Xavier']
+3498 815 Name of mother of x -1 Name of mother of Francis Xavier María de Azpilcueta Francis Xavier "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' Xavier' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Francis Xavier , the
+
+ The name of the" False Leger of Saint Francis Xavier Catholic Church asserted 5 [' Le', 'ger', ' of', ' Saint', ' Francis', ' Xavier']
+3499 816 Name of mother of x -1 Name of mother of P. G. Wodehouse Eleanor Deane P. G. Wodehouse "[',' ' the' ' author' ' of' ' the' ' J' 'ee' 'ves' ' stories' ',' ' and'
+ ' the' ' creator' ' of' ' Bert' 'ie' ' Wo' 'oster' ',' ' the']" , the author of the J ee ves stories , and the creator of Bert ie Wo oster , the False 6 ['P', '.', ' G', '.', ' W', 'ode', 'house']
+3500 816 Name of mother of x -1 Name of mother of P. G. Wodehouse Eleanor Deane P. G. Wodehouse "[',' ' the' ' author' ' of' ' the' ' J' 'ee' 'ves' ' stories' ',' ' and'
+ ' the' ' creator' ' of' ' Bert' 'ie' ' Wo' 'oster' ',' ' the']" , the author of the J ee ves stories , and the creator of Bert ie Wo oster , the False the stories by P. G. Wodehouse regularly holidays 9 [' the', ' stories', ' by', ' P', '.', ' G', '.', ' W', 'ode', 'house']
+3501 816 Name of mother of x -1 Name of mother of P. G. Wodehouse Eleanor Deane P. G. Wodehouse "[',' ' the' ' author' ' of' ' the' ' J' 'ee' 'ves' ' stories' ',' ' and'
+ ' the' ' creator' ' of' ' Bert' 'ie' ' Wo' 'oster' ',' ' the']" , the author of the J ee ves stories , and the creator of Bert ie Wo oster , the False writers such as P. G. Wodehouse or Agatha Christie, 9 [' writers', ' such', ' as', ' P', '.', ' G', '.', ' W', 'ode', 'house']
+3502 816 Name of mother of x -1 Name of mother of P. G. Wodehouse Eleanor Deane P. G. Wodehouse "[',' ' the' ' author' ' of' ' the' ' J' 'ee' 'ves' ' stories' ',' ' and'
+ ' the' ' creator' ' of' ' Bert' 'ie' ' Wo' 'oster' ',' ' the']" , the author of the J ee ves stories , and the creator of Bert ie Wo oster , the False 6 ['P', '.', ' G', '.', ' W', 'ode', 'house']
+3503 816 Name of mother of x -1 Name of mother of P. G. Wodehouse Eleanor Deane P. G. Wodehouse "[',' ' the' ' author' ' of' ' the' ' J' 'ee' 'ves' ' stories' ',' ' and'
+ ' the' ' creator' ' of' ' Bert' 'ie' ' Wo' 'oster' ',' ' the']" , the author of the J ee ves stories , and the creator of Bert ie Wo oster , the False best-selling writers such as P. G. Wodehouse or Agatha Christie, 12 [' best', '-', 'selling', ' writers', ' such', ' as', ' P', '.', ' G', '.', ' W', 'ode', 'house']
+3504 817 Name of mother of x -1 Name of mother of Clement Attlee Ellen Bravery Watson Clement Attlee "[',' ' the' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n' 'Prime'
+ ' Minister' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the' ' Prime'
+ ' Minister']" ", the Prime Minister , and the
+
+ Prime Minister of the United Kingdom , and the Prime Minister" False Prime Minister Clement Attlee for two hours in 4 [' Prime', ' Minister', ' Clement', ' Att', 'lee']
+3505 817 Name of mother of x -1 Name of mother of Clement Attlee Ellen Bravery Watson Clement Attlee "[',' ' the' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n' 'Prime'
+ ' Minister' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the' ' Prime'
+ ' Minister']" ", the Prime Minister , and the
+
+ Prime Minister of the United Kingdom , and the Prime Minister" False impressed Prime Minister Clement Attlee as Army member 5 [' impressed', ' Prime', ' Minister', ' Clement', ' Att', 'lee']
+3506 817 Name of mother of x -1 Name of mother of Clement Attlee Ellen Bravery Watson Clement Attlee "[',' ' the' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n' 'Prime'
+ ' Minister' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the' ' Prime'
+ ' Minister']" ", the Prime Minister , and the
+
+ Prime Minister of the United Kingdom , and the Prime Minister" False Prime Minister Clement Attlee announced that 4 [' Prime', ' Minister', ' Clement', ' Att', 'lee']
+3507 817 Name of mother of x -1 Name of mother of Clement Attlee Ellen Bravery Watson Clement Attlee "[',' ' the' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n' 'Prime'
+ ' Minister' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the' ' Prime'
+ ' Minister']" ", the Prime Minister , and the
+
+ Prime Minister of the United Kingdom , and the Prime Minister" False majority of 146, Clement Attlee formed the 6 [' majority', ' of', ' 146', ',', ' Clement', ' Att', 'lee']
+3508 817 Name of mother of x -1 Name of mother of Clement Attlee Ellen Bravery Watson Clement Attlee "[',' ' the' ' Prime' ' Minister' ',' ' and' ' the' '\n' '\n' 'Prime'
+ ' Minister' ' of' ' the' ' United' ' Kingdom' ',' ' and' ' the' ' Prime'
+ ' Minister']" ", the Prime Minister , and the
+
+ Prime Minister of the United Kingdom , and the Prime Minister" False Prime Minister, Clement Attlee visited 10 5 [' Prime', ' Minister', ',', ' Clement', ' Att', 'lee']
+3509 818 Name of mother of x -1 Name of mother of J. J. Thomson Emma Swindells J. J. Thomson "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' J' '.' ' J' '.'
+ ' Thomson' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of J . J . Thomson , the
+
+ Name of mother" False done. He met J. J. Thomson of the Cavendish Laboratory 8 [' done', '.', ' He', ' met', ' J', '.', ' J', '.', ' Thomson']
+3510 818 Name of mother of x -1 Name of mother of J. J. Thomson Emma Swindells J. J. Thomson "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' J' '.' ' J' '.'
+ ' Thomson' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of J . J . Thomson , the
+
+ Name of mother" False The physicist J. J. Thomson measured the mass of 6 [' The', ' physicist', ' J', '.', ' J', '.', ' Thomson']
+3511 818 Name of mother of x -1 Name of mother of J. J. Thomson Emma Swindells J. J. Thomson "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' J' '.' ' J' '.'
+ ' Thomson' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of J . J . Thomson , the
+
+ Name of mother" False John O'Keefe). J. J. Thomson was awarded the Physics 9 "[' John', ' O', ""'"", 'Keefe', ').', ' J', '.', ' J', '.', ' Thomson']"
+3512 818 Name of mother of x -1 Name of mother of J. J. Thomson Emma Swindells J. J. Thomson "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' J' '.' ' J' '.'
+ ' Thomson' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of J . J . Thomson , the
+
+ Name of mother" False being done. He met J. J. Thomson of the Cavendish Laboratory 9 [' being', ' done', '.', ' He', ' met', ' J', '.', ' J', '.', ' Thomson']
+3513 818 Name of mother of x -1 Name of mother of J. J. Thomson Emma Swindells J. J. Thomson "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' J' '.' ' J' '.'
+ ' Thomson' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of J . J . Thomson , the
+
+ Name of mother" False ultimately accepted by J. J. Thomson on condition that 7 [' ultimately', ' accepted', ' by', ' J', '.', ' J', '.', ' Thomson']
+3514 819 Name of mother of x -1 Name of mother of Lenny Kravitz Roxie Roker Lenny Kravitz "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False featuring Lenny Kravitz, released on the 5 [' featuring', ' L', 'enny', ' K', 'rav', 'itz']
+3515 819 Name of mother of x -1 Name of mother of Lenny Kravitz Roxie Roker Lenny Kravitz "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False that singer Lenny Kravitz had been approached 6 [' that', ' singer', ' L', 'enny', ' K', 'rav', 'itz']
+3516 819 Name of mother of x -1 Name of mother of Lenny Kravitz Roxie Roker Lenny Kravitz "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False " Light"" featuring Lenny Kravitz on guitar. ""The" 7 "[' Light', '""', ' featuring', ' L', 'enny', ' K', 'rav', 'itz']"
+3517 819 Name of mother of x -1 Name of mother of Lenny Kravitz Roxie Roker Lenny Kravitz "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False 2015, starring Lenny Kravitz and James Franco. 7 [' 2015', ',', ' starring', ' L', 'enny', ' K', 'rav', 'itz']
+3518 819 Name of mother of x -1 Name of mother of Lenny Kravitz Roxie Roker Lenny Kravitz "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Jackson featuring Lenny Kravitz, released on the 6 [' Jackson', ' featuring', ' L', 'enny', ' K', 'rav', 'itz']
+3519 820 Name of mother of x -1 Name of mother of Eugene O'Neill Ella O'Neill Eugene O'Neill "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' and' ' the' ' play' 'wright' ""'s"" ' brother' ',' ' the']" , the play wright , and his wife , the actress , and the play wright 's brother , the False direct debt to Shaw, Eugene O'Neill became an admirer at 8 "[' direct', ' debt', ' to', ' Shaw', ',', ' Eugene', ' O', ""'"", 'Neill']"
+3520 820 Name of mother of x -1 Name of mother of Eugene O'Neill Ella O'Neill Eugene O'Neill "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' and' ' the' ' play' 'wright' ""'s"" ' brother' ',' ' the']" , the play wright , and his wife , the actress , and the play wright 's brother , the False River premiered at Eugene O'Neill Theatre in New York 6 "[' River', ' premiered', ' at', ' Eugene', ' O', ""'"", 'Neill']"
+3521 820 Name of mother of x -1 Name of mother of Eugene O'Neill Ella O'Neill Eugene O'Neill "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' and' ' the' ' play' 'wright' ""'s"" ' brother' ',' ' the']" , the play wright , and his wife , the actress , and the play wright 's brother , the False River premiered at Eugene O'Neill Theatre in New York 6 "[' River', ' premiered', ' at', ' Eugene', ' O', ""'"", 'Neill']"
+3522 820 Name of mother of x -1 Name of mother of Eugene O'Neill Ella O'Neill Eugene O'Neill "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' and' ' the' ' play' 'wright' ""'s"" ' brother' ',' ' the']" , the play wright , and his wife , the actress , and the play wright 's brother , the False premiered at Eugene O'Neill Theatre in New 5 "[' premiered', ' at', ' Eugene', ' O', ""'"", 'Neill']"
+3523 820 Name of mother of x -1 Name of mother of Eugene O'Neill Ella O'Neill Eugene O'Neill "[',' ' the' ' play' 'wright' ',' ' and' ' his' ' wife' ',' ' the'
+ ' actress' ',' ' and' ' the' ' play' 'wright' ""'s"" ' brother' ',' ' the']" , the play wright , and his wife , the actress , and the play wright 's brother , the False obvious clue was the Eugene O'Neill $ 1 stamps used 7 "[' obvious', ' clue', ' was', ' the', ' Eugene', ' O', ""'"", 'Neill']"
+3524 823 Name of mother of x -1 Name of mother of Edward Snowden Elizabeth Snowden Edward Snowden "[',' ' the' ' NSA' ' whistleblower' ' who' ' leaked' ' the' ' information'
+ ' about' ' the' ' NSA' '�' '�' 's' ' mass' ' surveillance' ' program' '.'
+ '\n' '\n']" ", the NSA whistleblower who leaked the information about the NSA � � s mass surveillance program .
+
+" False suspicions that Edward Snowden was on board the 3 [' suspicions', ' that', ' Edward', ' Snowden']
+3525 823 Name of mother of x -1 Name of mother of Edward Snowden Elizabeth Snowden Edward Snowden "[',' ' the' ' NSA' ' whistleblower' ' who' ' leaked' ' the' ' information'
+ ' about' ' the' ' NSA' '�' '�' 's' ' mass' ' surveillance' ' program' '.'
+ '\n' '\n']" ", the NSA whistleblower who leaked the information about the NSA � � s mass surveillance program .
+
+" False a satirical Edward Snowden theme following 3 [' a', ' satirical', ' Edward', ' Snowden']
+3526 823 Name of mother of x -1 Name of mother of Edward Snowden Elizabeth Snowden Edward Snowden "[',' ' the' ' NSA' ' whistleblower' ' who' ' leaked' ' the' ' information'
+ ' about' ' the' ' NSA' '�' '�' 's' ' mass' ' surveillance' ' program' '.'
+ '\n' '\n']" ", the NSA whistleblower who leaked the information about the NSA � � s mass surveillance program .
+
+" False India) In June 2013 Edward Snowden revealed details 6 [' India', ')', ' In', ' June', ' 2013', ' Edward', ' Snowden']
+3527 823 Name of mother of x -1 Name of mother of Edward Snowden Elizabeth Snowden Edward Snowden "[',' ' the' ' NSA' ' whistleblower' ' who' ' leaked' ' the' ' information'
+ ' about' ' the' ' NSA' '�' '�' 's' ' mass' ' surveillance' ' program' '.'
+ '\n' '\n']" ", the NSA whistleblower who leaked the information about the NSA � � s mass surveillance program .
+
+" False suspicions that Edward Snowden was on board the 3 [' suspicions', ' that', ' Edward', ' Snowden']
+3528 823 Name of mother of x -1 Name of mother of Edward Snowden Elizabeth Snowden Edward Snowden "[',' ' the' ' NSA' ' whistleblower' ' who' ' leaked' ' the' ' information'
+ ' about' ' the' ' NSA' '�' '�' 's' ' mass' ' surveillance' ' program' '.'
+ '\n' '\n']" ", the NSA whistleblower who leaked the information about the NSA � � s mass surveillance program .
+
+" False program and praised Edward Snowden for having 4 [' program', ' and', ' praised', ' Edward', ' Snowden']
+3529 825 Name of mother of x -1 Name of mother of Emperor Huizong of Song Chen Shi Emperor Huizong of Song "[',' ' the' ' Emperor' ' of' ' the' ' Song' ' dynasty' '.' '\n' '\n' 'The'
+ ' Emperor' ' of' ' the' ' Song' ' dynasty' ' was' ' a' ' man' ' of']" ", the Emperor of the Song dynasty .
+
+ The Emperor of the Song dynasty was a man of" False dynasty. In 1127, Emperor Huizong of Song and the capital Bianjing 11 [' dynasty', '.', ' In', ' 112', '7', ',', ' Emperor', ' Hu', 'iz', 'ong', ' of', ' Song']
+3530 825 Name of mother of x -1 Name of mother of Emperor Huizong of Song Chen Shi Emperor Huizong of Song "[',' ' the' ' Emperor' ' of' ' the' ' Song' ' dynasty' '.' '\n' '\n' 'The'
+ ' Emperor' ' of' ' the' ' Song' ' dynasty' ' was' ' a' ' man' ' of']" ", the Emperor of the Song dynasty .
+
+ The Emperor of the Song dynasty was a man of" False not buried there are Emperor Huizong of Song and Emperor Qinzong 9 [' not', ' buried', ' there', ' are', ' Emperor', ' Hu', 'iz', 'ong', ' of', ' Song']
+3531 825 Name of mother of x -1 Name of mother of Emperor Huizong of Song Chen Shi Emperor Huizong of Song "[',' ' the' ' Emperor' ' of' ' the' ' Song' ' dynasty' '.' '\n' '\n' 'The'
+ ' Emperor' ' of' ' the' ' Song' ' dynasty' ' was' ' a' ' man' ' of']" ", the Emperor of the Song dynasty .
+
+ The Emperor of the Song dynasty was a man of" False buried there are Emperor Huizong of Song and Emperor Qinzong 8 [' buried', ' there', ' are', ' Emperor', ' Hu', 'iz', 'ong', ' of', ' Song']
+3532 825 Name of mother of x -1 Name of mother of Emperor Huizong of Song Chen Shi Emperor Huizong of Song "[',' ' the' ' Emperor' ' of' ' the' ' Song' ' dynasty' '.' '\n' '\n' 'The'
+ ' Emperor' ' of' ' the' ' Song' ' dynasty' ' was' ' a' ' man' ' of']" ", the Emperor of the Song dynasty .
+
+ The Emperor of the Song dynasty was a man of" False dynasty. In 1127, Emperor Huizong of Song and the capital 11 [' dynasty', '.', ' In', ' 112', '7', ',', ' Emperor', ' Hu', 'iz', 'ong', ' of', ' Song']
+3533 825 Name of mother of x -1 Name of mother of Emperor Huizong of Song Chen Shi Emperor Huizong of Song "[',' ' the' ' Emperor' ' of' ' the' ' Song' ' dynasty' '.' '\n' '\n' 'The'
+ ' Emperor' ' of' ' the' ' Song' ' dynasty' ' was' ' a' ' man' ' of']" ", the Emperor of the Song dynasty .
+
+ The Emperor of the Song dynasty was a man of" False year of 1101, the Emperor Huizong of Song decided to lessen 11 [' year', ' of', ' 11', '01', ',', ' the', ' Emperor', ' Hu', 'iz', 'ong', ' of', ' Song']
+3534 827 Name of mother of x -1 Name of mother of Spider-Man Mary Parker Spider-Man "[',' ' and' ' the' ' Spider' '-' 'Man' ' of' ' the' ' future' '.' '\n'
+ '\n' 'The' ' Spider' '-' 'Man' ' of' ' the' ' future' ' is']" ", and the Spider - Man of the future .
+
+ The Spider - Man of the future is" False 2 ['Spider', '-', 'Man']
+3535 827 Name of mother of x -1 Name of mother of Spider-Man Mary Parker Spider-Man "[',' ' and' ' the' ' Spider' '-' 'Man' ' of' ' the' ' future' '.' '\n'
+ '\n' 'The' ' Spider' '-' 'Man' ' of' ' the' ' future' ' is']" ", and the Spider - Man of the future .
+
+ The Spider - Man of the future is" False crime scene due to Spider-Man's apprehension of 6 [' crime', ' scene', ' due', ' to', ' Spider', '-', 'Man']
+3536 827 Name of mother of x -1 Name of mother of Spider-Man Mary Parker Spider-Man "[',' ' and' ' the' ' Spider' '-' 'Man' ' of' ' the' ' future' '.' '\n'
+ '\n' 'The' ' Spider' '-' 'Man' ' of' ' the' ' future' ' is']" ", and the Spider - Man of the future .
+
+ The Spider - Man of the future is" False next AAA title, a Spider-Man video game developed 7 [' next', ' AAA', ' title', ',', ' a', ' Spider', '-', 'Man']
+3537 827 Name of mother of x -1 Name of mother of Spider-Man Mary Parker Spider-Man "[',' ' and' ' the' ' Spider' '-' 'Man' ' of' ' the' ' future' '.' '\n'
+ '\n' 'The' ' Spider' '-' 'Man' ' of' ' the' ' future' ' is']" ", and the Spider - Man of the future .
+
+ The Spider - Man of the future is" False unveiled a massive Spider-Man (PA 1040) 5 [' unveiled', ' a', ' massive', ' Spider', '-', 'Man']
+3538 827 Name of mother of x -1 Name of mother of Spider-Man Mary Parker Spider-Man "[',' ' and' ' the' ' Spider' '-' 'Man' ' of' ' the' ' future' '.' '\n'
+ '\n' 'The' ' Spider' '-' 'Man' ' of' ' the' ' future' ' is']" ", and the Spider - Man of the future .
+
+ The Spider - Man of the future is" False goes after Connors as Spider-Man. While trying to communicate 7 [' goes', ' after', ' Conn', 'ors', ' as', ' Spider', '-', 'Man']
+3539 828 Name of mother of x -1 Name of mother of Jewel Lenedra Carroll Jewel "['ry' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" "ry
+
+ I am a mother of two , a wife , a daughter , a sister , a" False 1 ['Jew', 'el']
+3540 828 Name of mother of x -1 Name of mother of Jewel Lenedra Carroll Jewel "['ry' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" "ry
+
+ I am a mother of two , a wife , a daughter , a sister , a" False address, called the Jewel Voice Broadcast 4 [' address', ',', ' called', ' the', ' Jewel']
+3541 828 Name of mother of x -1 Name of mother of Jewel Lenedra Carroll Jewel "['ry' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" "ry
+
+ I am a mother of two , a wife , a daughter , a sister , a" False 1 ['Jew', 'el']
+3542 828 Name of mother of x -1 Name of mother of Jewel Lenedra Carroll Jewel "['ry' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" "ry
+
+ I am a mother of two , a wife , a daughter , a sister , a" False School of the Sacred Jewel or the School 4 [' School', ' of', ' the', ' Sacred', ' Jewel']
+3543 828 Name of mother of x -1 Name of mother of Jewel Lenedra Carroll Jewel "['ry' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ','
+ ' a' ' daughter' ',' ' a' ' sister' ',' ' a']" "ry
+
+ I am a mother of two , a wife , a daughter , a sister , a" False different creatures. Jewel Summoners can also 3 [' different', ' creatures', '.', ' Jewel']
+3544 831 Name of mother of x -1 Name of mother of Haakon VII of Norway Louise of Sweden Haakon VII of Norway "[',' ' the' ' King' ' of' ' Norway' ',' ' and' ' the' ' Queen' ' of'
+ ' Norway' '.' '\n' '\n' 'The' ' King' ' of' ' Norway' ' is' ' the']" ", the King of Norway , and the Queen of Norway .
+
+ The King of Norway is the" False 1928 and 1929, and Haakon VII of Norway in 1943. He also 10 [' 1928', ' and', ' 1929', ',', ' and', ' Ha', 'ak', 'on', ' VII', ' of', ' Norway']
+3545 831 Name of mother of x -1 Name of mother of Haakon VII of Norway Louise of Sweden Haakon VII of Norway "[',' ' the' ' King' ' of' ' Norway' ',' ' and' ' the' ' Queen' ' of'
+ ' Norway' '.' '\n' '\n' 'The' ' King' ' of' ' Norway' ' is' ' the']" ", the King of Norway , and the Queen of Norway .
+
+ The King of Norway is the" False at Oslo, where King Haakon VII of Norway inspected the crew 10 [' at', ' Oslo', ',', ' where', ' King', ' Ha', 'ak', 'on', ' VII', ' of', ' Norway']
+3546 831 Name of mother of x -1 Name of mother of Haakon VII of Norway Louise of Sweden Haakon VII of Norway "[',' ' the' ' King' ' of' ' Norway' ',' ' and' ' the' ' Queen' ' of'
+ ' Norway' '.' '\n' '\n' 'The' ' King' ' of' ' Norway' ' is' ' the']" ", the King of Norway , and the Queen of Norway .
+
+ The King of Norway is the" False letter to King Haakon VII of Norway (which Amundsen politely 8 [' letter', ' to', ' King', ' Ha', 'ak', 'on', ' VII', ' of', ' Norway']
+3547 831 Name of mother of x -1 Name of mother of Haakon VII of Norway Louise of Sweden Haakon VII of Norway "[',' ' the' ' King' ' of' ' Norway' ',' ' and' ' the' ' Queen' ' of'
+ ' Norway' '.' '\n' '\n' 'The' ' King' ' of' ' Norway' ' is' ' the']" ", the King of Norway , and the Queen of Norway .
+
+ The King of Norway is the" False supplies, a letter to King Haakon VII of Norway (which Amundsen 11 [' supplies', ',', ' a', ' letter', ' to', ' King', ' Ha', 'ak', 'on', ' VII', ' of', ' Norway']
+3548 831 Name of mother of x -1 Name of mother of Haakon VII of Norway Louise of Sweden Haakon VII of Norway "[',' ' the' ' King' ' of' ' Norway' ',' ' and' ' the' ' Queen' ' of'
+ ' Norway' '.' '\n' '\n' 'The' ' King' ' of' ' Norway' ' is' ' the']" ", the King of Norway , and the Queen of Norway .
+
+ The King of Norway is the" False where King Haakon VII of Norway inspected the crew 7 [' where', ' King', ' Ha', 'ak', 'on', ' VII', ' of', ' Norway']
+3549 832 Name of mother of x -1 Name of mother of Queen Elizabeth, The Queen Mother Cecilia Bowes-Lyon, Countess of Strathmore and Kinghorne Queen Elizabeth, The Queen Mother "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Queen'
+ ' Mother' ' was' ' born' ' Princess' ' Elizabeth' ',' ' the' ' eldest'
+ ' child']" ", and the Queen of England .
+
+ The Queen Mother was born Princess Elizabeth , the eldest child" False styles and 437 titles of Queen Elizabeth, The Queen Mother at the end of her funeral 11 [' styles', ' and', ' 4', '37', ' titles', ' of', ' Queen', ' Elizabeth', ',', ' The', ' Queen', ' Mother']
+3550 832 Name of mother of x -1 Name of mother of Queen Elizabeth, The Queen Mother Cecilia Bowes-Lyon, Countess of Strathmore and Kinghorne Queen Elizabeth, The Queen Mother "[',' ' and' ' the' ' Queen' ' of' ' England' '.' '\n' '\n' 'The' ' Queen'
+ ' Mother' ' was' ' born' ' Princess' ' Elizabeth' ',' ' the' ' eldest'
+ ' child']" ", and the Queen of England .
+
+ The Queen Mother was born Princess Elizabeth , the eldest child" False 437 titles of Queen Elizabeth, The Queen Mother at the end of her 9 [' 4', '37', ' titles', ' of', ' Queen', ' Elizabeth', ',', ' The', ' Queen', ' Mother']
+3551 833 Name of mother of x -1 Name of mother of Horatio Nelson Catherine Suckling Horatio Nelson "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' ships' ' of' ' the' ' Royal' ' Navy' ' that' ' have' ' been' ' named']" ", the
+
+ The following is a list of the ships of the Royal Navy that have been named" False of the death of Horatio Nelson — Baldwin was 6 [' of', ' the', ' death', ' of', ' Hor', 'atio', ' Nelson']
+3552 833 Name of mother of x -1 Name of mother of Horatio Nelson Catherine Suckling Horatio Nelson "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' ships' ' of' ' the' ' Royal' ' Navy' ' that' ' have' ' been' ' named']" ", the
+
+ The following is a list of the ships of the Royal Navy that have been named" False appointed Vice-Admiral Horatio Nelson to take command of 8 [' appointed', ' Vice', '-', 'Ad', 'mir', 'al', ' Hor', 'atio', ' Nelson']
+3553 833 Name of mother of x -1 Name of mother of Horatio Nelson Catherine Suckling Horatio Nelson "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' ships' ' of' ' the' ' Royal' ' Navy' ' that' ' have' ' been' ' named']" ", the
+
+ The following is a list of the ships of the Royal Navy that have been named" False of the death of Horatio Nelson — Baldwin was 6 [' of', ' the', ' death', ' of', ' Hor', 'atio', ' Nelson']
+3554 833 Name of mother of x -1 Name of mother of Horatio Nelson Catherine Suckling Horatio Nelson "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' ships' ' of' ' the' ' Royal' ' Navy' ' that' ' have' ' been' ' named']" ", the
+
+ The following is a list of the ships of the Royal Navy that have been named" False Rear-Admiral Sir Horatio Nelson had destroyed a French 8 [' Rear', '-', 'Ad', 'mir', 'al', ' Sir', ' Hor', 'atio', ' Nelson']
+3555 833 Name of mother of x -1 Name of mother of Horatio Nelson Catherine Suckling Horatio Nelson "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' ships' ' of' ' the' ' Royal' ' Navy' ' that' ' have' ' been' ' named']" ", the
+
+ The following is a list of the ships of the Royal Navy that have been named" False 2 ['Hor', 'atio', ' Nelson']
+3556 834 Name of mother of x -1 Name of mother of Christina Applegate Nancy Priddy Christina Applegate "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' former' ' Miss' ' America']" ", who is a former Miss America , and the mother of two .
+
+ The former Miss America" False Kristen Bell and Christina Applegate in the comedy 5 [' Kristen', ' Bell', ' and', ' Christina', ' App', 'legate']
+3557 834 Name of mother of x -1 Name of mother of Christina Applegate Nancy Priddy Christina Applegate "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' former' ' Miss' ' America']" ", who is a former Miss America , and the mother of two .
+
+ The former Miss America" False Kristen Bell and Christina Applegate in the comedy 5 [' Kristen', ' Bell', ' and', ' Christina', ' App', 'legate']
+3558 834 Name of mother of x -1 Name of mother of Christina Applegate Nancy Priddy Christina Applegate "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' former' ' Miss' ' America']" ", who is a former Miss America , and the mother of two .
+
+ The former Miss America" False Kristen Bell and Christina Applegate in the comedy Bad 5 [' Kristen', ' Bell', ' and', ' Christina', ' App', 'legate']
+3559 834 Name of mother of x -1 Name of mother of Christina Applegate Nancy Priddy Christina Applegate "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' former' ' Miss' ' America']" ", who is a former Miss America , and the mother of two .
+
+ The former Miss America" False Kristen Bell and Christina Applegate in the comedy 5 [' Kristen', ' Bell', ' and', ' Christina', ' App', 'legate']
+3560 834 Name of mother of x -1 Name of mother of Christina Applegate Nancy Priddy Christina Applegate "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' and' ' the'
+ ' mother' ' of' ' two' '.' '\n' '\n' 'The' ' former' ' Miss' ' America']" ", who is a former Miss America , and the mother of two .
+
+ The former Miss America" False 3 ['Christ', 'ina', ' App', 'legate']
+3561 835 Name of mother of x -1 Name of mother of Jean Renoir Aline Charigot Jean Renoir "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' star' ',' ' and']" , the director of the film , and the film 's producer , and the film 's star , and False the film. Jean Renoir was the first 5 [' the', ' film', '.', ' Jean', ' Ren', 'oir']
+3562 835 Name of mother of x -1 Name of mother of Jean Renoir Aline Charigot Jean Renoir "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' star' ',' ' and']" , the director of the film , and the film 's producer , and the film 's star , and False " Spielberg ""a space age Jean Renoir .... [F] or the" 7 "[' Spielberg', ' ""', 'a', ' space', ' age', ' Jean', ' Ren', 'oir']"
+3563 835 Name of mother of x -1 Name of mother of Jean Renoir Aline Charigot Jean Renoir "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' star' ',' ' and']" , the director of the film , and the film 's producer , and the film 's star , and False " Spielberg ""a space age Jean Renoir .... [F] or" 7 "[' Spielberg', ' ""', 'a', ' space', ' age', ' Jean', ' Ren', 'oir']"
+3564 835 Name of mother of x -1 Name of mother of Jean Renoir Aline Charigot Jean Renoir "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' star' ',' ' and']" , the director of the film , and the film 's producer , and the film 's star , and False French director Jean Renoir came to Calcutta 4 [' French', ' director', ' Jean', ' Ren', 'oir']
+3565 835 Name of mother of x -1 Name of mother of Jean Renoir Aline Charigot Jean Renoir "[',' ' the' ' director' ' of' ' the' ' film' ',' ' and' ' the' ' film'
+ ""'s"" ' producer' ',' ' and' ' the' ' film' ""'s"" ' star' ',' ' and']" , the director of the film , and the film 's producer , and the film 's star , and False French filmmaker Jean Renoir and viewing Vittorio 4 [' French', ' filmmaker', ' Jean', ' Ren', 'oir']
+3566 836 Name of mother of x -1 Name of mother of Luchino Visconti Carla Erba Luchino Visconti "[',' ' the' ' Italian' ' film' ' director' ',' ' and' ' his' ' wife' ','
+ ' L' 'uch' 'ino' ""'s"" ' mother' ',' ' the' ' actress' ' Maria' ' G']" , the Italian film director , and his wife , L uch ino 's mother , the actress Maria G False Italian director Luchino Visconti adapted Cain's The 8 [' Italian', ' director', ' L', 'uch', 'ino', ' V', 'isc', 'ont', 'i']
+3567 836 Name of mother of x -1 Name of mother of Luchino Visconti Carla Erba Luchino Visconti "[',' ' the' ' Italian' ' film' ' director' ',' ' and' ' his' ' wife' ','
+ ' L' 'uch' 'ino' ""'s"" ' mother' ',' ' the' ' actress' ' Maria' ' G']" , the Italian film director , and his wife , L uch ino 's mother , the actress Maria G False Elsewhere, Italian director Luchino Visconti adapted Cain's 11 [' Else', 'where', ',', ' Italian', ' director', ' L', 'uch', 'ino', ' V', 'isc', 'ont', 'i']
+3568 837 Name of mother of x -1 Name of mother of Richard Burton Edith Mawd Thomas Richard Burton "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False producer also approached Richard Burton for the role, 4 [' producer', ' also', ' approached', ' Richard', ' Burton']
+3569 837 Name of mother of x -1 Name of mother of Richard Burton Edith Mawd Thomas Richard Burton "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False Retro Gamer editor Richard Burton described 4 [' Retro', ' Gamer', ' editor', ' Richard', ' Burton']
+3570 837 Name of mother of x -1 Name of mother of Richard Burton Edith Mawd Thomas Richard Burton "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False (1969), with Richard Burton and Geneviève Bujold, 5 [' (', '1969', '),', ' with', ' Richard', ' Burton']
+3571 837 Name of mother of x -1 Name of mother of Richard Burton Edith Mawd Thomas Richard Burton "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False described in the sources. Richard Burton and Commander 6 [' described', ' in', ' the', ' sources', '.', ' Richard', ' Burton']
+3572 837 Name of mother of x -1 Name of mother of Richard Burton Edith Mawd Thomas Richard Burton "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' bastard' '!""' '\n' '\n' '""' 'I']" ", the son of a bitch !""
+
+ "" I 'm not a bastard !""
+
+ "" I" False Redgrave (1950), Richard Burton (1954), Derek Jacobi 6 [' Red', 'grave', ' (', '1950', '),', ' Richard', ' Burton']
+3573 838 Name of mother of x -1 Name of mother of Ozzy Osbourne Lillian Unitt Ozzy Osbourne "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ' Sharon' ' Os' 'bourne'
+ ',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' Sharon' ' Os']" , the singer , and his wife Sharon Os bourne , the singer , and his wife , Sharon Os False Lead singer Ozzy Osbourne's wife, Sharon, 5 [' Lead', ' singer', ' O', 'zzy', ' Os', 'bourne']
+3574 838 Name of mother of x -1 Name of mother of Ozzy Osbourne Lillian Unitt Ozzy Osbourne "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ' Sharon' ' Os' 'bourne'
+ ',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' Sharon' ' Os']" , the singer , and his wife Sharon Os bourne , the singer , and his wife , Sharon Os False Metallica supported Ozzy Osbourne on a U.S. tour. 6 [' Metall', 'ica', ' supported', ' O', 'zzy', ' Os', 'bourne']
+3575 838 Name of mother of x -1 Name of mother of Ozzy Osbourne Lillian Unitt Ozzy Osbourne "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ' Sharon' ' Os' 'bourne'
+ ',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' Sharon' ' Os']" , the singer , and his wife Sharon Os bourne , the singer , and his wife , Sharon Os False Runaways' Lita Ford, and Ozzy Osbourne provide character 11 "[' Run', 'aways', ""'"", ' L', 'ita', ' Ford', ',', ' and', ' O', 'zzy', ' Os', 'bourne']"
+3576 838 Name of mother of x -1 Name of mother of Ozzy Osbourne Lillian Unitt Ozzy Osbourne "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ' Sharon' ' Os' 'bourne'
+ ',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' Sharon' ' Os']" , the singer , and his wife Sharon Os bourne , the singer , and his wife , Sharon Os False drug abuse. Ozzy Osbourne could be heard preaching 6 [' drug', ' abuse', '.', ' O', 'zzy', ' Os', 'bourne']
+3577 838 Name of mother of x -1 Name of mother of Ozzy Osbourne Lillian Unitt Ozzy Osbourne "[',' ' the' ' singer' ',' ' and' ' his' ' wife' ' Sharon' ' Os' 'bourne'
+ ',' ' the' ' singer' ',' ' and' ' his' ' wife' ',' ' Sharon' ' Os']" , the singer , and his wife Sharon Os bourne , the singer , and his wife , Sharon Os False created by having Ozzy Osbourne sing from behind a 6 [' created', ' by', ' having', ' O', 'zzy', ' Os', 'bourne']
+3578 840 Name of mother of x -1 Name of mother of Charles IV Elizabeth of Bohemia Charles IV "[',' ' King' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and']" ", King of France , and the
+
+ King of France , and the King of France , and" False between Emperor Charles IV and Stephen II, 3 [' between', ' Emperor', ' Charles', ' IV']
+3579 840 Name of mother of x -1 Name of mother of Charles IV Elizabeth of Bohemia Charles IV "[',' ' King' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and']" ", King of France , and the
+
+ King of France , and the King of France , and" False Peninsular War. King Charles IV of Spain abdicated 7 [' Pen', 'ins', 'ular', ' War', '.', ' King', ' Charles', ' IV']
+3580 840 Name of mother of x -1 Name of mother of Charles IV Elizabeth of Bohemia Charles IV "[',' ' King' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and']" ", King of France , and the
+
+ King of France , and the King of France , and" False resulted in Charles IV declaring the duchy 3 [' resulted', ' in', ' Charles', ' IV']
+3581 840 Name of mother of x -1 Name of mother of Charles IV Elizabeth of Bohemia Charles IV "[',' ' King' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and']" ", King of France , and the
+
+ King of France , and the King of France , and" False resulted in Charles IV declaring the duchy 3 [' resulted', ' in', ' Charles', ' IV']
+3582 840 Name of mother of x -1 Name of mother of Charles IV Elizabeth of Bohemia Charles IV "[',' ' King' ' of' ' France' ',' ' and' ' the' '\n' '\n' 'King' ' of'
+ ' France' ',' ' and' ' the' ' King' ' of' ' France' ',' ' and']" ", King of France , and the
+
+ King of France , and the King of France , and" False between Emperor Charles IV and Stephen II, 3 [' between', ' Emperor', ' Charles', ' IV']
+3583 841 Name of mother of x -1 Name of mother of Hadrian Domitia Paulina Hadrian "[""'s"" ' Wall' ',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of'
+ ' mother' ' of' ' Had' 'rian' ""'s"" ' Wall' ',' ' and']" "'s Wall , and the
+
+ The
+
+ Name of mother of Had rian 's Wall , and" False the emperor Hadrian the proportion of 3 [' the', ' emperor', ' Had', 'rian']
+3584 841 Name of mother of x -1 Name of mother of Hadrian Domitia Paulina Hadrian "[""'s"" ' Wall' ',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of'
+ ' mother' ' of' ' Had' 'rian' ""'s"" ' Wall' ',' ' and']" "'s Wall , and the
+
+ The
+
+ Name of mother of Had rian 's Wall , and" False authority, from Pope Hadrian I. In one surviving 5 [' authority', ',', ' from', ' Pope', ' Had', 'rian']
+3585 841 Name of mother of x -1 Name of mother of Hadrian Domitia Paulina Hadrian "[""'s"" ' Wall' ',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of'
+ ' mother' ' of' ' Had' 'rian' ""'s"" ' Wall' ',' ' and']" "'s Wall , and the
+
+ The
+
+ Name of mother of Had rian 's Wall , and" False by the Emperor Hadrian during the 4 [' by', ' the', ' Emperor', ' Had', 'rian']
+3586 841 Name of mother of x -1 Name of mother of Hadrian Domitia Paulina Hadrian "[""'s"" ' Wall' ',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of'
+ ' mother' ' of' ' Had' 'rian' ""'s"" ' Wall' ',' ' and']" "'s Wall , and the
+
+ The
+
+ Name of mother of Had rian 's Wall , and" False time of Emperor Hadrian (2nd century 4 [' time', ' of', ' Emperor', ' Had', 'rian']
+3587 841 Name of mother of x -1 Name of mother of Hadrian Domitia Paulina Hadrian "[""'s"" ' Wall' ',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of'
+ ' mother' ' of' ' Had' 'rian' ""'s"" ' Wall' ',' ' and']" "'s Wall , and the
+
+ The
+
+ Name of mother of Had rian 's Wall , and" False abilities. Hadrian succeeded Trajan 3 [' abilities', '.', ' Had', 'rian']
+3588 842 Name of mother of x -1 Name of mother of Alfonso XIII of Spain Maria Christina of Austria Alfonso XIII of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' and' ' the' ' King']" , the King of Spain , and the Queen of Spain , the Queen of Spain , and the King False received King Alfonso XIII of Spain aboard while visiting 6 [' received', ' King', ' Alf', 'onso', ' XIII', ' of', ' Spain']
+3589 842 Name of mother of x -1 Name of mother of Alfonso XIII of Spain Maria Christina of Austria Alfonso XIII of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' and' ' the' ' King']" , the King of Spain , and the Queen of Spain , the Queen of Spain , and the King False to move again. Alfonso XIII of Spain had approached the 8 [' to', ' move', ' again', '.', ' Alf', 'onso', ' XIII', ' of', ' Spain']
+3590 842 Name of mother of x -1 Name of mother of Alfonso XIII of Spain Maria Christina of Austria Alfonso XIII of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' and' ' the' ' King']" , the King of Spain , and the Queen of Spain , the Queen of Spain , and the King False soon to move again. Alfonso XIII of Spain had approached the 9 [' soon', ' to', ' move', ' again', '.', ' Alf', 'onso', ' XIII', ' of', ' Spain']
+3591 842 Name of mother of x -1 Name of mother of Alfonso XIII of Spain Maria Christina of Austria Alfonso XIII of Spain "[',' ' the' ' King' ' of' ' Spain' ',' ' and' ' the' ' Queen' ' of'
+ ' Spain' ',' ' the' ' Queen' ' of' ' Spain' ',' ' and' ' the' ' King']" , the King of Spain , and the Queen of Spain , the Queen of Spain , and the King False soon to move again. Alfonso XIII of Spain had approached 9 [' soon', ' to', ' move', ' again', '.', ' Alf', 'onso', ' XIII', ' of', ' Spain']
+3592 843 Name of mother of x -1 Name of mother of Gregory of Nyssa Emmelia of Caesarea Gregory of Nyssa "[',' ' the' ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' and' ' the'
+ ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' the' ' mother' ' of'
+ ' the']" , the mother of the Virgin Mary , and the mother of the Virgin Mary , the mother of the False by Bishop Gregory of Nyssa which was a preserved 6 [' by', ' Bishop', ' Gregory', ' of', ' Ny', 'ss', 'a']
+3593 843 Name of mother of x -1 Name of mother of Gregory of Nyssa Emmelia of Caesarea Gregory of Nyssa "[',' ' the' ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' and' ' the'
+ ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' the' ' mother' ' of'
+ ' the']" , the mother of the Virgin Mary , and the mother of the Virgin Mary , the mother of the False book by Bishop Gregory of Nyssa which was a preserved 7 [' book', ' by', ' Bishop', ' Gregory', ' of', ' Ny', 'ss', 'a']
+3594 843 Name of mother of x -1 Name of mother of Gregory of Nyssa Emmelia of Caesarea Gregory of Nyssa "[',' ' the' ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' and' ' the'
+ ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' the' ' mother' ' of'
+ ' the']" , the mother of the Virgin Mary , and the mother of the Virgin Mary , the mother of the False " drink the good."" Gregory of Nyssa (died 395) made" 8 "[' drink', ' the', ' good', '.""', ' Gregory', ' of', ' Ny', 'ss', 'a']"
+3595 843 Name of mother of x -1 Name of mother of Gregory of Nyssa Emmelia of Caesarea Gregory of Nyssa "[',' ' the' ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' and' ' the'
+ ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' the' ' mother' ' of'
+ ' the']" , the mother of the Virgin Mary , and the mother of the Virgin Mary , the mother of the False " of Nyssa =
+" 10 [' of', ' Ny', 'ss', 'a', ' =', 'Greg', 'ory', ' of', ' Ny', 'ss', 'a']
+3596 843 Name of mother of x -1 Name of mother of Gregory of Nyssa Emmelia of Caesarea Gregory of Nyssa "[',' ' the' ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' and' ' the'
+ ' mother' ' of' ' the' ' Virgin' ' Mary' ',' ' the' ' mother' ' of'
+ ' the']" , the mother of the Virgin Mary , and the mother of the Virgin Mary , the mother of the False " of Nyssa =
+" 10 [' of', ' Ny', 'ss', 'a', ' =', 'Greg', 'ory', ' of', ' Ny', 'ss', 'a']
+3597 844 Name of mother of x -1 Name of mother of William Herschel Anna Ilse Moritzen William Herschel "[',' ' the' ' astronomer' ',' ' and' ' his' ' wife' ',' ' Caroline'
+ ' Hers' 'chel' ',' ' who' ' was' ' a' ' sister' ' of' ' Sir' ' William'
+ ' Hers']" , the astronomer , and his wife , Caroline Hers chel , who was a sister of Sir William Hers False Accordingly, in 1802, William Herschel suggested they 8 [' Accordingly', ',', ' in', ' 18', '02', ',', ' William', ' Hers', 'chel']
+3598 844 Name of mother of x -1 Name of mother of William Herschel Anna Ilse Moritzen William Herschel "[',' ' the' ' astronomer' ',' ' and' ' his' ' wife' ',' ' Caroline'
+ ' Hers' 'chel' ',' ' who' ' was' ' a' ' sister' ' of' ' Sir' ' William'
+ ' Hers']" , the astronomer , and his wife , Caroline Hers chel , who was a sister of Sir William Hers False until 1789 when William Herschel discovered two 6 [' until', ' 17', '89', ' when', ' William', ' Hers', 'chel']
+3599 844 Name of mother of x -1 Name of mother of William Herschel Anna Ilse Moritzen William Herschel "[',' ' the' ' astronomer' ',' ' and' ' his' ' wife' ',' ' Caroline'
+ ' Hers' 'chel' ',' ' who' ' was' ' a' ' sister' ' of' ' Sir' ' William'
+ ' Hers']" , the astronomer , and his wife , Caroline Hers chel , who was a sister of Sir William Hers False carried out by William Herschel in 1785 by 5 [' carried', ' out', ' by', ' William', ' Hers', 'chel']
+3600 844 Name of mother of x -1 Name of mother of William Herschel Anna Ilse Moritzen William Herschel "[',' ' the' ' astronomer' ',' ' and' ' his' ' wife' ',' ' Caroline'
+ ' Hers' 'chel' ',' ' who' ' was' ' a' ' sister' ' of' ' Sir' ' William'
+ ' Hers']" , the astronomer , and his wife , Caroline Hers chel , who was a sister of Sir William Hers False English astronomer William Herschel in 1790. Prior 4 [' English', ' astronomer', ' William', ' Hers', 'chel']
+3601 844 Name of mother of x -1 Name of mother of William Herschel Anna Ilse Moritzen William Herschel "[',' ' the' ' astronomer' ',' ' and' ' his' ' wife' ',' ' Caroline'
+ ' Hers' 'chel' ',' ' who' ' was' ' a' ' sister' ' of' ' Sir' ' William'
+ ' Hers']" , the astronomer , and his wife , Caroline Hers chel , who was a sister of Sir William Hers False carried out by William Herschel in 1785 by 5 [' carried', ' out', ' by', ' William', ' Hers', 'chel']
+3602 846 Name of mother of x -1 Name of mother of Francis Galton Frances Anne Violetta Darwin Francis Galton "[',' ' the' ' famous' ' English' ' scientist' ',' ' who' ' was' ' a'
+ ' pioneer' ' in' ' the' ' field' ' of' ' e' 'ugen' 'ics' '.' '\n' '\n']" ", the famous English scientist , who was a pioneer in the field of e ugen ics .
+
+" False coined in 1877 by Francis Galton to indicate an 7 [' coined', ' in', ' 18', '77', ' by', ' Francis', ' Gal', 'ton']
+3603 846 Name of mother of x -1 Name of mother of Francis Galton Frances Anne Violetta Darwin Francis Galton "[',' ' the' ' famous' ' English' ' scientist' ',' ' who' ' was' ' a'
+ ' pioneer' ' in' ' the' ' field' ' of' ' e' 'ugen' 'ics' '.' '\n' '\n']" ", the famous English scientist , who was a pioneer in the field of e ugen ics .
+
+" False 3 ['Franc', 'is', ' Gal', 'ton']
+3604 846 Name of mother of x -1 Name of mother of Francis Galton Frances Anne Violetta Darwin Francis Galton "[',' ' the' ' famous' ' English' ' scientist' ',' ' who' ' was' ' a'
+ ' pioneer' ' in' ' the' ' field' ' of' ' e' 'ugen' 'ics' '.' '\n' '\n']" ", the famous English scientist , who was a pioneer in the field of e ugen ics .
+
+" False coined in 1877 by Francis Galton to indicate an area 7 [' coined', ' in', ' 18', '77', ' by', ' Francis', ' Gal', 'ton']
+3605 846 Name of mother of x -1 Name of mother of Francis Galton Frances Anne Violetta Darwin Francis Galton "[',' ' the' ' famous' ' English' ' scientist' ',' ' who' ' was' ' a'
+ ' pioneer' ' in' ' the' ' field' ' of' ' e' 'ugen' 'ics' '.' '\n' '\n']" ", the famous English scientist , who was a pioneer in the field of e ugen ics .
+
+" False with the work of Francis Galton and the biometricians. 6 [' with', ' the', ' work', ' of', ' Francis', ' Gal', 'ton']
+3606 846 Name of mother of x -1 Name of mother of Francis Galton Frances Anne Violetta Darwin Francis Galton "[',' ' the' ' famous' ' English' ' scientist' ',' ' who' ' was' ' a'
+ ' pioneer' ' in' ' the' ' field' ' of' ' e' 'ugen' 'ics' '.' '\n' '\n']" ", the famous English scientist , who was a pioneer in the field of e ugen ics .
+
+" False " by his cousin Francis Galton and the ""biometric""" 5 [' by', ' his', ' cousin', ' Francis', ' Gal', 'ton']
+3607 847 Name of mother of x -1 Name of mother of Wilkie Collins Harriet Geddes Wilkie Collins "[',' ' the' ' author' ' of' ' _' 'The' ' Woman' ' in' ' White' '_' ','
+ ' and' ' _' 'The' ' Moon' 'stone' '_' ',' ' and' ' _']" , the author of _ The Woman in White _ , and _ The Moon stone _ , and _ False written by Wilkie Collins with assistance 4 [' written', ' by', ' Wil', 'kie', ' Collins']
+3608 847 Name of mother of x -1 Name of mother of Wilkie Collins Harriet Geddes Wilkie Collins "[',' ' the' ' author' ' of' ' _' 'The' ' Woman' ' in' ' White' '_' ','
+ ' and' ' _' 'The' ' Moon' 'stone' '_' ',' ' and' ' _']" , the author of _ The Woman in White _ , and _ The Moon stone _ , and _ False of Trollope, Wilkie Collins and Henry Kingsley. 6 [' of', ' Troll', 'ope', ',', ' Wil', 'kie', ' Collins']
+3609 847 Name of mother of x -1 Name of mother of Wilkie Collins Harriet Geddes Wilkie Collins "[',' ' the' ' author' ' of' ' _' 'The' ' Woman' ' in' ' White' '_' ','
+ ' and' ' _' 'The' ' Moon' 'stone' '_' ',' ' and' ' _']" , the author of _ The Woman in White _ , and _ The Moon stone _ , and _ False Trollope, Wilkie Collins and Henry Kingsley. 5 [' Troll', 'ope', ',', ' Wil', 'kie', ' Collins']
+3610 847 Name of mother of x -1 Name of mother of Wilkie Collins Harriet Geddes Wilkie Collins "[',' ' the' ' author' ' of' ' _' 'The' ' Woman' ' in' ' White' '_' ','
+ ' and' ' _' 'The' ' Moon' 'stone' '_' ',' ' and' ' _']" , the author of _ The Woman in White _ , and _ The Moon stone _ , and _ False 2 ['Wil', 'kie', ' Collins']
+3611 847 Name of mother of x -1 Name of mother of Wilkie Collins Harriet Geddes Wilkie Collins "[',' ' the' ' author' ' of' ' _' 'The' ' Woman' ' in' ' White' '_' ','
+ ' and' ' _' 'The' ' Moon' 'stone' '_' ',' ' and' ' _']" , the author of _ The Woman in White _ , and _ The Moon stone _ , and _ False 2 ['Wil', 'kie', ' Collins']
+3612 848 Name of mother of x -1 Name of mother of Paul von Hindenburg Luise Schwickart Paul von Hindenburg "[',' ' the' ' German' ' general' ' who' ' had' ' been' ' the' ' first'
+ ' to' ' use' ' the' ' term' ' ""' 'Bl' 'itz' 'k' 'rieg' '""' ' to']" ", the German general who had been the first to use the term "" Bl itz k rieg "" to" False replaced by Generals Paul von Hindenburg and Erich Ludendorff. 7 [' replaced', ' by', ' Gener', 'als', ' Paul', ' von', ' Hind', 'enburg']
+3613 848 Name of mother of x -1 Name of mother of Paul von Hindenburg Luise Schwickart Paul von Hindenburg "[',' ' the' ' German' ' general' ' who' ' had' ' been' ' the' ' first'
+ ' to' ' use' ' the' ' term' ' ""' 'Bl' 'itz' 'k' 'rieg' '""' ' to']" ", the German general who had been the first to use the term "" Bl itz k rieg "" to" False by President Paul von Hindenburg on 30 January 1933, 5 [' by', ' President', ' Paul', ' von', ' Hind', 'enburg']
+3614 848 Name of mother of x -1 Name of mother of Paul von Hindenburg Luise Schwickart Paul von Hindenburg "[',' ' the' ' German' ' general' ' who' ' had' ' been' ' the' ' first'
+ ' to' ' use' ' the' ' term' ' ""' 'Bl' 'itz' 'k' 'rieg' '""' ' to']" ", the German general who had been the first to use the term "" Bl itz k rieg "" to" False the remains of Paul von Hindenburg and his wife to 6 [' the', ' remains', ' of', ' Paul', ' von', ' Hind', 'enburg']
+3615 848 Name of mother of x -1 Name of mother of Paul von Hindenburg Luise Schwickart Paul von Hindenburg "[',' ' the' ' German' ' general' ' who' ' had' ' been' ' the' ' first'
+ ' to' ' use' ' the' ' term' ' ""' 'Bl' 'itz' 'k' 'rieg' '""' ' to']" ", the German general who had been the first to use the term "" Bl itz k rieg "" to" False remains of Paul von Hindenburg from East Prussia 5 [' remains', ' of', ' Paul', ' von', ' Hind', 'enburg']
+3616 848 Name of mother of x -1 Name of mother of Paul von Hindenburg Luise Schwickart Paul von Hindenburg "[',' ' the' ' German' ' general' ' who' ' had' ' been' ' the' ' first'
+ ' to' ' use' ' the' ' term' ' ""' 'Bl' 'itz' 'k' 'rieg' '""' ' to']" ", the German general who had been the first to use the term "" Bl itz k rieg "" to" False Weimar Republic Paul von Hindenburg on 30 January 1933. 6 [' We', 'imar', ' Republic', ' Paul', ' von', ' Hind', 'enburg']
+3617 849 Name of mother of x -1 Name of mother of Langston Hughes Carrie Langston Langston Hughes "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False " photograph or a Langston Hughes poem.""
+" 5 [' photograph', ' or', ' a', ' Lang', 'ston', ' Hughes']
+3618 849 Name of mother of x -1 Name of mother of Langston Hughes Carrie Langston Langston Hughes "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False Armstrong, and Langston Hughes were among the many 5 [' Armstrong', ',', ' and', ' Lang', 'ston', ' Hughes']
+3619 849 Name of mother of x -1 Name of mother of Langston Hughes Carrie Langston Langston Hughes "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False media and literature. Langston Hughes dedicated an untitled 6 [' media', ' and', ' literature', '.', ' Lang', 'ston', ' Hughes']
+3620 849 Name of mother of x -1 Name of mother of Langston Hughes Carrie Langston Langston Hughes "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False her friend Langston Hughes on her first RCA 4 [' her', ' friend', ' Lang', 'ston', ' Hughes']
+3621 849 Name of mother of x -1 Name of mother of Langston Hughes Carrie Langston Langston Hughes "[',' ' the' ' poet' ',' ' and' ' the' ' poet' ""'s"" ' mother' ',' ' the'
+ ' poet' ""'s"" ' mother' ',' ' the' ' poet' ""'s"" ' mother' ',']" , the poet , and the poet 's mother , the poet 's mother , the poet 's mother , False directorate of Newport Langston Hughes spontaneously wrote 6 [' director', 'ate', ' of', ' Newport', ' Lang', 'ston', ' Hughes']
+3622 850 Name of mother of x -1 Name of mother of Edith Wharton Lucretia Stevens Rhinelander Edith Wharton "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False a novel by Edith Wharton and helmed 6 [' a', ' novel', ' by', ' Ed', 'ith', ' Wh', 'arton']
+3623 850 Name of mother of x -1 Name of mother of Edith Wharton Lucretia Stevens Rhinelander Edith Wharton "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False American writer Edith Wharton visited in 1920 5 [' American', ' writer', ' Ed', 'ith', ' Wh', 'arton']
+3624 850 Name of mother of x -1 Name of mother of Edith Wharton Lucretia Stevens Rhinelander Edith Wharton "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False a novel by Edith Wharton and helmed by director 6 [' a', ' novel', ' by', ' Ed', 'ith', ' Wh', 'arton']
+3625 850 Name of mother of x -1 Name of mother of Edith Wharton Lucretia Stevens Rhinelander Edith Wharton "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False The American writer Edith Wharton visited in 6 [' The', ' American', ' writer', ' Ed', 'ith', ' Wh', 'arton']
+3626 850 Name of mother of x -1 Name of mother of Edith Wharton Lucretia Stevens Rhinelander Edith Wharton "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False The American writer Edith Wharton visited in 1920 6 [' The', ' American', ' writer', ' Ed', 'ith', ' Wh', 'arton']
+3627 851 Name of mother of x -1 Name of mother of Walter Crane Marie Kearsley Walter Crane "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Walter'
+ ' Crane' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Walter Crane , the
+
+ The name of the" False painted panels by Walter Crane were unveiled in Octavia 4 [' painted', ' panels', ' by', ' Walter', ' Crane']
+3628 851 Name of mother of x -1 Name of mother of Walter Crane Marie Kearsley Walter Crane "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Walter'
+ ' Crane' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Walter Crane , the
+
+ The name of the" False Faversham, Kent, Walter Crane (Kevin Fuller) 7 [' Fa', 'vers', 'ham', ',', ' Kent', ',', ' Walter', ' Crane']
+3629 851 Name of mother of x -1 Name of mother of Walter Crane Marie Kearsley Walter Crane "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Walter'
+ ' Crane' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Walter Crane , the
+
+ The name of the" False " Crane ===
+" 4 [' Crane', ' ===', 'Wal', 'ter', ' Crane']
+3630 851 Name of mother of x -1 Name of mother of Walter Crane Marie Kearsley Walter Crane "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Walter'
+ ' Crane' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Walter Crane , the
+
+ The name of the" False panels by Walter Crane were unveiled in 3 [' panels', ' by', ' Walter', ' Crane']
+3631 851 Name of mother of x -1 Name of mother of Walter Crane Marie Kearsley Walter Crane "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Walter'
+ ' Crane' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Walter Crane , the
+
+ The name of the" False Evans employed Walter Crane to illustrate 3 [' Evans', ' employed', ' Walter', ' Crane']
+3632 854 Name of mother of x -1 Name of mother of William III of England Mary Henrietta, Princess Royal William III of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' William' ' III' ' of' ' England' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of William III of England , and the" False Peter Lely and William III of England by Godfrey Kneller, 7 [' Peter', ' Le', 'ly', ' and', ' William', ' III', ' of', ' England']
+3633 854 Name of mother of x -1 Name of mother of William III of England Mary Henrietta, Princess Royal William III of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' William' ' III' ' of' ' England' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of William III of England , and the" False godfather became William III of England not long after his 6 [' god', 'father', ' became', ' William', ' III', ' of', ' England']
+3634 854 Name of mother of x -1 Name of mother of William III of England Mary Henrietta, Princess Royal William III of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' William' ' III' ' of' ' England' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of William III of England , and the" False godfather became William III of England not long after his 6 [' god', 'father', ' became', ' William', ' III', ' of', ' England']
+3635 854 Name of mother of x -1 Name of mother of William III of England Mary Henrietta, Princess Royal William III of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' William' ' III' ' of' ' England' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of William III of England , and the" False Peter Lely and William III of England by Godfrey Kneller, 7 [' Peter', ' Le', 'ly', ' and', ' William', ' III', ' of', ' England']
+3636 854 Name of mother of x -1 Name of mother of William III of England Mary Henrietta, Princess Royal William III of England "[',' ' and' ' the' '\n' '\n' 'The' '\n' '\n' 'Name' ' of' ' the' ' father'
+ ' of' ' William' ' III' ' of' ' England' ',' ' and' ' the']" ", and the
+
+ The
+
+ Name of the father of William III of England , and the" False " III of England =
+" 7 [' III', ' of', ' England', ' =', 'William', ' III', ' of', ' England']
+3637 855 Name of mother of x -1 Name of mother of Sammy Davis Jr. Elvera Sanchez Sammy Davis Jr. "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False won Soul Train's Sammy Davis Jr. Entertainer of 7 "[' won', ' Soul', ' Train', ""'s"", ' Sammy', ' Davis', ' Jr', '.']"
+3638 855 Name of mother of x -1 Name of mother of Sammy Davis Jr. Elvera Sanchez Sammy Davis Jr. "['\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a' ' wife' ',' ' a'
+ ' daughter' ',' ' a' ' sister' ',' ' a' ' friend']" "
+
+ I am a mother of two , a wife , a daughter , a sister , a friend" False singer winning a Sammy Davis Jr. Award in 1989 6 [' singer', ' winning', ' a', ' Sammy', ' Davis', ' Jr', '.']
+3639 856 Name of mother of x -1 Name of mother of Daphne du Maurier Muriel Beaumont Daphne du Maurier "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' Rebecca' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' Gerald' 'ine'
+ ' Fitzgerald' '.' '\n']" ", the author of the famous novel Rebecca , and the mother of the actress Gerald ine Fitzgerald .
+" False Myself, I like the Daphne du Maurier touch and prefer 11 [' My', 'self', ',', ' I', ' like', ' the', ' D', 'aph', 'ne', ' du', ' Maur', 'ier']
+3640 856 Name of mother of x -1 Name of mother of Daphne du Maurier Muriel Beaumont Daphne du Maurier "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' Rebecca' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' Gerald' 'ine'
+ ' Fitzgerald' '.' '\n']" ", the author of the famous novel Rebecca , and the mother of the actress Gerald ine Fitzgerald .
+" False 5 ['D', 'aph', 'ne', ' du', ' Maur', 'ier']
+3641 856 Name of mother of x -1 Name of mother of Daphne du Maurier Muriel Beaumont Daphne du Maurier "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' Rebecca' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' Gerald' 'ine'
+ ' Fitzgerald' '.' '\n']" ", the author of the famous novel Rebecca , and the mother of the actress Gerald ine Fitzgerald .
+" False 5 ['D', 'aph', 'ne', ' du', ' Maur', 'ier']
+3642 856 Name of mother of x -1 Name of mother of Daphne du Maurier Muriel Beaumont Daphne du Maurier "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' Rebecca' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' Gerald' 'ine'
+ ' Fitzgerald' '.' '\n']" ", the author of the famous novel Rebecca , and the mother of the actress Gerald ine Fitzgerald .
+" False " prose as ""a mix of Daphne du Maurier nostalgia and surreal" 11 "[' prose', ' as', ' ""', 'a', ' mix', ' of', ' D', 'aph', 'ne', ' du', ' Maur', 'ier']"
+3643 856 Name of mother of x -1 Name of mother of Daphne du Maurier Muriel Beaumont Daphne du Maurier "[',' ' the' ' author' ' of' ' the' ' famous' ' novel' ' Rebecca' ','
+ ' and' ' the' ' mother' ' of' ' the' ' actress' ' Gerald' 'ine'
+ ' Fitzgerald' '.' '\n']" ", the author of the famous novel Rebecca , and the mother of the actress Gerald ine Fitzgerald .
+" False " finished tenth. He married Daphne du Maurier in July 1932.
+" 10 [' finished', ' tenth', '.', ' He', ' married', ' D', 'aph', 'ne', ' du', ' Maur', 'ier']
+3644 857 Name of mother of x -1 Name of mother of Wernher von Braun Emmy von Braun Wernher von Braun "[',' ' the' ' German' ' rocket' ' scientist' ' who' ' was' ' the'
+ ' father' ' of' ' the' ' V' '-' '2' ' rocket' ',' ' and' ' the' ' father'
+ ' of']" , the German rocket scientist who was the father of the V - 2 rocket , and the father of False program began, Wernher von Braun and his team 7 [' program', ' began', ',', ' W', 'ern', 'her', ' von', ' Braun']
+3645 857 Name of mother of x -1 Name of mother of Wernher von Braun Emmy von Braun Wernher von Braun "[',' ' the' ' German' ' rocket' ' scientist' ' who' ' was' ' the'
+ ' father' ' of' ' the' ' V' '-' '2' ' rocket' ',' ' and' ' the' ' father'
+ ' of']" , the German rocket scientist who was the father of the V - 2 rocket , and the father of False Richard Metzger, Wernher von Braun — who was nicknamed 9 [' Richard', ' Met', 'z', 'ger', ',', ' W', 'ern', 'her', ' von', ' Braun']
+3646 857 Name of mother of x -1 Name of mother of Wernher von Braun Emmy von Braun Wernher von Braun "[',' ' the' ' German' ' rocket' ' scientist' ' who' ' was' ' the'
+ ' father' ' of' ' the' ' V' '-' '2' ' rocket' ',' ' and' ' the' ' father'
+ ' of']" , the German rocket scientist who was the father of the V - 2 rocket , and the father of False approached that accorded to Wernher von Braun or Chris Kraft. Kraft 9 [' approached', ' that', ' accord', 'ed', ' to', ' W', 'ern', 'her', ' von', ' Braun']
+3647 857 Name of mother of x -1 Name of mother of Wernher von Braun Emmy von Braun Wernher von Braun "[',' ' the' ' German' ' rocket' ' scientist' ' who' ' was' ' the'
+ ' father' ' of' ' the' ' V' '-' '2' ' rocket' ',' ' and' ' the' ' father'
+ ' of']" , the German rocket scientist who was the father of the V - 2 rocket , and the father of False Apollo program began, Wernher von Braun and his team 8 [' Apollo', ' program', ' began', ',', ' W', 'ern', 'her', ' von', ' Braun']
+3648 857 Name of mother of x -1 Name of mother of Wernher von Braun Emmy von Braun Wernher von Braun "[',' ' the' ' German' ' rocket' ' scientist' ' who' ' was' ' the'
+ ' father' ' of' ' the' ' V' '-' '2' ' rocket' ',' ' and' ' the' ' father'
+ ' of']" , the German rocket scientist who was the father of the V - 2 rocket , and the father of False program began, Wernher von Braun and his team 7 [' program', ' began', ',', ' W', 'ern', 'her', ' von', ' Braun']
+3649 859 Name of mother of x -1 Name of mother of Jawaharlal Nehru Swarup Rani Nehru Jawaharlal Nehru "[' University' ',' ' Delhi' ',' ' India' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' children' ',' ' a' ' daughter' ' and' ' a'
+ ' son']" " University , Delhi , India
+
+ I am a mother of two children , a daughter and a son" False affiliated with either Jawaharlal Nehru Technological 8 [' affiliated', ' with', ' either', ' Jaw', 'ah', 'arl', 'al', ' Neh', 'ru']
+3650 859 Name of mother of x -1 Name of mother of Jawaharlal Nehru Swarup Rani Nehru Jawaharlal Nehru "[' University' ',' ' Delhi' ',' ' India' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' children' ',' ' a' ' daughter' ' and' ' a'
+ ' son']" " University , Delhi , India
+
+ I am a mother of two children , a daughter and a son" False after the death of Jawaharlal Nehru and ensured 9 [' after', ' the', ' death', ' of', ' Jaw', 'ah', 'arl', 'al', ' Neh', 'ru']
+3651 859 Name of mother of x -1 Name of mother of Jawaharlal Nehru Swarup Rani Nehru Jawaharlal Nehru "[' University' ',' ' Delhi' ',' ' India' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' children' ',' ' a' ' daughter' ' and' ' a'
+ ' son']" " University , Delhi , India
+
+ I am a mother of two children , a daughter and a son" False historian from the Jawaharlal Nehru University in Delhi. 8 [' historian', ' from', ' the', ' Jaw', 'ah', 'arl', 'al', ' Neh', 'ru']
+3652 859 Name of mother of x -1 Name of mother of Jawaharlal Nehru Swarup Rani Nehru Jawaharlal Nehru "[' University' ',' ' Delhi' ',' ' India' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' children' ',' ' a' ' daughter' ' and' ' a'
+ ' son']" " University , Delhi , India
+
+ I am a mother of two children , a daughter and a son" False Indian Super League. Jawaharlal Nehru Stadium is 9 [' Indian', ' Super', ' League', '.', ' Jaw', 'ah', 'arl', 'al', ' Neh', 'ru']
+3653 859 Name of mother of x -1 Name of mother of Jawaharlal Nehru Swarup Rani Nehru Jawaharlal Nehru "[' University' ',' ' Delhi' ',' ' India' '\n' '\n' 'I' ' am' ' a'
+ ' mother' ' of' ' two' ' children' ',' ' a' ' daughter' ' and' ' a'
+ ' son']" " University , Delhi , India
+
+ I am a mother of two children , a daughter and a son" False " independence"". Jawaharlal Nehru described in his autobiography" 7 "[' independence', '"".', ' Jaw', 'ah', 'arl', 'al', ' Neh', 'ru']"
+3654 860 Name of mother of x -1 Name of mother of Carl Larsson Johanna Carolina Erika Larsson Carl Larsson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' was' ' born']" ", the father of the child , and the mother of the child .
+
+ The child was born" False place where Carl Larsson intended it 4 [' place', ' where', ' Carl', ' Lar', 'sson']
+3655 860 Name of mother of x -1 Name of mother of Carl Larsson Johanna Carolina Erika Larsson Carl Larsson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' was' ' born']" ", the father of the child , and the mother of the child .
+
+ The child was born" False permanently where Carl Larsson had intended it to 4 [' permanently', ' where', ' Carl', ' Lar', 'sson']
+3656 860 Name of mother of x -1 Name of mother of Carl Larsson Johanna Carolina Erika Larsson Carl Larsson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' was' ' born']" ", the father of the child , and the mother of the child .
+
+ The child was born" False wall remain empty. Carl Larsson received the ruling 6 [' wall', ' remain', ' empty', '.', ' Carl', ' Lar', 'sson']
+3657 860 Name of mother of x -1 Name of mother of Carl Larsson Johanna Carolina Erika Larsson Carl Larsson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' was' ' born']" ", the father of the child , and the mother of the child .
+
+ The child was born" False wall remain empty. Carl Larsson received the ruling 6 [' wall', ' remain', ' empty', '.', ' Carl', ' Lar', 'sson']
+3658 860 Name of mother of x -1 Name of mother of Carl Larsson Johanna Carolina Erika Larsson Carl Larsson "[',' ' the' ' father' ' of' ' the' ' child' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' child' '.' '\n' '\n' 'The' ' child' ' was' ' born']" ", the father of the child , and the mother of the child .
+
+ The child was born" False permanently where Carl Larsson had intended it 4 [' permanently', ' where', ' Carl', ' Lar', 'sson']
+3659 861 Name of mother of x -1 Name of mother of Robert Peel Ellen Yates Robert Peel "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False mill was built for Robert Peel on the River Irwell 5 [' mill', ' was', ' built', ' for', ' Robert', ' Peel']
+3660 861 Name of mother of x -1 Name of mother of Robert Peel Ellen Yates Robert Peel "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False 1875, while Sir Robert Peel lived in another 6 [' 18', '75', ',', ' while', ' Sir', ' Robert', ' Peel']
+3661 861 Name of mother of x -1 Name of mother of Robert Peel Ellen Yates Robert Peel "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False agriculture. In 1780 Robert Peel built the first 6 [' agriculture', '.', ' In', ' 17', '80', ' Robert', ' Peel']
+3662 861 Name of mother of x -1 Name of mother of Robert Peel Ellen Yates Robert Peel "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False the Tories under Sir Robert Peel were not able to win 5 [' the', ' Tories', ' under', ' Sir', ' Robert', ' Peel']
+3663 861 Name of mother of x -1 Name of mother of Robert Peel Ellen Yates Robert Peel "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' a'
+ ' wife' ',' ' a' ' daughter' ',' ' a' ' sister' ',']" ", the
+
+ I am a mother of two , a wife , a daughter , a sister ," False to Tory statesman Robert Peel and readily agreed 5 [' to', ' Tory', ' states', 'man', ' Robert', ' Peel']
+3664 862 Name of mother of x -1 Name of mother of Robert Boyle Catherine Fenton Robert Boyle "[',' ' the' ' son' ' of' ' the' ' late' ' Robert' ' Boyle' ',' ' the'
+ ' famous' '\n' '\n' 'B' 'oyle' ',' ' who' ' was' ' a' ' great']" ", the son of the late Robert Boyle , the famous
+
+ B oyle , who was a great" False " Boyle =
+" 3 [' Boyle', ' =', 'Robert', ' Boyle']
+3665 862 Name of mother of x -1 Name of mother of Robert Boyle Catherine Fenton Robert Boyle "[',' ' the' ' son' ' of' ' the' ' late' ' Robert' ' Boyle' ',' ' the'
+ ' famous' '\n' '\n' 'B' 'oyle' ',' ' who' ' was' ' a' ' great']" ", the son of the late Robert Boyle , the famous
+
+ B oyle , who was a great" False observation with Robert Boyle in a letter written 3 [' observation', ' with', ' Robert', ' Boyle']
+3666 862 Name of mother of x -1 Name of mother of Robert Boyle Catherine Fenton Robert Boyle "[',' ' the' ' son' ' of' ' the' ' late' ' Robert' ' Boyle' ',' ' the'
+ ' famous' '\n' '\n' 'B' 'oyle' ',' ' who' ' was' ' a' ' great']" ", the son of the late Robert Boyle , the famous
+
+ B oyle , who was a great" False " (1654 – 1691)
+" 8 [' (', '16', '54', ' –', ' 16', '91', ')', 'Robert', ' Boyle']
+3667 862 Name of mother of x -1 Name of mother of Robert Boyle Catherine Fenton Robert Boyle "[',' ' the' ' son' ' of' ' the' ' late' ' Robert' ' Boyle' ',' ' the'
+ ' famous' '\n' '\n' 'B' 'oyle' ',' ' who' ' was' ' a' ' great']" ", the son of the late Robert Boyle , the famous
+
+ B oyle , who was a great" False around 1500, Robert Boyle (1670), and Joseph 4 [' around', ' 1500', ',', ' Robert', ' Boyle']
+3668 862 Name of mother of x -1 Name of mother of Robert Boyle Catherine Fenton Robert Boyle "[',' ' the' ' son' ' of' ' the' ' late' ' Robert' ' Boyle' ',' ' the'
+ ' famous' '\n' '\n' 'B' 'oyle' ',' ' who' ' was' ' a' ' great']" ", the son of the late Robert Boyle , the famous
+
+ B oyle , who was a great" False 1670 – Sir Robert Boyle performed an experiment 5 [' 16', '70', ' –', ' Sir', ' Robert', ' Boyle']
+3669 863 Name of mother of x -1 Name of mother of Matthew Arnold Mary Penrose Matthew Arnold "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '22' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 22 , and died in 18 92 .
+
+ The first" False contemporary poet Matthew Arnold was early in observing 3 [' contemporary', ' poet', ' Matthew', ' Arnold']
+3670 863 Name of mother of x -1 Name of mother of Matthew Arnold Mary Penrose Matthew Arnold "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '22' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 22 , and died in 18 92 .
+
+ The first" False writer and critic Matthew Arnold scoffed at Fuller's 4 [' writer', ' and', ' critic', ' Matthew', ' Arnold']
+3671 863 Name of mother of x -1 Name of mother of Matthew Arnold Mary Penrose Matthew Arnold "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '22' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 22 , and died in 18 92 .
+
+ The first" False English poet Matthew Arnold (1822 – 1888), 3 [' English', ' poet', ' Matthew', ' Arnold']
+3672 863 Name of mother of x -1 Name of mother of Matthew Arnold Mary Penrose Matthew Arnold "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '22' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 22 , and died in 18 92 .
+
+ The first" False contemporary poet Matthew Arnold was early in 3 [' contemporary', ' poet', ' Matthew', ' Arnold']
+3673 863 Name of mother of x -1 Name of mother of Matthew Arnold Mary Penrose Matthew Arnold "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' 18' '22' ',' ' and' ' died'
+ ' in' ' 18' '92' '.' '\n' '\n' 'The' ' first']" ", the poet , was born in 18 22 , and died in 18 92 .
+
+ The first" False " suggestive."" In 1865, Matthew Arnold singled out" 6 "[' suggestive', '.""', ' In', ' 1865', ',', ' Matthew', ' Arnold']"
+3674 865 Name of mother of x -1 Name of mother of Alec Guinness Agnes Cuff Alec Guinness "[',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the' ' family' ','
+ ' and' ' a' ' great' ' friend' ' of' ' the' ' family' ' of' ' the'
+ ' family']" , who was a great friend of the family , and a great friend of the family of the family False made famous by Sir Alec Guinness in the original Star 5 [' made', ' famous', ' by', ' Sir', ' Alec', ' Guinness']
+3675 865 Name of mother of x -1 Name of mother of Alec Guinness Agnes Cuff Alec Guinness "[',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the' ' family' ','
+ ' and' ' a' ' great' ' friend' ' of' ' the' ' family' ' of' ' the'
+ ' family']" , who was a great friend of the family , and a great friend of the family of the family False viewed footage of Alec Guinness as Fagin in Oliver 4 [' viewed', ' footage', ' of', ' Alec', ' Guinness']
+3676 865 Name of mother of x -1 Name of mother of Alec Guinness Agnes Cuff Alec Guinness "[',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the' ' family' ','
+ ' and' ' a' ' great' ' friend' ' of' ' the' ' family' ' of' ' the'
+ ' family']" , who was a great friend of the family , and a great friend of the family of the family False and a very young Alec Guinness outrageous and more 5 [' and', ' a', ' very', ' young', ' Alec', ' Guinness']
+3677 865 Name of mother of x -1 Name of mother of Alec Guinness Agnes Cuff Alec Guinness "[',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the' ' family' ','
+ ' and' ' a' ' great' ' friend' ' of' ' the' ' family' ' of' ' the'
+ ' family']" , who was a great friend of the family , and a great friend of the family of the family False years later. Alec Guinness first performed 4 [' years', ' later', '.', ' Alec', ' Guinness']
+3678 865 Name of mother of x -1 Name of mother of Alec Guinness Agnes Cuff Alec Guinness "[',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the' ' family' ','
+ ' and' ' a' ' great' ' friend' ' of' ' the' ' family' ' of' ' the'
+ ' family']" , who was a great friend of the family , and a great friend of the family of the family False colonel, played by Alec Guinness in an Oscar-winning 5 [' colonel', ',', ' played', ' by', ' Alec', ' Guinness']
+3679 866 Name of mother of x -1 Name of mother of Arthur Sullivan Mary Clementina Coghlan Arthur Sullivan "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur'
+ ' Sullivan' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Arthur Sullivan , the
+
+ The name of the" False unquestionably funny .... Mr. Arthur Sullivan has entered with 7 [' unquestion', 'ably', ' funny', '....', ' Mr', '.', ' Arthur', ' Sullivan']
+3680 866 Name of mother of x -1 Name of mother of Arthur Sullivan Mary Clementina Coghlan Arthur Sullivan "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur'
+ ' Sullivan' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Arthur Sullivan , the
+
+ The name of the" False variations depicting Arthur Sullivan and Hubert Parry, 3 [' variations', ' depicting', ' Arthur', ' Sullivan']
+3681 866 Name of mother of x -1 Name of mother of Arthur Sullivan Mary Clementina Coghlan Arthur Sullivan "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur'
+ ' Sullivan' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Arthur Sullivan , the
+
+ The name of the" False Robert, and Arthur Sullivan was in the audience 4 [' Robert', ',', ' and', ' Arthur', ' Sullivan']
+3682 866 Name of mother of x -1 Name of mother of Arthur Sullivan Mary Clementina Coghlan Arthur Sullivan "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur'
+ ' Sullivan' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Arthur Sullivan , the
+
+ The name of the" False Box by composer Arthur Sullivan and dramatist F. C. 4 [' Box', ' by', ' composer', ' Arthur', ' Sullivan']
+3683 866 Name of mother of x -1 Name of mother of Arthur Sullivan Mary Clementina Coghlan Arthur Sullivan "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Arthur'
+ ' Sullivan' ',' ' the' '\n' '\n' 'The' ' name' ' of' ' the']" ", the
+
+ The name of the mother of Arthur Sullivan , the
+
+ The name of the" False contributed by Mr. Arthur Sullivan so pretty and 5 [' contributed', ' by', ' Mr', '.', ' Arthur', ' Sullivan']
+3684 867 Name of mother of x -1 Name of mother of Alexander II of Russia Alexandra Feodorovna Alexander II of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Empress' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of' '\n' '\n']" ", and the
+
+ Russian Empire , and the Empress of Austria , and the Empress of
+
+" False claimed that Tsar Alexander II of Russia in 1870 inducted 7 [' claimed', ' that', ' Ts', 'ar', ' Alexander', ' II', ' of', ' Russia']
+3685 867 Name of mother of x -1 Name of mother of Alexander II of Russia Alexandra Feodorovna Alexander II of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Empress' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of' '\n' '\n']" ", and the
+
+ Russian Empire , and the Empress of Austria , and the Empress of
+
+" False " daughter of Alexander II of Russia and Marie of Hesse.
+" 5 [' daughter', ' of', ' Alexander', ' II', ' of', ' Russia']
+3686 867 Name of mother of x -1 Name of mother of Alexander II of Russia Alexandra Feodorovna Alexander II of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Empress' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of' '\n' '\n']" ", and the
+
+ Russian Empire , and the Empress of Austria , and the Empress of
+
+" False claimed that Tsar Alexander II of Russia in 1870 inducted him 7 [' claimed', ' that', ' Ts', 'ar', ' Alexander', ' II', ' of', ' Russia']
+3687 867 Name of mother of x -1 Name of mother of Alexander II of Russia Alexandra Feodorovna Alexander II of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Empress' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of' '\n' '\n']" ", and the
+
+ Russian Empire , and the Empress of Austria , and the Empress of
+
+" False led by a brother of Alexander II of Russia for the bicentennial 8 [' led', ' by', ' a', ' brother', ' of', ' Alexander', ' II', ' of', ' Russia']
+3688 867 Name of mother of x -1 Name of mother of Alexander II of Russia Alexandra Feodorovna Alexander II of Russia "[',' ' and' ' the' '\n' '\n' 'Russian' ' Empire' ',' ' and' ' the'
+ ' Empress' ' of' ' Austria' ',' ' and' ' the' ' Empress' ' of' '\n' '\n']" ", and the
+
+ Russian Empire , and the Empress of Austria , and the Empress of
+
+" False where Napoleon and Alexander II of Russia conferred on the 6 [' where', ' Napoleon', ' and', ' Alexander', ' II', ' of', ' Russia']
+3689 869 Name of mother of x -1 Name of mother of Maximilian I Eleanor of Portugal, Holy Roman Empress Maximilian I "[',' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman'
+ ' Emperor' ' Max' 'imil' 'ian' ' I' ' (' '14' '93' '–' '15']" , the Holy Roman Emperor , and the Holy Roman Emperor Max imil ian I ( 14 93 – 15 False Duke of Teschen, Maximilian I of Mexico 9 [' Duke', ' of', ' T', 'esc', 'hen', ',', ' Max', 'imil', 'ian', ' I']
+3690 869 Name of mother of x -1 Name of mother of Maximilian I Eleanor of Portugal, Holy Roman Empress Maximilian I "[',' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman'
+ ' Emperor' ' Max' 'imil' 'ian' ' I' ' (' '14' '93' '–' '15']" , the Holy Roman Emperor , and the Holy Roman Emperor Max imil ian I ( 14 93 – 15 False decisive victory, king Maximilian I had no choice 7 [' decisive', ' victory', ',', ' king', ' Max', 'imil', 'ian', ' I']
+3691 869 Name of mother of x -1 Name of mother of Maximilian I Eleanor of Portugal, Holy Roman Empress Maximilian I "[',' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman'
+ ' Emperor' ' Max' 'imil' 'ian' ' I' ' (' '14' '93' '–' '15']" , the Holy Roman Emperor , and the Holy Roman Emperor Max imil ian I ( 14 93 – 15 False hired by Emperor Maximilian I to escort 6 [' hired', ' by', ' Emperor', ' Max', 'imil', 'ian', ' I']
+3692 869 Name of mother of x -1 Name of mother of Maximilian I Eleanor of Portugal, Holy Roman Empress Maximilian I "[',' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman'
+ ' Emperor' ' Max' 'imil' 'ian' ' I' ' (' '14' '93' '–' '15']" , the Holy Roman Emperor , and the Holy Roman Emperor Max imil ian I ( 14 93 – 15 False was also gained by Maximilian I of Bavaria 7 [' was', ' also', ' gained', ' by', ' Max', 'imil', 'ian', ' I']
+3693 869 Name of mother of x -1 Name of mother of Maximilian I Eleanor of Portugal, Holy Roman Empress Maximilian I "[',' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman'
+ ' Emperor' ' Max' 'imil' 'ian' ' I' ' (' '14' '93' '–' '15']" , the Holy Roman Emperor , and the Holy Roman Emperor Max imil ian I ( 14 93 – 15 False by Emperor Maximilian I (reigned 1508 5 [' by', ' Emperor', ' Max', 'imil', 'ian', ' I']
+3694 870 Name of mother of x -1 Name of mother of Catherine de' Medici Madeleine de La Tour d'Auvergne Catherine de' Medici "[',' ' the' ' daughter' ' of' ' Lorenzo' ' de' ""'"" ' Medic' 'i' ',' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the daughter of Lorenzo de ' Medic i , the
+ " False (1553 – 1610), Catherine de' Medici (1519 – 89) and her 11 "[' (', '15', '53', ' –', ' 16', '10', '),', ' Catherine', ' de', ""'"", ' Medic', 'i']"
+3695 870 Name of mother of x -1 Name of mother of Catherine de' Medici Madeleine de La Tour d'Auvergne Catherine de' Medici "[',' ' the' ' daughter' ' of' ' Lorenzo' ' de' ""'"" ' Medic' 'i' ',' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the daughter of Lorenzo de ' Medic i , the
+ " False summoned by Queen Catherine de' Medici to Paris in 7 "[' summoned', ' by', ' Queen', ' Catherine', ' de', ""'"", ' Medic', 'i']"
+3696 870 Name of mother of x -1 Name of mother of Catherine de' Medici Madeleine de La Tour d'Auvergne Catherine de' Medici "[',' ' the' ' daughter' ' of' ' Lorenzo' ' de' ""'"" ' Medic' 'i' ',' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the daughter of Lorenzo de ' Medic i , the
+ " False France (1553 – 1610), Catherine de' Medici (1519 – 89) 12 "[' France', ' (', '15', '53', ' –', ' 16', '10', '),', ' Catherine', ' de', ""'"", ' Medic', 'i']"
+3697 870 Name of mother of x -1 Name of mother of Catherine de' Medici Madeleine de La Tour d'Auvergne Catherine de' Medici "[',' ' the' ' daughter' ' of' ' Lorenzo' ' de' ""'"" ' Medic' 'i' ',' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the daughter of Lorenzo de ' Medic i , the
+ " False (1553 – 1610), Catherine de' Medici (1519 – 89) and 11 "[' (', '15', '53', ' –', ' 16', '10', '),', ' Catherine', ' de', ""'"", ' Medic', 'i']"
+3698 870 Name of mother of x -1 Name of mother of Catherine de' Medici Madeleine de La Tour d'Auvergne Catherine de' Medici "[',' ' the' ' daughter' ' of' ' Lorenzo' ' de' ""'"" ' Medic' 'i' ',' ' the'
+ '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the daughter of Lorenzo de ' Medic i , the
+ " False summoned by Queen Catherine de' Medici to Paris in 7 "[' summoned', ' by', ' Queen', ' Catherine', ' de', ""'"", ' Medic', 'i']"
+3699 871 Name of mother of x -1 Name of mother of Sofia Coppola Eleanor Coppola Sofia Coppola "[""'s"" ' The' ' Virgin' ' Su' 'icides' ',' ' and' ' the' ' film' ""'s""
+ ' director' ',' ' Sof' 'ia' ' Co' 'pp' 'ola' ',' ' is' ' a']" 's The Virgin Su icides , and the film 's director , Sof ia Co pp ola , is a False track to the 2013 Sofia Coppola film The Bling 8 [' track', ' to', ' the', ' 2013', ' Sof', 'ia', ' Co', 'pp', 'ola']
+3700 871 Name of mother of x -1 Name of mother of Sofia Coppola Eleanor Coppola Sofia Coppola "[""'s"" ' The' ' Virgin' ' Su' 'icides' ',' ' and' ' the' ' film' ""'s""
+ ' director' ',' ' Sof' 'ia' ' Co' 'pp' 'ola' ',' ' is' ' a']" 's The Virgin Su icides , and the film 's director , Sof ia Co pp ola , is a False the tenacity of a Sofia Coppola to produce frames where 9 [' the', ' ten', 'acity', ' of', ' a', ' Sof', 'ia', ' Co', 'pp', 'ola']
+3701 871 Name of mother of x -1 Name of mother of Sofia Coppola Eleanor Coppola Sofia Coppola "[""'s"" ' The' ' Virgin' ' Su' 'icides' ',' ' and' ' the' ' film' ""'s""
+ ' director' ',' ' Sof' 'ia' ' Co' 'pp' 'ola' ',' ' is' ' a']" 's The Virgin Su icides , and the film 's director , Sof ia Co pp ola , is a False Coppola's daughter Sofia Coppola was cast in 9 "[' Co', 'pp', 'ola', ""'s"", ' daughter', ' Sof', 'ia', ' Co', 'pp', 'ola']"
+3702 871 Name of mother of x -1 Name of mother of Sofia Coppola Eleanor Coppola Sofia Coppola "[""'s"" ' The' ' Virgin' ' Su' 'icides' ',' ' and' ' the' ' film' ""'s""
+ ' director' ',' ' Sof' 'ia' ' Co' 'pp' 'ola' ',' ' is' ' a']" 's The Virgin Su icides , and the film 's director , Sof ia Co pp ola , is a False collaborated with Sofia Coppola again and starred 6 [' collaborated', ' with', ' Sof', 'ia', ' Co', 'pp', 'ola']
+3703 871 Name of mother of x -1 Name of mother of Sofia Coppola Eleanor Coppola Sofia Coppola "[""'s"" ' The' ' Virgin' ' Su' 'icides' ',' ' and' ' the' ' film' ""'s""
+ ' director' ',' ' Sof' 'ia' ' Co' 'pp' 'ola' ',' ' is' ' a']" 's The Virgin Su icides , and the film 's director , Sof ia Co pp ola , is a False Coppola's daughter Sofia Coppola was cast in 9 "[' Co', 'pp', 'ola', ""'s"", ' daughter', ' Sof', 'ia', ' Co', 'pp', 'ola']"
+3704 872 Name of mother of x -1 Name of mother of Cardinal Mazarin Ortensia Buffalini Cardinal Mazarin "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the'
+ ' other' ' being' ' the' '\n' '\n' 'The' ' first' ' of' ' the']" ", the
+
+ The first of the two , the other being the
+
+ The first of the" False of the powerful Cardinal Mazarin was celebrated in a 5 [' of', ' the', ' powerful', ' Cardinal', ' Maz', 'arin']
+3705 872 Name of mother of x -1 Name of mother of Cardinal Mazarin Ortensia Buffalini Cardinal Mazarin "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the'
+ ' other' ' being' ' the' '\n' '\n' 'The' ' first' ' of' ' the']" ", the
+
+ The first of the two , the other being the
+
+ The first of the" False Mediterranean, Cardinal Mazarin planned a naval 4 [' Mediterranean', ',', ' Cardinal', ' Maz', 'arin']
+3706 872 Name of mother of x -1 Name of mother of Cardinal Mazarin Ortensia Buffalini Cardinal Mazarin "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the'
+ ' other' ' being' ' the' '\n' '\n' 'The' ' first' ' of' ' the']" ", the
+
+ The first of the two , the other being the
+
+ The first of the" False Mediterranean, Cardinal Mazarin planned a naval 4 [' Mediterranean', ',', ' Cardinal', ' Maz', 'arin']
+3707 872 Name of mother of x -1 Name of mother of Cardinal Mazarin Ortensia Buffalini Cardinal Mazarin "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the'
+ ' other' ' being' ' the' '\n' '\n' 'The' ' first' ' of' ' the']" ", the
+
+ The first of the two , the other being the
+
+ The first of the" False fall of the powerful Cardinal Mazarin was celebrated in 6 [' fall', ' of', ' the', ' powerful', ' Cardinal', ' Maz', 'arin']
+3708 872 Name of mother of x -1 Name of mother of Cardinal Mazarin Ortensia Buffalini Cardinal Mazarin "[',' ' the' '\n' '\n' 'The' ' first' ' of' ' the' ' two' ',' ' the'
+ ' other' ' being' ' the' '\n' '\n' 'The' ' first' ' of' ' the']" ", the
+
+ The first of the two , the other being the
+
+ The first of the" False the powerful Cardinal Mazarin was celebrated 4 [' the', ' powerful', ' Cardinal', ' Maz', 'arin']
+3709 873 Name of mother of x -1 Name of mother of Robert Southey Margaret Hill Robert Southey "[',' ' the' ' poet' ',' ' was' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He' ' was']" , the poet , was a man of the people , and a man of the people . He was False Vassall-Fox and Robert Southey expressed their 9 [' V', 'ass', 'all', '-', 'Fox', ' and', ' Robert', ' S', 'out', 'hey']
+3710 873 Name of mother of x -1 Name of mother of Robert Southey Margaret Hill Robert Southey "[',' ' the' ' poet' ',' ' was' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He' ' was']" , the poet , was a man of the people , and a man of the people . He was False The Romantic poet Robert Southey based his 1794 play 6 [' The', ' Romantic', ' poet', ' Robert', ' S', 'out', 'hey']
+3711 873 Name of mother of x -1 Name of mother of Robert Southey Margaret Hill Robert Southey "[',' ' the' ' poet' ',' ' was' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He' ' was']" , the poet , was a man of the people , and a man of the people . He was False Vassall-Fox and Robert Southey expressed their admiration 9 [' V', 'ass', 'all', '-', 'Fox', ' and', ' Robert', ' S', 'out', 'hey']
+3712 873 Name of mother of x -1 Name of mother of Robert Southey Margaret Hill Robert Southey "[',' ' the' ' poet' ',' ' was' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He' ' was']" , the poet , was a man of the people , and a man of the people . He was False Vassall-Fox and Robert Southey expressed their admiration 9 [' V', 'ass', 'all', '-', 'Fox', ' and', ' Robert', ' S', 'out', 'hey']
+3713 873 Name of mother of x -1 Name of mother of Robert Southey Margaret Hill Robert Southey "[',' ' the' ' poet' ',' ' was' ' a' ' man' ' of' ' the' ' people' ','
+ ' and' ' a' ' man' ' of' ' the' ' people' '.' ' He' ' was']" , the poet , was a man of the people , and a man of the people . He was False 18th-century poets Robert Southey and Thomas Chatterton. 8 [' 18', 'th', '-', 'century', ' poets', ' Robert', ' S', 'out', 'hey']
+3714 874 Name of mother of x -1 Name of mother of James Madison Eleanor Rose Conway James Madison "[',' ' the' ' father' ' of' ' the' ' Constitution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Constitution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Constitution']" , the father of the Constitution , and the father of the Constitution , and the father of the Constitution False teams. In 2010, James Madison defeated 13th-ranked 6 [' teams', '.', ' In', ' 2010', ',', ' James', ' Madison']
+3715 874 Name of mother of x -1 Name of mother of James Madison Eleanor Rose Conway James Madison "[',' ' the' ' father' ' of' ' the' ' Constitution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Constitution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Constitution']" , the father of the Constitution , and the father of the Constitution , and the father of the Constitution False citizens led by James Madison Porter met 4 [' citizens', ' led', ' by', ' James', ' Madison']
+3716 874 Name of mother of x -1 Name of mother of James Madison Eleanor Rose Conway James Madison "[',' ' the' ' father' ' of' ' the' ' Constitution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Constitution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Constitution']" , the father of the Constitution , and the father of the Constitution , and the father of the Constitution False future U.S. presidents James Madison and James Monroe. 7 [' future', ' U', '.', 'S', '.', ' presidents', ' James', ' Madison']
+3717 874 Name of mother of x -1 Name of mother of James Madison Eleanor Rose Conway James Madison "[',' ' the' ' father' ' of' ' the' ' Constitution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Constitution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Constitution']" , the father of the Constitution , and the father of the Constitution , and the father of the Constitution False Americans as James Madison and Alexander Hamilton. 3 [' Americans', ' as', ' James', ' Madison']
+3718 874 Name of mother of x -1 Name of mother of James Madison Eleanor Rose Conway James Madison "[',' ' the' ' father' ' of' ' the' ' Constitution' ',' ' and' ' the'
+ ' father' ' of' ' the' ' Constitution' ',' ' and' ' the' ' father' ' of'
+ ' the' ' Constitution']" , the father of the Constitution , and the father of the Constitution , and the father of the Constitution False argued that James Madison could have drafted 3 [' argued', ' that', ' James', ' Madison']
+3719 875 Name of mother of x -1 Name of mother of Evelyn Waugh Catherine Charlotte Raban Evelyn Waugh "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False whose creation both Evelyn Waugh and George Orwell 6 [' whose', ' creation', ' both', ' Eve', 'lyn', ' W', 'augh']
+3720 875 Name of mother of x -1 Name of mother of Evelyn Waugh Catherine Charlotte Raban Evelyn Waugh "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False War II. People like Evelyn Waugh and Graham Greene 8 [' War', ' II', '.', ' People', ' like', ' Eve', 'lyn', ' W', 'augh']
+3721 875 Name of mother of x -1 Name of mother of Evelyn Waugh Catherine Charlotte Raban Evelyn Waugh "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False whose creation both Evelyn Waugh and George Orwell 6 [' whose', ' creation', ' both', ' Eve', 'lyn', ' W', 'augh']
+3722 875 Name of mother of x -1 Name of mother of Evelyn Waugh Catherine Charlotte Raban Evelyn Waugh "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False the novelist Evelyn Waugh who made it 5 [' the', ' novelist', ' Eve', 'lyn', ' W', 'augh']
+3723 875 Name of mother of x -1 Name of mother of Evelyn Waugh Catherine Charlotte Raban Evelyn Waugh "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False The novelist Evelyn Waugh wrote a preface 5 [' The', ' novelist', ' Eve', 'lyn', ' W', 'augh']
+3724 876 Name of mother of x -1 Name of mother of Humphry Davy Grace Millett Humphry Davy "[',' ' the' ' son' ' of' ' the' ' late' ' Humph' 'ry' ' Dav' 'y' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the son of the late Humph ry Dav y , the
+ " False 19th-century chemist Humphry Davy (who worked in Hotwells). 8 [' 19', 'th', '-', 'century', ' chemist', ' Humph', 'ry', ' Dav', 'y']
+3725 876 Name of mother of x -1 Name of mother of Humphry Davy Grace Millett Humphry Davy "[',' ' the' ' son' ' of' ' the' ' late' ' Humph' 'ry' ' Dav' 'y' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the son of the late Humph ry Dav y , the
+ " False century, after Humphry Davy discovered the 6 [' century', ',', ' after', ' Humph', 'ry', ' Dav', 'y']
+3726 876 Name of mother of x -1 Name of mother of Humphry Davy Grace Millett Humphry Davy "[',' ' the' ' son' ' of' ' the' ' late' ' Humph' 'ry' ' Dav' 'y' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the son of the late Humph ry Dav y , the
+ " False 19th-century chemist Humphry Davy (who worked in Hotwells). 8 [' 19', 'th', '-', 'century', ' chemist', ' Humph', 'ry', ' Dav', 'y']
+3727 876 Name of mother of x -1 Name of mother of Humphry Davy Grace Millett Humphry Davy "[',' ' the' ' son' ' of' ' the' ' late' ' Humph' 'ry' ' Dav' 'y' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the son of the late Humph ry Dav y , the
+ " False arc in 1800 by Humphry Davy and of the continuous 7 [' arc', ' in', ' 1800', ' by', ' Humph', 'ry', ' Dav', 'y']
+3728 876 Name of mother of x -1 Name of mother of Humphry Davy Grace Millett Humphry Davy "[',' ' the' ' son' ' of' ' the' ' late' ' Humph' 'ry' ' Dav' 'y' ','
+ ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ']" ", the son of the late Humph ry Dav y , the
+ " False 1808, by Sir Humphry Davy in England. 8 [' 18', '08', ',', ' by', ' Sir', ' Humph', 'ry', ' Dav', 'y']
+3729 877 Name of mother of x -1 Name of mother of William Holman Hunt Sarah Holman William Holman Hunt "[',' ' the' ' painter' ',' ' and' ' his' ' wife' ',' ' Mary' ',' ' who'
+ ' was' ' the' ' daughter' ' of' ' the' ' Rev' '.' ' John' ' Hunt']" , the painter , and his wife , Mary , who was the daughter of the Rev . John Hunt False Pre-Raphaelite Brotherhood. William Holman Hunt and John Everett 10 [' Pre', '-', 'R', 'aphael', 'ite', ' Brotherhood', '.', ' William', ' Hol', 'man', ' Hunt']
+3730 879 Name of mother of x -1 Name of mother of Edgar Rice Burroughs Mary Evaline Zieger Edgar Rice Burroughs "[',' ' the' ' author' ' of' ' Tar' 'zan' ' of' ' the' ' Ap' 'es' ','
+ ' and' ' the' ' creator' ' of' ' Tar' 'zan' ' of' ' the' ' Ap']" , the author of Tar zan of the Ap es , and the creator of Tar zan of the Ap False Press's line of Edgar Rice Burroughs books (inked by Crandall); 8 "[' Press', ""'s"", ' line', ' of', ' Edgar', ' Rice', ' Bur', 'rough', 's']"
+3731 879 Name of mother of x -1 Name of mother of Edgar Rice Burroughs Mary Evaline Zieger Edgar Rice Burroughs "[',' ' the' ' author' ' of' ' Tar' 'zan' ' of' ' the' ' Ap' 'es' ','
+ ' and' ' the' ' creator' ' of' ' Tar' 'zan' ' of' ' the' ' Ap']" , the author of Tar zan of the Ap es , and the creator of Tar zan of the Ap False the works of Edgar Rice Burroughs as pioneers. The 7 [' the', ' works', ' of', ' Edgar', ' Rice', ' Bur', 'rough', 's']
+3732 879 Name of mother of x -1 Name of mother of Edgar Rice Burroughs Mary Evaline Zieger Edgar Rice Burroughs "[',' ' the' ' author' ' of' ' Tar' 'zan' ' of' ' the' ' Ap' 'es' ','
+ ' and' ' the' ' creator' ' of' ' Tar' 'zan' ' of' ' the' ' Ap']" , the author of Tar zan of the Ap es , and the creator of Tar zan of the Ap False Tarzan novels by Edgar Rice Burroughs and recently Alan 8 [' Tar', 'zan', ' novels', ' by', ' Edgar', ' Rice', ' Bur', 'rough', 's']
+3733 879 Name of mother of x -1 Name of mother of Edgar Rice Burroughs Mary Evaline Zieger Edgar Rice Burroughs "[',' ' the' ' author' ' of' ' Tar' 'zan' ' of' ' the' ' Ap' 'es' ','
+ ' and' ' the' ' creator' ' of' ' Tar' 'zan' ' of' ' the' ' Ap']" , the author of Tar zan of the Ap es , and the creator of Tar zan of the Ap False Raymond Z. Gallun, and Edgar Rice Burroughs who seem to 11 [' Raymond', ' Z', '.', ' Gall', 'un', ',', ' and', ' Edgar', ' Rice', ' Bur', 'rough', 's']
+3734 879 Name of mother of x -1 Name of mother of Edgar Rice Burroughs Mary Evaline Zieger Edgar Rice Burroughs "[',' ' the' ' author' ' of' ' Tar' 'zan' ' of' ' the' ' Ap' 'es' ','
+ ' and' ' the' ' creator' ' of' ' Tar' 'zan' ' of' ' the' ' Ap']" , the author of Tar zan of the Ap es , and the creator of Tar zan of the Ap False adventure novels by Edgar Rice Burroughs and H. Rider Haggard. 7 [' adventure', ' novels', ' by', ' Edgar', ' Rice', ' Bur', 'rough', 's']
+3735 880 Name of mother of x -1 Name of mother of James Stewart Elizabeth Ruth Stewart James Stewart "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' James' ' Stewart' ',' ' the' ' actor' ',' ' and' ' his' ' wife']" , who was a member of the family of the late James Stewart , the actor , and his wife False Crump, and featured James Stewart as a dashing pilot. 6 [' Cr', 'ump', ',', ' and', ' featured', ' James', ' Stewart']
+3736 880 Name of mother of x -1 Name of mother of James Stewart Elizabeth Ruth Stewart James Stewart "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' James' ' Stewart' ',' ' the' ' actor' ',' ' and' ' his' ' wife']" , who was a member of the family of the late James Stewart , the actor , and his wife False Autumn (1964) James Stewart as Wyatt Earp 5 [' Autumn', ' (', '1964', ')', ' James', ' Stewart']
+3737 880 Name of mother of x -1 Name of mother of James Stewart Elizabeth Ruth Stewart James Stewart "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' James' ' Stewart' ',' ' the' ' actor' ',' ' and' ' his' ' wife']" , who was a member of the family of the late James Stewart , the actor , and his wife False uncle, Sir James Stewart of Durrisdeer, who 4 [' uncle', ',', ' Sir', ' James', ' Stewart']
+3738 880 Name of mother of x -1 Name of mother of James Stewart Elizabeth Ruth Stewart James Stewart "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' James' ' Stewart' ',' ' the' ' actor' ',' ' and' ' his' ' wife']" , who was a member of the family of the late James Stewart , the actor , and his wife False Two plays later, James Stewart broke free 5 [' Two', ' plays', ' later', ',', ' James', ' Stewart']
+3739 880 Name of mother of x -1 Name of mother of James Stewart Elizabeth Ruth Stewart James Stewart "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' family' ' of' ' the'
+ ' late' ' James' ' Stewart' ',' ' the' ' actor' ',' ' and' ' his' ' wife']" , who was a member of the family of the late James Stewart , the actor , and his wife False Washington in which James Stewart tries to change 4 [' Washington', ' in', ' which', ' James', ' Stewart']
+3740 881 Name of mother of x -1 Name of mother of Julius II Teodora Marinola Julius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False 2 ['Jul', 'ius', ' II']
+3741 881 Name of mother of x -1 Name of mother of Julius II Teodora Marinola Julius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False return, Pope Julius II honoured Zwingli by 4 [' return', ',', ' Pope', ' Julius', ' II']
+3742 881 Name of mother of x -1 Name of mother of Julius II Teodora Marinola Julius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False " during the papacy of Julius II (1503 to 1513).
+" 6 [' during', ' the', ' pap', 'acy', ' of', ' Julius', ' II']
+3743 881 Name of mother of x -1 Name of mother of Julius II Teodora Marinola Julius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False " the papacy of Julius II (1503 to 1513).
+" 5 [' the', ' pap', 'acy', ' of', ' Julius', ' II']
+3744 881 Name of mother of x -1 Name of mother of Julius II Teodora Marinola Julius II "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ ' Pope' ""'s"" ' mother' ',' ' and' ' the' ' Pope' ""'s"" ' mother']" , the Pope , and the Pope 's mother , the Pope 's mother , and the Pope 's mother False was elected Pope Julius II in 1503, he issued 4 [' was', ' elected', ' Pope', ' Julius', ' II']
+3745 882 Name of mother of x -1 Name of mother of Freeman Dyson Mildred Atkey Freeman Dyson "[',' ' the' ' father' ' of' ' the' ' modern' ' atomic' ' bomb' '.' '\n'
+ '\n' 'The' ' first' ' atomic' ' bomb' ' was' ' dropped' ' on'
+ ' Hiroshima' ' on']" ", the father of the modern atomic bomb .
+
+ The first atomic bomb was dropped on Hiroshima on" False renormalization. Freeman Dyson was able to prove 6 [' ren', 'ormal', 'ization', '.', ' Freeman', ' D', 'yson']
+3746 882 Name of mother of x -1 Name of mother of Freeman Dyson Mildred Atkey Freeman Dyson "[',' ' the' ' father' ' of' ' the' ' modern' ' atomic' ' bomb' '.' '\n'
+ '\n' 'The' ' first' ' atomic' ' bomb' ' was' ' dropped' ' on'
+ ' Hiroshima' ' on']" ", the father of the modern atomic bomb .
+
+ The first atomic bomb was dropped on Hiroshima on" False Richard Feynman and Freeman Dyson, it was finally 7 [' Richard', ' Fe', 'yn', 'man', ' and', ' Freeman', ' D', 'yson']
+3747 882 Name of mother of x -1 Name of mother of Freeman Dyson Mildred Atkey Freeman Dyson "[',' ' the' ' father' ' of' ' the' ' modern' ' atomic' ' bomb' '.' '\n'
+ '\n' 'The' ' first' ' atomic' ' bomb' ' was' ' dropped' ' on'
+ ' Hiroshima' ' on']" ", the father of the modern atomic bomb .
+
+ The first atomic bomb was dropped on Hiroshima on" False An argument by Freeman Dyson shows that the radius 5 [' An', ' argument', ' by', ' Freeman', ' D', 'yson']
+3748 882 Name of mother of x -1 Name of mother of Freeman Dyson Mildred Atkey Freeman Dyson "[',' ' the' ' father' ' of' ' the' ' modern' ' atomic' ' bomb' '.' '\n'
+ '\n' 'The' ' first' ' atomic' ' bomb' ' was' ' dropped' ' on'
+ ' Hiroshima' ' on']" ", the father of the modern atomic bomb .
+
+ The first atomic bomb was dropped on Hiroshima on" False argument by Freeman Dyson shows that the 4 [' argument', ' by', ' Freeman', ' D', 'yson']
+3749 882 Name of mother of x -1 Name of mother of Freeman Dyson Mildred Atkey Freeman Dyson "[',' ' the' ' father' ' of' ' the' ' modern' ' atomic' ' bomb' '.' '\n'
+ '\n' 'The' ' first' ' atomic' ' bomb' ' was' ' dropped' ' on'
+ ' Hiroshima' ' on']" ", the father of the modern atomic bomb .
+
+ The first atomic bomb was dropped on Hiroshima on" False projects. He worked with Freeman Dyson on Project Orion, 7 [' projects', '.', ' He', ' worked', ' with', ' Freeman', ' D', 'yson']
+3750 883 Name of mother of x -1 Name of mother of Sam Shepard Jane Elaine Schook Sam Shepard "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' and' ' the' ' film' ""'s"" ' director' ',' ' and' ' the' ' film']" , the author of the book , and the film , and the film 's director , and the film False the performance of Sam Shepard in Days of Heaven. 4 [' the', ' performance', ' of', ' Sam', ' Shepard']
+3751 883 Name of mother of x -1 Name of mother of Sam Shepard Jane Elaine Schook Sam Shepard "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' and' ' the' ' film' ""'s"" ' director' ',' ' and' ' the' ' film']" , the author of the book , and the film , and the film 's director , and the film False performance of Sam Shepard in Days of Heaven. 3 [' performance', ' of', ' Sam', ' Shepard']
+3752 883 Name of mother of x -1 Name of mother of Sam Shepard Jane Elaine Schook Sam Shepard "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' and' ' the' ' film' ""'s"" ' director' ',' ' and' ' the' ' film']" , the author of the book , and the film , and the film 's director , and the film False saw me perform in the Sam Shepard play Seduced 6 [' saw', ' me', ' perform', ' in', ' the', ' Sam', ' Shepard']
+3753 883 Name of mother of x -1 Name of mother of Sam Shepard Jane Elaine Schook Sam Shepard "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' and' ' the' ' film' ""'s"" ' director' ',' ' and' ' the' ' film']" , the author of the book , and the film , and the film 's director , and the film False performance of Sam Shepard in Days of Heaven. 3 [' performance', ' of', ' Sam', ' Shepard']
+3754 883 Name of mother of x -1 Name of mother of Sam Shepard Jane Elaine Schook Sam Shepard "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' film' ','
+ ' and' ' the' ' film' ""'s"" ' director' ',' ' and' ' the' ' film']" , the author of the book , and the film , and the film 's director , and the film False (Tom Petty, Sam Shepard and Carole Bayer 5 [' (', 'Tom', ' Petty', ',', ' Sam', ' Shepard']
+3755 884 Name of mother of x -1 Name of mother of Francis de Sales Françoise de Sionnaz de Vallières Francis de Sales "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' de' ' Sales' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Francis de Sales , the
+
+ The name of" False 3 ['Franc', 'is', ' de', ' Sales']
+3756 884 Name of mother of x -1 Name of mother of Francis de Sales Françoise de Sionnaz de Vallières Francis de Sales "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' de' ' Sales' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Francis de Sales , the
+
+ The name of" False Anderson's St. Francis de Sales High School. Mike 6 "[' Anderson', ""'s"", ' St', '.', ' Francis', ' de', ' Sales']"
+3757 884 Name of mother of x -1 Name of mother of Francis de Sales Françoise de Sionnaz de Vallières Francis de Sales "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' de' ' Sales' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Francis de Sales , the
+
+ The name of" False confirmed at St. Francis de Sales Roman Catholic 6 [' confirmed', ' at', ' St', '.', ' Francis', ' de', ' Sales']
+3758 884 Name of mother of x -1 Name of mother of Francis de Sales Françoise de Sionnaz de Vallières Francis de Sales "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' de' ' Sales' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Francis de Sales , the
+
+ The name of" False " ===
+" 4 [' ===', 'Franc', 'is', ' de', ' Sales']
+3759 884 Name of mother of x -1 Name of mother of Francis de Sales Françoise de Sionnaz de Vallières Francis de Sales "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of'
+ ' Francis' ' de' ' Sales' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Francis de Sales , the
+
+ The name of" False " === Saint Francis de Sales ===
+" 4 [' ===', ' Saint', ' Francis', ' de', ' Sales']
+3760 885 Name of mother of x -1 Name of mother of Meret Becker Monika Hansen Meret Becker "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' ',' ' and' ' the' ' mother' ' of' ' the' ' groom']" , the mother of the bride , and the mother of the groom , and the mother of the groom False instrument in U2. Meret Becker is the lead actress 7 [' instrument', ' in', ' U', '2', '.', ' M', 'eret', ' Becker']
+3761 885 Name of mother of x -1 Name of mother of Meret Becker Monika Hansen Meret Becker "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' ',' ' and' ' the' ' mother' ' of' ' the' ' groom']" , the mother of the bride , and the mother of the groom , and the mother of the groom False instrument in U2. Meret Becker is the lead 7 [' instrument', ' in', ' U', '2', '.', ' M', 'eret', ' Becker']
+3762 885 Name of mother of x -1 Name of mother of Meret Becker Monika Hansen Meret Becker "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' ',' ' and' ' the' ' mother' ' of' ' the' ' groom']" , the mother of the bride , and the mother of the groom , and the mother of the groom False instrument in U2. Meret Becker is the lead actress 7 [' instrument', ' in', ' U', '2', '.', ' M', 'eret', ' Becker']
+3763 886 Name of mother of x -1 Name of mother of James II of England Henrietta Maria of France James II of England "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False Elector Palatine. James II of England put forward his brother-in-law, 9 [' Elect', 'or', ' Pal', 'at', 'ine', '.', ' James', ' II', ' of', ' England']
+3764 886 Name of mother of x -1 Name of mother of James II of England Henrietta Maria of France James II of England "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False brother, who became James II of England and Ireland 7 [' brother', ',', ' who', ' became', ' James', ' II', ' of', ' England']
+3765 886 Name of mother of x -1 Name of mother of James II of England Henrietta Maria of France James II of England "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False the future James II of England and James VII of 5 [' the', ' future', ' James', ' II', ' of', ' England']
+3766 886 Name of mother of x -1 Name of mother of James II of England Henrietta Maria of France James II of England "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False request of king James II of England to king Charles 6 [' request', ' of', ' king', ' James', ' II', ' of', ' England']
+3767 886 Name of mother of x -1 Name of mother of James II of England Henrietta Maria of France James II of England "[',' ' and' ' the' '\n' '\n' 'King' ' of' ' Scots' ',' ' and' ' the'
+ ' King' ' of' ' England' ',' ' and' ' the' ' King' ' of' ' France']" ", and the
+
+ King of Scots , and the King of England , and the King of France" False Duke of York, later James II of England (James VII of 8 [' Duke', ' of', ' York', ',', ' later', ' James', ' II', ' of', ' England']
+3768 887 Name of mother of x -1 Name of mother of David Nitzevet David "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' David' ',' ' and' ' the'
+ ' Name' ' of' ' David' ',' ' and' ' the' ' Name' ' of' ' David']" ", and the
+
+ Name of David , and the Name of David , and the Name of David" False was founded by David Chen, Andrew Walsh 3 [' was', ' founded', ' by', ' David']
+3769 887 Name of mother of x -1 Name of mother of David Nitzevet David "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' David' ',' ' and' ' the'
+ ' Name' ' of' ' David' ',' ' and' ' the' ' Name' ' of' ' David']" ", and the
+
+ Name of David , and the Name of David , and the Name of David" False dropped out, and David Lean came to direct 4 [' dropped', ' out', ',', ' and', ' David']
+3770 887 Name of mother of x -1 Name of mother of David Nitzevet David "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' David' ',' ' and' ' the'
+ ' Name' ' of' ' David' ',' ' and' ' the' ' Name' ' of' ' David']" ", and the
+
+ Name of David , and the Name of David , and the Name of David" False " of mission"". David McComb of Empire" 3 "[' of', ' mission', '"".', ' David']"
+3771 887 Name of mother of x -1 Name of mother of David Nitzevet David "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' David' ',' ' and' ' the'
+ ' Name' ' of' ' David' ',' ' and' ' the' ' Name' ' of' ' David']" ", and the
+
+ Name of David , and the Name of David , and the Name of David" False 0 ['David']
+3772 887 Name of mother of x -1 Name of mother of David Nitzevet David "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' David' ',' ' and' ' the'
+ ' Name' ' of' ' David' ',' ' and' ' the' ' Name' ' of' ' David']" ", and the
+
+ Name of David , and the Name of David , and the Name of David" False 0 ['David']
+3773 888 Name of mother of x -1 Name of mother of Petronius Plautia Petronius "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Pet' 'ron'
+ 'ius' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Pet ron ius , and the
+
+ Name of mother" False 94, with Titus Petronius Secundus as his colleague. 6 [' 94', ',', ' with', ' Titus', ' Pet', 'ron', 'ius']
+3774 888 Name of mother of x -1 Name of mother of Petronius Plautia Petronius "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Pet' 'ron'
+ 'ius' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Pet ron ius , and the
+
+ Name of mother" False poem; and the novelist Petronius (Satyricon) 7 [' poem', ';', ' and', ' the', ' novelist', ' Pet', 'ron', 'ius']
+3775 888 Name of mother of x -1 Name of mother of Petronius Plautia Petronius "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Pet' 'ron'
+ 'ius' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Pet ron ius , and the
+
+ Name of mother" False Satyricon of Petronius and The Golden 7 [' Sat', 'y', 'ric', 'on', ' of', ' Pet', 'ron', 'ius']
+3776 888 Name of mother of x -1 Name of mother of Petronius Plautia Petronius "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Pet' 'ron'
+ 'ius' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Pet ron ius , and the
+
+ Name of mother" False Praetorians. Titus Petronius Secundus and Parthenius 7 [' Pra', 'et', 'orians', '.', ' Titus', ' Pet', 'ron', 'ius']
+3777 888 Name of mother of x -1 Name of mother of Petronius Plautia Petronius "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Pet' 'ron'
+ 'ius' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of Pet ron ius , and the
+
+ Name of mother" False Satyricon of Petronius and The Golden Ass 7 [' Sat', 'y', 'ric', 'on', ' of', ' Pet', 'ron', 'ius']
+3778 889 Name of mother of x -1 Name of mother of Roger Federer Lynette Federer Roger Federer "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Mir' 'ka'
+ ',' ' who' ' is' ' a' ' former' ' world' ' number' ' one' ' tennis']" , the tennis player , and his wife , Mir ka , who is a former world number one tennis False joined Novak Djokovic, Roger Federer and Rafael Nadal 9 [' joined', ' Nov', 'ak', ' Dj', 'ok', 'ovic', ',', ' Roger', ' Fed', 'erer']
+3779 889 Name of mother of x -1 Name of mother of Roger Federer Lynette Federer Roger Federer "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Mir' 'ka'
+ ',' ' who' ' is' ' a' ' former' ' world' ' number' ' one' ' tennis']" , the tennis player , and his wife , Mir ka , who is a former world number one tennis False " golden era, Roger Federer remains skeptical:
+" 5 [' golden', ' era', ',', ' Roger', ' Fed', 'erer']
+3780 889 Name of mother of x -1 Name of mother of Roger Federer Lynette Federer Roger Federer "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Mir' 'ka'
+ ',' ' who' ' is' ' a' ' former' ' world' ' number' ' one' ' tennis']" , the tennis player , and his wife , Mir ka , who is a former world number one tennis False 2 ['Roger', ' Fed', 'erer']
+3781 889 Name of mother of x -1 Name of mother of Roger Federer Lynette Federer Roger Federer "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Mir' 'ka'
+ ',' ' who' ' is' ' a' ' former' ' world' ' number' ' one' ' tennis']" , the tennis player , and his wife , Mir ka , who is a former world number one tennis False Rafael Nadal and Roger Federer then world number's 6 [' Rafael', ' Nad', 'al', ' and', ' Roger', ' Fed', 'erer']
+3782 889 Name of mother of x -1 Name of mother of Roger Federer Lynette Federer Roger Federer "[',' ' the' ' tennis' ' player' ',' ' and' ' his' ' wife' ',' ' Mir' 'ka'
+ ',' ' who' ' is' ' a' ' former' ' world' ' number' ' one' ' tennis']" , the tennis player , and his wife , Mir ka , who is a former world number one tennis False pitted world No. 3 Roger Federer against world 7 [' pitted', ' world', ' No', '.', ' 3', ' Roger', ' Fed', 'erer']
+3783 890 Name of mother of x -1 Name of mother of Paul Langevin Marie-Adelaide Pinel Paul Langevin "[',' ' the' ' father' ' of' ' the' ' famous' ' French' ' poet' ',' ' was'
+ ' a' ' native' ' of' ' the' ' same' ' place' '.' '\n' '\n' 'The']" ", the father of the famous French poet , was a native of the same place .
+
+ The" False transducer used by Paul Langevin in early sonar 7 [' trans', 'du', 'cer', ' used', ' by', ' Paul', ' Lange', 'vin']
+3784 890 Name of mother of x -1 Name of mother of Paul Langevin Marie-Adelaide Pinel Paul Langevin "[',' ' the' ' father' ' of' ' the' ' famous' ' French' ' poet' ',' ' was'
+ ' a' ' native' ' of' ' the' ' same' ' place' '.' '\n' '\n' 'The']" ", the father of the famous French poet , was a native of the same place .
+
+ The" False transducer used by Paul Langevin in early sonar research. 7 [' trans', 'du', 'cer', ' used', ' by', ' Paul', ' Lange', 'vin']
+3785 890 Name of mother of x -1 Name of mother of Paul Langevin Marie-Adelaide Pinel Paul Langevin "[',' ' the' ' father' ' of' ' the' ' famous' ' French' ' poet' ',' ' was'
+ ' a' ' native' ' of' ' the' ' same' ' place' '.' '\n' '\n' 'The']" ", the father of the famous French poet , was a native of the same place .
+
+ The" False transducer used by Paul Langevin in early sonar 7 [' trans', 'du', 'cer', ' used', ' by', ' Paul', ' Lange', 'vin']
+3786 890 Name of mother of x -1 Name of mother of Paul Langevin Marie-Adelaide Pinel Paul Langevin "[',' ' the' ' father' ' of' ' the' ' famous' ' French' ' poet' ',' ' was'
+ ' a' ' native' ' of' ' the' ' same' ' place' '.' '\n' '\n' 'The']" ", the father of the famous French poet , was a native of the same place .
+
+ The" False transducer used by Paul Langevin in early sonar research. 7 [' trans', 'du', 'cer', ' used', ' by', ' Paul', ' Lange', 'vin']
+3787 890 Name of mother of x -1 Name of mother of Paul Langevin Marie-Adelaide Pinel Paul Langevin "[',' ' the' ' father' ' of' ' the' ' famous' ' French' ' poet' ',' ' was'
+ ' a' ' native' ' of' ' the' ' same' ' place' '.' '\n' '\n' 'The']" ", the father of the famous French poet , was a native of the same place .
+
+ The" False transducer used by Paul Langevin in early sonar 7 [' trans', 'du', 'cer', ' used', ' by', ' Paul', ' Lange', 'vin']
+3788 891 Name of mother of x -1 Name of mother of Anthony Eden Sybil Frances Grey Anthony Eden "[',' ' the' ' former' ' prime' ' minister' ' of' ' Great' ' Britain' ','
+ ' and' ' the' ' former' ' prime' ' minister' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former prime minister of Great Britain , and the former prime minister of the United States , and False ministerial post: Anthony Eden remained in 4 [' ministerial', ' post', ':', ' Anthony', ' Eden']
+3789 891 Name of mother of x -1 Name of mother of Anthony Eden Sybil Frances Grey Anthony Eden "[',' ' the' ' former' ' prime' ' minister' ' of' ' Great' ' Britain' ','
+ ' and' ' the' ' former' ' prime' ' minister' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former prime minister of Great Britain , and the former prime minister of the United States , and False " Eden hat =
+" 4 [' Eden', ' hat', ' =', 'Anthony', ' Eden']
+3790 891 Name of mother of x -1 Name of mother of Anthony Eden Sybil Frances Grey Anthony Eden "[',' ' the' ' former' ' prime' ' minister' ' of' ' Great' ' Britain' ','
+ ' and' ' the' ' former' ' prime' ' minister' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former prime minister of Great Britain , and the former prime minister of the United States , and False Charles de Gaulle, Sir Anthony Eden and Lord Mountbatten 7 [' Charles', ' de', ' Gaul', 'le', ',', ' Sir', ' Anthony', ' Eden']
+3791 891 Name of mother of x -1 Name of mother of Anthony Eden Sybil Frances Grey Anthony Eden "[',' ' the' ' former' ' prime' ' minister' ' of' ' Great' ' Britain' ','
+ ' and' ' the' ' former' ' prime' ' minister' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former prime minister of Great Britain , and the former prime minister of the United States , and False " disdain, ""who wears an Anthony Eden hat today? Only Mr" 7 "[' disdain', ',', ' ""', 'who', ' wears', ' an', ' Anthony', ' Eden']"
+3792 891 Name of mother of x -1 Name of mother of Anthony Eden Sybil Frances Grey Anthony Eden "[',' ' the' ' former' ' prime' ' minister' ' of' ' Great' ' Britain' ','
+ ' and' ' the' ' former' ' prime' ' minister' ' of' ' the' ' United'
+ ' States' ',' ' and']" , the former prime minister of Great Britain , and the former prime minister of the United States , and False " Anthony Eden hat =
+" 1 [' Anthony', ' Eden']
+3793 892 Name of mother of x -1 Name of mother of Neville Chamberlain Florence Kenrick Neville Chamberlain "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' and'
+ ' I' ' am' ' a' ' wife' ' of' ' a' ' wonderful' ' husband']" ", the
+
+ I am a mother of two , and I am a wife of a wonderful husband" False British Prime Minister Neville Chamberlain in early August 4 [' British', ' Prime', ' Minister', ' Neville', ' Chamberlain']
+3794 892 Name of mother of x -1 Name of mother of Neville Chamberlain Florence Kenrick Neville Chamberlain "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' and'
+ ' I' ' am' ' a' ' wife' ' of' ' a' ' wonderful' ' husband']" ", the
+
+ I am a mother of two , and I am a wife of a wonderful husband" False MacDonald and Neville Chamberlain had done so. 3 [' MacDonald', ' and', ' Neville', ' Chamberlain']
+3795 892 Name of mother of x -1 Name of mother of Neville Chamberlain Florence Kenrick Neville Chamberlain "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' and'
+ ' I' ' am' ' a' ' wife' ' of' ' a' ' wonderful' ' husband']" ", the
+
+ I am a mother of two , and I am a wife of a wonderful husband" False " his death, Neville Chamberlain wrote,
+" 4 [' his', ' death', ',', ' Neville', ' Chamberlain']
+3796 892 Name of mother of x -1 Name of mother of Neville Chamberlain Florence Kenrick Neville Chamberlain "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' and'
+ ' I' ' am' ' a' ' wife' ' of' ' a' ' wonderful' ' husband']" ", the
+
+ I am a mother of two , and I am a wife of a wonderful husband" False additional committee, Neville Chamberlain informed his half-brother 4 [' additional', ' committee', ',', ' Neville', ' Chamberlain']
+3797 892 Name of mother of x -1 Name of mother of Neville Chamberlain Florence Kenrick Neville Chamberlain "[',' ' the' '\n' '\n' 'I' ' am' ' a' ' mother' ' of' ' two' ',' ' and'
+ ' I' ' am' ' a' ' wife' ' of' ' a' ' wonderful' ' husband']" ", the
+
+ I am a mother of two , and I am a wife of a wonderful husband" False Law (Glasgow), and Neville Chamberlain (Mason Science College, 8 [' Law', ' (', 'Gl', 'as', 'gow', '),', ' and', ' Neville', ' Chamberlain']
+3798 893 Name of mother of x -1 Name of mother of H. Rider Haggard Ella Doveton H. Rider Haggard "[',' ' the' ' author' ' of' ' _' 'King' ' Solomon' ""'s"" ' Mines' '_' ','
+ ' _' 'She' '_' ',' ' and' ' _' 'King' ' Solomon' ""'s""]" , the author of _ King Solomon 's Mines _ , _ She _ , and _ King Solomon 's False 5 ['H', '.', ' Rider', ' H', 'agg', 'ard']
+3799 893 Name of mother of x -1 Name of mother of H. Rider Haggard Ella Doveton H. Rider Haggard "[',' ' the' ' author' ' of' ' _' 'King' ' Solomon' ""'s"" ' Mines' '_' ','
+ ' _' 'She' '_' ',' ' and' ' _' 'King' ' Solomon' ""'s""]" , the author of _ King Solomon 's Mines _ , _ She _ , and _ King Solomon 's False Bulwer-Lytton, H. Rider Haggard and E. T. A. Hoffman. 12 [' Bul', 'wer', '-', 'Ly', 'tt', 'on', ',', ' H', '.', ' Rider', ' H', 'agg', 'ard']
+3800 893 Name of mother of x -1 Name of mother of H. Rider Haggard Ella Doveton H. Rider Haggard "[',' ' the' ' author' ' of' ' _' 'King' ' Solomon' ""'s"" ' Mines' '_' ','
+ ' _' 'She' '_' ',' ' and' ' _' 'King' ' Solomon' ""'s""]" , the author of _ King Solomon 's Mines _ , _ She _ , and _ King Solomon 's False is a novel by H. Rider Haggard (1856 – 1925), first 9 [' is', ' a', ' novel', ' by', ' H', '.', ' Rider', ' H', 'agg', 'ard']
+3801 893 Name of mother of x -1 Name of mother of H. Rider Haggard Ella Doveton H. Rider Haggard "[',' ' the' ' author' ' of' ' _' 'King' ' Solomon' ""'s"" ' Mines' '_' ','
+ ' _' 'She' '_' ',' ' and' ' _' 'King' ' Solomon' ""'s""]" , the author of _ King Solomon 's Mines _ , _ She _ , and _ King Solomon 's False 5 ['H', '.', ' Rider', ' H', 'agg', 'ard']
+3802 893 Name of mother of x -1 Name of mother of H. Rider Haggard Ella Doveton H. Rider Haggard "[',' ' the' ' author' ' of' ' _' 'King' ' Solomon' ""'s"" ' Mines' '_' ','
+ ' _' 'She' '_' ',' ' and' ' _' 'King' ' Solomon' ""'s""]" , the author of _ King Solomon 's Mines _ , _ She _ , and _ King Solomon 's False Bulwer-Lytton, H. Rider Haggard and E. T. A. Hoffman. 12 [' Bul', 'wer', '-', 'Ly', 'tt', 'on', ',', ' H', '.', ' Rider', ' H', 'agg', 'ard']
+3803 894 Name of mother of x -1 Name of mother of Ben Affleck Christine Anne Boldt Ben Affleck "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False " Throughout the film, Ben Affleck had to wear ""cataract" 6 [' Throughout', ' the', ' film', ',', ' Ben', ' Aff', 'leck']
+3804 894 Name of mother of x -1 Name of mother of Ben Affleck Christine Anne Boldt Ben Affleck "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False he stated that Ben Affleck and Jennifer 5 [' he', ' stated', ' that', ' Ben', ' Aff', 'leck']
+3805 894 Name of mother of x -1 Name of mother of Ben Affleck Christine Anne Boldt Ben Affleck "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False stated that Ben Affleck and Jennifer Garner 4 [' stated', ' that', ' Ben', ' Aff', 'leck']
+3806 894 Name of mother of x -1 Name of mother of Ben Affleck Christine Anne Boldt Ben Affleck "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False In August 2013, Ben Affleck was announced 6 [' In', ' August', ' 2013', ',', ' Ben', ' Aff', 'leck']
+3807 894 Name of mother of x -1 Name of mother of Ben Affleck Christine Anne Boldt Ben Affleck "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Initiative founder Ben Affleck to testify before 4 [' Initiative', ' founder', ' Ben', ' Aff', 'leck']
+3808 896 Name of mother of x -1 Name of mother of Menander Hegesistrata Menander "[',' ' the' ' son' ' of' ' Tim' 'arch' 'us' ',' ' and' ' the' ' son' ' of'
+ ' Men' 'ander' ',' ' and' ' the' ' son' ' of' ' Tim']" , the son of Tim arch us , and the son of Men ander , and the son of Tim False Agathias and Menander Protector offer many 5 [' Ag', 'ath', 'ias', ' and', ' Men', 'ander']
+3809 896 Name of mother of x -1 Name of mother of Menander Hegesistrata Menander "[',' ' the' ' son' ' of' ' Tim' 'arch' 'us' ',' ' and' ' the' ' son' ' of'
+ ' Men' 'ander' ',' ' and' ' the' ' son' ' of' ' Tim']" , the son of Tim arch us , and the son of Men ander , and the son of Tim False as a source by Menander Protector. Until 5 [' as', ' a', ' source', ' by', ' Men', 'ander']
+3810 896 Name of mother of x -1 Name of mother of Menander Hegesistrata Menander "[',' ' the' ' son' ' of' ' Tim' 'arch' 'us' ',' ' and' ' the' ' son' ' of'
+ ' Men' 'ander' ',' ' and' ' the' ' son' ' of' ' Tim']" , the son of Tim arch us , and the son of Men ander , and the son of Tim False continuators Agathias and Menander Protector offer 7 [' continu', 'ators', ' Ag', 'ath', 'ias', ' and', ' Men', 'ander']
+3811 896 Name of mother of x -1 Name of mother of Menander Hegesistrata Menander "[',' ' the' ' son' ' of' ' Tim' 'arch' 'us' ',' ' and' ' the' ' son' ' of'
+ ' Men' 'ander' ',' ' and' ' the' ' son' ' of' ' Tim']" , the son of Tim arch us , and the son of Men ander , and the son of Tim False " spirit; thus, as Menander stated, ""be" 6 [' spirit', ';', ' thus', ',', ' as', ' Men', 'ander']
+3812 896 Name of mother of x -1 Name of mother of Menander Hegesistrata Menander "[',' ' the' ' son' ' of' ' Tim' 'arch' 'us' ',' ' and' ' the' ' son' ' of'
+ ' Men' 'ander' ',' ' and' ' the' ' son' ' of' ' Tim']" , the son of Tim arch us , and the son of Men ander , and the son of Tim False 6th-century historian Menander Protector, 6 [' 6', 'th', '-', 'century', ' historian', ' Men', 'ander']
+3813 897 Name of mother of x -1 Name of mother of Mikhail Glinka Evgenia Andreevna Glinka-Zemelka Mikhail Glinka "[',' ' the' ' Russian' ' composer' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Mikhail' ' Gl' 'inka' ',' ' the' ' Russian'
+ ' composer']" ", the Russian composer , and the
+
+ Name of mother of Mikhail Gl inka , the Russian composer" False the influences of Mikhail Glinka and fellow members 5 [' the', ' influences', ' of', ' Mikhail', ' Gl', 'inka']
+3814 897 Name of mother of x -1 Name of mother of Mikhail Glinka Evgenia Andreevna Glinka-Zemelka Mikhail Glinka "[',' ' the' ' Russian' ' composer' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Mikhail' ' Gl' 'inka' ',' ' the' ' Russian'
+ ' composer']" ", the Russian composer , and the
+
+ Name of mother of Mikhail Gl inka , the Russian composer" False Russian composer Mikhail Glinka (1804 – 1857) in collaboration 4 [' Russian', ' composer', ' Mikhail', ' Gl', 'inka']
+3815 897 Name of mother of x -1 Name of mother of Mikhail Glinka Evgenia Andreevna Glinka-Zemelka Mikhail Glinka "[',' ' the' ' Russian' ' composer' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Mikhail' ' Gl' 'inka' ',' ' the' ' Russian'
+ ' composer']" ", the Russian composer , and the
+
+ Name of mother of Mikhail Gl inka , the Russian composer" False influences of Mikhail Glinka and fellow members 4 [' influences', ' of', ' Mikhail', ' Gl', 'inka']
+3816 897 Name of mother of x -1 Name of mother of Mikhail Glinka Evgenia Andreevna Glinka-Zemelka Mikhail Glinka "[',' ' the' ' Russian' ' composer' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Mikhail' ' Gl' 'inka' ',' ' the' ' Russian'
+ ' composer']" ", the Russian composer , and the
+
+ Name of mother of Mikhail Gl inka , the Russian composer" False operas of Mikhail Glinka as a model and 5 [' oper', 'as', ' of', ' Mikhail', ' Gl', 'inka']
+3817 897 Name of mother of x -1 Name of mother of Mikhail Glinka Evgenia Andreevna Glinka-Zemelka Mikhail Glinka "[',' ' the' ' Russian' ' composer' ',' ' and' ' the' '\n' '\n' 'Name'
+ ' of' ' mother' ' of' ' Mikhail' ' Gl' 'inka' ',' ' the' ' Russian'
+ ' composer']" ", the Russian composer , and the
+
+ Name of mother of Mikhail Gl inka , the Russian composer" False Russian composer Mikhail Glinka (1804 – 1857). 4 [' Russian', ' composer', ' Mikhail', ' Gl', 'inka']
+3818 898 Name of mother of x -1 Name of mother of René Goscinny Anna Goscinny René Goscinny "[',' ' the' ' author' ' of' ' the' ' famous' ' comic' ' book' ' series'
+ ' Aster' 'ix' ',' ' and' ' his' ' wife' ',' ' Albert' ' U' 'der' 'zo']" , the author of the famous comic book series Aster ix , and his wife , Albert U der zo False at the success of René Goscinny and Albert Uderzo's 9 [' at', ' the', ' success', ' of', ' Ren', 'é', ' G', 'osc', 'in', 'ny']
+3819 898 Name of mother of x -1 Name of mother of René Goscinny Anna Goscinny René Goscinny "[',' ' the' ' author' ' of' ' the' ' famous' ' comic' ' book' ' series'
+ ' Aster' 'ix' ',' ' and' ' his' ' wife' ',' ' Albert' ' U' 'der' 'zo']" , the author of the famous comic book series Aster ix , and his wife , Albert U der zo False group including René Goscinny and Albert Uderzo 7 [' group', ' including', ' Ren', 'é', ' G', 'osc', 'in', 'ny']
+3820 898 Name of mother of x -1 Name of mother of René Goscinny Anna Goscinny René Goscinny "[',' ' the' ' author' ' of' ' the' ' famous' ' comic' ' book' ' series'
+ ' Aster' 'ix' ',' ' and' ' his' ' wife' ',' ' Albert' ' U' 'der' 'zo']" , the author of the famous comic book series Aster ix , and his wife , Albert U der zo False at the success of René Goscinny and Albert Uderzo's 9 [' at', ' the', ' success', ' of', ' Ren', 'é', ' G', 'osc', 'in', 'ny']
+3821 898 Name of mother of x -1 Name of mother of René Goscinny Anna Goscinny René Goscinny "[',' ' the' ' author' ' of' ' the' ' famous' ' comic' ' book' ' series'
+ ' Aster' 'ix' ',' ' and' ' his' ' wife' ',' ' Albert' ' U' 'der' 'zo']" , the author of the famous comic book series Aster ix , and his wife , Albert U der zo False Pilote magazine to see René Goscinny and Jean-Michel Charlier. 10 [' Pil', 'ote', ' magazine', ' to', ' see', ' Ren', 'é', ' G', 'osc', 'in', 'ny']
+3822 898 Name of mother of x -1 Name of mother of René Goscinny Anna Goscinny René Goscinny "[',' ' the' ' author' ' of' ' the' ' famous' ' comic' ' book' ' series'
+ ' Aster' 'ix' ',' ' and' ' his' ' wife' ',' ' Albert' ' U' 'der' 'zo']" , the author of the famous comic book series Aster ix , and his wife , Albert U der zo False at the success of René Goscinny and Albert Uderzo's 9 [' at', ' the', ' success', ' of', ' Ren', 'é', ' G', 'osc', 'in', 'ny']
+3823 900 Name of mother of x -1 Name of mother of Hirohito Empress Teimei Hirohito "[',' ' the' ' Emperor' ' of' ' Japan' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' United' ' States' ' of' ' America' '.' '\n' '\n' 'The'
+ ' Emperor']" ", the Emperor of Japan , and the Emperor of the United States of America .
+
+ The Emperor" False in briefing Emperor Hirohito on Japan's 5 [' in', ' briefing', ' Emperor', ' Hiro', 'h', 'ito']
+3824 900 Name of mother of x -1 Name of mother of Hirohito Empress Teimei Hirohito "[',' ' the' ' Emperor' ' of' ' Japan' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' United' ' States' ' of' ' America' '.' '\n' '\n' 'The'
+ ' Emperor']" ", the Emperor of Japan , and the Emperor of the United States of America .
+
+ The Emperor" False victory. Only Emperor Hirohito and the highest 6 [' victory', '.', ' Only', ' Emperor', ' Hiro', 'h', 'ito']
+3825 900 Name of mother of x -1 Name of mother of Hirohito Empress Teimei Hirohito "[',' ' the' ' Emperor' ' of' ' Japan' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' United' ' States' ' of' ' America' '.' '\n' '\n' 'The'
+ ' Emperor']" ", the Emperor of Japan , and the Emperor of the United States of America .
+
+ The Emperor" False flagship of Emperor Hirohito during the 1927 naval 5 [' flagship', ' of', ' Emperor', ' Hiro', 'h', 'ito']
+3826 900 Name of mother of x -1 Name of mother of Hirohito Empress Teimei Hirohito "[',' ' the' ' Emperor' ' of' ' Japan' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' United' ' States' ' of' ' America' '.' '\n' '\n' 'The'
+ ' Emperor']" ", the Emperor of Japan , and the Emperor of the United States of America .
+
+ The Emperor" False review by Emperor Hirohito on 11 October 5 [' review', ' by', ' Emperor', ' Hiro', 'h', 'ito']
+3827 900 Name of mother of x -1 Name of mother of Hirohito Empress Teimei Hirohito "[',' ' the' ' Emperor' ' of' ' Japan' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' United' ' States' ' of' ' America' '.' '\n' '\n' 'The'
+ ' Emperor']" ", the Emperor of Japan , and the Emperor of the United States of America .
+
+ The Emperor" False his declaration, Hirohito referred to 5 [' his', ' declaration', ',', ' Hiro', 'h', 'ito']
+3828 902 Name of mother of x -1 Name of mother of Osama bin Laden Hamida al-Attas Osama bin Laden "[',' ' the' ' terrorist' ' leader' ',' ' and' ' the' ' man' ' who' ' was'
+ ' responsible' ' for' ' the' ' 9' '/' '11' ' attacks' '.' '\n' '\n']" ", the terrorist leader , and the man who was responsible for the 9 / 11 attacks .
+
+" False " Time referred to Osama bin Laden as ""a geopolitical" 5 [' Time', ' referred', ' to', ' Osama', ' bin', ' Laden']
+3829 902 Name of mother of x -1 Name of mother of Osama bin Laden Hamida al-Attas Osama bin Laden "[',' ' the' ' terrorist' ' leader' ',' ' and' ' the' ' man' ' who' ' was'
+ ' responsible' ' for' ' the' ' 9' '/' '11' ' attacks' '.' '\n' '\n']" ", the terrorist leader , and the man who was responsible for the 9 / 11 attacks .
+
+" False if they 're in an Osama bin Laden costume if they 8 "[' if', ' they', "" '"", 're', ' in', ' an', ' Osama', ' bin', ' Laden']"
+3830 902 Name of mother of x -1 Name of mother of Osama bin Laden Hamida al-Attas Osama bin Laden "[',' ' the' ' terrorist' ' leader' ',' ' and' ' the' ' man' ' who' ' was'
+ ' responsible' ' for' ' the' ' 9' '/' '11' ' attacks' '.' '\n' '\n']" ", the terrorist leader , and the man who was responsible for the 9 / 11 attacks .
+
+" False " Time referred to Osama bin Laden as ""a geopolitical" 5 [' Time', ' referred', ' to', ' Osama', ' bin', ' Laden']
+3831 902 Name of mother of x -1 Name of mother of Osama bin Laden Hamida al-Attas Osama bin Laden "[',' ' the' ' terrorist' ' leader' ',' ' and' ' the' ' man' ' who' ' was'
+ ' responsible' ' for' ' the' ' 9' '/' '11' ' attacks' '.' '\n' '\n']" ", the terrorist leader , and the man who was responsible for the 9 / 11 attacks .
+
+" False first presented it to Osama bin Laden in 1996. At 6 [' first', ' presented', ' it', ' to', ' Osama', ' bin', ' Laden']
+3832 902 Name of mother of x -1 Name of mother of Osama bin Laden Hamida al-Attas Osama bin Laden "[',' ' the' ' terrorist' ' leader' ',' ' and' ' the' ' man' ' who' ' was'
+ ' responsible' ' for' ' the' ' 9' '/' '11' ' attacks' '.' '\n' '\n']" ", the terrorist leader , and the man who was responsible for the 9 / 11 attacks .
+
+" False first believed it was Osama bin Laden but later revised 6 [' first', ' believed', ' it', ' was', ' Osama', ' bin', ' Laden']
+3833 903 Name of mother of x -1 Name of mother of Martin Amis Hilary Bardwell Martin Amis "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " ""cultural book"". Martin Amis from The Observer" 6 "[' ""', 'cultural', ' book', '"".', ' Martin', ' Am', 'is']"
+3834 903 Name of mother of x -1 Name of mother of Martin Amis Hilary Bardwell Martin Amis "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False the novelist Martin Amis in The War Against 4 [' the', ' novelist', ' Martin', ' Am', 'is']
+3835 903 Name of mother of x -1 Name of mother of Martin Amis Hilary Bardwell Martin Amis "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False " ""cultural book"". Martin Amis from The Observer" 6 "[' ""', 'cultural', ' book', '"".', ' Martin', ' Am', 'is']"
+3836 903 Name of mother of x -1 Name of mother of Martin Amis Hilary Bardwell Martin Amis "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False Granta. The novelist Martin Amis described 7 [' Gr', 'anta', '.', ' The', ' novelist', ' Martin', ' Am', 'is']
+3837 903 Name of mother of x -1 Name of mother of Martin Amis Hilary Bardwell Martin Amis "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False by the novelist Martin Amis in The War Against 5 [' by', ' the', ' novelist', ' Martin', ' Am', 'is']
+3838 905 Name of mother of x -1 Name of mother of Clara Schumann Mariane Bargiel Clara Schumann "[',' ' the' ' daughter' ' of' ' the' ' late' ' Dr' '.' ' Sch' 'umann' ','
+ ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n']" ", the daughter of the late Dr . Sch umann , who was a friend of the family .
+" False husband, much as Clara Schumann had been with her own 6 [' husband', ',', ' much', ' as', ' Clara', ' Sch', 'umann']
+3839 905 Name of mother of x -1 Name of mother of Clara Schumann Mariane Bargiel Clara Schumann "[',' ' the' ' daughter' ' of' ' the' ' late' ' Dr' '.' ' Sch' 'umann' ','
+ ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n']" ", the daughter of the late Dr . Sch umann , who was a friend of the family .
+" False her husband, much as Clara Schumann had been with her own 7 [' her', ' husband', ',', ' much', ' as', ' Clara', ' Sch', 'umann']
+3840 905 Name of mother of x -1 Name of mother of Clara Schumann Mariane Bargiel Clara Schumann "[',' ' the' ' daughter' ' of' ' the' ' late' ' Dr' '.' ' Sch' 'umann' ','
+ ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n']" ", the daughter of the late Dr . Sch umann , who was a friend of the family .
+" False Paderewski and Clara Schumann expressed a preference 7 [' Pad', 'ere', 'ws', 'ki', ' and', ' Clara', ' Sch', 'umann']
+3841 905 Name of mother of x -1 Name of mother of Clara Schumann Mariane Bargiel Clara Schumann "[',' ' the' ' daughter' ' of' ' the' ' late' ' Dr' '.' ' Sch' 'umann' ','
+ ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n']" ", the daughter of the late Dr . Sch umann , who was a friend of the family .
+" False visiting artists such as Clara Schumann and Franz Liszt but 6 [' visiting', ' artists', ' such', ' as', ' Clara', ' Sch', 'umann']
+3842 905 Name of mother of x -1 Name of mother of Clara Schumann Mariane Bargiel Clara Schumann "[',' ' the' ' daughter' ' of' ' the' ' late' ' Dr' '.' ' Sch' 'umann' ','
+ ' who' ' was' ' a' ' friend' ' of' ' the' ' family' '.' '\n']" ", the daughter of the late Dr . Sch umann , who was a friend of the family .
+" False tenure, until 1892, of Clara Schumann as head of 9 [' tenure', ',', ' until', ' 18', '92', ',', ' of', ' Clara', ' Sch', 'umann']
+3843 906 Name of mother of x -1 Name of mother of Yuri Gagarin Anna Gagarina Yuri Gagarin "[',' ' the' ' first' ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first'
+ ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first' ' man']" ", the first man in space .
+
+ The first man in space .
+
+ The first man" False given a tour of the Yuri Gagarin Cosmonaut Training 8 [' given', ' a', ' tour', ' of', ' the', ' Yuri', ' G', 'ag', 'arin']
+3844 906 Name of mother of x -1 Name of mother of Yuri Gagarin Anna Gagarina Yuri Gagarin "[',' ' the' ' first' ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first'
+ ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first' ' man']" ", the first man in space .
+
+ The first man in space .
+
+ The first man" False given a tour of the Yuri Gagarin Cosmonaut Training 8 [' given', ' a', ' tour', ' of', ' the', ' Yuri', ' G', 'ag', 'arin']
+3845 906 Name of mother of x -1 Name of mother of Yuri Gagarin Anna Gagarina Yuri Gagarin "[',' ' the' ' first' ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first'
+ ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first' ' man']" ", the first man in space .
+
+ The first man in space .
+
+ The first man" False given a tour of the Yuri Gagarin Cosmonaut Training 8 [' given', ' a', ' tour', ' of', ' the', ' Yuri', ' G', 'ag', 'arin']
+3846 906 Name of mother of x -1 Name of mother of Yuri Gagarin Anna Gagarina Yuri Gagarin "[',' ' the' ' first' ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first'
+ ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first' ' man']" ", the first man in space .
+
+ The first man in space .
+
+ The first man" False launching cosmonaut Yuri Gagarin into a single 7 [' launching', ' cos', 'mon', 'aut', ' Yuri', ' G', 'ag', 'arin']
+3847 906 Name of mother of x -1 Name of mother of Yuri Gagarin Anna Gagarina Yuri Gagarin "[',' ' the' ' first' ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first'
+ ' man' ' in' ' space' '.' '\n' '\n' 'The' ' first' ' man']" ", the first man in space .
+
+ The first man in space .
+
+ The first man" False achieved by Yuri Gagarin of the Soviet 5 [' achieved', ' by', ' Yuri', ' G', 'ag', 'arin']
+3848 909 Name of mother of x -1 Name of mother of Joseph Priestley Mary Swift Joseph Priestley "[',' ' the' ' father' ' of' ' modern' ' chemistry' ',' ' was' ' born'
+ ' in' ' 17' '33' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great']" ", the father of modern chemistry , was born in 17 33 .
+
+ The first of the great" False administrator for the Joseph Priestley House in the 1990s, 5 [' administrator', ' for', ' the', ' Joseph', ' Priest', 'ley']
+3849 909 Name of mother of x -1 Name of mother of Joseph Priestley Mary Swift Joseph Priestley "[',' ' the' ' father' ' of' ' modern' ' chemistry' ',' ' was' ' born'
+ ' in' ' 17' '33' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great']" ", the father of modern chemistry , was born in 17 33 .
+
+ The first of the great" False " Friends of the Joseph Priestley House"" (FJPH), who" 5 [' Friends', ' of', ' the', ' Joseph', ' Priest', 'ley']
+3850 909 Name of mother of x -1 Name of mother of Joseph Priestley Mary Swift Joseph Priestley "[',' ' the' ' father' ' of' ' modern' ' chemistry' ',' ' was' ' born'
+ ' in' ' 17' '33' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great']" ", the father of modern chemistry , was born in 17 33 .
+
+ The first of the great" False chloride gas. Joseph Priestley of Leeds, England 5 [' chloride', ' gas', '.', ' Joseph', ' Priest', 'ley']
+3851 909 Name of mother of x -1 Name of mother of Joseph Priestley Mary Swift Joseph Priestley "[',' ' the' ' father' ' of' ' modern' ' chemistry' ',' ' was' ' born'
+ ' in' ' 17' '33' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great']" ", the father of modern chemistry , was born in 17 33 .
+
+ The first of the great" False had died in 1791, and Joseph Priestley had been forced 9 [' had', ' died', ' in', ' 17', '91', ',', ' and', ' Joseph', ' Priest', 'ley']
+3852 909 Name of mother of x -1 Name of mother of Joseph Priestley Mary Swift Joseph Priestley "[',' ' the' ' father' ' of' ' modern' ' chemistry' ',' ' was' ' born'
+ ' in' ' 17' '33' '.' '\n' '\n' 'The' ' first' ' of' ' the' ' great']" ", the father of modern chemistry , was born in 17 33 .
+
+ The first of the great" False Postal Service's Joseph Priestley commemorative 5 "[' Postal', ' Service', ""'s"", ' Joseph', ' Priest', 'ley']"
+3853 910 Name of mother of x -1 Name of mother of Anne, Queen of Great Britain Anne, Duchess of York Anne, Queen of Great Britain "[',' ' and' ' the' ' mother' ' of' ' the' ' King' ' of' ' England' ','
+ ' and' ' the' ' mother' ' of' ' the' ' King' ' of' ' England' ',' ' and']" , and the mother of the King of England , and the mother of the King of England , and False " Britain =
+" 7 [' Britain', ' =', 'Anne', ',', ' Queen', ' of', ' Great', ' Britain']
+3854 910 Name of mother of x -1 Name of mother of Anne, Queen of Great Britain Anne, Duchess of York Anne, Queen of Great Britain "[',' ' and' ' the' ' mother' ' of' ' the' ' King' ' of' ' England' ','
+ ' and' ' the' ' mother' ' of' ' the' ' King' ' of' ' England' ',' ' and']" , and the mother of the King of England , and the mother of the King of England , and False " Great Britain =
+" 8 [' Great', ' Britain', ' =', 'Anne', ',', ' Queen', ' of', ' Great', ' Britain']
+3855 910 Name of mother of x -1 Name of mother of Anne, Queen of Great Britain Anne, Duchess of York Anne, Queen of Great Britain "[',' ' and' ' the' ' mother' ' of' ' the' ' King' ' of' ' England' ','
+ ' and' ' the' ' mother' ' of' ' the' ' King' ' of' ' England' ',' ' and']" , and the mother of the King of England , and the mother of the King of England , and False " Great Britain =
+" 8 [' Great', ' Britain', ' =', 'Anne', ',', ' Queen', ' of', ' Great', ' Britain']
+3856 910 Name of mother of x -1 Name of mother of Anne, Queen of Great Britain Anne, Duchess of York Anne, Queen of Great Britain "[',' ' and' ' the' ' mother' ' of' ' the' ' King' ' of' ' England' ','
+ ' and' ' the' ' mother' ' of' ' the' ' King' ' of' ' England' ',' ' and']" , and the mother of the King of England , and the mother of the King of England , and False " Great Britain =
+" 8 [' Great', ' Britain', ' =', 'Anne', ',', ' Queen', ' of', ' Great', ' Britain']
+3857 911 Name of mother of x -1 Name of mother of Michael Douglas Diana Douglas Michael Douglas "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False screenplay for the Michael Douglas film, It's My Turn. 4 [' screenplay', ' for', ' the', ' Michael', ' Douglas']
+3858 911 Name of mother of x -1 Name of mother of Michael Douglas Diana Douglas Michael Douglas "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False pre-production. Michael Douglas reprised his 5 [' pre', '-', 'production', '.', ' Michael', ' Douglas']
+3859 911 Name of mother of x -1 Name of mother of Michael Douglas Diana Douglas Michael Douglas "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False abuse starring Michael Douglas and Benicio del 3 [' abuse', ' starring', ' Michael', ' Douglas']
+3860 911 Name of mother of x -1 Name of mother of Michael Douglas Diana Douglas Michael Douglas "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False wife of homicidal Michael Douglas in the thriller 5 [' wife', ' of', ' hom', 'icidal', ' Michael', ' Douglas']
+3861 911 Name of mother of x -1 Name of mother of Michael Douglas Diana Douglas Michael Douglas "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Bud's naivete. Michael Douglas had just come off 6 "[' Bud', ""'s"", ' naive', 'te', '.', ' Michael', ' Douglas']"
+3862 913 Name of mother of x -1 Name of mother of Hilaire Belloc Bessie Rayner Parkes Hilaire Belloc "[',' ' the' ' author' ' of' ' the' ' famous' ' ""' 'The' ' Serv' 'ile'
+ ' State' '""' ' and' ' ""' 'The' ' Serv' 'ile' ' State' '""' ' (']" ", the author of the famous "" The Serv ile State "" and "" The Serv ile State "" (" False 4 ['H', 'il', 'aire', ' Bell', 'oc']
+3863 913 Name of mother of x -1 Name of mother of Hilaire Belloc Bessie Rayner Parkes Hilaire Belloc "[',' ' the' ' author' ' of' ' the' ' famous' ' ""' 'The' ' Serv' 'ile'
+ ' State' '""' ' and' ' ""' 'The' ' Serv' 'ile' ' State' '""' ' (']" ", the author of the famous "" The Serv ile State "" and "" The Serv ile State "" (" False (1874 – 1936) and Hilaire Belloc (1870 – 1953). 10 [' (', '18', '74', ' –', ' 1936', ')', ' and', ' Hil', 'aire', ' Bell', 'oc']
+3864 913 Name of mother of x -1 Name of mother of Hilaire Belloc Bessie Rayner Parkes Hilaire Belloc "[',' ' the' ' author' ' of' ' the' ' famous' ' ""' 'The' ' Serv' 'ile'
+ ' State' '""' ' and' ' ""' 'The' ' Serv' 'ile' ' State' '""' ' (']" ", the author of the famous "" The Serv ile State "" and "" The Serv ile State "" (" False (1874 – 1936) and Hilaire Belloc (1870 – 1953). 10 [' (', '18', '74', ' –', ' 1936', ')', ' and', ' Hil', 'aire', ' Bell', 'oc']
+3865 913 Name of mother of x -1 Name of mother of Hilaire Belloc Bessie Rayner Parkes Hilaire Belloc "[',' ' the' ' author' ' of' ' the' ' famous' ' ""' 'The' ' Serv' 'ile'
+ ' State' '""' ' and' ' ""' 'The' ' Serv' 'ile' ' State' '""' ' (']" ", the author of the famous "" The Serv ile State "" and "" The Serv ile State "" (" False – 1936) and Hilaire Belloc (1870 – 1953). He 7 [' –', ' 1936', ')', ' and', ' Hil', 'aire', ' Bell', 'oc']
+3866 913 Name of mother of x -1 Name of mother of Hilaire Belloc Bessie Rayner Parkes Hilaire Belloc "[',' ' the' ' author' ' of' ' the' ' famous' ' ""' 'The' ' Serv' 'ile'
+ ' State' '""' ' and' ' ""' 'The' ' Serv' 'ile' ' State' '""' ' (']" ", the author of the famous "" The Serv ile State "" and "" The Serv ile State "" (" False Catholic writers Hilaire Belloc and G. K. Chesterton 5 [' Catholic', ' writers', ' Hil', 'aire', ' Bell', 'oc']
+3867 914 Name of mother of x -1 Name of mother of David Foster Eleanor Foster David Foster "[' Wallace' ',' ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and'
+ ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and' ' the' ' author']" Wallace , the author of Infinite J est , and the author of Infinite J est , and the author False its near neighbours; David Foster commented in 5 [' its', ' near', ' neighbours', ';', ' David', ' Foster']
+3868 914 Name of mother of x -1 Name of mother of David Foster Eleanor Foster David Foster "[' Wallace' ',' ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and'
+ ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and' ' the' ' author']" Wallace , the author of Infinite J est , and the author of Infinite J est , and the author False the singer alongside David Foster and Nathan Chapman. 4 [' the', ' singer', ' alongside', ' David', ' Foster']
+3869 914 Name of mother of x -1 Name of mother of David Foster Eleanor Foster David Foster "[' Wallace' ',' ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and'
+ ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and' ' the' ' author']" Wallace , the author of Infinite J est , and the author of Infinite J est , and the author False 1 ['David', ' Foster']
+3870 914 Name of mother of x -1 Name of mother of David Foster Eleanor Foster David Foster "[' Wallace' ',' ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and'
+ ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and' ' the' ' author']" Wallace , the author of Infinite J est , and the author of Infinite J est , and the author False " string arrangement
+" 3 [' string', ' arrangement', 'David', ' Foster']
+3871 914 Name of mother of x -1 Name of mother of David Foster Eleanor Foster David Foster "[' Wallace' ',' ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and'
+ ' the' ' author' ' of' ' Infinite' ' J' 'est' ',' ' and' ' the' ' author']" Wallace , the author of Infinite J est , and the author of Infinite J est , and the author False with support from David Foster on piano, and sparse, 4 [' with', ' support', ' from', ' David', ' Foster']
+3872 915 Name of mother of x -1 Name of mother of Paul III Giovanna Caetani Paul III "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False " Grandsons =
+" 5 [' Gr', 'ands', 'ons', ' =', 'Paul', ' III']
+3873 915 Name of mother of x -1 Name of mother of Paul III Giovanna Caetani Paul III "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False doctrine. Pope Paul III convened a council 4 [' doctrine', '.', ' Pope', ' Paul', ' III']
+3874 915 Name of mother of x -1 Name of mother of Paul III Giovanna Caetani Paul III "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False broke with Pope Paul III in order to annul 4 [' broke', ' with', ' Pope', ' Paul', ' III']
+3875 915 Name of mother of x -1 Name of mother of Paul III Giovanna Caetani Paul III "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False broke with Pope Paul III over the annulment 4 [' broke', ' with', ' Pope', ' Paul', ' III']
+3876 915 Name of mother of x -1 Name of mother of Paul III Giovanna Caetani Paul III "['.' '\n' '\n' 'The' ' first' ' thing' ' I' ' noticed' ' was' ' the'
+ ' smell' '.' ' It' ' was' ' a' ' smell' ' that' ' I' ' had' ' never']" ".
+
+ The first thing I noticed was the smell . It was a smell that I had never" False 1 ['Paul', ' III']
+3877 917 Name of mother of x -1 Name of mother of Joseph Dalton Hooker Maria Turner Joseph Dalton Hooker "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' Joseph' ' Dalton' ' Hook' 'er' ',' ' born' ' in']" ", the
+
+ The following is a list of the children of Joseph Dalton Hook er , born in" False with the botanist Joseph Dalton Hooker in January 1844, and 8 [' with', ' the', ' bot', 'an', 'ist', ' Joseph', ' Dalton', ' Hook', 'er']
+3878 917 Name of mother of x -1 Name of mother of Joseph Dalton Hooker Maria Turner Joseph Dalton Hooker "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' Joseph' ' Dalton' ' Hook' 'er' ',' ' born' ' in']" ", the
+
+ The following is a list of the children of Joseph Dalton Hook er , born in" False correspondence with Joseph Dalton Hooker affirmed that 5 [' correspondence', ' with', ' Joseph', ' Dalton', ' Hook', 'er']
+3879 917 Name of mother of x -1 Name of mother of Joseph Dalton Hooker Maria Turner Joseph Dalton Hooker "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' Joseph' ' Dalton' ' Hook' 'er' ',' ' born' ' in']" ", the
+
+ The following is a list of the children of Joseph Dalton Hook er , born in" False 1855 by Sir Joseph Dalton Hooker and Edward William 7 [' 18', '55', ' by', ' Sir', ' Joseph', ' Dalton', ' Hook', 'er']
+3880 917 Name of mother of x -1 Name of mother of Joseph Dalton Hooker Maria Turner Joseph Dalton Hooker "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' Joseph' ' Dalton' ' Hook' 'er' ',' ' born' ' in']" ", the
+
+ The following is a list of the children of Joseph Dalton Hook er , born in" False 3 ['Joseph', ' Dalton', ' Hook', 'er']
+3881 917 Name of mother of x -1 Name of mother of Joseph Dalton Hooker Maria Turner Joseph Dalton Hooker "[',' ' the' '\n' '\n' 'The' ' following' ' is' ' a' ' list' ' of' ' the'
+ ' children' ' of' ' Joseph' ' Dalton' ' Hook' 'er' ',' ' born' ' in']" ", the
+
+ The following is a list of the children of Joseph Dalton Hook er , born in" False scientists, Joseph Dalton Hooker and Edward 5 [' scientists', ',', ' Joseph', ' Dalton', ' Hook', 'er']
+3882 918 Name of mother of x -1 Name of mother of Moby Elizabeth McBride Warner Moby "[' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mob'
+ 'y' ' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" " Dick , and the
+
+ Name of mother of Mob y Dick , and the
+
+ Name of" False " create ""a superhero Moby Dick; something" 5 "[' create', ' ""', 'a', ' superhero', ' Mob', 'y']"
+3883 918 Name of mother of x -1 Name of mother of Moby Elizabeth McBride Warner Moby "[' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mob'
+ 'y' ' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" " Dick , and the
+
+ Name of mother of Mob y Dick , and the
+
+ Name of" False filming of Moby Dick (1956). 3 [' filming', ' of', ' Mob', 'y']
+3884 918 Name of mother of x -1 Name of mother of Moby Elizabeth McBride Warner Moby "[' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mob'
+ 'y' ' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" " Dick , and the
+
+ Name of mother of Mob y Dick , and the
+
+ Name of" False remixes by Moby and Oscar The Punk. 4 [' remix', 'es', ' by', ' Mob', 'y']
+3885 918 Name of mother of x -1 Name of mother of Moby Elizabeth McBride Warner Moby "[' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mob'
+ 'y' ' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" " Dick , and the
+
+ Name of mother of Mob y Dick , and the
+
+ Name of" False instrumentation performed by Moby himself, including 5 [' instrument', 'ation', ' performed', ' by', ' Mob', 'y']
+3886 918 Name of mother of x -1 Name of mother of Moby Elizabeth McBride Warner Moby "[' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Mob'
+ 'y' ' Dick' ',' ' and' ' the' '\n' '\n' 'Name' ' of']" " Dick , and the
+
+ Name of mother of Mob y Dick , and the
+
+ Name of" False (T-002 and T-103), Moby Dick Toys in 10 [' (', 'T', '-', '002', ' and', ' T', '-', '103', '),', ' Mob', 'y']
+3887 919 Name of mother of x -1 Name of mother of John Wayne Mary Brown John Wayne "[' G' 'acy' ',' ' the' ' serial' ' killer' ' who' ' murdered' ' 33'
+ ' people' ' and' ' buried' ' them' ' in' ' the' ' crawl' ' space' ' of'
+ ' his' ' home']" G acy , the serial killer who murdered 33 people and buried them in the crawl space of his home False onetime hero John Wayne embraced the Vietnam 4 [' on', 'etime', ' hero', ' John', ' Wayne']
+3888 919 Name of mother of x -1 Name of mother of John Wayne Mary Brown John Wayne "[' G' 'acy' ',' ' the' ' serial' ' killer' ' who' ' murdered' ' 33'
+ ' people' ' and' ' buried' ' them' ' in' ' the' ' crawl' ' space' ' of'
+ ' his' ' home']" G acy , the serial killer who murdered 33 people and buried them in the crawl space of his home False Stanwyck, and John Wayne — pressured the 6 [' Stan', 'wy', 'ck', ',', ' and', ' John', ' Wayne']
+3889 919 Name of mother of x -1 Name of mother of John Wayne Mary Brown John Wayne "[' G' 'acy' ',' ' the' ' serial' ' killer' ' who' ' murdered' ' 33'
+ ' people' ' and' ' buried' ' them' ' in' ' the' ' crawl' ' space' ' of'
+ ' his' ' home']" G acy , the serial killer who murdered 33 people and buried them in the crawl space of his home False northward as John Wayne Parkway, a four-lane 4 [' north', 'ward', ' as', ' John', ' Wayne']
+3890 919 Name of mother of x -1 Name of mother of John Wayne Mary Brown John Wayne "[' G' 'acy' ',' ' the' ' serial' ' killer' ' who' ' murdered' ' 33'
+ ' people' ' and' ' buried' ' them' ' in' ' the' ' crawl' ' space' ' of'
+ ' his' ' home']" G acy , the serial killer who murdered 33 people and buried them in the crawl space of his home False extended contracts. John Wayne appeared in 1943's 4 [' extended', ' contracts', '.', ' John', ' Wayne']
+3891 919 Name of mother of x -1 Name of mother of John Wayne Mary Brown John Wayne "[' G' 'acy' ',' ' the' ' serial' ' killer' ' who' ' murdered' ' 33'
+ ' people' ' and' ' buried' ' them' ' in' ' the' ' crawl' ' space' ' of'
+ ' his' ' home']" G acy , the serial killer who murdered 33 people and buried them in the crawl space of his home False appearances by John Wayne, Earp used 3 [' appearances', ' by', ' John', ' Wayne']
+3892 920 Name of mother of x -1 Name of mother of Jane Goodall Margaret Myfanwe Joseph Jane Goodall "[',' ' the' ' famous' ' prim' 'at' 'ologist' ',' ' who' ' has' ' been'
+ ' studying' ' chimpanzees' ' in' ' the' ' wild' ' for' ' more' ' than'
+ ' 50' ' years']" , the famous prim at ologist , who has been studying chimpanzees in the wild for more than 50 years False 1958, primatologist Jane Goodall studied primate 7 [' 1958', ',', ' prim', 'at', 'ologist', ' Jane', ' Good', 'all']
+3893 920 Name of mother of x -1 Name of mother of Jane Goodall Margaret Myfanwe Joseph Jane Goodall "[',' ' the' ' famous' ' prim' 'at' 'ologist' ',' ' who' ' has' ' been'
+ ' studying' ' chimpanzees' ' in' ' the' ' wild' ' for' ' more' ' than'
+ ' 50' ' years']" , the famous prim at ologist , who has been studying chimpanzees in the wild for more than 50 years False researchers such as Jane Goodall and George Schaller 5 [' researchers', ' such', ' as', ' Jane', ' Good', 'all']
+3894 920 Name of mother of x -1 Name of mother of Jane Goodall Margaret Myfanwe Joseph Jane Goodall "[',' ' the' ' famous' ' prim' 'at' 'ologist' ',' ' who' ' has' ' been'
+ ' studying' ' chimpanzees' ' in' ' the' ' wild' ' for' ' more' ' than'
+ ' 50' ' years']" , the famous prim at ologist , who has been studying chimpanzees in the wild for more than 50 years False 1958, primatologist Jane Goodall studied primate 7 [' 1958', ',', ' prim', 'at', 'ologist', ' Jane', ' Good', 'all']
+3895 920 Name of mother of x -1 Name of mother of Jane Goodall Margaret Myfanwe Joseph Jane Goodall "[',' ' the' ' famous' ' prim' 'at' 'ologist' ',' ' who' ' has' ' been'
+ ' studying' ' chimpanzees' ' in' ' the' ' wild' ' for' ' more' ' than'
+ ' 50' ' years']" , the famous prim at ologist , who has been studying chimpanzees in the wild for more than 50 years False first reported by Jane Goodall in 1966, has however 5 [' first', ' reported', ' by', ' Jane', ' Good', 'all']
+3896 920 Name of mother of x -1 Name of mother of Jane Goodall Margaret Myfanwe Joseph Jane Goodall "[',' ' the' ' famous' ' prim' 'at' 'ologist' ',' ' who' ' has' ' been'
+ ' studying' ' chimpanzees' ' in' ' the' ' wild' ' for' ' more' ' than'
+ ' 50' ' years']" , the famous prim at ologist , who has been studying chimpanzees in the wild for more than 50 years False 2 ['Jane', ' Good', 'all']
+3897 921 Name of mother of x -1 Name of mother of Joseph II, Holy Roman Emperor Maria Theresa of Austria Joseph II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False Europe as Joseph II, Holy Roman Emperor viewed the 7 [' Europe', ' as', ' Joseph', ' II', ',', ' Holy', ' Roman', ' Emperor']
+3898 921 Name of mother of x -1 Name of mother of Joseph II, Holy Roman Emperor Maria Theresa of Austria Joseph II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False rulers of Europe as Joseph II, Holy Roman Emperor viewed the revolution 9 [' rulers', ' of', ' Europe', ' as', ' Joseph', ' II', ',', ' Holy', ' Roman', ' Emperor']
+3899 921 Name of mother of x -1 Name of mother of Joseph II, Holy Roman Emperor Maria Theresa of Austria Joseph II, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False rulers of Europe as Joseph II, Holy Roman Emperor viewed the revolution 9 [' rulers', ' of', ' Europe', ' as', ' Joseph', ' II', ',', ' Holy', ' Roman', ' Emperor']
+3900 923 Name of mother of x -1 Name of mother of Javier Bardem Pilar Bardem Javier Bardem "[',' ' the' ' actor' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' is' ' a' ' former' ' member' ' of' ' the' ' Spanish' ' National'
+ ' Team']" , the actor who plays the lead in the film , is a former member of the Spanish National Team False Olga Kurylenko, Javier Bardem and Rachel 9 [' Ol', 'ga', ' K', 'ury', 'len', 'ko', ',', ' Javier', ' Bard', 'em']
+3901 923 Name of mother of x -1 Name of mother of Javier Bardem Pilar Bardem Javier Bardem "[',' ' the' ' actor' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' is' ' a' ' former' ' member' ' of' ' the' ' Spanish' ' National'
+ ' Team']" , the actor who plays the lead in the film , is a former member of the Spanish National Team False 3 ['J', 'avier', ' Bard', 'em']
+3902 923 Name of mother of x -1 Name of mother of Javier Bardem Pilar Bardem Javier Bardem "[',' ' the' ' actor' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' is' ' a' ' former' ' member' ' of' ' the' ' Spanish' ' National'
+ ' Team']" , the actor who plays the lead in the film , is a former member of the Spanish National Team False appearing opposite Javier Bardem and Penélope 4 [' appearing', ' opposite', ' Javier', ' Bard', 'em']
+3903 923 Name of mother of x -1 Name of mother of Javier Bardem Pilar Bardem Javier Bardem "[',' ' the' ' actor' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' is' ' a' ' former' ' member' ' of' ' the' ' Spanish' ' National'
+ ' Team']" , the actor who plays the lead in the film , is a former member of the Spanish National Team False Gekko in the film. Javier Bardem was in final 9 [' G', 'ek', 'ko', ' in', ' the', ' film', '.', ' Javier', ' Bard', 'em']
+3904 923 Name of mother of x -1 Name of mother of Javier Bardem Pilar Bardem Javier Bardem "[',' ' the' ' actor' ' who' ' plays' ' the' ' lead' ' in' ' the' ' film'
+ ',' ' is' ' a' ' former' ' member' ' of' ' the' ' Spanish' ' National'
+ ' Team']" , the actor who plays the lead in the film , is a former member of the Spanish National Team False 3 ['J', 'avier', ' Bard', 'em']
+3905 924 Name of mother of x -1 Name of mother of Lindsay Lohan Chela Lora Lindsay Lohan "[',' ' who' ' is' ' a' ' former' ' porn' ' star' ',' ' and' ' her'
+ ' daughter' ',' ' who' ' is' ' a' ' former' ' porn' ' star' '.' '\n']" ", who is a former porn star , and her daughter , who is a former porn star .
+" False Cynthia Nixon and Lindsay Lohan commenting openly on 5 [' Cynthia', ' Nixon', ' and', ' Lindsay', ' L', 'ohan']
+3906 924 Name of mother of x -1 Name of mother of Lindsay Lohan Chela Lora Lindsay Lohan "[',' ' who' ' is' ' a' ' former' ' porn' ' star' ',' ' and' ' her'
+ ' daughter' ',' ' who' ' is' ' a' ' former' ' porn' ' star' '.' '\n']" ", who is a former porn star , and her daughter , who is a former porn star .
+" False 3 ['Lind', 'say', ' L', 'ohan']
+3907 924 Name of mother of x -1 Name of mother of Lindsay Lohan Chela Lora Lindsay Lohan "[',' ' who' ' is' ' a' ' former' ' porn' ' star' ',' ' and' ' her'
+ ' daughter' ',' ' who' ' is' ' a' ' former' ' porn' ' star' '.' '\n']" ", who is a former porn star , and her daughter , who is a former porn star .
+" False Gravano and actress Lindsay Lohan have both filed 6 [' Grav', 'ano', ' and', ' actress', ' Lindsay', ' L', 'ohan']
+3908 924 Name of mother of x -1 Name of mother of Lindsay Lohan Chela Lora Lindsay Lohan "[',' ' who' ' is' ' a' ' former' ' porn' ' star' ',' ' and' ' her'
+ ' daughter' ',' ' who' ' is' ' a' ' former' ' porn' ' star' '.' '\n']" ", who is a former porn star , and her daughter , who is a former porn star .
+" False of talking to Lindsay Lohan and that guy from 5 [' of', ' talking', ' to', ' Lindsay', ' L', 'ohan']
+3909 924 Name of mother of x -1 Name of mother of Lindsay Lohan Chela Lora Lindsay Lohan "[',' ' who' ' is' ' a' ' former' ' porn' ' star' ',' ' and' ' her'
+ ' daughter' ',' ' who' ' is' ' a' ' former' ' porn' ' star' '.' '\n']" ", who is a former porn star , and her daughter , who is a former porn star .
+" False " Lohan =
+" 6 [' L', 'ohan', ' =', 'Lind', 'say', ' L', 'ohan']
+3910 925 Name of mother of x -1 Name of mother of Horace Walpole Catherine, Lady Walpole Horace Walpole "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' _' 'Spect'
+ 'ator' '_' ',' ' and' ' I' ' find' ' it' ' very' ' entertaining']" ", the
+
+ I have been reading the _ Spect ator _ , and I find it very entertaining" False " in Paris by Horace Walpole for the ""Gallery""" 6 [' in', ' Paris', ' by', ' Hor', 'ace', ' Wal', 'pole']
+3911 925 Name of mother of x -1 Name of mother of Horace Walpole Catherine, Lady Walpole Horace Walpole "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' _' 'Spect'
+ 'ator' '_' ',' ' and' ' I' ' find' ' it' ' very' ' entertaining']" ", the
+
+ I have been reading the _ Spect ator _ , and I find it very entertaining" False " entitled"". While Horace Walpole praised the accompanying" 6 "[' entitled', '"".', ' While', ' Hor', 'ace', ' Wal', 'pole']"
+3912 925 Name of mother of x -1 Name of mother of Horace Walpole Catherine, Lady Walpole Horace Walpole "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' _' 'Spect'
+ 'ator' '_' ',' ' and' ' I' ' find' ' it' ' very' ' entertaining']" ", the
+
+ I have been reading the _ Spect ator _ , and I find it very entertaining" False " works entitled"". While Horace Walpole praised the accompanying" 7 "[' works', ' entitled', '"".', ' While', ' Hor', 'ace', ' Wal', 'pole']"
+3913 925 Name of mother of x -1 Name of mother of Horace Walpole Catherine, Lady Walpole Horace Walpole "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' _' 'Spect'
+ 'ator' '_' ',' ' and' ' I' ' find' ' it' ' very' ' entertaining']" ", the
+
+ I have been reading the _ Spect ator _ , and I find it very entertaining" False Italian for Horace Walpole to remark in 1740, 5 [' Italian', ' for', ' Hor', 'ace', ' Wal', 'pole']
+3914 925 Name of mother of x -1 Name of mother of Horace Walpole Catherine, Lady Walpole Horace Walpole "[',' ' the' '\n' '\n' 'I' ' have' ' been' ' reading' ' the' ' _' 'Spect'
+ 'ator' '_' ',' ' and' ' I' ' find' ' it' ' very' ' entertaining']" ", the
+
+ I have been reading the _ Spect ator _ , and I find it very entertaining" False further popularised by Horace Walpole during the 18th 7 [' further', ' popular', 'ised', ' by', ' Hor', 'ace', ' Wal', 'pole']
+3915 926 Name of mother of x -1 Name of mother of Jay-Z Gloria Carter Jay-Z "[',' ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay' '-' 'Z' ','
+ ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay']" , the rapper , and the rapper , Jay - Z , the rapper , and the rapper , Jay False aspiring rapper Jay-Z was recruited 4 [' aspiring', ' rapper', ' Jay', '-', 'Z']
+3916 926 Name of mother of x -1 Name of mother of Jay-Z Gloria Carter Jay-Z "[',' ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay' '-' 'Z' ','
+ ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay']" , the rapper , and the rapper , Jay - Z , the rapper , and the rapper , Jay False 2 ['Jay', '-', 'Z']
+3917 926 Name of mother of x -1 Name of mother of Jay-Z Gloria Carter Jay-Z "[',' ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay' '-' 'Z' ','
+ ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay']" , the rapper , and the rapper , Jay - Z , the rapper , and the rapper , Jay False filmed quickly as Jay-Z had to travel to London 5 [' filmed', ' quickly', ' as', ' Jay', '-', 'Z']
+3918 926 Name of mother of x -1 Name of mother of Jay-Z Gloria Carter Jay-Z "[',' ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay' '-' 'Z' ','
+ ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay']" , the rapper , and the rapper , Jay - Z , the rapper , and the rapper , Jay False " Pimpin '"" by Jay-Z and ""Push It"" by Salt-n-Pepa." 8 "[' P', 'imp', 'in', "" '"", '""', ' by', ' Jay', '-', 'Z']"
+3919 926 Name of mother of x -1 Name of mother of Jay-Z Gloria Carter Jay-Z "[',' ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay' '-' 'Z' ','
+ ' the' ' rapper' ',' ' and' ' the' ' rapper' ',' ' Jay']" , the rapper , and the rapper , Jay - Z , the rapper , and the rapper , Jay False Glastonbury 5 [' Gl', 'aston', 'bur', 'Jay', '-', 'Z']
+3920 927 Name of mother of x -1 Name of mother of Alexander Graham Bell Eliza Bell Alexander Graham Bell "[',' ' the' ' inventor' ' of' ' the' ' telephone' ',' ' and' ' the'
+ ' first' ' person' ' to' ' make' ' a' ' telephone' ' call' ' over' ' a'
+ ' distance' ' of']" , the inventor of the telephone , and the first person to make a telephone call over a distance of False dedicated as the Alexander Graham Bell Parkway on April 5 [' dedicated', ' as', ' the', ' Alexander', ' Graham', ' Bell']
+3921 927 Name of mother of x -1 Name of mother of Alexander Graham Bell Eliza Bell Alexander Graham Bell "[',' ' the' ' inventor' ' of' ' the' ' telephone' ',' ' and' ' the'
+ ' first' ' person' ' to' ' make' ' a' ' telephone' ' call' ' over' ' a'
+ ' distance' ' of']" , the inventor of the telephone , and the first person to make a telephone call over a distance of False and his teenage son. Alexander Graham Bell obtained a copy of 7 [' and', ' his', ' teenage', ' son', '.', ' Alexander', ' Graham', ' Bell']
+3922 927 Name of mother of x -1 Name of mother of Alexander Graham Bell Eliza Bell Alexander Graham Bell "[',' ' the' ' inventor' ' of' ' the' ' telephone' ',' ' and' ' the'
+ ' first' ' person' ' to' ' make' ' a' ' telephone' ' call' ' over' ' a'
+ ' distance' ' of']" , the inventor of the telephone , and the first person to make a telephone call over a distance of False " my oatmeal."" Alexander Graham Bell tried to locate" 6 "[' my', ' o', 'atmeal', '.""', ' Alexander', ' Graham', ' Bell']"
+3923 927 Name of mother of x -1 Name of mother of Alexander Graham Bell Eliza Bell Alexander Graham Bell "[',' ' the' ' inventor' ' of' ' the' ' telephone' ',' ' and' ' the'
+ ' first' ' person' ' to' ' make' ' a' ' telephone' ' call' ' over' ' a'
+ ' distance' ' of']" , the inventor of the telephone , and the first person to make a telephone call over a distance of False 2 ['Alexander', ' Graham', ' Bell']
+3924 927 Name of mother of x -1 Name of mother of Alexander Graham Bell Eliza Bell Alexander Graham Bell "[',' ' the' ' inventor' ' of' ' the' ' telephone' ',' ' and' ' the'
+ ' first' ' person' ' to' ' make' ' a' ' telephone' ' call' ' over' ' a'
+ ' distance' ' of']" , the inventor of the telephone , and the first person to make a telephone call over a distance of False shrink. Independently, Alexander Graham Bell also suggested the 8 [' shrink', '.', ' Ind', 'epend', 'ently', ',', ' Alexander', ' Graham', ' Bell']
+3925 928 Name of mother of x -1 Name of mother of Ulysses S. Grant Hannah Simpson Grant Ulysses S. Grant "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' of' ' the' ' United' ' States' ' of' ' America' ',' ' to' '\n']" ", the
+
+ President of the United States , and of the United States of America , to
+" False 1864 by promoting Ulysses S. Grant to commander of 9 [' 18', '64', ' by', ' promoting', ' U', 'lyss', 'es', ' S', '.', ' Grant']
+3926 928 Name of mother of x -1 Name of mother of Ulysses S. Grant Hannah Simpson Grant Ulysses S. Grant "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' of' ' the' ' United' ' States' ' of' ' America' ',' ' to' '\n']" ", the
+
+ President of the United States , and of the United States of America , to
+" False Union General Ulysses S. Grant ordered an assault 7 [' Union', ' General', ' U', 'lyss', 'es', ' S', '.', ' Grant']
+3927 928 Name of mother of x -1 Name of mother of Ulysses S. Grant Hannah Simpson Grant Ulysses S. Grant "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' of' ' the' ' United' ' States' ' of' ' America' ',' ' to' '\n']" ", the
+
+ President of the United States , and of the United States of America , to
+" False Republican Presidents Ulysses S. Grant and Rutherford B. Hayes, 7 [' Republican', ' Presidents', ' U', 'lyss', 'es', ' S', '.', ' Grant']
+3928 928 Name of mother of x -1 Name of mother of Ulysses S. Grant Hannah Simpson Grant Ulysses S. Grant "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' of' ' the' ' United' ' States' ' of' ' America' ',' ' to' '\n']" ", the
+
+ President of the United States , and of the United States of America , to
+" False 1871, U.S. President Ulysses S. Grant had named Hall 13 [' 18', '71', ',', ' U', '.', 'S', '.', ' President', ' U', 'lyss', 'es', ' S', '.', ' Grant']
+3929 928 Name of mother of x -1 Name of mother of Ulysses S. Grant Hannah Simpson Grant Ulysses S. Grant "[',' ' the' '\n' '\n' 'President' ' of' ' the' ' United' ' States' ','
+ ' and' ' of' ' the' ' United' ' States' ' of' ' America' ',' ' to' '\n']" ", the
+
+ President of the United States , and of the United States of America , to
+" False " stated that Ulysses S. Grant was ""incapable of" 7 [' stated', ' that', ' U', 'lyss', 'es', ' S', '.', ' Grant']
+3930 929 Name of mother of x -1 Name of mother of Dick Grayson Mary Grayson Dick Grayson "[',' ' the' ' son' ' of' ' Bruce' ' Wayne' ' and' ' Diana' '.' '\n' '\n'
+ 'The' ' first' ' thing' ' I' ' noticed' ' about' ' the' ' new' ' Batman']" ", the son of Bruce Wayne and Diana .
+
+ The first thing I noticed about the new Batman" False circus acrobat named Dick Grayson (O 'Donnell), who 5 [' circus', ' ac', 'robat', ' named', ' Dick', ' Grayson']
+3931 929 Name of mother of x -1 Name of mother of Dick Grayson Mary Grayson Dick Grayson "[',' ' the' ' son' ' of' ' Bruce' ' Wayne' ' and' ' Diana' '.' '\n' '\n'
+ 'The' ' first' ' thing' ' I' ' noticed' ' about' ' the' ' new' ' Batman']" ", the son of Bruce Wayne and Diana .
+
+ The first thing I noticed about the new Batman" False " little work."" Dick Grayson appeared in the shooting" 4 "[' little', ' work', '.""', ' Dick', ' Grayson']"
+3932 929 Name of mother of x -1 Name of mother of Dick Grayson Mary Grayson Dick Grayson "[',' ' the' ' son' ' of' ' Bruce' ' Wayne' ' and' ' Diana' '.' '\n' '\n'
+ 'The' ' first' ' thing' ' I' ' noticed' ' about' ' the' ' new' ' Batman']" ", the son of Bruce Wayne and Diana .
+
+ The first thing I noticed about the new Batman" False Bruce Wayne, Dick Grayson (having recovered 4 [' Bruce', ' Wayne', ',', ' Dick', ' Grayson']
+3933 929 Name of mother of x -1 Name of mother of Dick Grayson Mary Grayson Dick Grayson "[',' ' the' ' son' ' of' ' Bruce' ' Wayne' ' and' ' Diana' '.' '\n' '\n'
+ 'The' ' first' ' thing' ' I' ' noticed' ' about' ' the' ' new' ' Batman']" ", the son of Bruce Wayne and Diana .
+
+ The first thing I noticed about the new Batman" False The Dark Knight. Dick Grayson returns to the mantle 5 [' The', ' Dark', ' Knight', '.', ' Dick', ' Grayson']
+3934 929 Name of mother of x -1 Name of mother of Dick Grayson Mary Grayson Dick Grayson "[',' ' the' ' son' ' of' ' Bruce' ' Wayne' ' and' ' Diana' '.' '\n' '\n'
+ 'The' ' first' ' thing' ' I' ' noticed' ' about' ' the' ' new' ' Batman']" ", the son of Bruce Wayne and Diana .
+
+ The first thing I noticed about the new Batman" False Barbara Gordon and Dick Grayson is also cut short 4 [' Barbara', ' Gordon', ' and', ' Dick', ' Grayson']
+3935 932 Name of mother of x -1 Name of mother of John Gielgud Kate Terry-Lewis John Gielgud "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False Jeanne Moreau and John Gielgud being available 8 [' Jeanne', ' More', 'au', ' and', ' John', ' Gi', 'el', 'g', 'ud']
+3936 932 Name of mother of x -1 Name of mother of John Gielgud Kate Terry-Lewis John Gielgud "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False contemporaries John Gielgud and Laurence 5 [' contemporaries', ' John', ' Gi', 'el', 'g', 'ud']
+3937 932 Name of mother of x -1 Name of mother of John Gielgud Kate Terry-Lewis John Gielgud "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False Alec Guinness in a John Gielgud production of 8 [' Alec', ' Guinness', ' in', ' a', ' John', ' Gi', 'el', 'g', 'ud']
+3938 932 Name of mother of x -1 Name of mother of John Gielgud Kate Terry-Lewis John Gielgud "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False letters from actor John Gielgud and literary 7 [' letters', ' from', ' actor', ' John', ' Gi', 'el', 'g', 'ud']
+3939 932 Name of mother of x -1 Name of mother of John Gielgud Kate Terry-Lewis John Gielgud "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' and' ' the' ' two' ' of' ' them' ' were' ' in' ' the' ' audience']" , the actor , and his wife , the actress , and the two of them were in the audience False Shakespeare experts Sir John Gielgud and Kenneth Branagh 7 [' Shakespeare', ' experts', ' Sir', ' John', ' Gi', 'el', 'g', 'ud']
+3940 933 Name of mother of x -1 Name of mother of Milan Kundera Milada Kunderová Milan Kundera "[',' ' the' ' Czech' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the Czech writer , who was a friend of the family , and who had been a friend of False The writer Milan Kundera suggests that 5 [' The', ' writer', ' Milan', ' K', 'under', 'a']
+3941 933 Name of mother of x -1 Name of mother of Milan Kundera Milada Kunderová Milan Kundera "[',' ' the' ' Czech' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the Czech writer , who was a friend of the family , and who had been a friend of False promises for reform, Milan Kundera published the article 7 [' promises', ' for', ' reform', ',', ' Milan', ' K', 'under', 'a']
+3942 933 Name of mother of x -1 Name of mother of Milan Kundera Milada Kunderová Milan Kundera "[',' ' the' ' Czech' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the Czech writer , who was a friend of the family , and who had been a friend of False prose. The writer Milan Kundera suggests that 7 [' prose', '.', ' The', ' writer', ' Milan', ' K', 'under', 'a']
+3943 933 Name of mother of x -1 Name of mother of Milan Kundera Milada Kunderová Milan Kundera "[',' ' the' ' Czech' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the Czech writer , who was a friend of the family , and who had been a friend of False " ""Intimate Letters"". Milan Kundera called these" 8 "[' ""', 'Int', 'imate', ' Letters', '"".', ' Milan', ' K', 'under', 'a']"
+3944 933 Name of mother of x -1 Name of mother of Milan Kundera Milada Kunderová Milan Kundera "[',' ' the' ' Czech' ' writer' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' ',' ' and' ' who' ' had' ' been' ' a' ' friend' ' of']" , the Czech writer , who was a friend of the family , and who had been a friend of False for reform, Milan Kundera published the article 6 [' for', ' reform', ',', ' Milan', ' K', 'under', 'a']
+3945 934 Name of mother of x -1 Name of mother of Hedy Lamarr Gertrud Kiesler Hedy Lamarr "[',' ' the' ' actress' ',' ' and' ' the' ' inventor' ' of' ' the'
+ ' frequency' ' hopping' ' spread' ' spectrum' ' technology' ' that'
+ ' made' ' the' ' cell' ' phone' ' possible']" , the actress , and the inventor of the frequency hopping spread spectrum technology that made the cell phone possible False Ladd as Rick, Hedy Lamarr as Ilsa, and John 8 [' L', 'add', ' as', ' Rick', ',', ' H', 'edy', ' Lam', 'arr']
+3946 934 Name of mother of x -1 Name of mother of Hedy Lamarr Gertrud Kiesler Hedy Lamarr "[',' ' the' ' actress' ',' ' and' ' the' ' inventor' ' of' ' the'
+ ' frequency' ' hopping' ' spread' ' spectrum' ' technology' ' that'
+ ' made' ' the' ' cell' ' phone' ' possible']" , the actress , and the inventor of the frequency hopping spread spectrum technology that made the cell phone possible False Alan Ladd as Rick, Hedy Lamarr as Ilsa, and John 9 [' Alan', ' L', 'add', ' as', ' Rick', ',', ' H', 'edy', ' Lam', 'arr']
+3947 934 Name of mother of x -1 Name of mother of Hedy Lamarr Gertrud Kiesler Hedy Lamarr "[',' ' the' ' actress' ',' ' and' ' the' ' inventor' ' of' ' the'
+ ' frequency' ' hopping' ' spread' ' spectrum' ' technology' ' that'
+ ' made' ' the' ' cell' ' phone' ' possible']" , the actress , and the inventor of the frequency hopping spread spectrum technology that made the cell phone possible False This Woman with Hedy Lamarr was a critical 6 [' This', ' Woman', ' with', ' H', 'edy', ' Lam', 'arr']
+3948 934 Name of mother of x -1 Name of mother of Hedy Lamarr Gertrud Kiesler Hedy Lamarr "[',' ' the' ' actress' ',' ' and' ' the' ' inventor' ' of' ' the'
+ ' frequency' ' hopping' ' spread' ' spectrum' ' technology' ' that'
+ ' made' ' the' ' cell' ' phone' ' possible']" , the actress , and the inventor of the frequency hopping spread spectrum technology that made the cell phone possible False Alan Ladd as Rick, Hedy Lamarr as Ilsa, and John 9 [' Alan', ' L', 'add', ' as', ' Rick', ',', ' H', 'edy', ' Lam', 'arr']
+3949 934 Name of mother of x -1 Name of mother of Hedy Lamarr Gertrud Kiesler Hedy Lamarr "[',' ' the' ' actress' ',' ' and' ' the' ' inventor' ' of' ' the'
+ ' frequency' ' hopping' ' spread' ' spectrum' ' technology' ' that'
+ ' made' ' the' ' cell' ' phone' ' possible']" , the actress , and the inventor of the frequency hopping spread spectrum technology that made the cell phone possible False and looked to Hedy Lamarr — who was 6 [' and', ' looked', ' to', ' H', 'edy', ' Lam', 'arr']
+3950 935 Name of mother of x -1 Name of mother of Jiang Zemin Wang Zhelan Jiang Zemin "[',' ' the' ' former' ' president' ' of' ' China' ',' ' and' ' the'
+ ' former' ' president' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' and']" , the former president of China , and the former president of the People 's Republic of China , and False " economic activity"". Jiang Zemin supported" 5 "[' economic', ' activity', '"".', ' Jiang', ' Z', 'emin']"
+3951 935 Name of mother of x -1 Name of mother of Jiang Zemin Wang Zhelan Jiang Zemin "[',' ' the' ' former' ' president' ' of' ' China' ',' ' and' ' the'
+ ' former' ' president' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' and']" , the former president of China , and the former president of the People 's Republic of China , and False On 7 June 1999, Jiang Zemin convened a meeting 7 [' On', ' 7', ' June', ' 1999', ',', ' Jiang', ' Z', 'emin']
+3952 935 Name of mother of x -1 Name of mother of Jiang Zemin Wang Zhelan Jiang Zemin "[',' ' the' ' former' ' president' ' of' ' China' ',' ' and' ' the'
+ ' former' ' president' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' and']" , the former president of China , and the former president of the People 's Republic of China , and False 2009, however, Jiang Zemin and Luo Gan were 6 [' 2009', ',', ' however', ',', ' Jiang', ' Z', 'emin']
+3953 935 Name of mother of x -1 Name of mother of Jiang Zemin Wang Zhelan Jiang Zemin "[',' ' the' ' former' ' president' ' of' ' China' ',' ' and' ' the'
+ ' former' ' president' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' and']" , the former president of China , and the former president of the People 's Republic of China , and False General Secretary Jiang Zemin succeeded Deng as “ 4 [' General', ' Secretary', ' Jiang', ' Z', 'emin']
+3954 935 Name of mother of x -1 Name of mother of Jiang Zemin Wang Zhelan Jiang Zemin "[',' ' the' ' former' ' president' ' of' ' China' ',' ' and' ' the'
+ ' former' ' president' ' of' ' the' ' People' ""'s"" ' Republic' ' of'
+ ' China' ',' ' and']" , the former president of China , and the former president of the People 's Republic of China , and False general secretary Jiang Zemin in preparing to force 4 [' general', ' secretary', ' Jiang', ' Z', 'emin']
+3955 936 Name of mother of x -1 Name of mother of Robert Hooke Cecily Gyles Robert Hooke "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Robert'
+ ' Ho' 'oke' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Robert Ho oke , the
+
+ The name of" False John Wilkins and Robert Hooke at John Comstock's 6 [' John', ' Wil', 'kins', ' and', ' Robert', ' Ho', 'oke']
+3956 936 Name of mother of x -1 Name of mother of Robert Hooke Cecily Gyles Robert Hooke "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Robert'
+ ' Ho' 'oke' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Robert Ho oke , the
+
+ The name of" False microscope, Polymath Robert Hooke discovered 6 [' microscope', ',', ' Poly', 'math', ' Robert', ' Ho', 'oke']
+3957 936 Name of mother of x -1 Name of mother of Robert Hooke Cecily Gyles Robert Hooke "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Robert'
+ ' Ho' 'oke' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Robert Ho oke , the
+
+ The name of" False and calculus, and Robert Hooke his eponymously 6 [' and', ' calculus', ',', ' and', ' Robert', ' Ho', 'oke']
+3958 936 Name of mother of x -1 Name of mother of Robert Hooke Cecily Gyles Robert Hooke "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Robert'
+ ' Ho' 'oke' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Robert Ho oke , the
+
+ The name of" False 1675, Huygens and Robert Hooke invented the 10 [' 16', '75', ',', ' H', 'uy', 'g', 'ens', ' and', ' Robert', ' Ho', 'oke']
+3959 936 Name of mother of x -1 Name of mother of Robert Hooke Cecily Gyles Robert Hooke "[',' ' the' '\n' '\n' 'The' ' name' ' of' ' the' ' mother' ' of' ' Robert'
+ ' Ho' 'oke' ',' ' the' '\n' '\n' 'The' ' name' ' of']" ", the
+
+ The name of the mother of Robert Ho oke , the
+
+ The name of" False Descartes (1637), Robert Hooke (1665), and Christiaan 9 [' Des', 'cart', 'es', ' (', '16', '37', '),', ' Robert', ' Ho', 'oke']
+3960 937 Name of mother of x -1 Name of mother of Judi Dench Eleonora Olive Jones Judi Dench "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Le Chiffre; Judi Dench returned for her 8 [' Le', ' Ch', 'iff', 're', ';', ' Jud', 'i', ' Den', 'ch']
+3961 937 Name of mother of x -1 Name of mother of Judi Dench Eleonora Olive Jones Judi Dench "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False York actress Dame Judi Dench acted in the 6 [' York', ' actress', ' Dame', ' Jud', 'i', ' Den', 'ch']
+3962 937 Name of mother of x -1 Name of mother of Judi Dench Eleonora Olive Jones Judi Dench "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False The head of 6 [' The', ' head', ' o', 'Jud', 'i', ' Den', 'ch']
+3963 937 Name of mother of x -1 Name of mother of Judi Dench Eleonora Olive Jones Judi Dench "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False something like the Judi Dench [as M or] Samuel 6 [' something', ' like', ' the', ' Jud', 'i', ' Den', 'ch']
+3964 937 Name of mother of x -1 Name of mother of Judi Dench Eleonora Olive Jones Judi Dench "[',' ' who' ' is' ' a' ' very' ' good' ' friend' ' of' ' mine' '.' '\n'
+ '\n' 'I' ' am' ' a' ' big' ' fan' ' of' ' the' ' show']" ", who is a very good friend of mine .
+
+ I am a big fan of the show" False Drama, but Judi Dench was awarded the 6 [' Drama', ',', ' but', ' Jud', 'i', ' Den', 'ch']
+3965 938 Name of mother of x -1 Name of mother of Daryl Hannah Susan Metzger Daryl Hannah "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' her' ' husband' ',' ' actor' ' John' ' Goodman' ',' ' who' ' is' ' a']" , who is a former Miss Universe winner , and her husband , actor John Goodman , who is a False character. Stone cast Daryl Hannah as Bud Fox's materialistic 5 [' character', '.', ' Stone', ' cast', ' Daryl', ' Hannah']
+3966 938 Name of mother of x -1 Name of mother of Daryl Hannah Susan Metzger Daryl Hannah "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' her' ' husband' ',' ' actor' ' John' ' Goodman' ',' ' who' ' is' ' a']" , who is a former Miss Universe winner , and her husband , actor John Goodman , who is a False Cindy Crawford and Daryl Hannah successfully fought 4 [' Cindy', ' Crawford', ' and', ' Daryl', ' Hannah']
+3967 938 Name of mother of x -1 Name of mother of Daryl Hannah Susan Metzger Daryl Hannah "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' her' ' husband' ',' ' actor' ' John' ' Goodman' ',' ' who' ' is' ' a']" , who is a former Miss Universe winner , and her husband , actor John Goodman , who is a False starring actors Daryl Hannah and Jon Favreau, 3 [' starring', ' actors', ' Daryl', ' Hannah']
+3968 938 Name of mother of x -1 Name of mother of Daryl Hannah Susan Metzger Daryl Hannah "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' her' ' husband' ',' ' actor' ' John' ' Goodman' ',' ' who' ' is' ' a']" , who is a former Miss Universe winner , and her husband , actor John Goodman , who is a False Stone cast Daryl Hannah as Bud Fox's 3 [' Stone', ' cast', ' Daryl', ' Hannah']
+3969 938 Name of mother of x -1 Name of mother of Daryl Hannah Susan Metzger Daryl Hannah "[',' ' who' ' is' ' a' ' former' ' Miss' ' Universe' ' winner' ',' ' and'
+ ' her' ' husband' ',' ' actor' ' John' ' Goodman' ',' ' who' ' is' ' a']" , who is a former Miss Universe winner , and her husband , actor John Goodman , who is a False Cindy Crawford and Daryl Hannah successfully fought 4 [' Cindy', ' Crawford', ' and', ' Daryl', ' Hannah']
+3970 939 Name of mother of x -1 Name of mother of Tom Stoppard Martha Stoppard Tom Stoppard "[""'s"" ' _' 'The' ' Invention' ' of' ' Love' '_' ',' ' and' ' the' ' film'
+ ' _' 'The' ' Last' ' of' ' the' ' Moh' 'icans' '_' ',']" 's _ The Invention of Love _ , and the film _ The Last of the Moh icans _ , False Gaghan and Tom Stoppard were also 6 [' G', 'aghan', ' and', ' Tom', ' St', 'opp', 'ard']
+3971 939 Name of mother of x -1 Name of mother of Tom Stoppard Martha Stoppard Tom Stoppard "[""'s"" ' _' 'The' ' Invention' ' of' ' Love' '_' ',' ' and' ' the' ' film'
+ ' _' 'The' ' Last' ' of' ' the' ' Moh' 'icans' '_' ',']" 's _ The Invention of Love _ , and the film _ The Last of the Moh icans _ , False " Crawford marks Tom Stoppard as ""the most Shavian" 5 [' Crawford', ' marks', ' Tom', ' St', 'opp', 'ard']
+3972 939 Name of mother of x -1 Name of mother of Tom Stoppard Martha Stoppard Tom Stoppard "[""'s"" ' _' 'The' ' Invention' ' of' ' Love' '_' ',' ' and' ' the' ' film'
+ ' _' 'The' ' Last' ' of' ' the' ' Moh' 'icans' '_' ',']" 's _ The Invention of Love _ , and the film _ The Last of the Moh icans _ , False and a rewrite by Tom Stoppard (under the pen 7 [' and', ' a', ' rewrite', ' by', ' Tom', ' St', 'opp', 'ard']
+3973 939 Name of mother of x -1 Name of mother of Tom Stoppard Martha Stoppard Tom Stoppard "[""'s"" ' _' 'The' ' Invention' ' of' ' Love' '_' ',' ' and' ' the' ' film'
+ ' _' 'The' ' Last' ' of' ' the' ' Moh' 'icans' '_' ',']" 's _ The Invention of Love _ , and the film _ The Last of the Moh icans _ , False revision and a rewrite by Tom Stoppard (under the pen 8 [' revision', ' and', ' a', ' rewrite', ' by', ' Tom', ' St', 'opp', 'ard']
+3974 939 Name of mother of x -1 Name of mother of Tom Stoppard Martha Stoppard Tom Stoppard "[""'s"" ' _' 'The' ' Invention' ' of' ' Love' '_' ',' ' and' ' the' ' film'
+ ' _' 'The' ' Last' ' of' ' the' ' Moh' 'icans' '_' ',']" 's _ The Invention of Love _ , and the film _ The Last of the Moh icans _ , False 3 ['Tom', ' St', 'opp', 'ard']
+3975 940 Name of mother of x -1 Name of mother of Angélique Kidjo Yvonne Kidjo Angélique Kidjo "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' the' ' musician'
+ ',' ' musician' ',' ' and' ' producer' ',' ' Ang' 'é' 'lique' ' Kid']" , the singer , and her husband , the musician , musician , and producer , Ang é lique Kid False announced via Twitter that Angélique Kidjo will be featured 8 [' announced', ' via', ' Twitter', ' that', ' Ang', 'é', 'lique', ' Kid', 'jo']
+3976 940 Name of mother of x -1 Name of mother of Angélique Kidjo Yvonne Kidjo Angélique Kidjo "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' the' ' musician'
+ ',' ' musician' ',' ' and' ' producer' ',' ' Ang' 'é' 'lique' ' Kid']" , the singer , and her husband , the musician , musician , and producer , Ang é lique Kid False Twitter that Angélique Kidjo will be featured on 6 [' Twitter', ' that', ' Ang', 'é', 'lique', ' Kid', 'jo']
+3977 940 Name of mother of x -1 Name of mother of Angélique Kidjo Yvonne Kidjo Angélique Kidjo "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' the' ' musician'
+ ',' ' musician' ',' ' and' ' producer' ',' ' Ang' 'é' 'lique' ' Kid']" , the singer , and her husband , the musician , musician , and producer , Ang é lique Kid False via Twitter that Angélique Kidjo will be featured on 7 [' via', ' Twitter', ' that', ' Ang', 'é', 'lique', ' Kid', 'jo']
+3978 940 Name of mother of x -1 Name of mother of Angélique Kidjo Yvonne Kidjo Angélique Kidjo "[',' ' the' ' singer' ',' ' and' ' her' ' husband' ',' ' the' ' musician'
+ ',' ' musician' ',' ' and' ' producer' ',' ' Ang' 'é' 'lique' ' Kid']" , the singer , and her husband , the musician , musician , and producer , Ang é lique Kid False announced via Twitter that Angélique Kidjo will be featured 8 [' announced', ' via', ' Twitter', ' that', ' Ang', 'é', 'lique', ' Kid', 'jo']
+3979 941 Name of mother of x -1 Name of mother of Oceanus Gaia Oceanus "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False afloat on the river of Oceanus and overlooked 6 [' afloat', ' on', ' the', ' river', ' of', ' Ocean', 'us']
+3980 941 Name of mother of x -1 Name of mother of Oceanus Gaia Oceanus "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False afloat on the river of Oceanus and overlooked by 6 [' afloat', ' on', ' the', ' river', ' of', ' Ocean', 'us']
+3981 941 Name of mother of x -1 Name of mother of Oceanus Gaia Oceanus "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False the river of Oceanus and overlooked 4 [' the', ' river', ' of', ' Ocean', 'us']
+3982 941 Name of mother of x -1 Name of mother of Oceanus Gaia Oceanus "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False daughters of the Titans Oceanus and Tethys. The 5 [' daughters', ' of', ' the', ' Titans', ' Ocean', 'us']
+3983 941 Name of mother of x -1 Name of mother of Oceanus Gaia Oceanus "[',' ' the' ' god' ' of' ' the' ' sea' ',' ' and' ' the' ' god' ' of'
+ ' the' ' sea' ',' ' and' ' the' ' god' ' of' ' the' ' sea']" , the god of the sea , and the god of the sea , and the god of the sea False of the Titans Oceanus and Tethys. The 4 [' of', ' the', ' Titans', ' Ocean', 'us']
+3984 942 Name of mother of x -1 Name of mother of Qianlong Emperor Empress Xiaoshengxian Qianlong Emperor "[',' ' the' ' Emperor' ' of' ' China' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' Qing' ' dynasty' '.' '\n' '\n' 'The' ' Emperor' ' of' ' China']" ", the Emperor of China , and the Emperor of the Qing dynasty .
+
+ The Emperor of China" False erected by the Qianlong Emperor (r. 1735 – 96) during 5 [' erected', ' by', ' the', ' Qian', 'long', ' Emperor']
+3985 942 Name of mother of x -1 Name of mother of Qianlong Emperor Empress Xiaoshengxian Qianlong Emperor "[',' ' the' ' Emperor' ' of' ' China' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' Qing' ' dynasty' '.' '\n' '\n' 'The' ' Emperor' ' of' ' China']" ", the Emperor of China , and the Emperor of the Qing dynasty .
+
+ The Emperor of China" False 18th century the Qianlong Emperor of the Qing dynasty, 6 [' 18', 'th', ' century', ' the', ' Qian', 'long', ' Emperor']
+3986 942 Name of mother of x -1 Name of mother of Qianlong Emperor Empress Xiaoshengxian Qianlong Emperor "[',' ' the' ' Emperor' ' of' ' China' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' Qing' ' dynasty' '.' '\n' '\n' 'The' ' Emperor' ' of' ' China']" ", the Emperor of China , and the Emperor of the Qing dynasty .
+
+ The Emperor of China" False complex built by the Qianlong Emperor in anticipation 6 [' complex', ' built', ' by', ' the', ' Qian', 'long', ' Emperor']
+3987 942 Name of mother of x -1 Name of mother of Qianlong Emperor Empress Xiaoshengxian Qianlong Emperor "[',' ' the' ' Emperor' ' of' ' China' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' Qing' ' dynasty' '.' '\n' '\n' 'The' ' Emperor' ' of' ' China']" ", the Emperor of China , and the Emperor of the Qing dynasty .
+
+ The Emperor of China" False was erected by the Qianlong Emperor (r. 1735 – 96) 6 [' was', ' erected', ' by', ' the', ' Qian', 'long', ' Emperor']
+3988 942 Name of mother of x -1 Name of mother of Qianlong Emperor Empress Xiaoshengxian Qianlong Emperor "[',' ' the' ' Emperor' ' of' ' China' ',' ' and' ' the' ' Emperor' ' of'
+ ' the' ' Qing' ' dynasty' '.' '\n' '\n' 'The' ' Emperor' ' of' ' China']" ", the Emperor of China , and the Emperor of the Qing dynasty .
+
+ The Emperor of China" False erected by the Qianlong Emperor (r. 1735 – 96) during 5 [' erected', ' by', ' the', ' Qian', 'long', ' Emperor']
+3989 943 Name of mother of x -1 Name of mother of Wonder Woman Hippolyta Wonder Woman "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of' ' the'
+ ' gods' '.' '\n' '\n' 'The' ' first' ' is' ' the' ' mother' ' of']" ", and the other two are the children of the gods .
+
+ The first is the mother of" False later cast as Wonder Woman in the film's 4 [' later', ' cast', ' as', ' Wonder', ' Woman']
+3990 943 Name of mother of x -1 Name of mother of Wonder Woman Hippolyta Wonder Woman "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of' ' the'
+ ' gods' '.' '\n' '\n' 'The' ' first' ' is' ' the' ' mother' ' of']" ", and the other two are the children of the gods .
+
+ The first is the mother of" False relationship with the Wonder Woman movie in a Whedonesque 4 [' relationship', ' with', ' the', ' Wonder', ' Woman']
+3991 943 Name of mother of x -1 Name of mother of Wonder Woman Hippolyta Wonder Woman "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of' ' the'
+ ' gods' '.' '\n' '\n' 'The' ' first' ' is' ' the' ' mother' ' of']" ", and the other two are the children of the gods .
+
+ The first is the mother of" False Superman, Batman, and Wonder Woman based on their 6 [' Superman', ',', ' Batman', ',', ' and', ' Wonder', ' Woman']
+3992 943 Name of mother of x -1 Name of mother of Wonder Woman Hippolyta Wonder Woman "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of' ' the'
+ ' gods' '.' '\n' '\n' 'The' ' first' ' is' ' the' ' mother' ' of']" ", and the other two are the children of the gods .
+
+ The first is the mother of" False to stand trial. Wonder Woman dethrones her counterpart. 5 [' to', ' stand', ' trial', '.', ' Wonder', ' Woman']
+3993 943 Name of mother of x -1 Name of mother of Wonder Woman Hippolyta Wonder Woman "[',' ' and' ' the' ' other' ' two' ' are' ' the' ' children' ' of' ' the'
+ ' gods' '.' '\n' '\n' 'The' ' first' ' is' ' the' ' mother' ' of']" ", and the other two are the children of the gods .
+
+ The first is the mother of" False 1 ['Wonder', ' Woman']
+3994 944 Name of mother of x -1 Name of mother of Richard Brinsley Sheridan Frances Sheridan Richard Brinsley Sheridan "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' a' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the actor , and his wife , the actress , who was a
+
+ Name of mother of" False and to the playwright Richard Brinsley Sheridan for £ 22,000. They 9 [' and', ' to', ' the', ' play', 'wright', ' Richard', ' Br', 'ins', 'ley', ' Sheridan']
+3995 944 Name of mother of x -1 Name of mother of Richard Brinsley Sheridan Frances Sheridan Richard Brinsley Sheridan "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' a' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the actor , and his wife , the actress , who was a
+
+ Name of mother of" False Gibbon, the playwright Richard Brinsley Sheridan and the painter 10 [' Gib', 'bon', ',', ' the', ' play', 'wright', ' Richard', ' Br', 'ins', 'ley', ' Sheridan']
+3996 944 Name of mother of x -1 Name of mother of Richard Brinsley Sheridan Frances Sheridan Richard Brinsley Sheridan "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' a' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the actor , and his wife , the actress , who was a
+
+ Name of mother of" False Gibbon, the playwright Richard Brinsley Sheridan and the painter Joshua 10 [' Gib', 'bon', ',', ' the', ' play', 'wright', ' Richard', ' Br', 'ins', 'ley', ' Sheridan']
+3997 944 Name of mother of x -1 Name of mother of Richard Brinsley Sheridan Frances Sheridan Richard Brinsley Sheridan "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' a' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the actor , and his wife , the actress , who was a
+
+ Name of mother of" False " provoked the suggestion by Richard Brinsley Sheridan that ""music should" 8 [' provoked', ' the', ' suggestion', ' by', ' Richard', ' Br', 'ins', 'ley', ' Sheridan']
+3998 944 Name of mother of x -1 Name of mother of Richard Brinsley Sheridan Frances Sheridan Richard Brinsley Sheridan "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' the' ' actress' ','
+ ' who' ' was' ' a' '\n' '\n' 'Name' ' of' ' mother' ' of']" ", the actor , and his wife , the actress , who was a
+
+ Name of mother of" False " playwright and MP, Richard Brinsley Sheridan were residents.
+" 9 [' play', 'wright', ' and', ' MP', ',', ' Richard', ' Br', 'ins', 'ley', ' Sheridan']
+3999 946 Name of mother of x -1 Name of mother of Dag Hammarskjöld Agnes Maria Carolina Almquist Dag Hammarskjöld "[',' ' the' ' UN' ' Secretary' ' General' ',' ' who' ' was' ' killed'
+ ' in' ' a' ' plane' ' crash' ' in' ' 1961' '.' '\n' '\n' 'The' ' plane']" ", the UN Secretary General , who was killed in a plane crash in 1961 .
+
+ The plane" False Secretary General Dag Hammarskjöld to try and achieve 7 [' Secretary', ' General', ' Dag', ' Hamm', 'ars', 'kj', 'ö', 'ld']
+4000 946 Name of mother of x -1 Name of mother of Dag Hammarskjöld Agnes Maria Carolina Almquist Dag Hammarskjöld "[',' ' the' ' UN' ' Secretary' ' General' ',' ' who' ' was' ' killed'
+ ' in' ' a' ' plane' ' crash' ' in' ' 1961' '.' '\n' '\n' 'The' ' plane']" ", the UN Secretary General , who was killed in a plane crash in 1961 .
+
+ The plane" False " memorial ===
+" 8 [' memorial', ' ===', 'D', 'ag', ' Hamm', 'ars', 'kj', 'ö', 'ld']
+4001 946 Name of mother of x -1 Name of mother of Dag Hammarskjöld Agnes Maria Carolina Almquist Dag Hammarskjöld "[',' ' the' ' UN' ' Secretary' ' General' ',' ' who' ' was' ' killed'
+ ' in' ' a' ' plane' ' crash' ' in' ' 1961' '.' '\n' '\n' 'The' ' plane']" ", the UN Secretary General , who was killed in a plane crash in 1961 .
+
+ The plane" False the United Nations, Dag Hammarskjöld (1953 – 61), said 9 [' the', ' United', ' Nations', ',', ' Dag', ' Hamm', 'ars', 'kj', 'ö', 'ld']
+4002 946 Name of mother of x -1 Name of mother of Dag Hammarskjöld Agnes Maria Carolina Almquist Dag Hammarskjöld "[',' ' the' ' UN' ' Secretary' ' General' ',' ' who' ' was' ' killed'
+ ' in' ' a' ' plane' ' crash' ' in' ' 1961' '.' '\n' '\n' 'The' ' plane']" ", the UN Secretary General , who was killed in a plane crash in 1961 .
+
+ The plane" False 10, 2008: Dag Hammarskjöld Library, United 9 [' 10', ',', ' 2008', ':', ' Dag', ' Hamm', 'ars', 'kj', 'ö', 'ld']
+4003 946 Name of mother of x -1 Name of mother of Dag Hammarskjöld Agnes Maria Carolina Almquist Dag Hammarskjöld "[',' ' the' ' UN' ' Secretary' ' General' ',' ' who' ' was' ' killed'
+ ' in' ' a' ' plane' ' crash' ' in' ' 1961' '.' '\n' '\n' 'The' ' plane']" ", the UN Secretary General , who was killed in a plane crash in 1961 .
+
+ The plane" False Secretary General Dag Hammarskjöld called the U.S. position 7 [' Secretary', ' General', ' Dag', ' Hamm', 'ars', 'kj', 'ö', 'ld']
+4004 948 Name of mother of x -1 Name of mother of Pius XI Teresa Ratti Pius XI "[',' ' the' ' Pope' ',' ' and' ' the' '\n' '\n' 'Pope' ' P' 'ius' ' XI'
+ ',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the']" ", the Pope , and the
+
+ Pope P ius XI , who was a great friend of the" False Kulwicki attended Pius XI High School, 6 [' Kul', 'wick', 'i', ' attended', ' P', 'ius', ' XI']
+4005 948 Name of mother of x -1 Name of mother of Pius XI Teresa Ratti Pius XI "[',' ' the' ' Pope' ',' ' and' ' the' '\n' '\n' 'Pope' ' P' 'ius' ' XI'
+ ',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the']" ", the Pope , and the
+
+ Pope P ius XI , who was a great friend of the" False 2 ['P', 'ius', ' XI']
+4006 948 Name of mother of x -1 Name of mother of Pius XI Teresa Ratti Pius XI "[',' ' the' ' Pope' ',' ' and' ' the' '\n' '\n' 'Pope' ' P' 'ius' ' XI'
+ ',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the']" ", the Pope , and the
+
+ Pope P ius XI , who was a great friend of the" False Kulwicki attended Pius XI High School, a Roman 6 [' Kul', 'wick', 'i', ' attended', ' P', 'ius', ' XI']
+4007 948 Name of mother of x -1 Name of mother of Pius XI Teresa Ratti Pius XI "[',' ' the' ' Pope' ',' ' and' ' the' '\n' '\n' 'Pope' ' P' 'ius' ' XI'
+ ',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the']" ", the Pope , and the
+
+ Pope P ius XI , who was a great friend of the" False 2 ['P', 'ius', ' XI']
+4008 948 Name of mother of x -1 Name of mother of Pius XI Teresa Ratti Pius XI "[',' ' the' ' Pope' ',' ' and' ' the' '\n' '\n' 'Pope' ' P' 'ius' ' XI'
+ ',' ' who' ' was' ' a' ' great' ' friend' ' of' ' the']" ", the Pope , and the
+
+ Pope P ius XI , who was a great friend of the" False 2 ['P', 'ius', ' XI']
+4009 949 Name of mother of x -1 Name of mother of Richard Francis Burton Martha Baker Richard Francis Burton "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' on' ' the']" ", the author of the book , and the
+
+ The book is a collection of essays on the" False " Kaaba in 1853, Richard Francis Burton also noted that:
+" 8 [' Ka', 'aba', ' in', ' 18', '53', ',', ' Richard', ' Francis', ' Burton']
+4010 949 Name of mother of x -1 Name of mother of Richard Francis Burton Martha Baker Richard Francis Burton "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' on' ' the']" ", the author of the book , and the
+
+ The book is a collection of essays on the" False particularly the works of Richard Francis Burton and Percy Bysshe 6 [' particularly', ' the', ' works', ' of', ' Richard', ' Francis', ' Burton']
+4011 949 Name of mother of x -1 Name of mother of Richard Francis Burton Martha Baker Richard Francis Burton "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' on' ' the']" ", the author of the book , and the
+
+ The book is a collection of essays on the" False particularly the works of Richard Francis Burton and Percy Bysshe 6 [' particularly', ' the', ' works', ' of', ' Richard', ' Francis', ' Burton']
+4012 949 Name of mother of x -1 Name of mother of Richard Francis Burton Martha Baker Richard Francis Burton "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' on' ' the']" ", the author of the book , and the
+
+ The book is a collection of essays on the" False for instance Richard Francis Burton who mentions in 4 [' for', ' instance', ' Richard', ' Francis', ' Burton']
+4013 949 Name of mother of x -1 Name of mother of Richard Francis Burton Martha Baker Richard Francis Burton "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' '\n' '\n'
+ 'The' ' book' ' is' ' a' ' collection' ' of' ' essays' ' on' ' the']" ", the author of the book , and the
+
+ The book is a collection of essays on the" False though explorer Sir Richard Francis Burton doubted that they 5 [' though', ' explorer', ' Sir', ' Richard', ' Francis', ' Burton']
+4014 950 Name of mother of x -1 Name of mother of Alexander Scriabin Lyubov Petrovna Shchetinina Alexander Scriabin "[',' ' the' ' composer' ',' ' and' ' the' ' composer' ""'s"" ' mother' ','
+ ' the' ' mother' ' of' ' the' ' composer' ""'s"" ' mother' ',' ' and'
+ ' the']" , the composer , and the composer 's mother , the mother of the composer 's mother , and the False and idioms. Alexander Scriabin was devoted to the 7 [' and', ' idi', 'oms', '.', ' Alexander', ' Sc', 'ri', 'abin']
+4015 953 Name of mother of x -1 Name of mother of Gilbert Stuart Elizabeth Stuart Gilbert Stuart "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Gilbert' ' Stuart' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Gilbert Stuart , the painter , and the" False public as the Gilbert Stuart Birthplace and Museum. 4 [' public', ' as', ' the', ' Gilbert', ' Stuart']
+4016 953 Name of mother of x -1 Name of mother of Gilbert Stuart Elizabeth Stuart Gilbert Stuart "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Gilbert' ' Stuart' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Gilbert Stuart , the painter , and the" False his career, Gilbert Stuart produced portraits 4 [' his', ' career', ',', ' Gilbert', ' Stuart']
+4017 953 Name of mother of x -1 Name of mother of Gilbert Stuart Elizabeth Stuart Gilbert Stuart "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Gilbert' ' Stuart' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Gilbert Stuart , the painter , and the" False educators, and musicians. Gilbert Stuart is found on 6 [' educators', ',', ' and', ' musicians', '.', ' Gilbert', ' Stuart']
+4018 953 Name of mother of x -1 Name of mother of Gilbert Stuart Elizabeth Stuart Gilbert Stuart "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Gilbert' ' Stuart' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Gilbert Stuart , the painter , and the" False " Gilbert Stuart =
+" 1 [' Gilbert', ' Stuart']
+4019 953 Name of mother of x -1 Name of mother of Gilbert Stuart Elizabeth Stuart Gilbert Stuart "[',' ' the' ' painter' ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother'
+ ' of' ' Gilbert' ' Stuart' ',' ' the' ' painter' ',' ' and' ' the']" ", the painter , and the
+
+ Name of mother of Gilbert Stuart , the painter , and the" False his career, Gilbert Stuart produced portraits 4 [' his', ' career', ',', ' Gilbert', ' Stuart']
+4020 954 Name of mother of x -1 Name of mother of Philip de László Johanna Balogh Philip de László "[',' ' the' ' Hungarian' '-' 'born' ' American' ' composer' ',' ' pian'
+ 'ist' ',' ' and' ' conductor' '.' '\n' '\n' 'References' '\n' '\n'
+ 'External']" ", the Hungarian - born American composer , pian ist , and conductor .
+
+ References
+
+ External" False features a painting by Philip de László of Duchess Eileen, 9 [' features', ' a', ' painting', ' by', ' Philip', ' de', ' L', 'ás', 'zl', 'ó']
+4021 955 Name of mother of x -1 Name of mother of Utada Hikaru Keiko Fuji Utada Hikaru "[',' ' the' ' daughter' ' of' ' the' ' late' ' Emperor' ' Hiro' 'h' 'ito'
+ ',' ' and' ' the' ' Empress' ' K' 'ō' 'jun' ',' ' the' ' mother']" , the daughter of the late Emperor Hiro h ito , and the Empress K ō jun , the mother False compilation releases: Utada Hikaru Single Collection 6 [' compilation', ' releases', ':', ' Ut', 'ada', ' Hik', 'aru']
+4022 955 Name of mother of x -1 Name of mother of Utada Hikaru Keiko Fuji Utada Hikaru "[',' ' the' ' daughter' ' of' ' the' ' late' ' Emperor' ' Hiro' 'h' 'ito'
+ ',' ' and' ' the' ' Empress' ' K' 'ō' 'jun' ',' ' the' ' mother']" , the daughter of the late Emperor Hiro h ito , and the Empress K ō jun , the mother False hits album, Utada Hikaru Single Collection 6 [' hits', ' album', ',', ' Ut', 'ada', ' Hik', 'aru']
+4023 955 Name of mother of x -1 Name of mother of Utada Hikaru Keiko Fuji Utada Hikaru "[',' ' the' ' daughter' ' of' ' the' ' late' ' Emperor' ' Hiro' 'h' 'ito'
+ ',' ' and' ' the' ' Empress' ' K' 'ō' 'jun' ',' ' the' ' mother']" , the daughter of the late Emperor Hiro h ito , and the Empress K ō jun , the mother False releases: Utada Hikaru Single Collection 5 [' releases', ':', ' Ut', 'ada', ' Hik', 'aru']
+4024 955 Name of mother of x -1 Name of mother of Utada Hikaru Keiko Fuji Utada Hikaru "[',' ' the' ' daughter' ' of' ' the' ' late' ' Emperor' ' Hiro' 'h' 'ito'
+ ',' ' and' ' the' ' Empress' ' K' 'ō' 'jun' ',' ' the' ' mother']" , the daughter of the late Emperor Hiro h ito , and the Empress K ō jun , the mother False recording artist Utada Hikaru for her fourth studio 5 [' recording', ' artist', ' Ut', 'ada', ' Hik', 'aru']
+4025 955 Name of mother of x -1 Name of mother of Utada Hikaru Keiko Fuji Utada Hikaru "[',' ' the' ' daughter' ' of' ' the' ' late' ' Emperor' ' Hiro' 'h' 'ito'
+ ',' ' and' ' the' ' Empress' ' K' 'ō' 'jun' ',' ' the' ' mother']" , the daughter of the late Emperor Hiro h ito , and the Empress K ō jun , the mother False compilation releases: Utada Hikaru Single Collection 6 [' compilation', ' releases', ':', ' Ut', 'ada', ' Hik', 'aru']
+4026 956 Name of mother of x -1 Name of mother of Alfred Nobel Karolina Andriette Ahlsell Alfred Nobel "[',' ' the' ' inventor' ' of' ' dynam' 'ite' ',' ' and' ' the' ' inventor'
+ ' of' ' the' ' dynam' 'ite' ' bomb' '.' '\n' '\n' 'The' ' Nobel']" ", the inventor of dynam ite , and the inventor of the dynam ite bomb .
+
+ The Nobel" False a panel each on Alfred Nobel and the Nobel 5 [' a', ' panel', ' each', ' on', ' Alfred', ' Nobel']
+4027 956 Name of mother of x -1 Name of mother of Alfred Nobel Karolina Andriette Ahlsell Alfred Nobel "[',' ' the' ' inventor' ' of' ' dynam' 'ite' ',' ' and' ' the' ' inventor'
+ ' of' ' the' ' dynam' 'ite' ' bomb' '.' '\n' '\n' 'The' ' Nobel']" ", the inventor of dynam ite , and the inventor of the dynam ite bomb .
+
+ The Nobel" False 3 ['A', 'lf', 'red', ' Nobel']
+4028 956 Name of mother of x -1 Name of mother of Alfred Nobel Karolina Andriette Ahlsell Alfred Nobel "[',' ' the' ' inventor' ' of' ' dynam' 'ite' ',' ' and' ' the' ' inventor'
+ ' of' ' the' ' dynam' 'ite' ' bomb' '.' '\n' '\n' 'The' ' Nobel']" ", the inventor of dynam ite , and the inventor of the dynam ite bomb .
+
+ The Nobel" False " and a panel each on Alfred Nobel and the Nobel Prizes.
+" 6 [' and', ' a', ' panel', ' each', ' on', ' Alfred', ' Nobel']
+4029 956 Name of mother of x -1 Name of mother of Alfred Nobel Karolina Andriette Ahlsell Alfred Nobel "[',' ' the' ' inventor' ' of' ' dynam' 'ite' ',' ' and' ' the' ' inventor'
+ ' of' ' the' ' dynam' 'ite' ' bomb' '.' '\n' '\n' 'The' ' Nobel']" ", the inventor of dynam ite , and the inventor of the dynam ite bomb .
+
+ The Nobel" False years and that Alfred Nobel should be commemorated 4 [' years', ' and', ' that', ' Alfred', ' Nobel']
+4030 956 Name of mother of x -1 Name of mother of Alfred Nobel Karolina Andriette Ahlsell Alfred Nobel "[',' ' the' ' inventor' ' of' ' dynam' 'ite' ',' ' and' ' the' ' inventor'
+ ' of' ' the' ' dynam' 'ite' ' bomb' '.' '\n' '\n' 'The' ' Nobel']" ", the inventor of dynam ite , and the inventor of the dynam ite bomb .
+
+ The Nobel" False 3 ['A', 'lf', 'red', ' Nobel']
+4031 957 Name of mother of x -1 Name of mother of Giovanni Pico della Mirandola Giulia Boiardo Giovanni Pico della Mirandola "[',' ' the' ' great' ' human' 'ist' ',' ' who' ' was' ' born' ' in' ' 14'
+ '63' ',' ' and' ' died' ' in' ' 14' '94' '.' '\n']" ", the great human ist , who was born in 14 63 , and died in 14 94 .
+" False Renaissance humanism Giovanni Pico della Mirandola wrote the famous 10 [' Renaissance', ' human', 'ism', ' Giovanni', ' P', 'ico', ' de', 'lla', ' Mir', 'and', 'ola']
+4032 958 Name of mother of x -1 Name of mother of Urban VIII Camilla Barbadori Urban VIII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False Barberini as Pope Urban VIII in 1623. Barberini 5 [' Barber', 'ini', ' as', ' Pope', ' Urban', ' VIII']
+4033 958 Name of mother of x -1 Name of mother of Urban VIII Camilla Barbadori Urban VIII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False Galileo put the words of Urban VIII into the mouth 6 [' Galileo', ' put', ' the', ' words', ' of', ' Urban', ' VIII']
+4034 958 Name of mother of x -1 Name of mother of Urban VIII Camilla Barbadori Urban VIII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False loggia. Pope Urban VIII had it enclosed 6 [' log', 'g', 'ia', '.', ' Pope', ' Urban', ' VIII']
+4035 958 Name of mother of x -1 Name of mother of Urban VIII Camilla Barbadori Urban VIII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False opposition against Pope Urban VIII (1623 – 1644), even 4 [' opposition', ' against', ' Pope', ' Urban', ' VIII']
+4036 958 Name of mother of x -1 Name of mother of Urban VIII Camilla Barbadori Urban VIII "[',' ' the' ' Pope' ',' ' and' ' the' ' Pope' ""'s"" ' mother' ',' ' the'
+ '\n' '\n' 'Pope' ""'s"" ' mother' ',' ' the' ' Pope' ""'s""]" ", the Pope , and the Pope 's mother , the
+
+ Pope 's mother , the Pope 's" False Barberini as Pope Urban VIII in 1623. Barberini 5 [' Barber', 'ini', ' as', ' Pope', ' Urban', ' VIII']
+4037 960 Name of mother of x -1 Name of mother of Cole Porter Kate Porter Cole Porter "[',' ' the' ' composer' ' of' ' the' ' song' ',' ' and' ' the' ' song'
+ ' was' ' written' ' by' ' Cole' ' Porter' '.' '\n' '\n' 'The' ' song']" ", the composer of the song , and the song was written by Cole Porter .
+
+ The song" False cover of the 1936 Cole Porter song, recorded 5 [' cover', ' of', ' the', ' 1936', ' Cole', ' Porter']
+4038 960 Name of mother of x -1 Name of mother of Cole Porter Kate Porter Cole Porter "[',' ' the' ' composer' ' of' ' the' ' song' ',' ' and' ' the' ' song'
+ ' was' ' written' ' by' ' Cole' ' Porter' '.' '\n' '\n' 'The' ' song']" ", the composer of the song , and the song was written by Cole Porter .
+
+ The song" False " Fall in Love"" by Cole Porter as he realises that" 6 "[' Fall', ' in', ' Love', '""', ' by', ' Cole', ' Porter']"
+4039 960 Name of mother of x -1 Name of mother of Cole Porter Kate Porter Cole Porter "[',' ' the' ' composer' ' of' ' the' ' song' ',' ' and' ' the' ' song'
+ ' was' ' written' ' by' ' Cole' ' Porter' '.' '\n' '\n' 'The' ' song']" ", the composer of the song , and the song was written by Cole Porter .
+
+ The song" False staging dances for Cole Porter ’ s Broadway 4 [' staging', ' dances', ' for', ' Cole', ' Porter']
+4040 960 Name of mother of x -1 Name of mother of Cole Porter Kate Porter Cole Porter "[',' ' the' ' composer' ' of' ' the' ' song' ',' ' and' ' the' ' song'
+ ' was' ' written' ' by' ' Cole' ' Porter' '.' '\n' '\n' 'The' ' song']" ", the composer of the song , and the song was written by Cole Porter .
+
+ The song" False " Say Goodbye"" by Cole Porter is featured in the" 5 "[' Say', ' Goodbye', '""', ' by', ' Cole', ' Porter']"
+4041 960 Name of mother of x -1 Name of mother of Cole Porter Kate Porter Cole Porter "[',' ' the' ' composer' ' of' ' the' ' song' ',' ' and' ' the' ' song'
+ ' was' ' written' ' by' ' Cole' ' Porter' '.' '\n' '\n' 'The' ' song']" ", the composer of the song , and the song was written by Cole Porter .
+
+ The song" False " reference to the Cole Porter song ""At Long Last" 4 [' reference', ' to', ' the', ' Cole', ' Porter']
+4042 961 Name of mother of x -1 Name of mother of Vincent d'Indy Mathilde d'Indy Vincent d'Indy "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Vincent' ' d' ""'""
+ 'Ind' 'y' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Vincent d ' Ind y , the
+
+ Name of mother" False French composer Vincent d'Indy directed a concert 6 "[' French', ' composer', ' Vincent', ' d', ""'"", 'Ind', 'y']"
+4043 961 Name of mother of x -1 Name of mother of Vincent d'Indy Mathilde d'Indy Vincent d'Indy "[',' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' Vincent' ' d' ""'""
+ 'Ind' 'y' ',' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", the
+
+ Name of mother of Vincent d ' Ind y , the
+
+ Name of mother" False French composer Vincent d'Indy directed a concert 6 "[' French', ' composer', ' Vincent', ' d', ""'"", 'Ind', 'y']"
+4044 962 Name of mother of x -1 Name of mother of Isadora Duncan Mary Dora Gray Isadora Duncan "[',' ' the' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the' ' dancer'
+ ""'s"" ' daughter' ',' ' the' ' dancer' ""'s"" ' daughter' ',' ' and' ' the']" , the dancer , and the mother of the dancer 's daughter , the dancer 's daughter , and the False American dancer Isadora Duncan in 1900, attempted 5 [' American', ' dancer', ' Is', 'ad', 'ora', ' Duncan']
+4045 962 Name of mother of x -1 Name of mother of Isadora Duncan Mary Dora Gray Isadora Duncan "[',' ' the' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the' ' dancer'
+ ""'s"" ' daughter' ',' ' the' ' dancer' ""'s"" ' daughter' ',' ' and' ' the']" , the dancer , and the mother of the dancer 's daughter , the dancer 's daughter , and the False also mentored young Isadora Duncan who later described 7 [' also', ' ment', 'ored', ' young', ' Is', 'ad', 'ora', ' Duncan']
+4046 962 Name of mother of x -1 Name of mother of Isadora Duncan Mary Dora Gray Isadora Duncan "[',' ' the' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the' ' dancer'
+ ""'s"" ' daughter' ',' ' the' ' dancer' ""'s"" ' daughter' ',' ' and' ' the']" , the dancer , and the mother of the dancer 's daughter , the dancer 's daughter , and the False American dancer Isadora Duncan in 1900, attempted 5 [' American', ' dancer', ' Is', 'ad', 'ora', ' Duncan']
+4047 962 Name of mother of x -1 Name of mother of Isadora Duncan Mary Dora Gray Isadora Duncan "[',' ' the' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the' ' dancer'
+ ""'s"" ' daughter' ',' ' the' ' dancer' ""'s"" ' daughter' ',' ' and' ' the']" , the dancer , and the mother of the dancer 's daughter , the dancer 's daughter , and the False Choreographers Isadora Duncan and Martha Graham 6 [' Ch', 'ore', 'ographers', ' Is', 'ad', 'ora', ' Duncan']
+4048 962 Name of mother of x -1 Name of mother of Isadora Duncan Mary Dora Gray Isadora Duncan "[',' ' the' ' dancer' ',' ' and' ' the' ' mother' ' of' ' the' ' dancer'
+ ""'s"" ' daughter' ',' ' the' ' dancer' ""'s"" ' daughter' ',' ' and' ' the']" , the dancer , and the mother of the dancer 's daughter , the dancer 's daughter , and the False American dancer Isadora Duncan in 1900, attempted 5 [' American', ' dancer', ' Is', 'ad', 'ora', ' Duncan']
+4049 963 Name of mother of x -1 Name of mother of Ferruccio Busoni Anna Weiss Ferruccio Busoni "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' pian' 'ist' ',' ' and' ' composer' ',' ' who' ' was']" , the composer , and his wife , the pian ist , pian ist , and composer , who was False " Busoni =
+" 8 [' Bus', 'oni', ' =', 'Fer', 'ru', 'cc', 'io', ' Bus', 'oni']
+4050 963 Name of mother of x -1 Name of mother of Ferruccio Busoni Anna Weiss Ferruccio Busoni "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' pian' 'ist' ',' ' and' ' composer' ',' ' who' ' was']" , the composer , and his wife , the pian ist , pian ist , and composer , who was False 5 ['Fer', 'ru', 'cc', 'io', ' Bus', 'oni']
+4051 963 Name of mother of x -1 Name of mother of Ferruccio Busoni Anna Weiss Ferruccio Busoni "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' pian' 'ist' ',' ' and' ' composer' ',' ' who' ' was']" , the composer , and his wife , the pian ist , pian ist , and composer , who was False so excited Ferruccio Busoni that he drafted 7 [' so', ' excited', ' Fer', 'ru', 'cc', 'io', ' Bus', 'oni']
+4052 963 Name of mother of x -1 Name of mother of Ferruccio Busoni Anna Weiss Ferruccio Busoni "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' pian' 'ist' ',' ' and' ' composer' ',' ' who' ' was']" , the composer , and his wife , the pian ist , pian ist , and composer , who was False recording by Ferruccio Busoni of 1922, in 7 [' recording', ' by', ' Fer', 'ru', 'cc', 'io', ' Bus', 'oni']
+4053 963 Name of mother of x -1 Name of mother of Ferruccio Busoni Anna Weiss Ferruccio Busoni "[',' ' the' ' composer' ',' ' and' ' his' ' wife' ',' ' the' ' pian' 'ist'
+ ',' ' pian' 'ist' ',' ' and' ' composer' ',' ' who' ' was']" , the composer , and his wife , the pian ist , pian ist , and composer , who was False further notes that Ferruccio Busoni repeated the comparison 8 [' further', ' notes', ' that', ' Fer', 'ru', 'cc', 'io', ' Bus', 'oni']
+4054 965 Name of mother of x -1 Name of mother of Nero Agrippina the Younger Nero "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Nero , and the son of the Emperor Nero , and the son of False 1 ['N', 'ero']
+4055 965 Name of mother of x -1 Name of mother of Nero Agrippina the Younger Nero "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Nero , and the son of the Emperor Nero , and the son of False Tiridates I's arrival, Nero came to the Forum 6 "[' Tir', 'idates', ' I', ""'s"", ' arrival', ',', ' Nero']"
+4056 965 Name of mother of x -1 Name of mother of Nero Agrippina the Younger Nero "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Nero , and the son of the Emperor Nero , and the son of False succeeded by his stepson Nero. The Parthian encroachment 5 [' succeeded', ' by', ' his', ' step', 'son', ' Nero']
+4057 965 Name of mother of x -1 Name of mother of Nero Agrippina the Younger Nero "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Nero , and the son of the Emperor Nero , and the son of False sale of slaves, which Nero shifted from 5 [' sale', ' of', ' slaves', ',', ' which', ' Nero']
+4058 965 Name of mother of x -1 Name of mother of Nero Agrippina the Younger Nero "[',' ' the' ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the'
+ ' son' ' of' ' the' ' Emperor' ' Nero' ',' ' and' ' the' ' son' ' of']" , the son of the Emperor Nero , and the son of the Emperor Nero , and the son of False while Mark Edward Nero of About.com 3 [' while', ' Mark', ' Edward', ' Nero']
+4059 966 Name of mother of x -1 Name of mother of Jamie Lee Curtis Janet Leigh Jamie Lee Curtis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'H'
+ 'alloween']" , the actress who played the role of the mother of the bride in the movie � � H alloween False and actress Jamie Lee Curtis is both his 4 [' and', ' actress', ' Jamie', ' Lee', ' Curtis']
+4060 966 Name of mother of x -1 Name of mother of Jamie Lee Curtis Janet Leigh Jamie Lee Curtis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'H'
+ 'alloween']" , the actress who played the role of the mother of the bride in the movie � � H alloween False alongside Jamie Lee Curtis in the 2003 remake 3 [' alongside', ' Jamie', ' Lee', ' Curtis']
+4061 966 Name of mother of x -1 Name of mother of Jamie Lee Curtis Janet Leigh Jamie Lee Curtis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'H'
+ 'alloween']" , the actress who played the role of the mother of the bride in the movie � � H alloween False her life, by Jamie Lee Curtis in Death of a 6 [' her', ' life', ',', ' by', ' Jamie', ' Lee', ' Curtis']
+4062 966 Name of mother of x -1 Name of mother of Jamie Lee Curtis Janet Leigh Jamie Lee Curtis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'H'
+ 'alloween']" , the actress who played the role of the mother of the bride in the movie � � H alloween False Scacchi and Jamie Lee Curtis in Diane Kurys's 6 [' Sc', 'ac', 'chi', ' and', ' Jamie', ' Lee', ' Curtis']
+4063 966 Name of mother of x -1 Name of mother of Jamie Lee Curtis Janet Leigh Jamie Lee Curtis "[',' ' the' ' actress' ' who' ' played' ' the' ' role' ' of' ' the'
+ ' mother' ' of' ' the' ' bride' ' in' ' the' ' movie' ' �' '�' 'H'
+ 'alloween']" , the actress who played the role of the mother of the bride in the movie � � H alloween False godfather, and actress Jamie Lee Curtis is both his literal 7 [' god', 'father', ',', ' and', ' actress', ' Jamie', ' Lee', ' Curtis']
+4064 967 Name of mother of x -1 Name of mother of Humphrey Bogart Maud Humphrey Humphrey Bogart "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lauren' ' Bac' 'all'
+ ',' ' who' ' was' ' the' ' mother' ' of' ' his' ' daughter' ',']" , the actor , and his wife , Lauren Bac all , who was the mother of his daughter , False such as the Rat Pack, Humphrey Bogart and Marlene Dietrich. 9 [' such', ' as', ' the', ' Rat', ' Pack', ',', ' Humph', 'rey', ' Bog', 'art']
+4065 967 Name of mother of x -1 Name of mother of Humphrey Bogart Maud Humphrey Humphrey Bogart "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lauren' ' Bac' 'all'
+ ',' ' who' ' was' ' the' ' mother' ' of' ' his' ' daughter' ',']" , the actor , and his wife , Lauren Bac all , who was the mother of his daughter , False characters played by Humphrey Bogart and Ingrid Bergman, 6 [' characters', ' played', ' by', ' Humph', 'rey', ' Bog', 'art']
+4066 967 Name of mother of x -1 Name of mother of Humphrey Bogart Maud Humphrey Humphrey Bogart "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lauren' ' Bac' 'all'
+ ',' ' who' ' was' ' the' ' mother' ' of' ' his' ' daughter' ',']" , the actor , and his wife , Lauren Bac all , who was the mother of his daughter , False from the club. Humphrey Bogart was banned after 7 [' from', ' the', ' club', '.', ' Humph', 'rey', ' Bog', 'art']
+4067 967 Name of mother of x -1 Name of mother of Humphrey Bogart Maud Humphrey Humphrey Bogart "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lauren' ' Bac' 'all'
+ ',' ' who' ' was' ' the' ' mother' ' of' ' his' ' daughter' ',']" , the actor , and his wife , Lauren Bac all , who was the mother of his daughter , False setting looks like a Humphrey Bogart movie, further 7 [' setting', ' looks', ' like', ' a', ' Humph', 'rey', ' Bog', 'art']
+4068 967 Name of mother of x -1 Name of mother of Humphrey Bogart Maud Humphrey Humphrey Bogart "[',' ' the' ' actor' ',' ' and' ' his' ' wife' ',' ' Lauren' ' Bac' 'all'
+ ',' ' who' ' was' ' the' ' mother' ' of' ' his' ' daughter' ',']" , the actor , and his wife , Lauren Bac all , who was the mother of his daughter , False reference to the Humphrey Bogart and Ingrid Bergman 6 [' reference', ' to', ' the', ' Humph', 'rey', ' Bog', 'art']
+4069 970 Name of mother of x -1 Name of mother of Charles Lindbergh Evangeline Lodge Land Charles Lindbergh "[',' ' the' ' av' 'iator' ',' ' and' ' his' ' wife' ',' ' Anne' ' Morrow'
+ ' Lind' 'ber' 'gh' ',' ' who' ' was' ' the' ' first' ' woman']" , the av iator , and his wife , Anne Morrow Lind ber gh , who was the first woman False investigation. He selected Charles Lindbergh and Paul Tibbets 7 [' investigation', '.', ' He', ' selected', ' Charles', ' Lind', 'ber', 'gh']
+4070 970 Name of mother of x -1 Name of mother of Charles Lindbergh Evangeline Lodge Land Charles Lindbergh "[',' ' the' ' av' 'iator' ',' ' and' ' his' ' wife' ',' ' Anne' ' Morrow'
+ ' Lind' 'ber' 'gh' ',' ' who' ' was' ' the' ' first' ' woman']" , the av iator , and his wife , Anne Morrow Lind ber gh , who was the first woman False D-UKYM. That same day, Charles Lindbergh was visiting 13 [' D', '-', 'UK', 'Y', 'M', '.', ' That', ' same', ' day', ',', ' Charles', ' Lind', 'ber', 'gh']
+4071 970 Name of mother of x -1 Name of mother of Charles Lindbergh Evangeline Lodge Land Charles Lindbergh "[',' ' the' ' av' 'iator' ',' ' and' ' his' ' wife' ',' ' Anne' ' Morrow'
+ ' Lind' 'ber' 'gh' ',' ' who' ' was' ' the' ' first' ' woman']" , the av iator , and his wife , Anne Morrow Lind ber gh , who was the first woman False biography of aviator Charles Lindbergh was published 7 [' biography', ' of', ' av', 'iator', ' Charles', ' Lind', 'ber', 'gh']
+4072 970 Name of mother of x -1 Name of mother of Charles Lindbergh Evangeline Lodge Land Charles Lindbergh "[',' ' the' ' av' 'iator' ',' ' and' ' his' ' wife' ',' ' Anne' ' Morrow'
+ ' Lind' 'ber' 'gh' ',' ' who' ' was' ' the' ' first' ' woman']" , the av iator , and his wife , Anne Morrow Lind ber gh , who was the first woman False investigation. He selected Charles Lindbergh and Paul Tibbets 7 [' investigation', '.', ' He', ' selected', ' Charles', ' Lind', 'ber', 'gh']
+4073 970 Name of mother of x -1 Name of mother of Charles Lindbergh Evangeline Lodge Land Charles Lindbergh "[',' ' the' ' av' 'iator' ',' ' and' ' his' ' wife' ',' ' Anne' ' Morrow'
+ ' Lind' 'ber' 'gh' ',' ' who' ' was' ' the' ' first' ' woman']" , the av iator , and his wife , Anne Morrow Lind ber gh , who was the first woman False biography of aviator Charles Lindbergh was published in 7 [' biography', ' of', ' av', 'iator', ' Charles', ' Lind', 'ber', 'gh']
+4074 972 Name of mother of x -1 Name of mother of Max von Sydow Maria Margareta, Friherrinna Rappe Max von Sydow "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False precog Agatha, and Max von Sydow as Anderton's 9 [' prec', 'og', ' Ag', 'atha', ',', ' and', ' Max', ' von', ' Syd', 'ow']
+4075 972 Name of mother of x -1 Name of mother of Max von Sydow Maria Margareta, Friherrinna Rappe Max von Sydow "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False Lamar 4 [' Lama', 'Max', ' von', ' Syd', 'ow']
+4076 972 Name of mother of x -1 Name of mother of Max von Sydow Maria Margareta, Friherrinna Rappe Max von Sydow "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False 3 ['Max', ' von', ' Syd', 'ow']
+4077 972 Name of mother of x -1 Name of mother of Max von Sydow Maria Margareta, Friherrinna Rappe Max von Sydow "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False ... joke saying that Max von Sydow should play him. 7 ['...', ' joke', ' saying', ' that', ' Max', ' von', ' Syd', 'ow']
+4078 972 Name of mother of x -1 Name of mother of Max von Sydow Maria Margareta, Friherrinna Rappe Max von Sydow "[',' ' who' ' was' ' a' ' friend' ' of' ' the' ' family' ',' ' and' ' who'
+ ' had' ' been' ' a' ' friend' ' of' ' the' ' family' ' for' ' years']" , who was a friend of the family , and who had been a friend of the family for years False same route came Max von Sydow as Ernst Stavro 6 [' same', ' route', ' came', ' Max', ' von', ' Syd', 'ow']
+4079 973 Name of mother of x -1 Name of mother of L. Ron Hubbard Ledora May Hubbard L. Ron Hubbard "[',' ' the' ' founder' ' of' ' Scientology' ',' ' and' ' the' ' founder'
+ ' of' ' Dian' 'etics' '.' '\n' '\n' 'The' ' Church' ' of' ' Scientology'
+ ' is']" ", the founder of Scientology , and the founder of Dian etics .
+
+ The Church of Scientology is" False true story of L. Ron Hubbard is much more 6 [' true', ' story', ' of', ' L', '.', ' Ron', ' Hubbard']
+4080 973 Name of mother of x -1 Name of mother of L. Ron Hubbard Ledora May Hubbard L. Ron Hubbard "[',' ' the' ' founder' ' of' ' Scientology' ',' ' and' ' the' ' founder'
+ ' of' ' Dian' 'etics' '.' '\n' '\n' 'The' ' Church' ' of' ' Scientology'
+ ' is']" ", the founder of Scientology , and the founder of Dian etics .
+
+ The Church of Scientology is" False copyrighted by the L. Ron Hubbard Library. The album 6 [' copyrighted', ' by', ' the', ' L', '.', ' Ron', ' Hubbard']
+4081 973 Name of mother of x -1 Name of mother of L. Ron Hubbard Ledora May Hubbard L. Ron Hubbard "[',' ' the' ' founder' ' of' ' Scientology' ',' ' and' ' the' ' founder'
+ ' of' ' Dian' 'etics' '.' '\n' '\n' 'The' ' Church' ' of' ' Scientology'
+ ' is']" ", the founder of Scientology , and the founder of Dian etics .
+
+ The Church of Scientology is" False well-received novels by L. Ron Hubbard and short stories 8 [' well', '-', 'received', ' novels', ' by', ' L', '.', ' Ron', ' Hubbard']
+4082 973 Name of mother of x -1 Name of mother of L. Ron Hubbard Ledora May Hubbard L. Ron Hubbard "[',' ' the' ' founder' ' of' ' Scientology' ',' ' and' ' the' ' founder'
+ ' of' ' Dian' 'etics' '.' '\n' '\n' 'The' ' Church' ' of' ' Scientology'
+ ' is']" ", the founder of Scientology , and the founder of Dian etics .
+
+ The Church of Scientology is" False predicted, would win L. Ron Hubbard the Nobel Peace Prize. 7 [' predicted', ',', ' would', ' win', ' L', '.', ' Ron', ' Hubbard']
+4083 973 Name of mother of x -1 Name of mother of L. Ron Hubbard Ledora May Hubbard L. Ron Hubbard "[',' ' the' ' founder' ' of' ' Scientology' ',' ' and' ' the' ' founder'
+ ' of' ' Dian' 'etics' '.' '\n' '\n' 'The' ' Church' ' of' ' Scientology'
+ ' is']" ", the founder of Scientology , and the founder of Dian etics .
+
+ The Church of Scientology is" False Scientology founder L. Ron Hubbard included space opera 5 [' Scientology', ' founder', ' L', '.', ' Ron', ' Hubbard']
+4084 974 Name of mother of x -1 Name of mother of Andrei Sakharov Yekaterina Sakharova Andrei Sakharov "[',' ' the' ' Soviet' ' diss' 'ident' ' and' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' who' ' was' ' arrested' ' in' ' the' ' Soviet' ' Union'
+ ' in' ' 1967']" , the Soviet diss ident and Nobel Peace Prize winner , who was arrested in the Soviet Union in 1967 False including Andrei Sakharov and Alexander Solzhenitsyn. 5 [' including', ' Andre', 'i', ' Sak', 'har', 'ov']
+4085 974 Name of mother of x -1 Name of mother of Andrei Sakharov Yekaterina Sakharova Andrei Sakharov "[',' ' the' ' Soviet' ' diss' 'ident' ' and' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' who' ' was' ' arrested' ' in' ' the' ' Soviet' ' Union'
+ ' in' ' 1967']" , the Soviet diss ident and Nobel Peace Prize winner , who was arrested in the Soviet Union in 1967 False dissidents, such as Andrei Sakharov and Aleksandr Solzhenitsyn. 8 [' dissidents', ',', ' such', ' as', ' Andre', 'i', ' Sak', 'har', 'ov']
+4086 974 Name of mother of x -1 Name of mother of Andrei Sakharov Yekaterina Sakharova Andrei Sakharov "[',' ' the' ' Soviet' ' diss' 'ident' ' and' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' who' ' was' ' arrested' ' in' ' the' ' Soviet' ' Union'
+ ' in' ' 1967']" , the Soviet diss ident and Nobel Peace Prize winner , who was arrested in the Soviet Union in 1967 False Big Bang accumulated, Andrei Sakharov realized in 1967 8 [' Big', ' Bang', ' accumulated', ',', ' Andre', 'i', ' Sak', 'har', 'ov']
+4087 974 Name of mother of x -1 Name of mother of Andrei Sakharov Yekaterina Sakharova Andrei Sakharov "[',' ' the' ' Soviet' ' diss' 'ident' ' and' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' who' ' was' ' arrested' ' in' ' the' ' Soviet' ' Union'
+ ' in' ' 1967']" , the Soviet diss ident and Nobel Peace Prize winner , who was arrested in the Soviet Union in 1967 False Bang accumulated, Andrei Sakharov realized in 7 [' Bang', ' accumulated', ',', ' Andre', 'i', ' Sak', 'har', 'ov']
+4088 974 Name of mother of x -1 Name of mother of Andrei Sakharov Yekaterina Sakharova Andrei Sakharov "[',' ' the' ' Soviet' ' diss' 'ident' ' and' ' Nobel' ' Peace' ' Prize'
+ ' winner' ',' ' who' ' was' ' arrested' ' in' ' the' ' Soviet' ' Union'
+ ' in' ' 1967']" , the Soviet diss ident and Nobel Peace Prize winner , who was arrested in the Soviet Union in 1967 False dissidents, such as Andrei Sakharov and Aleksandr 8 [' dissidents', ',', ' such', ' as', ' Andre', 'i', ' Sak', 'har', 'ov']
+4089 975 Name of mother of x -1 Name of mother of Herbert von Karajan Martha von Karajan Herbert von Karajan "[',' ' the' ' famous' ' conductor' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the famous conductor , who was a friend of the family .
+
+ The house was a large" False Kraack was conducted by Herbert von Karajan at the Vienna State 9 [' Kra', 'ack', ' was', ' conducted', ' by', ' Herbert', ' von', ' Kar', 'aj', 'an']
+4090 975 Name of mother of x -1 Name of mother of Herbert von Karajan Martha von Karajan Herbert von Karajan "[',' ' the' ' famous' ' conductor' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the famous conductor , who was a friend of the family .
+
+ The house was a large" False " symphonies,"" and Herbert von Karajan called the Enigma Variations" 9 "[' sym', 'ph', 'onies', ',""', ' and', ' Herbert', ' von', ' Kar', 'aj', 'an']"
+4091 975 Name of mother of x -1 Name of mother of Herbert von Karajan Martha von Karajan Herbert von Karajan "[',' ' the' ' famous' ' conductor' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the famous conductor , who was a friend of the family .
+
+ The house was a large" False documentaries of Herbert von Karajan conducting Verdi's 6 [' documentaries', ' of', ' Herbert', ' von', ' Kar', 'aj', 'an']
+4092 975 Name of mother of x -1 Name of mother of Herbert von Karajan Martha von Karajan Herbert von Karajan "[',' ' the' ' famous' ' conductor' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the famous conductor , who was a friend of the family .
+
+ The house was a large" False " symphonies,"" and Herbert von Karajan called the" 9 "[' sym', 'ph', 'onies', ',""', ' and', ' Herbert', ' von', ' Kar', 'aj', 'an']"
+4093 975 Name of mother of x -1 Name of mother of Herbert von Karajan Martha von Karajan Herbert von Karajan "[',' ' the' ' famous' ' conductor' ',' ' who' ' was' ' a' ' friend' ' of'
+ ' the' ' family' '.' '\n' '\n' 'The' ' house' ' was' ' a' ' large']" ", the famous conductor , who was a friend of the family .
+
+ The house was a large" False " conducted by Herbert von Karajan for EMI in 1956.
+" 6 [' conducted', ' by', ' Herbert', ' von', ' Kar', 'aj', 'an']
+4094 976 Name of mother of x -1 Name of mother of Buster Keaton Myra Keaton Buster Keaton "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Buster' ' Ke' 'aton' ',' ' and' ' the' ' father' ' of' ' the']" , the actor , and the father of the actor , Buster Ke aton , and the father of the False Jackie Chan and Buster Keaton meet Quentin Tarantino 5 [' Jackie', ' Chan', ' and', ' Buster', ' Ke', 'aton']
+4095 976 Name of mother of x -1 Name of mother of Buster Keaton Myra Keaton Buster Keaton "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Buster' ' Ke' 'aton' ',' ' and' ' the' ' father' ' of' ' the']" , the actor , and the father of the actor , Buster Ke aton , and the father of the False to a boat used by Buster Keaton in the movie 7 [' to', ' a', ' boat', ' used', ' by', ' Buster', ' Ke', 'aton']
+4096 976 Name of mother of x -1 Name of mother of Buster Keaton Myra Keaton Buster Keaton "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Buster' ' Ke' 'aton' ',' ' and' ' the' ' father' ' of' ' the']" , the actor , and the father of the actor , Buster Ke aton , and the father of the False to a boat used by Buster Keaton in the movie 7 [' to', ' a', ' boat', ' used', ' by', ' Buster', ' Ke', 'aton']
+4097 976 Name of mother of x -1 Name of mother of Buster Keaton Myra Keaton Buster Keaton "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Buster' ' Ke' 'aton' ',' ' and' ' the' ' father' ' of' ' the']" , the actor , and the father of the actor , Buster Ke aton , and the father of the False Jackie Chan and Buster Keaton meet Quentin Tarantino 5 [' Jackie', ' Chan', ' and', ' Buster', ' Ke', 'aton']
+4098 976 Name of mother of x -1 Name of mother of Buster Keaton Myra Keaton Buster Keaton "[',' ' the' ' actor' ',' ' and' ' the' ' father' ' of' ' the' ' actor' ','
+ ' Buster' ' Ke' 'aton' ',' ' and' ' the' ' father' ' of' ' the']" , the actor , and the father of the actor , Buster Ke aton , and the father of the False Productions. Filmmaker Buster Keaton rode the back 7 [' Productions', '.', ' Fil', 'mm', 'aker', ' Buster', ' Ke', 'aton']
+4099 977 Name of mother of x -1 Name of mother of Indira Gandhi Kamala Nehru Indira Gandhi "[',' ' the' ' daughter' ' of' ' the' ' late' ' Prime' ' Minister' ' of'
+ ' India' ',' ' Ind' 'ira' ' Gandhi' ',' ' was' ' born' ' in' ' the'
+ ' same']" , the daughter of the late Prime Minister of India , Ind ira Gandhi , was born in the same False Kailash, Munirka, Indira Gandhi International 10 [' K', 'ail', 'ash', ',', ' Mun', 'ir', 'ka', ',', ' Ind', 'ira', ' Gandhi']
+4100 977 Name of mother of x -1 Name of mother of Indira Gandhi Kamala Nehru Indira Gandhi "[',' ' the' ' daughter' ' of' ' the' ' late' ' Prime' ' Minister' ' of'
+ ' India' ',' ' Ind' 'ira' ' Gandhi' ',' ' was' ' born' ' in' ' the'
+ ' same']" , the daughter of the late Prime Minister of India , Ind ira Gandhi , was born in the same False 2 ['Ind', 'ira', ' Gandhi']
+4101 977 Name of mother of x -1 Name of mother of Indira Gandhi Kamala Nehru Indira Gandhi "[',' ' the' ' daughter' ' of' ' the' ' late' ' Prime' ' Minister' ' of'
+ ' India' ',' ' Ind' 'ira' ' Gandhi' ',' ' was' ' born' ' in' ' the'
+ ' same']" , the daughter of the late Prime Minister of India , Ind ira Gandhi , was born in the same False at Terminal 3 at Indira Gandhi International 6 [' at', ' Terminal', ' 3', ' at', ' Ind', 'ira', ' Gandhi']
+4102 977 Name of mother of x -1 Name of mother of Indira Gandhi Kamala Nehru Indira Gandhi "[',' ' the' ' daughter' ' of' ' the' ' late' ' Prime' ' Minister' ' of'
+ ' India' ',' ' Ind' 'ira' ' Gandhi' ',' ' was' ' born' ' in' ' the'
+ ' same']" , the daughter of the late Prime Minister of India , Ind ira Gandhi , was born in the same False " Prime Minister Indira Gandhi in 1984.
+" 4 [' Prime', ' Minister', ' Ind', 'ira', ' Gandhi']
+4103 977 Name of mother of x -1 Name of mother of Indira Gandhi Kamala Nehru Indira Gandhi "[',' ' the' ' daughter' ' of' ' the' ' late' ' Prime' ' Minister' ' of'
+ ' India' ',' ' Ind' 'ira' ' Gandhi' ',' ' was' ' born' ' in' ' the'
+ ' same']" , the daughter of the late Prime Minister of India , Ind ira Gandhi , was born in the same False subsequently enrolled at the Indira Gandhi National Open University 6 [' subsequently', ' enrolled', ' at', ' the', ' Ind', 'ira', ' Gandhi']
+4104 980 Name of mother of x -1 Name of mother of Douglas MacArthur Mary Pinckney Hardy Douglas MacArthur "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False the Emperor. General Douglas MacArthur and the State Department 5 [' the', ' Emperor', '.', ' General', ' Douglas', ' MacArthur']
+4105 980 Name of mother of x -1 Name of mother of Douglas MacArthur Mary Pinckney Hardy Douglas MacArthur "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False In October, General Douglas MacArthur gave Kenney a dressing 5 [' In', ' October', ',', ' General', ' Douglas', ' MacArthur']
+4106 980 Name of mother of x -1 Name of mother of Douglas MacArthur Mary Pinckney Hardy Douglas MacArthur "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False U.S. General Douglas MacArthur lived in suites 6 [' U', '.', 'S', '.', ' General', ' Douglas', ' MacArthur']
+4107 980 Name of mother of x -1 Name of mother of Douglas MacArthur Mary Pinckney Hardy Douglas MacArthur "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False — where General Douglas MacArthur told Cooper 4 [' —', ' where', ' General', ' Douglas', ' MacArthur']
+4108 980 Name of mother of x -1 Name of mother of Douglas MacArthur Mary Pinckney Hardy Douglas MacArthur "[',' ' the' ' son' ' of' ' a' ' bitch' '!""' '\n' '\n' '""' 'I' ""'m"" ' not'
+ ' a' ' son' ' of' ' a' ' bitch' '!""' '\n']" ", the son of a bitch !""
+
+ "" I 'm not a son of a bitch !""
+" False landing of General Douglas MacArthur and the US Army's 4 [' landing', ' of', ' General', ' Douglas', ' MacArthur']
+4109 981 Name of mother of x -1 Name of mother of Enrico Caruso Anna Baldini Enrico Caruso "[',' ' the' ' famous' ' opera' ' singer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' C' 'ag' 'li' 'ari' ',' ' Sard' 'inia']" , the famous opera singer , who was born in the town of C ag li ari , Sard inia False Scarpia. The young Enrico Caruso had hoped to create 11 [' Sc', 'arp', 'ia', '.', ' The', ' young', ' En', 'ric', 'o', ' Car', 'us', 'o']
+4110 981 Name of mother of x -1 Name of mother of Enrico Caruso Anna Baldini Enrico Caruso "[',' ' the' ' famous' ' opera' ' singer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' C' 'ag' 'li' 'ari' ',' ' Sard' 'inia']" , the famous opera singer , who was born in the town of C ag li ari , Sard inia False vocalists such as Enrico Caruso and Nellie Melba performed 9 [' vocal', 'ists', ' such', ' as', ' En', 'ric', 'o', ' Car', 'us', 'o']
+4111 981 Name of mother of x -1 Name of mother of Enrico Caruso Anna Baldini Enrico Caruso "[',' ' the' ' famous' ' opera' ' singer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' C' 'ag' 'li' 'ari' ',' ' Sard' 'inia']" , the famous opera singer , who was born in the town of C ag li ari , Sard inia False vocalists such as Enrico Caruso and Nellie Melba 9 [' vocal', 'ists', ' such', ' as', ' En', 'ric', 'o', ' Car', 'us', 'o']
+4112 981 Name of mother of x -1 Name of mother of Enrico Caruso Anna Baldini Enrico Caruso "[',' ' the' ' famous' ' opera' ' singer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' C' 'ag' 'li' 'ari' ',' ' Sard' 'inia']" , the famous opera singer , who was born in the town of C ag li ari , Sard inia False among the jury and Enrico Caruso singing, among many 9 [' among', ' the', ' jury', ' and', ' En', 'ric', 'o', ' Car', 'us', 'o']
+4113 981 Name of mother of x -1 Name of mother of Enrico Caruso Anna Baldini Enrico Caruso "[',' ' the' ' famous' ' opera' ' singer' ',' ' who' ' was' ' born' ' in'
+ ' the' ' town' ' of' ' C' 'ag' 'li' 'ari' ',' ' Sard' 'inia']" , the famous opera singer , who was born in the town of C ag li ari , Sard inia False the jury and Enrico Caruso singing, among 8 [' the', ' jury', ' and', ' En', 'ric', 'o', ' Car', 'us', 'o']
+4114 982 Name of mother of x -1 Name of mother of A. A. Milne Sarah Maria Heginbotham A. A. Milne "[',' ' the' ' author' ' of' ' Winn' 'ie' '-' 'the' '-' 'Po' 'oh' ','
+ ' and' ' the' ' author' ' of' ' the' ' Winn' 'ie' '-']" , the author of Winn ie - the - Po oh , and the author of the Winn ie - False English author A. A. Milne for his son Christopher 7 [' English', ' author', ' A', '.', ' A', '.', ' Mil', 'ne']
+4115 982 Name of mother of x -1 Name of mother of A. A. Milne Sarah Maria Heginbotham A. A. Milne "[',' ' the' ' author' ' of' ' Winn' 'ie' '-' 'the' '-' 'Po' 'oh' ','
+ ' and' ' the' ' author' ' of' ' the' ' Winn' 'ie' '-']" , the author of Winn ie - the - Po oh , and the author of the Winn ie - False " imitated by A. A. Milne in Winnie the Pooh.
+" 8 [' im', 'itated', ' by', ' A', '.', ' A', '.', ' Mil', 'ne']
+4116 982 Name of mother of x -1 Name of mother of A. A. Milne Sarah Maria Heginbotham A. A. Milne "[',' ' the' ' author' ' of' ' Winn' 'ie' '-' 'the' '-' 'Po' 'oh' ','
+ ' and' ' the' ' author' ' of' ' the' ' Winn' 'ie' '-']" , the author of Winn ie - the - Po oh , and the author of the Winn ie - False Winnipeg. A. A. Milne later wrote 7 [' Winnipeg', '.', ' A', '.', ' A', '.', ' Mil', 'ne']
+4117 982 Name of mother of x -1 Name of mother of A. A. Milne Sarah Maria Heginbotham A. A. Milne "[',' ' the' ' author' ' of' ' Winn' 'ie' '-' 'the' '-' 'Po' 'oh' ','
+ ' and' ' the' ' author' ' of' ' the' ' Winn' 'ie' '-']" , the author of Winn ie - the - Po oh , and the author of the Winn ie - False authors including A. A. Milne who called his 7 [' authors', ' including', ' A', '.', ' A', '.', ' Mil', 'ne']
+4118 982 Name of mother of x -1 Name of mother of A. A. Milne Sarah Maria Heginbotham A. A. Milne "[',' ' the' ' author' ' of' ' Winn' 'ie' '-' 'the' '-' 'Po' 'oh' ','
+ ' and' ' the' ' author' ' of' ' the' ' Winn' 'ie' '-']" , the author of Winn ie - the - Po oh , and the author of the Winn ie - False including A. A. Milne who called his 6 [' including', ' A', '.', ' A', '.', ' Mil', 'ne']
+4119 984 Name of mother of x -1 Name of mother of Samuel Barber Maruerite 'Daisy' McLeod Barber (Beatty) Samuel Barber "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Samuel' ' Barber' ',' ' the' ' father' ' of' ' the' '\n' '\n']" ", the father of the
+
+ Name of mother of Samuel Barber , the father of the
+
+" False " — ""a little Samuel Barber meets Giorgio" 5 "[' —', ' ""', 'a', ' little', ' Samuel', ' Barber']"
+4120 984 Name of mother of x -1 Name of mother of Samuel Barber Maruerite 'Daisy' McLeod Barber (Beatty) Samuel Barber "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Samuel' ' Barber' ',' ' the' ' father' ' of' ' the' '\n' '\n']" ", the father of the
+
+ Name of mother of Samuel Barber , the father of the
+
+" False " liked — ""a little Samuel Barber meets Giorgio" 6 "[' liked', ' —', ' ""', 'a', ' little', ' Samuel', ' Barber']"
+4121 984 Name of mother of x -1 Name of mother of Samuel Barber Maruerite 'Daisy' McLeod Barber (Beatty) Samuel Barber "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Samuel' ' Barber' ',' ' the' ' father' ' of' ' the' '\n' '\n']" ", the father of the
+
+ Name of mother of Samuel Barber , the father of the
+
+" False like Arvo Pärt, Samuel Barber and Henryk Górecki; 8 [' like', ' Ar', 'vo', ' P', 'ä', 'rt', ',', ' Samuel', ' Barber']
+4122 984 Name of mother of x -1 Name of mother of Samuel Barber Maruerite 'Daisy' McLeod Barber (Beatty) Samuel Barber "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Samuel' ' Barber' ',' ' the' ' father' ' of' ' the' '\n' '\n']" ", the father of the
+
+ Name of mother of Samuel Barber , the father of the
+
+" False like Arvo Pärt, Samuel Barber and Henryk Górecki; 8 [' like', ' Ar', 'vo', ' P', 'ä', 'rt', ',', ' Samuel', ' Barber']
+4123 984 Name of mother of x -1 Name of mother of Samuel Barber Maruerite 'Daisy' McLeod Barber (Beatty) Samuel Barber "[',' ' the' ' father' ' of' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of'
+ ' Samuel' ' Barber' ',' ' the' ' father' ' of' ' the' '\n' '\n']" ", the father of the
+
+ Name of mother of Samuel Barber , the father of the
+
+" False like Arvo Pärt, Samuel Barber and Henryk 8 [' like', ' Ar', 'vo', ' P', 'ä', 'rt', ',', ' Samuel', ' Barber']
+4124 985 Name of mother of x -1 Name of mother of Ferdinand I, Holy Roman Emperor Juana I of Castile Ferdinand I, Holy Roman Emperor "[',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ',' ' and' ' the' ' Holy'
+ ' Roman' ' Emperor' ',' ' and' ' the' ' Holy' ' Roman' ' Emperor' ','
+ ' and']" , and the Holy Roman Emperor , and the Holy Roman Emperor , and the Holy Roman Emperor , and False the ambassador of Ferdinand I, Holy Roman Emperor to the Sultan of Turkey, 8 [' the', ' ambassador', ' of', ' Ferdinand', ' I', ',', ' Holy', ' Roman', ' Emperor']
+4125 986 Name of mother of x -1 Name of mother of William the Silent Juliana of Stolberg William the Silent "[',' ' the' ' first' ' of' ' the' ' Dutch' ' royal' ' family' ',' ' and'
+ ' the' ' first' ' of' ' the' ' House' ' of' ' Orange' '.' '\n' '\n']" ", the first of the Dutch royal family , and the first of the House of Orange .
+
+" False unofficially assist William the Silent in his struggle 6 [' un', 'offic', 'ially', ' assist', ' William', ' the', ' Silent']
+4126 986 Name of mother of x -1 Name of mother of William the Silent Juliana of Stolberg William the Silent "[',' ' the' ' first' ' of' ' the' ' Dutch' ' royal' ' family' ',' ' and'
+ ' the' ' first' ' of' ' the' ' House' ' of' ' Orange' '.' '\n' '\n']" ", the first of the Dutch royal family , and the first of the House of Orange .
+
+" False assassinated (as William the Silent had been a few 5 [' assassinated', ' (', 'as', ' William', ' the', ' Silent']
+4127 986 Name of mother of x -1 Name of mother of William the Silent Juliana of Stolberg William the Silent "[',' ' the' ' first' ' of' ' the' ' Dutch' ' royal' ' family' ',' ' and'
+ ' the' ' first' ' of' ' the' ' House' ' of' ' Orange' '.' '\n' '\n']" ", the first of the Dutch royal family , and the first of the House of Orange .
+
+" False " sculpture depicts William the Silent standing:
+" 4 [' sculpture', ' depicts', ' William', ' the', ' Silent']
+4128 986 Name of mother of x -1 Name of mother of William the Silent Juliana of Stolberg William the Silent "[',' ' the' ' first' ' of' ' the' ' Dutch' ' royal' ' family' ',' ' and'
+ ' the' ' first' ' of' ' the' ' House' ' of' ' Orange' '.' '\n' '\n']" ", the first of the Dutch royal family , and the first of the House of Orange .
+
+" False assassinated (as William the Silent had been a few 5 [' assassinated', ' (', 'as', ' William', ' the', ' Silent']
+4129 986 Name of mother of x -1 Name of mother of William the Silent Juliana of Stolberg William the Silent "[',' ' the' ' first' ' of' ' the' ' Dutch' ' royal' ' family' ',' ' and'
+ ' the' ' first' ' of' ' the' ' House' ' of' ' Orange' '.' '\n' '\n']" ", the first of the Dutch royal family , and the first of the House of Orange .
+
+" False Campus and the William the Silent statue. It was 5 [' Campus', ' and', ' the', ' William', ' the', ' Silent']
+4130 987 Name of mother of x -1 Name of mother of Kangxi Emperor Empress Xiaokangzhang Kangxi Emperor "[',' ' the' ' Emperor' ' of' ' China' '.' '\n' '\n' 'The' ' Emperor' ' of'
+ ' China' ' is' ' the' ' most' ' powerful' ' ruler' ' in' ' the' ' world']" ", the Emperor of China .
+
+ The Emperor of China is the most powerful ruler in the world" False 3 ['K', 'ang', 'xi', ' Emperor']
+4131 987 Name of mother of x -1 Name of mother of Kangxi Emperor Empress Xiaokangzhang Kangxi Emperor "[',' ' the' ' Emperor' ' of' ' China' '.' '\n' '\n' 'The' ' Emperor' ' of'
+ ' China' ' is' ' the' ' most' ' powerful' ' ruler' ' in' ' the' ' world']" ", the Emperor of China .
+
+ The Emperor of China is the most powerful ruler in the world" False the minority of the Kangxi Emperor (r. 1661 – 1722). Under 6 [' the', ' minority', ' of', ' the', ' Kang', 'xi', ' Emperor']
+4132 987 Name of mother of x -1 Name of mother of Kangxi Emperor Empress Xiaokangzhang Kangxi Emperor "[',' ' the' ' Emperor' ' of' ' China' '.' '\n' '\n' 'The' ' Emperor' ' of'
+ ' China' ' is' ' the' ' most' ' powerful' ' ruler' ' in' ' the' ' world']" ", the Emperor of China .
+
+ The Emperor of China is the most powerful ruler in the world" False 3 ['K', 'ang', 'xi', ' Emperor']
+4133 987 Name of mother of x -1 Name of mother of Kangxi Emperor Empress Xiaokangzhang Kangxi Emperor "[',' ' the' ' Emperor' ' of' ' China' '.' '\n' '\n' 'The' ' Emperor' ' of'
+ ' China' ' is' ' the' ' most' ' powerful' ' ruler' ' in' ' the' ' world']" ", the Emperor of China .
+
+ The Emperor of China is the most powerful ruler in the world" False Daoist study. The Kangxi Emperor was especially 8 [' D', 'ao', 'ist', ' study', '.', ' The', ' Kang', 'xi', ' Emperor']
+4134 987 Name of mother of x -1 Name of mother of Kangxi Emperor Empress Xiaokangzhang Kangxi Emperor "[',' ' the' ' Emperor' ' of' ' China' '.' '\n' '\n' 'The' ' Emperor' ' of'
+ ' China' ' is' ' the' ' most' ' powerful' ' ruler' ' in' ' the' ' world']" ", the Emperor of China .
+
+ The Emperor of China is the most powerful ruler in the world" False inoculation. The Kangxi Emperor (r. 1661 – 1722) 6 [' inoc', 'ulation', '.', ' The', ' Kang', 'xi', ' Emperor']
+4135 988 Name of mother of x -1 Name of mother of Balthus Baladine Klossowska Balthus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' B' 'alth' 'us'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of B alth us , and the
+
+ Name of mother" False " Reverend Norman Balthus and Brother Justin.
+" 4 [' Reverend', ' Norman', ' B', 'alth', 'us']
+4136 988 Name of mother of x -1 Name of mother of Balthus Baladine Klossowska Balthus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' B' 'alth' 'us'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of B alth us , and the
+
+ Name of mother" False Polish-French artist Balthus drew a wordless 6 [' Polish', '-', 'French', ' artist', ' B', 'alth', 'us']
+4137 988 Name of mother of x -1 Name of mother of Balthus Baladine Klossowska Balthus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' B' 'alth' 'us'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of B alth us , and the
+
+ Name of mother" False Among modern artists, Balthus appears to have incorporated 6 [' Among', ' modern', ' artists', ',', ' B', 'alth', 'us']
+4138 988 Name of mother of x -1 Name of mother of Balthus Baladine Klossowska Balthus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' B' 'alth' 'us'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of B alth us , and the
+
+ Name of mother" False " Reverend Norman Balthus ===
+" 4 [' Reverend', ' Norman', ' B', 'alth', 'us']
+4139 988 Name of mother of x -1 Name of mother of Balthus Baladine Klossowska Balthus "[',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother' ' of' ' B' 'alth' 'us'
+ ',' ' and' ' the' '\n' '\n' 'Name' ' of' ' mother']" ", and the
+
+ Name of mother of B alth us , and the
+
+ Name of mother" False Reverend Norman Balthus later saved 4 [' Reverend', ' Norman', ' B', 'alth', 'us']
+4140 989 Name of mother of x -1 Name of mother of Lawrence Durrell Louisa Durrell Lawrence Durrell "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False according to Lawrence Durrell often had a string 4 [' according', ' to', ' Lawrence', ' Dur', 'rell']
+4141 989 Name of mother of x -1 Name of mother of Lawrence Durrell Louisa Durrell Lawrence Durrell "[',' ' the' ' author' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ' of' ' the' ' book' ',' ' and' ' the' ' author' ' of' ' the' ' book']" , the author of the book , and the author of the book , and the author of the book False the Good Lord Nelson, Lawrence Durrell included the 7 [' the', ' Good', ' Lord', ' Nelson', ',', ' Lawrence', ' Dur', 'rell']
+4142 991 Name of mother of x -1 Name of mother of Bridget Fonda Susan Brewer Bridget Fonda "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False avaricious of souls. And Bridget Fonda offers solid 10 [' av', 'ar', 'icious', ' of', ' souls', '.', ' And', ' Br', 'idget', ' F', 'onda']
+4143 991 Name of mother of x -1 Name of mother of Bridget Fonda Susan Brewer Bridget Fonda "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Ryder and Bridget Fonda were in competition 5 [' Ryder', ' and', ' Br', 'idget', ' F', 'onda']
+4144 991 Name of mother of x -1 Name of mother of Bridget Fonda Susan Brewer Bridget Fonda "[',' ' the' ' mother' ' of' ' the' ' bride' ',' ' and' ' the' ' mother'
+ ' of' ' the' ' groom' '.' '\n' '\n' 'The' ' bride' ' and' ' groom']" ", the mother of the bride , and the mother of the groom .
+
+ The bride and groom" False Robbins. Winona Ryder and Bridget Fonda were in competition 9 [' Robbins', '.', ' Win', 'ona', ' Ryder', ' and', ' Br', 'idget', ' F', 'onda']
+4145 992 Name of mother of x -1 Name of mother of Ludvig Holberg Karen Lem Ludvig Holberg "[',' ' the' ' Danish' ' poet' ',' ' was' ' born' ' in' ' the' ' town'
+ ' of' ' Hol' 'berg' ',' ' in' ' the' ' province' ' of' ' Zealand' ',']" , the Danish poet , was born in the town of Hol berg , in the province of Zealand , False Close by, Professor Ludvig Holberg left his home 8 [' Close', ' by', ',', ' Professor', ' Lud', 'v', 'ig', ' Hol', 'berg']
+4146 992 Name of mother of x -1 Name of mother of Ludvig Holberg Karen Lem Ludvig Holberg "[',' ' the' ' Danish' ' poet' ',' ' was' ' born' ' in' ' the' ' town'
+ ' of' ' Hol' 'berg' ',' ' in' ' the' ' province' ' of' ' Zealand' ',']" , the Danish poet , was born in the town of Hol berg , in the province of Zealand , False by, Professor Ludvig Holberg left his home 7 [' by', ',', ' Professor', ' Lud', 'v', 'ig', ' Hol', 'berg']
+4147 993 Name of mother of x -1 Name of mother of Algernon Charles Swinburne Lady Jane Ashburnham Algernon Charles Swinburne "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' private' ' residence' '.']" ", the poet , was born in this house .
+
+ The house is now a private residence ." False lifetime, counting Algernon Charles Swinburne among his 11 [' lifetime', ',', ' counting', ' Al', 'g', 'ern', 'on', ' Charles', ' Sw', 'in', 'burn', 'e']
+4148 993 Name of mother of x -1 Name of mother of Algernon Charles Swinburne Lady Jane Ashburnham Algernon Charles Swinburne "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' private' ' residence' '.']" ", the poet , was born in this house .
+
+ The house is now a private residence ." False Whistler's friend Algernon Charles Swinburne – titled Before 13 "[' Wh', 'ist', 'ler', ""'s"", ' friend', ' Al', 'g', 'ern', 'on', ' Charles', ' Sw', 'in', 'burn', 'e']"
+4149 993 Name of mother of x -1 Name of mother of Algernon Charles Swinburne Lady Jane Ashburnham Algernon Charles Swinburne "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' private' ' residence' '.']" ", the poet , was born in this house .
+
+ The house is now a private residence ." False " was described by Algernon Charles Swinburne as ""a person of high" 11 [' was', ' described', ' by', ' Al', 'g', 'ern', 'on', ' Charles', ' Sw', 'in', 'burn', 'e']
+4150 993 Name of mother of x -1 Name of mother of Algernon Charles Swinburne Lady Jane Ashburnham Algernon Charles Swinburne "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' private' ' residence' '.']" ", the poet , was born in this house .
+
+ The house is now a private residence ." False Whistler's friend Algernon Charles Swinburne – titled Before 13 "[' Wh', 'ist', 'ler', ""'s"", ' friend', ' Al', 'g', 'ern', 'on', ' Charles', ' Sw', 'in', 'burn', 'e']"
+4151 993 Name of mother of x -1 Name of mother of Algernon Charles Swinburne Lady Jane Ashburnham Algernon Charles Swinburne "[',' ' the' ' poet' ',' ' was' ' born' ' in' ' this' ' house' '.' '\n'
+ '\n' 'The' ' house' ' is' ' now' ' a' ' private' ' residence' '.']" ", the poet , was born in this house .
+
+ The house is now a private residence ." False 8 ['Al', 'g', 'ern', 'on', ' Charles', ' Sw', 'in', 'burn', 'e']
diff --git a/patchscopes/code/preprocessed_data/factual/person_plays_position_in_sport.tsv b/patchscopes/code/preprocessed_data/factual/person_plays_position_in_sport.tsv
new file mode 100644
index 00000000..90029cb7
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/person_plays_position_in_sport.tsv
@@ -0,0 +1,4697 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 In their sport, the position played by x -1 In their sport, the position played by Percy Snow linebacker Percy Snow "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False is the sister of Percy Snow and Eric Snow, 5 [' is', ' the', ' sister', ' of', ' Percy', ' Snow']
+1 0 In their sport, the position played by x -1 In their sport, the position played by Percy Snow linebacker Percy Snow "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False is the sister of Percy Snow and Eric Snow, 5 [' is', ' the', ' sister', ' of', ' Percy', ' Snow']
+2 2 In their sport, the position played by x -1 In their sport, the position played by Ruud Gullit midfielder Ruud Gullit "[',' ' the' ' Dutch' 'man' ' who' ' was' ' the' ' first' ' player' ' to'
+ ' score' ' a' ' hat' '-' 'tr' 'ick' ' in' ' a' ' World' ' Cup']" , the Dutch man who was the first player to score a hat - tr ick in a World Cup False 13th again, with Ruud Gullit having replaced Kenny 9 [' 13', 'th', ' again', ',', ' with', ' Ru', 'ud', ' G', 'ull', 'it']
+3 2 In their sport, the position played by x -1 In their sport, the position played by Ruud Gullit midfielder Ruud Gullit "[',' ' the' ' Dutch' 'man' ' who' ' was' ' the' ' first' ' player' ' to'
+ ' score' ' a' ' hat' '-' 'tr' 'ick' ' in' ' a' ' World' ' Cup']" , the Dutch man who was the first player to score a hat - tr ick in a World Cup False resignation of Ruud Gullit at Newcastle United, 6 [' resignation', ' of', ' Ru', 'ud', ' G', 'ull', 'it']
+4 2 In their sport, the position played by x -1 In their sport, the position played by Ruud Gullit midfielder Ruud Gullit "[',' ' the' ' Dutch' 'man' ' who' ' was' ' the' ' first' ' player' ' to'
+ ' score' ' a' ' hat' '-' 'tr' 'ick' ' in' ' a' ' World' ' Cup']" , the Dutch man who was the first player to score a hat - tr ick in a World Cup False the resignation of Ruud Gullit at Newcastle United, 7 [' the', ' resignation', ' of', ' Ru', 'ud', ' G', 'ull', 'it']
+5 2 In their sport, the position played by x -1 In their sport, the position played by Ruud Gullit midfielder Ruud Gullit "[',' ' the' ' Dutch' 'man' ' who' ' was' ' the' ' first' ' player' ' to'
+ ' score' ' a' ' hat' '-' 'tr' 'ick' ' in' ' a' ' World' ' Cup']" , the Dutch man who was the first player to score a hat - tr ick in a World Cup False appointment of Ruud Gullit as player-manager 6 [' appointment', ' of', ' Ru', 'ud', ' G', 'ull', 'it']
+6 2 In their sport, the position played by x -1 In their sport, the position played by Ruud Gullit midfielder Ruud Gullit "[',' ' the' ' Dutch' 'man' ' who' ' was' ' the' ' first' ' player' ' to'
+ ' score' ' a' ' hat' '-' 'tr' 'ick' ' in' ' a' ' World' ' Cup']" , the Dutch man who was the first player to score a hat - tr ick in a World Cup False European side coached by Ruud Gullit and an all-Italian 8 [' European', ' side', ' coached', ' by', ' Ru', 'ud', ' G', 'ull', 'it']
+7 4 In their sport, the position played by x -1 In their sport, the position played by Lars Ricken midfielder Lars Ricken "['backer' ',' ' the' ' player' ' who' ' is' ' the' ' most' ' famous'
+ ' for' ' his' ' work' ' with' ' the' ' band' ' The' ' Beach' ' Boys' ','
+ ' is']" backer , the player who is the most famous for his work with the band The Beach Boys , is False on 57 minutes, Lars Ricken on 70 minutes, 6 [' on', ' 57', ' minutes', ',', ' Lars', ' R', 'icken']
+8 4 In their sport, the position played by x -1 In their sport, the position played by Lars Ricken midfielder Lars Ricken "['backer' ',' ' the' ' player' ' who' ' is' ' the' ' most' ' famous'
+ ' for' ' his' ' work' ' with' ' the' ' band' ' The' ' Beach' ' Boys' ','
+ ' is']" backer , the player who is the most famous for his work with the band The Beach Boys , is False equaliser replacing Lars Ricken with Jörg Heinrich, 5 [' equal', 'iser', ' replacing', ' Lars', ' R', 'icken']
+9 4 In their sport, the position played by x -1 In their sport, the position played by Lars Ricken midfielder Lars Ricken "['backer' ',' ' the' ' player' ' who' ' is' ' the' ' most' ' famous'
+ ' for' ' his' ' work' ' with' ' the' ' band' ' The' ' Beach' ' Boys' ','
+ ' is']" backer , the player who is the most famous for his work with the band The Beach Boys , is False on 57 minutes, Lars Ricken on 70 minutes, 6 [' on', ' 57', ' minutes', ',', ' Lars', ' R', 'icken']
+10 4 In their sport, the position played by x -1 In their sport, the position played by Lars Ricken midfielder Lars Ricken "['backer' ',' ' the' ' player' ' who' ' is' ' the' ' most' ' famous'
+ ' for' ' his' ' work' ' with' ' the' ' band' ' The' ' Beach' ' Boys' ','
+ ' is']" backer , the player who is the most famous for his work with the band The Beach Boys , is False equaliser replacing Lars Ricken with Jörg 5 [' equal', 'iser', ' replacing', ' Lars', ' R', 'icken']
+11 4 In their sport, the position played by x -1 In their sport, the position played by Lars Ricken midfielder Lars Ricken "['backer' ',' ' the' ' player' ' who' ' is' ' the' ' most' ' famous'
+ ' for' ' his' ' work' ' with' ' the' ' band' ' The' ' Beach' ' Boys' ','
+ ' is']" backer , the player who is the most famous for his work with the band The Beach Boys , is False Koller on 57 minutes, Lars Ricken on 70 minutes, 8 [' K', 'oller', ' on', ' 57', ' minutes', ',', ' Lars', ' R', 'icken']
+12 5 In their sport, the position played by x -1 In their sport, the position played by Tuukka Rask goaltender Tuukka Rask "[' is' ' the' ' equivalent' ' of' ' a' ' quarterback' ' in' ' football'
+ '.' ' He' ' is' ' the' ' quarterback' ' of' ' the' ' Bruins' ',' ' and'
+ ' he' ' is']" is the equivalent of a quarterback in football . He is the quarterback of the Bruins , and he is False Bruins goaltender Tuukka Rask put on an 6 [' Bruins', ' goaltender', ' Tu', 'uk', 'ka', ' R', 'ask']
+13 5 In their sport, the position played by x -1 In their sport, the position played by Tuukka Rask goaltender Tuukka Rask "[' is' ' the' ' equivalent' ' of' ' a' ' quarterback' ' in' ' football'
+ '.' ' He' ' is' ' the' ' quarterback' ' of' ' the' ' Bruins' ',' ' and'
+ ' he' ' is']" is the equivalent of a quarterback in football . He is the quarterback of the Bruins , and he is False Boston Bruins goaltender Tuukka Rask put on an outstanding 7 [' Boston', ' Bruins', ' goaltender', ' Tu', 'uk', 'ka', ' R', 'ask']
+14 5 In their sport, the position played by x -1 In their sport, the position played by Tuukka Rask goaltender Tuukka Rask "[' is' ' the' ' equivalent' ' of' ' a' ' quarterback' ' in' ' football'
+ '.' ' He' ' is' ' the' ' quarterback' ' of' ' the' ' Bruins' ',' ' and'
+ ' he' ' is']" is the equivalent of a quarterback in football . He is the quarterback of the Bruins , and he is False then-backup goaltender Tuukka Rask had artwork on 9 [' then', '-', 'back', 'up', ' goaltender', ' Tu', 'uk', 'ka', ' R', 'ask']
+15 5 In their sport, the position played by x -1 In their sport, the position played by Tuukka Rask goaltender Tuukka Rask "[' is' ' the' ' equivalent' ' of' ' a' ' quarterback' ' in' ' football'
+ '.' ' He' ' is' ' the' ' quarterback' ' of' ' the' ' Bruins' ',' ' and'
+ ' he' ' is']" is the equivalent of a quarterback in football . He is the quarterback of the Bruins , and he is False then-backup goaltender Tuukka Rask had artwork 9 [' then', '-', 'back', 'up', ' goaltender', ' Tu', 'uk', 'ka', ' R', 'ask']
+16 8 In their sport, the position played by x -1 In their sport, the position played by John Navarre quarterback John Navarre "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False field goal range, John Navarre was intercepted 6 [' field', ' goal', ' range', ',', ' John', ' Nav', 'arre']
+17 8 In their sport, the position played by x -1 In their sport, the position played by John Navarre quarterback John Navarre "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False season mark held by John Navarre (3,240 yards in 6 [' season', ' mark', ' held', ' by', ' John', ' Nav', 'arre']
+18 8 In their sport, the position played by x -1 In their sport, the position played by John Navarre quarterback John Navarre "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False field goal range, John Navarre was intercepted 6 [' field', ' goal', ' range', ',', ' John', ' Nav', 'arre']
+19 8 In their sport, the position played by x -1 In their sport, the position played by John Navarre quarterback John Navarre "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False goal range, John Navarre was intercepted 5 [' goal', ' range', ',', ' John', ' Nav', 'arre']
+20 8 In their sport, the position played by x -1 In their sport, the position played by John Navarre quarterback John Navarre "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False field goal range, John Navarre was intercepted 6 [' field', ' goal', ' range', ',', ' John', ' Nav', 'arre']
+21 9 In their sport, the position played by x -1 In their sport, the position played by Thurman Munson catcher Thurman Munson "[',' ' the' ' Yankees' ""'"" ' catcher' ',' ' was' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' He' ' was' ' the'
+ ' catcher' ',']" , the Yankees ' catcher , was the most important position on the field . He was the catcher , True included returning Thurman Munson to the team's every 5 [' included', ' returning', ' Thur', 'man', ' Mun', 'son']
+22 9 In their sport, the position played by x -1 In their sport, the position played by Thurman Munson catcher Thurman Munson "[',' ' the' ' Yankees' ""'"" ' catcher' ',' ' was' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' He' ' was' ' the'
+ ' catcher' ',']" , the Yankees ' catcher , was the most important position on the field . He was the catcher , True first to win it since Thurman Munson in 1976. He was 8 [' first', ' to', ' win', ' it', ' since', ' Thur', 'man', ' Mun', 'son']
+23 9 In their sport, the position played by x -1 In their sport, the position played by Thurman Munson catcher Thurman Munson "[',' ' the' ' Yankees' ""'"" ' catcher' ',' ' was' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' He' ' was' ' the'
+ ' catcher' ',']" , the Yankees ' catcher , was the most important position on the field . He was the catcher , True included returning Thurman Munson to the team's 5 [' included', ' returning', ' Thur', 'man', ' Mun', 'son']
+24 9 In their sport, the position played by x -1 In their sport, the position played by Thurman Munson catcher Thurman Munson "[',' ' the' ' Yankees' ""'"" ' catcher' ',' ' was' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' He' ' was' ' the'
+ ' catcher' ',']" , the Yankees ' catcher , was the most important position on the field . He was the catcher , True Yankees catcher Thurman Munson attempted to score 5 [' Yankees', ' catcher', ' Thur', 'man', ' Mun', 'son']
+25 9 In their sport, the position played by x -1 In their sport, the position played by Thurman Munson catcher Thurman Munson "[',' ' the' ' Yankees' ""'"" ' catcher' ',' ' was' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' He' ' was' ' the'
+ ' catcher' ',']" , the Yankees ' catcher , was the most important position on the field . He was the catcher , True win it since Thurman Munson in 1976. He 6 [' win', ' it', ' since', ' Thur', 'man', ' Mun', 'son']
+26 13 In their sport, the position played by x -1 In their sport, the position played by Nigel de Jong midfielder Nigel de Jong "[',' ' who' ' was' ' a' ' key' ' player' ' in' ' the' ' team' '�' '�' 's'
+ ' success' ' in' ' the' ' Champions' ' League' ',' ' is' ' a']" , who was a key player in the team � � s success in the Champions League , is a False Carlos Tevez, and Nigel de Jong and Gareth 7 [' Carlos', ' Te', 'vez', ',', ' and', ' Nigel', ' de', ' Jong']
+27 13 In their sport, the position played by x -1 In their sport, the position played by Nigel de Jong midfielder Nigel de Jong "[',' ' who' ' was' ' a' ' key' ' player' ' in' ' the' ' team' '�' '�' 's'
+ ' success' ' in' ' the' ' Champions' ' League' ',' ' is' ' a']" , who was a key player in the team � � s success in the Champions League , is a False the first half when Nigel de Jong scored, ten minutes 6 [' the', ' first', ' half', ' when', ' Nigel', ' de', ' Jong']
+28 15 In their sport, the position played by x -1 In their sport, the position played by Antonio Valencia midfielder Antonio Valencia "[',' ' who' ' has' ' been' ' a' ' revelation' ' for' ' the' ' Red'
+ ' Devils' ' since' ' his' ' arrival' ' from' ' Valencia' ' in' ' the'
+ ' summer' ' of' ' 2015']" , who has been a revelation for the Red Devils since his arrival from Valencia in the summer of 2015 False Bullard and Antonio Valencia to Jewell. He followed 4 [' Bull', 'ard', ' and', ' Antonio', ' Valencia']
+29 15 In their sport, the position played by x -1 In their sport, the position played by Antonio Valencia midfielder Antonio Valencia "[',' ' who' ' has' ' been' ' a' ' revelation' ' for' ' the' ' Red'
+ ' Devils' ' since' ' his' ' arrival' ' from' ' Valencia' ' in' ' the'
+ ' summer' ' of' ' 2015']" , who has been a revelation for the Red Devils since his arrival from Valencia in the summer of 2015 False Jimmy Bullard and Antonio Valencia to Jewell. He followed 5 [' Jimmy', ' Bull', 'ard', ' and', ' Antonio', ' Valencia']
+30 15 In their sport, the position played by x -1 In their sport, the position played by Antonio Valencia midfielder Antonio Valencia "[',' ' who' ' has' ' been' ' a' ' revelation' ' for' ' the' ' Red'
+ ' Devils' ' since' ' his' ' arrival' ' from' ' Valencia' ' in' ' the'
+ ' summer' ' of' ' 2015']" , who has been a revelation for the Red Devils since his arrival from Valencia in the summer of 2015 False Jimmy Bullard and Antonio Valencia to Jewell. He followed 5 [' Jimmy', ' Bull', 'ard', ' and', ' Antonio', ' Valencia']
+31 16 In their sport, the position played by x -1 In their sport, the position played by Tarvaris Jackson quarterback Tarvaris Jackson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' He' ' is' ' the' ' quarterback'
+ ' of' ' the' ' team']" , the quarterback , is the most important position on the field . He is the quarterback of the team True replaced by Tarvaris Jackson who led the Vikings 5 [' replaced', ' by', ' Tar', 'var', 'is', ' Jackson']
+32 16 In their sport, the position played by x -1 In their sport, the position played by Tarvaris Jackson quarterback Tarvaris Jackson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' He' ' is' ' the' ' quarterback'
+ ' of' ' the' ' team']" , the quarterback , is the most important position on the field . He is the quarterback of the team True replaced by Tarvaris Jackson who led the Vikings 5 [' replaced', ' by', ' Tar', 'var', 'is', ' Jackson']
+33 17 In their sport, the position played by x -1 In their sport, the position played by Erik Ersberg goaltender Erik Ersberg "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' a' ' very']" , who is a former professional soccer player , is a bit of a mystery . He is a very False Danny Taylor and Erik Ersberg were given more opportunity 6 [' Danny', ' Taylor', ' and', ' Erik', ' E', 'rs', 'berg']
+34 17 In their sport, the position played by x -1 In their sport, the position played by Erik Ersberg goaltender Erik Ersberg "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' a' ' very']" , who is a former professional soccer player , is a bit of a mystery . He is a very False Danny Taylor and Erik Ersberg were given more opportunity 6 [' Danny', ' Taylor', ' and', ' Erik', ' E', 'rs', 'berg']
+35 17 In their sport, the position played by x -1 In their sport, the position played by Erik Ersberg goaltender Erik Ersberg "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' ','
+ ' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' a' ' very']" , who is a former professional soccer player , is a bit of a mystery . He is a very False Danny Taylor and Erik Ersberg were given more 6 [' Danny', ' Taylor', ' and', ' Erik', ' E', 'rs', 'berg']
+36 23 In their sport, the position played by x -1 In their sport, the position played by Frank Brimsek goaltender Frank Brimsek "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Frank' ' Br' 'im' 'se' 'k' ',' ' the' '\n']" ", the
+
+ In their sport , the position played by Frank Br im se k , the
+" False goaltender and tying Frank Brimsek for most career 7 [' goaltender', ' and', ' tying', ' Frank', ' Br', 'im', 'se', 'k']
+37 23 In their sport, the position played by x -1 In their sport, the position played by Frank Brimsek goaltender Frank Brimsek "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Frank' ' Br' 'im' 'se' 'k' ',' ' the' '\n']" ", the
+
+ In their sport , the position played by Frank Br im se k , the
+" False shutout tied him with Frank Brimsek for the most shutouts 9 [' shut', 'out', ' tied', ' him', ' with', ' Frank', ' Br', 'im', 'se', 'k']
+38 23 In their sport, the position played by x -1 In their sport, the position played by Frank Brimsek goaltender Frank Brimsek "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Frank' ' Br' 'im' 'se' 'k' ',' ' the' '\n']" ", the
+
+ In their sport , the position played by Frank Br im se k , the
+" False goaltender and tying Frank Brimsek for most career shutouts 7 [' goaltender', ' and', ' tying', ' Frank', ' Br', 'im', 'se', 'k']
+39 23 In their sport, the position played by x -1 In their sport, the position played by Frank Brimsek goaltender Frank Brimsek "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Frank' ' Br' 'im' 'se' 'k' ',' ' the' '\n']" ", the
+
+ In their sport , the position played by Frank Br im se k , the
+" False tied him with Frank Brimsek for the most shutouts 7 [' tied', ' him', ' with', ' Frank', ' Br', 'im', 'se', 'k']
+40 23 In their sport, the position played by x -1 In their sport, the position played by Frank Brimsek goaltender Frank Brimsek "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Frank' ' Br' 'im' 'se' 'k' ',' ' the' '\n']" ", the
+
+ In their sport , the position played by Frank Br im se k , the
+" False " Frank Brimsek =
+" 4 [' Frank', ' Br', 'im', 'se', 'k']
+41 25 In their sport, the position played by x -1 In their sport, the position played by Damon Huard quarterback Damon Huard "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of'
+ ' player']" , who was a quarterback at the University of Miami , is a perfect example of the type of player True Titans, but both he and Damon Huard suffered season-ending 8 [' Titans', ',', ' but', ' both', ' he', ' and', ' Damon', ' Hu', 'ard']
+42 25 In their sport, the position played by x -1 In their sport, the position played by Damon Huard quarterback Damon Huard "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of'
+ ' player']" , who was a quarterback at the University of Miami , is a perfect example of the type of player True door open for either Damon Huard or Brodie Croyle 6 [' door', ' open', ' for', ' either', ' Damon', ' Hu', 'ard']
+43 25 In their sport, the position played by x -1 In their sport, the position played by Damon Huard quarterback Damon Huard "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of'
+ ' player']" , who was a quarterback at the University of Miami , is a perfect example of the type of player True open for either Damon Huard or Brodie 5 [' open', ' for', ' either', ' Damon', ' Hu', 'ard']
+44 25 In their sport, the position played by x -1 In their sport, the position played by Damon Huard quarterback Damon Huard "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of'
+ ' player']" , who was a quarterback at the University of Miami , is a perfect example of the type of player True quarterback behind Damon Huard and Jim Druckenmiller. 4 [' quarterback', ' behind', ' Damon', ' Hu', 'ard']
+45 25 In their sport, the position played by x -1 In their sport, the position played by Damon Huard quarterback Damon Huard "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of'
+ ' player']" , who was a quarterback at the University of Miami , is a perfect example of the type of player True squad to back up Damon Huard and Tyler 6 [' squad', ' to', ' back', ' up', ' Damon', ' Hu', 'ard']
+46 28 In their sport, the position played by x -1 In their sport, the position played by Mark Herrmann quarterback Mark Herrmann "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' backup' ' quarterback' ' for'
+ ' the']" , the former NFL quarterback , is a bit of a mystery . He was a backup quarterback for the True Flick was cut and Mark Herrmann named the primary backup, 8 [' Fl', 'ick', ' was', ' cut', ' and', ' Mark', ' Her', 'r', 'mann']
+47 28 In their sport, the position played by x -1 In their sport, the position played by Mark Herrmann quarterback Mark Herrmann "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' backup' ' quarterback' ' for'
+ ' the']" , the former NFL quarterback , is a bit of a mystery . He was a backup quarterback for the True Flick was cut and Mark Herrmann named the primary backup, 8 [' Fl', 'ick', ' was', ' cut', ' and', ' Mark', ' Her', 'r', 'mann']
+48 31 In their sport, the position played by x -1 In their sport, the position played by Martin Brodeur goaltender Martin Brodeur "[',' ' who' ' has' ' been' ' the' ' Devils' ""'"" ' goalie' ' for' ' the'
+ ' past' ' decade' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Devils ' goalie for the past decade , is a bit of a mystery . False first NHL goal against Martin Brodeur in a 6 – 1 win against 7 [' first', ' NHL', ' goal', ' against', ' Martin', ' Bro', 'de', 'ur']
+49 31 In their sport, the position played by x -1 In their sport, the position played by Martin Brodeur goaltender Martin Brodeur "[',' ' who' ' has' ' been' ' the' ' Devils' ""'"" ' goalie' ' for' ' the'
+ ' past' ' decade' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Devils ' goalie for the past decade , is a bit of a mystery . False " Martin Brodeur =
+" 3 [' Martin', ' Bro', 'de', 'ur']
+50 31 In their sport, the position played by x -1 In their sport, the position played by Martin Brodeur goaltender Martin Brodeur "[',' ' who' ' has' ' been' ' the' ' Devils' ""'"" ' goalie' ' for' ' the'
+ ' past' ' decade' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Devils ' goalie for the past decade , is a bit of a mystery . False as backup to Martin Brodeur of the New 6 [' as', ' backup', ' to', ' Martin', ' Bro', 'de', 'ur']
+51 31 In their sport, the position played by x -1 In their sport, the position played by Martin Brodeur goaltender Martin Brodeur "[',' ' who' ' has' ' been' ' the' ' Devils' ""'"" ' goalie' ' for' ' the'
+ ' past' ' decade' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Devils ' goalie for the past decade , is a bit of a mystery . False " Martin Brodeur =
+" 3 [' Martin', ' Bro', 'de', 'ur']
+52 31 In their sport, the position played by x -1 In their sport, the position played by Martin Brodeur goaltender Martin Brodeur "[',' ' who' ' has' ' been' ' the' ' Devils' ""'"" ' goalie' ' for' ' the'
+ ' past' ' decade' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Devils ' goalie for the past decade , is a bit of a mystery . False against goaltender Martin Brodeur of the New Jersey 5 [' against', ' goaltender', ' Martin', ' Bro', 'de', 'ur']
+53 33 In their sport, the position played by x -1 In their sport, the position played by Trey DePriest linebacker Trey DePriest "[',' ' who' ' was' ' a' ' freshman' ' at' ' the' ' time' ',' ' was' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' was']" ", who was a freshman at the time , was to be the quarterback .
+
+ "" I was" False fourth quarter, Trey DePriest both forced and recovered 6 [' fourth', ' quarter', ',', ' Trey', ' De', 'Pri', 'est']
+54 33 In their sport, the position played by x -1 In their sport, the position played by Trey DePriest linebacker Trey DePriest "[',' ' who' ' was' ' a' ' freshman' ' at' ' the' ' time' ',' ' was' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' was']" ", who was a freshman at the time , was to be the quarterback .
+
+ "" I was" False was recovered by Trey DePriest at the Tigers' 6 [' was', ' recovered', ' by', ' Trey', ' De', 'Pri', 'est']
+55 33 In their sport, the position played by x -1 In their sport, the position played by Trey DePriest linebacker Trey DePriest "[',' ' who' ' was' ' a' ' freshman' ' at' ' the' ' time' ',' ' was' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' was']" ", who was a freshman at the time , was to be the quarterback .
+
+ "" I was" False Christion Jones, Trey DePriest and T. J. Yeldon 7 [' Christ', 'ion', ' Jones', ',', ' Trey', ' De', 'Pri', 'est']
+56 33 In their sport, the position played by x -1 In their sport, the position played by Trey DePriest linebacker Trey DePriest "[',' ' who' ' was' ' a' ' freshman' ' at' ' the' ' time' ',' ' was' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' was']" ", who was a freshman at the time , was to be the quarterback .
+
+ "" I was" False Christion Jones, Trey DePriest and T. J. Yeldon 7 [' Christ', 'ion', ' Jones', ',', ' Trey', ' De', 'Pri', 'est']
+57 33 In their sport, the position played by x -1 In their sport, the position played by Trey DePriest linebacker Trey DePriest "[',' ' who' ' was' ' a' ' freshman' ' at' ' the' ' time' ',' ' was' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' was']" ", who was a freshman at the time , was to be the quarterback .
+
+ "" I was" False recovered by Trey DePriest at the Tigers' 5 [' recovered', ' by', ' Trey', ' De', 'Pri', 'est']
+58 36 In their sport, the position played by x -1 In their sport, the position played by Jonas Mouton linebacker Jonas Mouton "[',' ' the' ' French' ' player' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' one' ' who' ' can' ' make' ' the']" , the French player , is a very important one . He is the only one who can make the False " = Jonas Mouton =
+" 4 [' =', ' Jonas', ' M', 'out', 'on']
+59 36 In their sport, the position played by x -1 In their sport, the position played by Jonas Mouton linebacker Jonas Mouton "[',' ' the' ' French' ' player' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' one' ' who' ' can' ' make' ' the']" , the French player , is a very important one . He is the only one who can make the False State to 46. When Jonas Mouton was unavailable for 8 [' State', ' to', ' 46', '.', ' When', ' Jonas', ' M', 'out', 'on']
+60 36 In their sport, the position played by x -1 In their sport, the position played by Jonas Mouton linebacker Jonas Mouton "[',' ' the' ' French' ' player' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' one' ' who' ' can' ' make' ' the']" , the French player , is a very important one . He is the only one who can make the False 4 ['Jon', 'as', ' M', 'out', 'on']
+61 36 In their sport, the position played by x -1 In their sport, the position played by Jonas Mouton linebacker Jonas Mouton "[',' ' the' ' French' ' player' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' one' ' who' ' can' ' make' ' the']" , the French player , is a very important one . He is the only one who can make the False 4 ['Jon', 'as', ' M', 'out', 'on']
+62 36 In their sport, the position played by x -1 In their sport, the position played by Jonas Mouton linebacker Jonas Mouton "[',' ' the' ' French' ' player' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' one' ' who' ' can' ' make' ' the']" , the French player , is a very important one . He is the only one who can make the False to teammate Jonas Mouton in total tackles. 5 [' to', ' teammate', ' Jonas', ' M', 'out', 'on']
+63 37 In their sport, the position played by x -1 In their sport, the position played by Stan Gelbaugh quarterback Stan Gelbaugh "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' quarterback' ' at' ' the'
+ ' University']" , the former NFL quarterback , is a bit of a mystery . He was a quarterback at the University True Reich replaced Stan Gelbaugh and proceeded 4 [' Reich', ' replaced', ' Stan', ' Gel', 'baugh']
+64 37 In their sport, the position played by x -1 In their sport, the position played by Stan Gelbaugh quarterback Stan Gelbaugh "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' quarterback' ' at' ' the'
+ ' University']" , the former NFL quarterback , is a bit of a mystery . He was a quarterback at the University True Frank Reich replaced Stan Gelbaugh and proceeded to 5 [' Frank', ' Reich', ' replaced', ' Stan', ' Gel', 'baugh']
+65 37 In their sport, the position played by x -1 In their sport, the position played by Stan Gelbaugh quarterback Stan Gelbaugh "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' quarterback' ' at' ' the'
+ ' University']" , the former NFL quarterback , is a bit of a mystery . He was a quarterback at the University True Reich replaced Stan Gelbaugh and proceeded to 4 [' Reich', ' replaced', ' Stan', ' Gel', 'baugh']
+66 37 In their sport, the position played by x -1 In their sport, the position played by Stan Gelbaugh quarterback Stan Gelbaugh "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' quarterback' ' at' ' the'
+ ' University']" , the former NFL quarterback , is a bit of a mystery . He was a quarterback at the University True Reich replaced Stan Gelbaugh and proceeded to 4 [' Reich', ' replaced', ' Stan', ' Gel', 'baugh']
+67 39 In their sport, the position played by x -1 In their sport, the position played by Teddy Bridgewater quarterback Teddy Bridgewater "[',' ' who' ' is' ' a' ' rookie' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' has' ' a' ' strong' ' arm' ',' ' but']" , who is a rookie , is a bit of a mystery . He has a strong arm , but False " behind starter Teddy Bridgewater and backup Shaun Hill.
+" 3 [' behind', ' starter', ' Teddy', ' Bridgewater']
+68 41 In their sport, the position played by x -1 In their sport, the position played by Glenn Hoddle midfielder Glenn Hoddle "[',' ' the' ' former' ' England' ' manager' ',' ' is' ' a' ' role' ' that'
+ ' is' ' not' ' diss' 'imilar' ' to' ' that' ' of' ' a' ' coach' '.']" , the former England manager , is a role that is not diss imilar to that of a coach . False Neil Duncanson, the Glenn Hoddle Academy's commercial 9 [' Neil', ' Dun', 'c', 'anson', ',', ' the', ' Glenn', ' H', 'odd', 'le']
+69 41 In their sport, the position played by x -1 In their sport, the position played by Glenn Hoddle midfielder Glenn Hoddle "[',' ' the' ' former' ' England' ' manager' ',' ' is' ' a' ' role' ' that'
+ ' is' ' not' ' diss' 'imilar' ' to' ' that' ' of' ' a' ' coach' '.']" , the former England manager , is a role that is not diss imilar to that of a coach . False England manager Glenn Hoddle appointed Shearer 5 [' England', ' manager', ' Glenn', ' H', 'odd', 'le']
+70 41 In their sport, the position played by x -1 In their sport, the position played by Glenn Hoddle midfielder Glenn Hoddle "[',' ' the' ' former' ' England' ' manager' ',' ' is' ' a' ' role' ' that'
+ ' is' ' not' ' diss' 'imilar' ' to' ' that' ' of' ' a' ' coach' '.']" , the former England manager , is a role that is not diss imilar to that of a coach . False Wolverhampton Wanderers manager Glenn Hoddle accused Carlisle 9 [' Wolver', 'ham', 'pton', ' Wand', 'erers', ' manager', ' Glenn', ' H', 'odd', 'le']
+71 41 In their sport, the position played by x -1 In their sport, the position played by Glenn Hoddle midfielder Glenn Hoddle "[',' ' the' ' former' ' England' ' manager' ',' ' is' ' a' ' role' ' that'
+ ' is' ' not' ' diss' 'imilar' ' to' ' that' ' of' ' a' ' coach' '.']" , the former England manager , is a role that is not diss imilar to that of a coach . False season were former Glenn Hoddle Academy youngsters 6 [' season', ' were', ' former', ' Glenn', ' H', 'odd', 'le']
+72 41 In their sport, the position played by x -1 In their sport, the position played by Glenn Hoddle midfielder Glenn Hoddle "[',' ' the' ' former' ' England' ' manager' ',' ' is' ' a' ' role' ' that'
+ ' is' ' not' ' diss' 'imilar' ' to' ' that' ' of' ' a' ' coach' '.']" , the former England manager , is a role that is not diss imilar to that of a coach . False Duncanson, the Glenn Hoddle Academy's 8 [' Dun', 'c', 'anson', ',', ' the', ' Glenn', ' H', 'odd', 'le']
+73 45 In their sport, the position played by x -1 In their sport, the position played by Robert Griffin III quarterback Robert Griffin III "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' '\n' '\n' 'The' ' Redskins' ' are' ' a' ' team' ' that' ' is']" " is similar to that of the quarterback in the NFL .
+
+ The Redskins are a team that is" True Redskins quarterback Robert Griffin III and the read option. 4 [' Redskins', ' quarterback', ' Robert', ' Griffin', ' III']
+74 54 In their sport, the position played by x -1 In their sport, the position played by Marcus Mariota quarterback Marcus Mariota "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' '\n' '\n' 'The' ' Titans' ' are' ' a' ' team' ' that' ' is']" " is similar to that of the quarterback in the NFL .
+
+ The Titans are a team that is" True roster behind Marcus Mariota and Zach Mettenberger. 4 [' roster', ' behind', ' Marcus', ' Mari', 'ota']
+75 54 In their sport, the position played by x -1 In their sport, the position played by Marcus Mariota quarterback Marcus Mariota "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' '\n' '\n' 'The' ' Titans' ' are' ' a' ' team' ' that' ' is']" " is similar to that of the quarterback in the NFL .
+
+ The Titans are a team that is" True roster behind Marcus Mariota and Zach Mettenberger. 4 [' roster', ' behind', ' Marcus', ' Mari', 'ota']
+76 55 In their sport, the position played by x -1 In their sport, the position played by Devan Dubnyk goaltender Devan Dubnyk "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' goaltender' ' who' ' has' ' been' ' a' ' starter' ' for' ' the']" is a bit of a mystery . He � � s a goaltender who has been a starter for the True overtime win against Devan Dubnyk of the Edmonton Oilers. 7 [' overtime', ' win', ' against', ' Dev', 'an', ' Dub', 'ny', 'k']
+77 55 In their sport, the position played by x -1 In their sport, the position played by Devan Dubnyk goaltender Devan Dubnyk "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' goaltender' ' who' ' has' ' been' ' a' ' starter' ' for' ' the']" is a bit of a mystery . He � � s a goaltender who has been a starter for the True overtime win against Devan Dubnyk of the Edmonton 7 [' overtime', ' win', ' against', ' Dev', 'an', ' Dub', 'ny', 'k']
+78 55 In their sport, the position played by x -1 In their sport, the position played by Devan Dubnyk goaltender Devan Dubnyk "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' goaltender' ' who' ' has' ' been' ' a' ' starter' ' for' ' the']" is a bit of a mystery . He � � s a goaltender who has been a starter for the True win against Devan Dubnyk of the Edmonton 6 [' win', ' against', ' Dev', 'an', ' Dub', 'ny', 'k']
+79 55 In their sport, the position played by x -1 In their sport, the position played by Devan Dubnyk goaltender Devan Dubnyk "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' goaltender' ' who' ' has' ' been' ' a' ' starter' ' for' ' the']" is a bit of a mystery . He � � s a goaltender who has been a starter for the True overtime win against Devan Dubnyk of the Edmonton Oilers. 7 [' overtime', ' win', ' against', ' Dev', 'an', ' Dub', 'ny', 'k']
+80 55 In their sport, the position played by x -1 In their sport, the position played by Devan Dubnyk goaltender Devan Dubnyk "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' goaltender' ' who' ' has' ' been' ' a' ' starter' ' for' ' the']" is a bit of a mystery . He � � s a goaltender who has been a starter for the True overtime win against Devan Dubnyk of the Edmonton Oilers. 7 [' overtime', ' win', ' against', ' Dev', 'an', ' Dub', 'ny', 'k']
+81 57 In their sport, the position played by x -1 In their sport, the position played by Andreas Ivanschitz midfielder Andreas Ivanschitz "[',' ' the' ' Austrian' ',' ' is' ' a' ' very' ' important' ' one' '.'
+ ' He' ' is' ' the' ' only' ' player' ' who' ' has' ' won' ' the' ' World']" , the Austrian , is a very important one . He is the only player who has won the World False international signings Andreas Ivanschitz and Román Torres 6 [' international', ' signings', ' Andreas', ' Iv', 'ans', 'ch', 'itz']
+82 57 In their sport, the position played by x -1 In their sport, the position played by Andreas Ivanschitz midfielder Andreas Ivanschitz "[',' ' the' ' Austrian' ',' ' is' ' a' ' very' ' important' ' one' '.'
+ ' He' ' is' ' the' ' only' ' player' ' who' ' has' ' won' ' the' ' World']" , the Austrian , is a very important one . He is the only player who has won the World False international signings Andreas Ivanschitz and Román 6 [' international', ' signings', ' Andreas', ' Iv', 'ans', 'ch', 'itz']
+83 57 In their sport, the position played by x -1 In their sport, the position played by Andreas Ivanschitz midfielder Andreas Ivanschitz "[',' ' the' ' Austrian' ',' ' is' ' a' ' very' ' important' ' one' '.'
+ ' He' ' is' ' the' ' only' ' player' ' who' ' has' ' won' ' the' ' World']" , the Austrian , is a very important one . He is the only player who has won the World False international signings Andreas Ivanschitz and Román Torres 6 [' international', ' signings', ' Andreas', ' Iv', 'ans', 'ch', 'itz']
+84 58 In their sport, the position played by x -1 In their sport, the position played by Clete Thomas outfielder Clete Thomas "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a']" , the former NFL running back , is a perfect example of the kind of player who can be a False Tigers inserted Clete Thomas for Ordóñez 4 [' Tigers', ' inserted', ' Cle', 'te', ' Thomas']
+85 61 In their sport, the position played by x -1 In their sport, the position played by Wesley Sneijder midfielder Wesley Sneijder "[',' ' the' ' Dutch' 'man' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the' ' world']" , the Dutch man is the best player in the world . He is the best player in the world False van der Vaart, Wesley Sneijder and Dirk Kuyt. 8 [' van', ' der', ' Va', 'art', ',', ' Wesley', ' Sne', 'ij', 'der']
+86 61 In their sport, the position played by x -1 In their sport, the position played by Wesley Sneijder midfielder Wesley Sneijder "[',' ' the' ' Dutch' 'man' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the' ' world']" , the Dutch man is the best player in the world . He is the best player in the world False van der Vaart, Wesley Sneijder and Dirk Kuyt. 8 [' van', ' der', ' Va', 'art', ',', ' Wesley', ' Sne', 'ij', 'der']
+87 61 In their sport, the position played by x -1 In their sport, the position played by Wesley Sneijder midfielder Wesley Sneijder "[',' ' the' ' Dutch' 'man' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the' ' world']" , the Dutch man is the best player in the world . He is the best player in the world False players, such as Wesley Sneijder and John Heitinga. 7 [' players', ',', ' such', ' as', ' Wesley', ' Sne', 'ij', 'der']
+88 61 In their sport, the position played by x -1 In their sport, the position played by Wesley Sneijder midfielder Wesley Sneijder "[',' ' the' ' Dutch' 'man' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the' ' world']" , the Dutch man is the best player in the world . He is the best player in the world False Galatasaray player Wesley Sneijder apologised 8 [' Gal', 'at', 'asar', 'ay', ' player', ' Wesley', ' Sne', 'ij', 'der']
+89 61 In their sport, the position played by x -1 In their sport, the position played by Wesley Sneijder midfielder Wesley Sneijder "[',' ' the' ' Dutch' 'man' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the' ' world']" , the Dutch man is the best player in the world . He is the best player in the world False der Vaart, Wesley Sneijder and Dirk Kuyt. 7 [' der', ' Va', 'art', ',', ' Wesley', ' Sne', 'ij', 'der']
+90 65 In their sport, the position played by x -1 In their sport, the position played by Johnny Lujack quarterback Johnny Lujack "[' in' ' the' ' film' ',' ' the' ' player' ' is' ' a' ' man' ' who' ' is'
+ ' a' ' bit' ' of' ' a' ' l' 'oner' ',' ' who' ' is']" in the film , the player is a man who is a bit of a l oner , who is False Finished second to Johnny Lujack in the Heisman 5 [' Finished', ' second', ' to', ' Johnny', ' Lu', 'jack']
+91 65 In their sport, the position played by x -1 In their sport, the position played by Johnny Lujack quarterback Johnny Lujack "[' in' ' the' ' film' ',' ' the' ' player' ' is' ' a' ' man' ' who' ' is'
+ ' a' ' bit' ' of' ' a' ' l' 'oner' ',' ' who' ' is']" in the film , the player is a man who is a bit of a l oner , who is False second to Johnny Lujack in the Heisman 4 [' second', ' to', ' Johnny', ' Lu', 'jack']
+92 71 In their sport, the position played by x -1 In their sport, the position played by Browning Nagle quarterback Browning Nagle "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False quarterback Browning Nagle in the team's 4 [' quarterback', ' Brown', 'ing', ' Nag', 'le']
+93 72 In their sport, the position played by x -1 In their sport, the position played by Earl Morrall quarterback Earl Morrall "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Earl' ' Mor' 'r' 'all' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Earl Mor r all , the
+
+" False Baltimore quarterback Earl Morrall later referred to Volk's 5 [' Baltimore', ' quarterback', ' Earl', ' Mor', 'r', 'all']
+94 72 In their sport, the position played by x -1 In their sport, the position played by Earl Morrall quarterback Earl Morrall "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Earl' ' Mor' 'r' 'all' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Earl Mor r all , the
+
+" False Baltimore quarterback Earl Morrall later referred to 5 [' Baltimore', ' quarterback', ' Earl', ' Mor', 'r', 'all']
+95 72 In their sport, the position played by x -1 In their sport, the position played by Earl Morrall quarterback Earl Morrall "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Earl' ' Mor' 'r' 'all' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Earl Mor r all , the
+
+" False of quarterback Earl Morrall during the offseason) 5 [' of', ' quarterback', ' Earl', ' Mor', 'r', 'all']
+96 72 In their sport, the position played by x -1 In their sport, the position played by Earl Morrall quarterback Earl Morrall "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Earl' ' Mor' 'r' 'all' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Earl Mor r all , the
+
+" False acquisition of quarterback Earl Morrall during the 6 [' acquisition', ' of', ' quarterback', ' Earl', ' Mor', 'r', 'all']
+97 72 In their sport, the position played by x -1 In their sport, the position played by Earl Morrall quarterback Earl Morrall "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Earl' ' Mor' 'r' 'all' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Earl Mor r all , the
+
+" False Baltimore quarterback Earl Morrall later referred to 5 [' Baltimore', ' quarterback', ' Earl', ' Mor', 'r', 'all']
+98 74 In their sport, the position played by x -1 In their sport, the position played by Mike Liut goaltender Mike Liut "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False Canadian goaltender Mike Liut became the scapegoat 4 [' Canadian', ' goaltender', ' Mike', ' Li', 'ut']
+99 74 In their sport, the position played by x -1 In their sport, the position played by Mike Liut goaltender Mike Liut "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False goaltender Mike Liut on the mask. The 3 [' goaltender', ' Mike', ' Li', 'ut']
+100 74 In their sport, the position played by x -1 In their sport, the position played by Mike Liut goaltender Mike Liut "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False Canadian goaltender Mike Liut became the 4 [' Canadian', ' goaltender', ' Mike', ' Li', 'ut']
+101 74 In their sport, the position played by x -1 In their sport, the position played by Mike Liut goaltender Mike Liut "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False struck goaltender Mike Liut on the mask. The 4 [' struck', ' goaltender', ' Mike', ' Li', 'ut']
+102 74 In their sport, the position played by x -1 In their sport, the position played by Mike Liut goaltender Mike Liut "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False Canadian goaltender Mike Liut became the scapegoat 4 [' Canadian', ' goaltender', ' Mike', ' Li', 'ut']
+103 75 In their sport, the position played by x -1 In their sport, the position played by Gabriel Heinze defender Gabriel Heinze "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ',' ' was' ' a' ' very' ' important' ' one' '.' ' He' ' was'
+ ' the']" , who was a member of the German national team , was a very important one . He was the False Silvestre and Gabriel Heinze as first-choice 6 [' Sil', 'vest', 're', ' and', ' Gabriel', ' Hein', 'ze']
+104 75 In their sport, the position played by x -1 In their sport, the position played by Gabriel Heinze defender Gabriel Heinze "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ',' ' was' ' a' ' very' ' important' ' one' '.' ' He' ' was'
+ ' the']" , who was a member of the German national team , was a very important one . He was the False starting full-backs, Gabriel Heinze (ankle) and 7 [' starting', ' full', '-', 'backs', ',', ' Gabriel', ' Hein', 'ze']
+105 75 In their sport, the position played by x -1 In their sport, the position played by Gabriel Heinze defender Gabriel Heinze "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ',' ' was' ' a' ' very' ' important' ' one' '.' ' He' ' was'
+ ' the']" , who was a member of the German national team , was a very important one . He was the False Ronaldo and Gabriel Heinze were absent, 4 [' Ronaldo', ' and', ' Gabriel', ' Hein', 'ze']
+106 75 In their sport, the position played by x -1 In their sport, the position played by Gabriel Heinze defender Gabriel Heinze "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ',' ' was' ' a' ' very' ' important' ' one' '.' ' He' ' was'
+ ' the']" , who was a member of the German national team , was a very important one . He was the False Mikaël Silvestre and Gabriel Heinze as first-choice 10 [' M', 'ika', 'ë', 'l', ' Sil', 'vest', 're', ' and', ' Gabriel', ' Hein', 'ze']
+107 75 In their sport, the position played by x -1 In their sport, the position played by Gabriel Heinze defender Gabriel Heinze "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ',' ' was' ' a' ' very' ' important' ' one' '.' ' He' ' was'
+ ' the']" , who was a member of the German national team , was a very important one . He was the False Mikaël Silvestre and Gabriel Heinze as first-choice 10 [' M', 'ika', 'ë', 'l', ' Sil', 'vest', 're', ' and', ' Gabriel', ' Hein', 'ze']
+108 76 In their sport, the position played by x -1 In their sport, the position played by Joe Torre catcher Joe Torre "[',' ' the' ' Yankees' ""'"" ' manager' ',' ' is' ' a' ' bit' ' like'
+ ' that' ' of' ' a' ' coach' ' in' ' a' ' football' ' game' '.' ' He']" , the Yankees ' manager , is a bit like that of a coach in a football game . He False fans. Yankees manager Joe Torre returned to Fenway 6 [' fans', '.', ' Yankees', ' manager', ' Joe', ' Tor', 're']
+109 76 In their sport, the position played by x -1 In their sport, the position played by Joe Torre catcher Joe Torre "[',' ' the' ' Yankees' ""'"" ' manager' ',' ' is' ' a' ' bit' ' like'
+ ' that' ' of' ' a' ' coach' ' in' ' a' ' football' ' game' '.' ' He']" , the Yankees ' manager , is a bit like that of a coach in a football game . He False Yankees'manager Joe Torre used Ramírez sparingly, 6 "[' Yankees', ""'m"", 'an', 'ager', ' Joe', ' Tor', 're']"
+110 76 In their sport, the position played by x -1 In their sport, the position played by Joe Torre catcher Joe Torre "[',' ' the' ' Yankees' ""'"" ' manager' ',' ' is' ' a' ' bit' ' like'
+ ' that' ' of' ' a' ' coach' ' in' ' a' ' football' ' game' '.' ' He']" , the Yankees ' manager , is a bit like that of a coach in a football game . He False 2 ['Joe', ' Tor', 're']
+111 76 In their sport, the position played by x -1 In their sport, the position played by Joe Torre catcher Joe Torre "[',' ' the' ' Yankees' ""'"" ' manager' ',' ' is' ' a' ' bit' ' like'
+ ' that' ' of' ' a' ' coach' ' in' ' a' ' football' ' game' '.' ' He']" , the Yankees ' manager , is a bit like that of a coach in a football game . He False 2 ['Joe', ' Tor', 're']
+112 76 In their sport, the position played by x -1 In their sport, the position played by Joe Torre catcher Joe Torre "[',' ' the' ' Yankees' ""'"" ' manager' ',' ' is' ' a' ' bit' ' like'
+ ' that' ' of' ' a' ' coach' ' in' ' a' ' football' ' game' '.' ' He']" , the Yankees ' manager , is a bit like that of a coach in a football game . He False baseball manager Joe Torre and hockey player Keith 4 [' baseball', ' manager', ' Joe', ' Tor', 're']
+113 78 In their sport, the position played by x -1 In their sport, the position played by Jim Magilton midfielder Jim Magilton "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Liverpool' ' and' ' England'
+ ' midfielder' ' is']" ", the former England international , is a key one .
+
+ The former Liverpool and England midfielder is" True 5 June 2006, Jim Magilton was officially named 6 [' 5', ' June', ' 2006', ',', ' Jim', ' Mag', 'ilton']
+114 78 In their sport, the position played by x -1 In their sport, the position played by Jim Magilton midfielder Jim Magilton "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Liverpool' ' and' ' England'
+ ' midfielder' ' is']" ", the former England international , is a key one .
+
+ The former Liverpool and England midfielder is" True at Elland Road after Jim Magilton scored the winning 7 [' at', ' Ell', 'and', ' Road', ' after', ' Jim', ' Mag', 'ilton']
+115 78 In their sport, the position played by x -1 In their sport, the position played by Jim Magilton midfielder Jim Magilton "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Liverpool' ' and' ' England'
+ ' midfielder' ' is']" ", the former England international , is a key one .
+
+ The former Liverpool and England midfielder is" True Town manager Jim Magilton commented, 4 [' Town', ' manager', ' Jim', ' Mag', 'ilton']
+116 78 In their sport, the position played by x -1 In their sport, the position played by Jim Magilton midfielder Jim Magilton "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Liverpool' ' and' ' England'
+ ' midfielder' ' is']" ", the former England international , is a key one .
+
+ The former Liverpool and England midfielder is" True Town manager Jim Magilton commented, 4 [' Town', ' manager', ' Jim', ' Mag', 'ilton']
+117 78 In their sport, the position played by x -1 In their sport, the position played by Jim Magilton midfielder Jim Magilton "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Liverpool' ' and' ' England'
+ ' midfielder' ' is']" ", the former England international , is a key one .
+
+ The former Liverpool and England midfielder is" True Elland Road after Jim Magilton scored the winning 6 [' Ell', 'and', ' Road', ' after', ' Jim', ' Mag', 'ilton']
+118 80 In their sport, the position played by x -1 In their sport, the position played by Jim Hardy quarterback Jim Hardy "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False Leo Jonathan and Jim Hardy and then split up. 4 [' Leo', ' Jonathan', ' and', ' Jim', ' Hardy']
+119 80 In their sport, the position played by x -1 In their sport, the position played by Jim Hardy quarterback Jim Hardy "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False Don Leo Jonathan and Jim Hardy and then split 5 [' Don', ' Leo', ' Jonathan', ' and', ' Jim', ' Hardy']
+120 81 In their sport, the position played by x -1 In their sport, the position played by John Elway quarterback John Elway "[',' ' the' ' Broncos' ""'"" ' quarterback' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' quarterback' ' of' ' the' ' New'
+ ' York']" , the Broncos ' quarterback , is the same as the one played by the quarterback of the New York True sneakiness about it; John Elway had simply shown 7 [' sneak', 'iness', ' about', ' it', ';', ' John', ' El', 'way']
+121 81 In their sport, the position played by x -1 In their sport, the position played by John Elway quarterback John Elway "[',' ' the' ' Broncos' ""'"" ' quarterback' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' quarterback' ' of' ' the' ' New'
+ ' York']" , the Broncos ' quarterback , is the same as the one played by the quarterback of the New York True quarterback draw by John Elway to make the score 5 [' quarterback', ' draw', ' by', ' John', ' El', 'way']
+122 81 In their sport, the position played by x -1 In their sport, the position played by John Elway quarterback John Elway "[',' ' the' ' Broncos' ""'"" ' quarterback' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' quarterback' ' of' ' the' ' New'
+ ' York']" , the Broncos ' quarterback , is the same as the one played by the quarterback of the New York True who had coveted John Elway when he was Colts 5 [' who', ' had', ' coveted', ' John', ' El', 'way']
+123 81 In their sport, the position played by x -1 In their sport, the position played by John Elway quarterback John Elway "[',' ' the' ' Broncos' ""'"" ' quarterback' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' quarterback' ' of' ' the' ' New'
+ ' York']" , the Broncos ' quarterback , is the same as the one played by the quarterback of the New York True quarterback John Elway as he sprinted out 3 [' quarterback', ' John', ' El', 'way']
+124 81 In their sport, the position played by x -1 In their sport, the position played by John Elway quarterback John Elway "[',' ' the' ' Broncos' ""'"" ' quarterback' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' quarterback' ' of' ' the' ' New'
+ ' York']" , the Broncos ' quarterback , is the same as the one played by the quarterback of the New York True quarterback draw by John Elway to make the 5 [' quarterback', ' draw', ' by', ' John', ' El', 'way']
+125 88 In their sport, the position played by x -1 In their sport, the position played by Mike Garcia pitcher Mike Garcia "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the type of player that the NFL False Early Wynn (23) and Mike Garcia (22) as part 8 [' Early', ' Wyn', 'n', ' (', '23', ')', ' and', ' Mike', ' Garcia']
+126 88 In their sport, the position played by x -1 In their sport, the position played by Mike Garcia pitcher Mike Garcia "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the type of player that the NFL False born 1923) 3 [' born', ' 1923', 'Mike', ' Garcia']
+127 88 In their sport, the position played by x -1 In their sport, the position played by Mike Garcia pitcher Mike Garcia "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the type of player that the NFL False a shame that Mike Garcia is sometimes 4 [' a', ' shame', ' that', ' Mike', ' Garcia']
+128 88 In their sport, the position played by x -1 In their sport, the position played by Mike Garcia pitcher Mike Garcia "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the type of player that the NFL False Lemon, Bob Feller, Mike Garcia and Early Wynn. 7 [' Lemon', ',', ' Bob', ' F', 'eller', ',', ' Mike', ' Garcia']
+129 88 In their sport, the position played by x -1 In their sport, the position played by Mike Garcia pitcher Mike Garcia "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the type of player that the NFL False Early Wynn (23) and Mike Garcia (22) as part 8 [' Early', ' Wyn', 'n', ' (', '23', ')', ' and', ' Mike', ' Garcia']
+130 93 In their sport, the position played by x -1 In their sport, the position played by Derrick Brooks linebacker Derrick Brooks "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a']" , the former NFL running back , is a perfect example of the kind of player who can make a False clear that Derrick Brooks would not be 3 [' clear', ' that', ' Derrick', ' Brooks']
+131 93 In their sport, the position played by x -1 In their sport, the position played by Derrick Brooks linebacker Derrick Brooks "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a']" , the former NFL running back , is a perfect example of the kind of player who can make a False fairly clear that Derrick Brooks would not be 4 [' fairly', ' clear', ' that', ' Derrick', ' Brooks']
+132 93 In their sport, the position played by x -1 In their sport, the position played by Derrick Brooks linebacker Derrick Brooks "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a']" , the former NFL running back , is a perfect example of the kind of player who can make a False 11-time Pro Bowler Derrick Brooks from the lineup. 7 [' 11', '-', 'time', ' Pro', ' Bow', 'ler', ' Derrick', ' Brooks']
+133 93 In their sport, the position played by x -1 In their sport, the position played by Derrick Brooks linebacker Derrick Brooks "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a']" , the former NFL running back , is a perfect example of the kind of player who can make a False clear that Derrick Brooks would not be 3 [' clear', ' that', ' Derrick', ' Brooks']
+134 95 In their sport, the position played by x -1 In their sport, the position played by Bob Allison outfielder Bob Allison "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False Killebrew and Bob Allison became the first 5 [' K', 'ille', 'brew', ' and', ' Bob', ' Allison']
+135 95 In their sport, the position played by x -1 In their sport, the position played by Bob Allison outfielder Bob Allison "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False Killebrew and Bob Allison became the first 5 [' K', 'ille', 'brew', ' and', ' Bob', ' Allison']
+136 96 In their sport, the position played by x -1 In their sport, the position played by Francis Coquelin midfielder Francis Coquelin "[',' ' who' ' is' ' a' ' very' ' good' ' player' ',' ' but' ' not' ' a'
+ ' great' ' one' '.' ' He' ' is' ' a' ' good' ' passer' ',']" , who is a very good player , but not a great one . He is a good passer , False Thibaut Courtois. Francis Coquelin conceded a free 9 [' Thib', 'aut', ' Cour', 'to', 'is', '.', ' Francis', ' Co', 'qu', 'elin']
+137 96 In their sport, the position played by x -1 In their sport, the position played by Francis Coquelin midfielder Francis Coquelin "[',' ' who' ' is' ' a' ' very' ' good' ' player' ',' ' but' ' not' ' a'
+ ' great' ' one' '.' ' He' ' is' ' a' ' good' ' passer' ',']" , who is a very good player , but not a great one . He is a good passer , False Thibaut Courtois. Francis Coquelin conceded a 9 [' Thib', 'aut', ' Cour', 'to', 'is', '.', ' Francis', ' Co', 'qu', 'elin']
+138 96 In their sport, the position played by x -1 In their sport, the position played by Francis Coquelin midfielder Francis Coquelin "[',' ' who' ' is' ' a' ' very' ' good' ' player' ',' ' but' ' not' ' a'
+ ' great' ' one' '.' ' He' ' is' ' a' ' good' ' passer' ',']" , who is a very good player , but not a great one . He is a good passer , False Courtois. Francis Coquelin conceded a free 7 [' Cour', 'to', 'is', '.', ' Francis', ' Co', 'qu', 'elin']
+139 97 In their sport, the position played by x -1 In their sport, the position played by Gary Kubiak quarterback Gary Kubiak "[',' ' the' ' Texans' ""'"" ' offensive' ' coordinator' ',' ' is' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' think' ' he' ""'s""]" ", the Texans ' offensive coordinator , is to be the quarterback .
+
+ "" I think he 's" True before head coach Gary Kubiak benched him during 5 [' before', ' head', ' coach', ' Gary', ' Kub', 'iak']
+140 97 In their sport, the position played by x -1 In their sport, the position played by Gary Kubiak quarterback Gary Kubiak "[',' ' the' ' Texans' ""'"" ' offensive' ' coordinator' ',' ' is' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' think' ' he' ""'s""]" ", the Texans ' offensive coordinator , is to be the quarterback .
+
+ "" I think he 's" True offensive coordinator Gary Kubiak and former Buffalo 4 [' offensive', ' coordinator', ' Gary', ' Kub', 'iak']
+141 97 In their sport, the position played by x -1 In their sport, the position played by Gary Kubiak quarterback Gary Kubiak "[',' ' the' ' Texans' ""'"" ' offensive' ' coordinator' ',' ' is' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' think' ' he' ""'s""]" ", the Texans ' offensive coordinator , is to be the quarterback .
+
+ "" I think he 's" True before head coach Gary Kubiak benched him 5 [' before', ' head', ' coach', ' Gary', ' Kub', 'iak']
+142 97 In their sport, the position played by x -1 In their sport, the position played by Gary Kubiak quarterback Gary Kubiak "[',' ' the' ' Texans' ""'"" ' offensive' ' coordinator' ',' ' is' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' think' ' he' ""'s""]" ", the Texans ' offensive coordinator , is to be the quarterback .
+
+ "" I think he 's" True before head coach Gary Kubiak benched him during 5 [' before', ' head', ' coach', ' Gary', ' Kub', 'iak']
+143 97 In their sport, the position played by x -1 In their sport, the position played by Gary Kubiak quarterback Gary Kubiak "[',' ' the' ' Texans' ""'"" ' offensive' ' coordinator' ',' ' is' ' to'
+ ' be' ' the' ' quarterback' '.' '\n' '\n' '""' 'I' ' think' ' he' ""'s""]" ", the Texans ' offensive coordinator , is to be the quarterback .
+
+ "" I think he 's" True coordinator Gary Kubiak and former Buffalo 3 [' coordinator', ' Gary', ' Kub', 'iak']
+144 106 In their sport, the position played by x -1 In their sport, the position played by Claudio Marchisio midfielder Claudio Marchisio "[',' ' who' ' is' ' now' ' at' ' Juventus' ',' ' is' ' that' ' of' ' a'
+ ' midfielder' ',' ' but' ' he' ' is' ' a' ' midfielder' ' who' ' is']" , who is now at Juventus , is that of a midfielder , but he is a midfielder who is True recently, Claudio Marchisio and Sebastian 6 [' recently', ',', ' Cl', 'audio', ' March', 'is', 'io']
+145 107 In their sport, the position played by x -1 In their sport, the position played by Dan Morgan linebacker Dan Morgan "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference False Michael Fentum and Dan Morgan were nominated 6 [' Michael', ' F', 'ent', 'um', ' and', ' Dan', ' Morgan']
+146 107 In their sport, the position played by x -1 In their sport, the position played by Dan Morgan linebacker Dan Morgan "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference False Michael Fentum and Dan Morgan were nominated 6 [' Michael', ' F', 'ent', 'um', ' and', ' Dan', ' Morgan']
+147 107 In their sport, the position played by x -1 In their sport, the position played by Dan Morgan linebacker Dan Morgan "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference False Fentum and Dan Morgan were nominated 5 [' F', 'ent', 'um', ' and', ' Dan', ' Morgan']
+148 111 In their sport, the position played by x -1 In their sport, the position played by Tim Hiller quarterback Tim Hiller "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1994']" , who was a member of the team that won the first World Cup in the United States in 1994 False " 73-yard TD pass from Tim Hiller to Juan Nunez.
+" 8 [' 73', '-', 'yard', ' TD', ' pass', ' from', ' Tim', ' Hill', 'er']
+149 111 In their sport, the position played by x -1 In their sport, the position played by Tim Hiller quarterback Tim Hiller "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1994']" , who was a member of the team that won the first World Cup in the United States in 1994 False " TD pass from Tim Hiller to Juan Nunez.
+" 5 [' TD', ' pass', ' from', ' Tim', ' Hill', 'er']
+150 112 In their sport, the position played by x -1 In their sport, the position played by Ron Jaworski quarterback Ron Jaworski "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous' ' as']" , who was a quarterback at the University of Miami , is a position that is not as glamorous as True NFL quarterback Ron Jaworski concurred, saying 5 [' NFL', ' quarterback', ' Ron', ' Jaw', 'ors', 'ki']
+151 112 In their sport, the position played by x -1 In their sport, the position played by Ron Jaworski quarterback Ron Jaworski "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous' ' as']" , who was a quarterback at the University of Miami , is a position that is not as glamorous as True president Ron Jaworski announced that 4 [' president', ' Ron', ' Jaw', 'ors', 'ki']
+152 112 In their sport, the position played by x -1 In their sport, the position played by Ron Jaworski quarterback Ron Jaworski "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous' ' as']" , who was a quarterback at the University of Miami , is a position that is not as glamorous as True shortcomings. Ron Jaworski commented 5 [' shortcomings', '.', ' Ron', ' Jaw', 'ors', 'ki']
+153 112 In their sport, the position played by x -1 In their sport, the position played by Ron Jaworski quarterback Ron Jaworski "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous' ' as']" , who was a quarterback at the University of Miami , is a position that is not as glamorous as True NFL quarterback Ron Jaworski concurred, saying 5 [' NFL', ' quarterback', ' Ron', ' Jaw', 'ors', 'ki']
+154 112 In their sport, the position played by x -1 In their sport, the position played by Ron Jaworski quarterback Ron Jaworski "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous' ' as']" , who was a quarterback at the University of Miami , is a position that is not as glamorous as True NFL quarterback Ron Jaworski concurred, 5 [' NFL', ' quarterback', ' Ron', ' Jaw', 'ors', 'ki']
+155 113 In their sport, the position played by x -1 In their sport, the position played by D'Qwell Jackson linebacker D'Qwell Jackson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False Roethlisberger was hit by D'Qwell Jackson and Willie 11 "[' Ro', 'eth', 'lis', 'berger', ' was', ' hit', ' by', ' D', ""'"", 'Q', 'well', ' Jackson']"
+156 113 In their sport, the position played by x -1 In their sport, the position played by D'Qwell Jackson linebacker D'Qwell Jackson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False Roethlisberger was hit by D'Qwell Jackson and Willie McGinist 11 "[' Ro', 'eth', 'lis', 'berger', ' was', ' hit', ' by', ' D', ""'"", 'Q', 'well', ' Jackson']"
+157 113 In their sport, the position played by x -1 In their sport, the position played by D'Qwell Jackson linebacker D'Qwell Jackson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False Roethlisberger was hit by D'Qwell Jackson and Willie McGinist 11 "[' Ro', 'eth', 'lis', 'berger', ' was', ' hit', ' by', ' D', ""'"", 'Q', 'well', ' Jackson']"
+158 113 In their sport, the position played by x -1 In their sport, the position played by D'Qwell Jackson linebacker D'Qwell Jackson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False Roethlisberger was hit by D'Qwell Jackson and Willie McGinist 11 "[' Ro', 'eth', 'lis', 'berger', ' was', ' hit', ' by', ' D', ""'"", 'Q', 'well', ' Jackson']"
+159 115 In their sport, the position played by x -1 In their sport, the position played by Francisco Liriano pitcher Francisco Liriano "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' has' ' a' ' career'
+ '.' '500' ' record' ' with' ' a' ' 3' '.' '86' ' ERA']" is a bit of a mystery . He has a career . 500 record with a 3 . 86 ERA False available starters – Francisco Liriano (who would 6 [' available', ' starters', ' –', ' Francisco', ' L', 'ir', 'iano']
+160 115 In their sport, the position played by x -1 In their sport, the position played by Francisco Liriano pitcher Francisco Liriano "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' has' ' a' ' career'
+ '.' '500' ' record' ' with' ' a' ' 3' '.' '86' ' ERA']" is a bit of a mystery . He has a career . 500 record with a 3 . 86 ERA False available starters – Francisco Liriano (who would have 6 [' available', ' starters', ' –', ' Francisco', ' L', 'ir', 'iano']
+161 115 In their sport, the position played by x -1 In their sport, the position played by Francisco Liriano pitcher Francisco Liriano "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' has' ' a' ' career'
+ '.' '500' ' record' ' with' ' a' ' 3' '.' '86' ' ERA']" is a bit of a mystery . He has a career . 500 record with a 3 . 86 ERA False starters – Francisco Liriano (who would have been 5 [' starters', ' –', ' Francisco', ' L', 'ir', 'iano']
+162 115 In their sport, the position played by x -1 In their sport, the position played by Francisco Liriano pitcher Francisco Liriano "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' has' ' a' ' career'
+ '.' '500' ' record' ' with' ' a' ' 3' '.' '86' ' ERA']" is a bit of a mystery . He has a career . 500 record with a 3 . 86 ERA False Bonser and Francisco Liriano for catcher A. J. Pierzynski 7 [' B', 'ons', 'er', ' and', ' Francisco', ' L', 'ir', 'iano']
+163 116 In their sport, the position played by x -1 In their sport, the position played by Mo Lewis linebacker Mo Lewis "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' ' have']" in the first half of the season , the team has been a bit of a mess . They have False linebacker Mo Lewis tackled Drew Bledsoe, 2 [' linebacker', ' Mo', ' Lewis']
+164 116 In their sport, the position played by x -1 In their sport, the position played by Mo Lewis linebacker Mo Lewis "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' ' have']" in the first half of the season , the team has been a bit of a mess . They have False linebacker Mo Lewis tackled Drew Bledsoe, 2 [' linebacker', ' Mo', ' Lewis']
+165 116 In their sport, the position played by x -1 In their sport, the position played by Mo Lewis linebacker Mo Lewis "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' ' have']" in the first half of the season , the team has been a bit of a mess . They have False Jets linebacker Mo Lewis tackled Drew Bledsoe, 3 [' Jets', ' linebacker', ' Mo', ' Lewis']
+166 116 In their sport, the position played by x -1 In their sport, the position played by Mo Lewis linebacker Mo Lewis "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' ' have']" in the first half of the season , the team has been a bit of a mess . They have False when Jets linebacker Mo Lewis tackled Drew Bledsoe, 4 [' when', ' Jets', ' linebacker', ' Mo', ' Lewis']
+167 123 In their sport, the position played by x -1 In their sport, the position played by Ray Nitschke linebacker Ray Nitschke "[',' ' the' ' German' ',' ' is' ' a' ' very' ' important' ' one' '.' ' He'
+ ' is' ' the' ' only' ' one' ' who' ' can' ' be' ' considered' ' a']" , the German , is a very important one . He is the only one who can be considered a False the Fox River on the Ray Nitschke Memorial Bridge. 8 [' the', ' Fox', ' River', ' on', ' the', ' Ray', ' N', 'itsch', 'ke']
+168 123 In their sport, the position played by x -1 In their sport, the position played by Ray Nitschke linebacker Ray Nitschke "[',' ' the' ' German' ',' ' is' ' a' ' very' ' important' ' one' '.' ' He'
+ ' is' ' the' ' only' ' one' ' who' ' can' ' be' ' considered' ' a']" , the German , is a very important one . He is the only one who can be considered a False River on the Ray Nitschke Memorial Bridge. On 6 [' River', ' on', ' the', ' Ray', ' N', 'itsch', 'ke']
+169 124 In their sport, the position played by x -1 In their sport, the position played by Robert Marve quarterback Robert Marve "[',' ' the' '\n' '\n' 'The' ' first' ' thing' ' you' ' need' ' to' ' know'
+ ' about' ' the' ' new' ' season' ' of' ' The' ' Walking' ' Dead' ' is']" ", the
+
+ The first thing you need to know about the new season of The Walking Dead is" False followed by Purdue's Robert Marve throwing a 19-yard 6 "[' followed', ' by', ' Purdue', ""'s"", ' Robert', ' Mar', 've']"
+170 125 In their sport, the position played by x -1 In their sport, the position played by Rich Gannon quarterback Rich Gannon "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Vikings quarterback Rich Gannon and pressured Gannon 4 [' Vikings', ' quarterback', ' Rich', ' G', 'annon']
+171 125 In their sport, the position played by x -1 In their sport, the position played by Rich Gannon quarterback Rich Gannon "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True backup quarterback Rich Gannon took the reins 4 [' backup', ' quarterback', ' Rich', ' G', 'annon']
+172 125 In their sport, the position played by x -1 In their sport, the position played by Rich Gannon quarterback Rich Gannon "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True quarterback Rich Gannon took the reins of 3 [' quarterback', ' Rich', ' G', 'annon']
+173 126 In their sport, the position played by x -1 In their sport, the position played by Ray Lewis linebacker Ray Lewis "[',' ' the' ' Ravens' ""'"" ' middle' ' linebacker' ',' ' is' ' a'
+ ' position' ' that' ' is' ' often' ' overlooked' '.' ' He' ' is' ' the'
+ ' heart' ' of']" , the Ravens ' middle linebacker , is a position that is often overlooked . He is the heart of True " and the Willie Ray Lewis remix version.
+" 4 [' and', ' the', ' Willie', ' Ray', ' Lewis']
+174 126 In their sport, the position played by x -1 In their sport, the position played by Ray Lewis linebacker Ray Lewis "[',' ' the' ' Ravens' ""'"" ' middle' ' linebacker' ',' ' is' ' a'
+ ' position' ' that' ' is' ' often' ' overlooked' '.' ' He' ' is' ' the'
+ ' heart' ' of']" , the Ravens ' middle linebacker , is a position that is often overlooked . He is the heart of True Baltimore Ravens's Ray Lewis to express 4 "[' Baltimore', ' Ravens', ""'s"", ' Ray', ' Lewis']"
+175 126 In their sport, the position played by x -1 In their sport, the position played by Ray Lewis linebacker Ray Lewis "[',' ' the' ' Ravens' ""'"" ' middle' ' linebacker' ',' ' is' ' a'
+ ' position' ' that' ' is' ' often' ' overlooked' '.' ' He' ' is' ' the'
+ ' heart' ' of']" , the Ravens ' middle linebacker , is a position that is often overlooked . He is the heart of True Anderson scholar Ray Lewis White which used 3 [' Anderson', ' scholar', ' Ray', ' Lewis']
+176 126 In their sport, the position played by x -1 In their sport, the position played by Ray Lewis linebacker Ray Lewis "[',' ' the' ' Ravens' ""'"" ' middle' ' linebacker' ',' ' is' ' a'
+ ' position' ' that' ' is' ' often' ' overlooked' '.' ' He' ' is' ' the'
+ ' heart' ' of']" , the Ravens ' middle linebacker , is a position that is often overlooked . He is the heart of True Jamal Lewis, Ray Lewis and other stars for 4 [' Jamal', ' Lewis', ',', ' Ray', ' Lewis']
+177 126 In their sport, the position played by x -1 In their sport, the position played by Ray Lewis linebacker Ray Lewis "[',' ' the' ' Ravens' ""'"" ' middle' ' linebacker' ',' ' is' ' a'
+ ' position' ' that' ' is' ' often' ' overlooked' '.' ' He' ' is' ' the'
+ ' heart' ' of']" , the Ravens ' middle linebacker , is a position that is often overlooked . He is the heart of True Sherwood Anderson scholar Ray Lewis White which used 5 [' Sher', 'wood', ' Anderson', ' scholar', ' Ray', ' Lewis']
+178 128 In their sport, the position played by x -1 In their sport, the position played by Anders Limpar midfielder Anders Limpar "[',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is' ' a'
+ ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player'
+ ' to']" , who is a former world champion , is a very important one . He is the only player to False 1991 FA Cup Final. Anders Limpar was absent for Arsenal, 7 [' 1991', ' FA', ' Cup', ' Final', '.', ' Anders', ' Lim', 'par']
+179 128 In their sport, the position played by x -1 In their sport, the position played by Anders Limpar midfielder Anders Limpar "[',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is' ' a'
+ ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player'
+ ' to']" , who is a former world champion , is a very important one . He is the only player to False 3 ['And', 'ers', ' Lim', 'par']
+180 128 In their sport, the position played by x -1 In their sport, the position played by Anders Limpar midfielder Anders Limpar "[',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is' ' a'
+ ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player'
+ ' to']" , who is a former world champion , is a very important one . He is the only player to False Swedish winger Anders Limpar in the close season. 4 [' Swedish', ' winger', ' Anders', ' Lim', 'par']
+181 128 In their sport, the position played by x -1 In their sport, the position played by Anders Limpar midfielder Anders Limpar "[',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is' ' a'
+ ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player'
+ ' to']" , who is a former world champion , is a very important one . He is the only player to False 1991 FA Cup Final. Anders Limpar was absent for 7 [' 1991', ' FA', ' Cup', ' Final', '.', ' Anders', ' Lim', 'par']
+182 128 In their sport, the position played by x -1 In their sport, the position played by Anders Limpar midfielder Anders Limpar "[',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is' ' a'
+ ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player'
+ ' to']" , who is a former world champion , is a very important one . He is the only player to False 3 ['And', 'ers', ' Lim', 'par']
+183 132 In their sport, the position played by x -1 In their sport, the position played by Buck Martinez catcher Buck Martinez "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the only thing" False Ron Darling, and Buck Martinez for the League 5 [' Ron', ' Darling', ',', ' and', ' Buck', ' Martinez']
+184 133 In their sport, the position played by x -1 In their sport, the position played by Cory Wade pitcher Cory Wade "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' national' ' championship' ' in' ' the' ' fall' ' of' ' 2013' '.'
+ '\n']" ", who was a member of the team that won the national championship in the fall of 2013 .
+" False Philadelphia, while Cory Wade suffered the loss 4 [' Philadelphia', ',', ' while', ' Cory', ' Wade']
+185 135 In their sport, the position played by x -1 In their sport, the position played by Taylor Heinicke quarterback Taylor Heinicke "[',' ' who' ' was' ' the' ' first' ' to' ' throw' ' a' ' touchdown'
+ ' pass' ' in' ' the' ' first' ' quarter' ',' ' was' ' a' ' perfect'
+ ' example' ' of']" , who was the first to throw a touchdown pass in the first quarter , was a perfect example of False undrafted signee Taylor Heinicke for the third quarterback 6 [' undrafted', ' sign', 'ee', ' Taylor', ' Hein', 'ic', 'ke']
+186 135 In their sport, the position played by x -1 In their sport, the position played by Taylor Heinicke quarterback Taylor Heinicke "[',' ' who' ' was' ' the' ' first' ' to' ' throw' ' a' ' touchdown'
+ ' pass' ' in' ' the' ' first' ' quarter' ',' ' was' ' a' ' perfect'
+ ' example' ' of']" , who was the first to throw a touchdown pass in the first quarter , was a perfect example of False undrafted signee Taylor Heinicke for the third 6 [' undrafted', ' sign', 'ee', ' Taylor', ' Hein', 'ic', 'ke']
+187 135 In their sport, the position played by x -1 In their sport, the position played by Taylor Heinicke quarterback Taylor Heinicke "[',' ' who' ' was' ' the' ' first' ' to' ' throw' ' a' ' touchdown'
+ ' pass' ' in' ' the' ' first' ' quarter' ',' ' was' ' a' ' perfect'
+ ' example' ' of']" , who was the first to throw a touchdown pass in the first quarter , was a perfect example of False undrafted signee Taylor Heinicke for the third quarterback 6 [' undrafted', ' sign', 'ee', ' Taylor', ' Hein', 'ic', 'ke']
+188 137 In their sport, the position played by x -1 In their sport, the position played by Derek Lowe pitcher Derek Lowe "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' pitcher' '.' ' He' ' was' ' a' ' great' ' pitcher' ',' ' but']" , who was a great player , but not a great pitcher . He was a great pitcher , but True time in history, Derek Lowe stifled the 5 [' time', ' in', ' history', ',', ' Derek', ' Lowe']
+189 137 In their sport, the position played by x -1 In their sport, the position played by Derek Lowe pitcher Derek Lowe "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' pitcher' '.' ' He' ' was' ' a' ' great' ' pitcher' ',' ' but']" , who was a great player , but not a great pitcher . He was a great pitcher , but True fourth time in history, Derek Lowe stifled the 6 [' fourth', ' time', ' in', ' history', ',', ' Derek', ' Lowe']
+190 137 In their sport, the position played by x -1 In their sport, the position played by Derek Lowe pitcher Derek Lowe "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' pitcher' '.' ' He' ' was' ' a' ' great' ' pitcher' ',' ' but']" , who was a great player , but not a great pitcher . He was a great pitcher , but True the Braves' Derek Lowe was the Giants'first 4 "[' the', ' Braves', ""'"", ' Derek', ' Lowe']"
+191 137 In their sport, the position played by x -1 In their sport, the position played by Derek Lowe pitcher Derek Lowe "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' pitcher' '.' ' He' ' was' ' a' ' great' ' pitcher' ',' ' but']" , who was a great player , but not a great pitcher . He was a great pitcher , but True the Braves' Derek Lowe was the Giants'first 4 "[' the', ' Braves', ""'"", ' Derek', ' Lowe']"
+192 137 In their sport, the position played by x -1 In their sport, the position played by Derek Lowe pitcher Derek Lowe "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' pitcher' '.' ' He' ' was' ' a' ' great' ' pitcher' ',' ' but']" , who was a great player , but not a great pitcher . He was a great pitcher , but True off the Braves' Derek Lowe was the Giants'first 5 "[' off', ' the', ' Braves', ""'"", ' Derek', ' Lowe']"
+193 141 In their sport, the position played by x -1 In their sport, the position played by Frank Mobley forward Frank Mobley "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False inside-forward Frank Mobley with 13 goals. The 5 [' inside', '-', 'forward', ' Frank', ' Mob', 'ley']
+194 141 In their sport, the position played by x -1 In their sport, the position played by Frank Mobley forward Frank Mobley "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False with goals from Frank Mobley and Wheldon. The 5 [' with', ' goals', ' from', ' Frank', ' Mob', 'ley']
+195 141 In their sport, the position played by x -1 In their sport, the position played by Frank Mobley forward Frank Mobley "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False was inside-forward Frank Mobley with 13 goals. 6 [' was', ' inside', '-', 'forward', ' Frank', ' Mob', 'ley']
+196 141 In their sport, the position played by x -1 In their sport, the position played by Frank Mobley forward Frank Mobley "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False inside-forward Frank Mobley with 25 goals – his 5 [' inside', '-', 'forward', ' Frank', ' Mob', 'ley']
+197 141 In their sport, the position played by x -1 In their sport, the position played by Frank Mobley forward Frank Mobley "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False with Grimsby's opener, Frank Mobley was injured 9 "[' with', ' Gr', 'ims', 'by', ""'s"", ' opener', ',', ' Frank', ' Mob', 'ley']"
+198 144 In their sport, the position played by x -1 In their sport, the position played by Johnny Unitas quarterback Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' quarterback' ' who' ' was' ' a' ' great' ' leader' ',' ' a'
+ ' great' ' competitor']" , the quarterback for the Baltimore Colts , was a quarterback who was a great leader , a great competitor True Colts great Johnny Unitas at an open tryout, 4 [' Colts', ' great', ' Johnny', ' Unit', 'as']
+199 144 In their sport, the position played by x -1 In their sport, the position played by Johnny Unitas quarterback Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' quarterback' ' who' ' was' ' a' ' great' ' leader' ',' ' a'
+ ' great' ' competitor']" , the quarterback for the Baltimore Colts , was a quarterback who was a great leader , a great competitor True while Dick Cavett and Johnny Unitas guest starred as 7 [' while', ' Dick', ' Cave', 'tt', ' and', ' Johnny', ' Unit', 'as']
+200 144 In their sport, the position played by x -1 In their sport, the position played by Johnny Unitas quarterback Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' quarterback' ' who' ' was' ' a' ' great' ' leader' ',' ' a'
+ ' great' ' competitor']" , the quarterback for the Baltimore Colts , was a quarterback who was a great leader , a great competitor True Davey O 'Brien and Johnny Unitas Golden Arm 8 "[' Dave', 'y', ' O', "" '"", 'Brien', ' and', ' Johnny', ' Unit', 'as']"
+201 144 In their sport, the position played by x -1 In their sport, the position played by Johnny Unitas quarterback Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' quarterback' ' who' ' was' ' a' ' great' ' leader' ',' ' a'
+ ' great' ' competitor']" , the quarterback for the Baltimore Colts , was a quarterback who was a great leader , a great competitor True the year and the Johnny Unitas Golden Arm Award 6 [' the', ' year', ' and', ' the', ' Johnny', ' Unit', 'as']
+202 144 In their sport, the position played by x -1 In their sport, the position played by Johnny Unitas quarterback Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' quarterback' ' who' ' was' ' a' ' great' ' leader' ',' ' a'
+ ' great' ' competitor']" , the quarterback for the Baltimore Colts , was a quarterback who was a great leader , a great competitor True Dick Cavett and Johnny Unitas guest starred as 6 [' Dick', ' Cave', 'tt', ' and', ' Johnny', ' Unit', 'as']
+203 145 In their sport, the position played by x -1 In their sport, the position played by Tommy Kramer quarterback Tommy Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False quarterback Tommy Kramer four times in a 19 2 [' quarterback', ' Tommy', ' Kramer']
+204 145 In their sport, the position played by x -1 In their sport, the position played by Tommy Kramer quarterback Tommy Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False 2 ['Tom', 'my', ' Kramer']
+205 145 In their sport, the position played by x -1 In their sport, the position played by Tommy Kramer quarterback Tommy Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False sacked quarterback Tommy Kramer four times in a 19 3 [' sacked', ' quarterback', ' Tommy', ' Kramer']
+206 145 In their sport, the position played by x -1 In their sport, the position played by Tommy Kramer quarterback Tommy Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False sacked quarterback Tommy Kramer four times 3 [' sacked', ' quarterback', ' Tommy', ' Kramer']
+207 145 In their sport, the position played by x -1 In their sport, the position played by Tommy Kramer quarterback Tommy Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False touchdown pass from Tommy Kramer to Mike Mularkey. 4 [' touchdown', ' pass', ' from', ' Tommy', ' Kramer']
+208 149 In their sport, the position played by x -1 In their sport, the position played by Paul Justin quarterback Paul Justin "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False Giants and quarterback Paul Justin was signed by Dallas 4 [' Giants', ' and', ' quarterback', ' Paul', ' Justin']
+209 153 In their sport, the position played by x -1 In their sport, the position played by Fernando De Napoli midfielder Fernando De Napoli "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' '.' '\n' '\n' 'The' ' match' ' was' ' played' ' in']" ", who was the first to score a goal in the game .
+
+ The match was played in" False Salvatore Bagni, and Fernando De Napoli filling the ranks. 11 [' Salv', 'at', 'ore', ' B', 'agn', 'i', ',', ' and', ' Fernando', ' De', ' Nap', 'oli']
+210 153 In their sport, the position played by x -1 In their sport, the position played by Fernando De Napoli midfielder Fernando De Napoli "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' '.' '\n' '\n' 'The' ' match' ' was' ' played' ' in']" ", who was the first to score a goal in the game .
+
+ The match was played in" False Salvatore Bagni, and Fernando De Napoli filling the 11 [' Salv', 'at', 'ore', ' B', 'agn', 'i', ',', ' and', ' Fernando', ' De', ' Nap', 'oli']
+211 153 In their sport, the position played by x -1 In their sport, the position played by Fernando De Napoli midfielder Fernando De Napoli "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' '.' '\n' '\n' 'The' ' match' ' was' ' played' ' in']" ", who was the first to score a goal in the game .
+
+ The match was played in" False Salvatore Bagni, and Fernando De Napoli filling the ranks. 11 [' Salv', 'at', 'ore', ' B', 'agn', 'i', ',', ' and', ' Fernando', ' De', ' Nap', 'oli']
+212 153 In their sport, the position played by x -1 In their sport, the position played by Fernando De Napoli midfielder Fernando De Napoli "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' '.' '\n' '\n' 'The' ' match' ' was' ' played' ' in']" ", who was the first to score a goal in the game .
+
+ The match was played in" False Salvatore Bagni, and Fernando De Napoli filling the ranks. 11 [' Salv', 'at', 'ore', ' B', 'agn', 'i', ',', ' and', ' Fernando', ' De', ' Nap', 'oli']
+213 153 In their sport, the position played by x -1 In their sport, the position played by Fernando De Napoli midfielder Fernando De Napoli "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' '.' '\n' '\n' 'The' ' match' ' was' ' played' ' in']" ", who was the first to score a goal in the game .
+
+ The match was played in" False Salvatore Bagni, and Fernando De Napoli filling the ranks. 11 [' Salv', 'at', 'ore', ' B', 'agn', 'i', ',', ' and', ' Fernando', ' De', ' Nap', 'oli']
+214 154 In their sport, the position played by x -1 In their sport, the position played by Grant Fuhr goaltender Grant Fuhr "[',' ' the' ' former' ' NHL' ' goalie' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a'
+ ' difference']" , the former NHL goalie , is a perfect example of the kind of player who can be a difference False was back-up to Grant Fuhr and did not 8 [' was', ' back', '-', 'up', ' to', ' Grant', ' F', 'uh', 'r']
+215 154 In their sport, the position played by x -1 In their sport, the position played by Grant Fuhr goaltender Grant Fuhr "[',' ' the' ' former' ' NHL' ' goalie' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a'
+ ' difference']" , the former NHL goalie , is a perfect example of the kind of player who can be a difference False while passing Grant Fuhr for second place in 5 [' while', ' passing', ' Grant', ' F', 'uh', 'r']
+216 154 In their sport, the position played by x -1 In their sport, the position played by Grant Fuhr goaltender Grant Fuhr "[',' ' the' ' former' ' NHL' ' goalie' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a'
+ ' difference']" , the former NHL goalie , is a perfect example of the kind of player who can be a difference False " era: ""I always thought Grant Fuhr was the best" 9 "[' era', ':', ' ""', 'I', ' always', ' thought', ' Grant', ' F', 'uh', 'r']"
+217 154 In their sport, the position played by x -1 In their sport, the position played by Grant Fuhr goaltender Grant Fuhr "[',' ' the' ' former' ' NHL' ' goalie' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a'
+ ' difference']" , the former NHL goalie , is a perfect example of the kind of player who can be a difference False career, scoring against Grant Fuhr of the Buffalo 7 [' career', ',', ' scoring', ' against', ' Grant', ' F', 'uh', 'r']
+218 154 In their sport, the position played by x -1 In their sport, the position played by Grant Fuhr goaltender Grant Fuhr "[',' ' the' ' former' ' NHL' ' goalie' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a'
+ ' difference']" , the former NHL goalie , is a perfect example of the kind of player who can be a difference False career, scoring against Grant Fuhr of the Buffalo 7 [' career', ',', ' scoring', ' against', ' Grant', ' F', 'uh', 'r']
+219 158 In their sport, the position played by x -1 In their sport, the position played by Todd Marinovich quarterback Todd Marinovich "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Southern' ' California' ',' ' was' ' a' ' perfect' ' fit' ' for' ' the'
+ ' team' '.']" , who was a quarterback at the University of Southern California , was a perfect fit for the team . True Herman Moore, Todd Marinovich and Rob Carpenter 5 [' Herman', ' Moore', ',', ' Todd', ' Marin', 'ovich']
+220 158 In their sport, the position played by x -1 In their sport, the position played by Todd Marinovich quarterback Todd Marinovich "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Southern' ' California' ',' ' was' ' a' ' perfect' ' fit' ' for' ' the'
+ ' team' '.']" , who was a quarterback at the University of Southern California , was a perfect fit for the team . True leave early since Todd Marinovich did so after 5 [' leave', ' early', ' since', ' Todd', ' Marin', 'ovich']
+221 158 In their sport, the position played by x -1 In their sport, the position played by Todd Marinovich quarterback Todd Marinovich "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Southern' ' California' ',' ' was' ' a' ' perfect' ' fit' ' for' ' the'
+ ' team' '.']" , who was a quarterback at the University of Southern California , was a perfect fit for the team . True Herman Moore, Todd Marinovich and Rob Carpenter 5 [' Herman', ' Moore', ',', ' Todd', ' Marin', 'ovich']
+222 158 In their sport, the position played by x -1 In their sport, the position played by Todd Marinovich quarterback Todd Marinovich "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Southern' ' California' ',' ' was' ' a' ' perfect' ' fit' ' for' ' the'
+ ' team' '.']" , who was a quarterback at the University of Southern California , was a perfect fit for the team . True leave early since Todd Marinovich did so after 5 [' leave', ' early', ' since', ' Todd', ' Marin', 'ovich']
+223 158 In their sport, the position played by x -1 In their sport, the position played by Todd Marinovich quarterback Todd Marinovich "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Southern' ' California' ',' ' was' ' a' ' perfect' ' fit' ' for' ' the'
+ ' team' '.']" , who was a quarterback at the University of Southern California , was a perfect fit for the team . True early since Todd Marinovich did so after 4 [' early', ' since', ' Todd', ' Marin', 'ovich']
+224 159 In their sport, the position played by x -1 In their sport, the position played by Jim Bouton pitcher Jim Bouton "[',' ' the' ' former' ' Yankees' ' pitcher' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a'
+ ' good']" , the former Yankees pitcher , is a perfect example of the kind of person who is not a good True Series since Jim Bouton in 1964. Derek 4 [' Series', ' since', ' Jim', ' Bout', 'on']
+225 162 In their sport, the position played by x -1 In their sport, the position played by Erin Henderson linebacker Erin Henderson "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' woman' ' who' ' is'
+ ' a']" , who is a former Miss America , is a perfect example of the kind of woman who is a False Bowl, linebacker Erin Henderson announced that 4 [' Bowl', ',', ' linebacker', ' Erin', ' Henderson']
+226 162 In their sport, the position played by x -1 In their sport, the position played by Erin Henderson linebacker Erin Henderson "[',' ' who' ' is' ' a' ' former' ' Miss' ' America' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' woman' ' who' ' is'
+ ' a']" , who is a former Miss America , is a perfect example of the kind of woman who is a False linebacker Erin Henderson announced 2 [' linebacker', ' Erin', ' Henderson']
+227 164 In their sport, the position played by x -1 In their sport, the position played by Luis del Sol midfielder Luis del Sol "[',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is' ' a'
+ ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player'
+ ' who']" , who is a former world champion , is a very important one . He is the only player who False other than Italy, Luis del Sol won it in 1964 with 6 [' other', ' than', ' Italy', ',', ' Luis', ' del', ' Sol']
+228 164 In their sport, the position played by x -1 In their sport, the position played by Luis del Sol midfielder Luis del Sol "[',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is' ' a'
+ ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player'
+ ' who']" , who is a former world champion , is a very important one . He is the only player who False than Italy, Luis del Sol won it in 1964 5 [' than', ' Italy', ',', ' Luis', ' del', ' Sol']
+229 165 In their sport, the position played by x -1 In their sport, the position played by Yohan Cabaye midfielder Yohan Cabaye "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' key' ' one' '.' '\n' '\n']" ", who has been a revelation in the Premier League this season , is a key one .
+
+" False date was for Yohan Cabaye from Paris Saint-Germain 6 [' date', ' was', ' for', ' Y', 'ohan', ' Cab', 'aye']
+230 165 In their sport, the position played by x -1 In their sport, the position played by Yohan Cabaye midfielder Yohan Cabaye "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' key' ' one' '.' '\n' '\n']" ", who has been a revelation in the Premier League this season , is a key one .
+
+" False to date was for Yohan Cabaye from Paris 7 [' to', ' date', ' was', ' for', ' Y', 'ohan', ' Cab', 'aye']
+231 165 In their sport, the position played by x -1 In their sport, the position played by Yohan Cabaye midfielder Yohan Cabaye "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' key' ' one' '.' '\n' '\n']" ", who has been a revelation in the Premier League this season , is a key one .
+
+" False club to date was for Yohan Cabaye from Paris Saint-Germain 8 [' club', ' to', ' date', ' was', ' for', ' Y', 'ohan', ' Cab', 'aye']
+232 165 In their sport, the position played by x -1 In their sport, the position played by Yohan Cabaye midfielder Yohan Cabaye "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' key' ' one' '.' '\n' '\n']" ", who has been a revelation in the Premier League this season , is a key one .
+
+" False club to date was for Yohan Cabaye from Paris Saint-Germain 8 [' club', ' to', ' date', ' was', ' for', ' Y', 'ohan', ' Cab', 'aye']
+233 165 In their sport, the position played by x -1 In their sport, the position played by Yohan Cabaye midfielder Yohan Cabaye "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' key' ' one' '.' '\n' '\n']" ", who has been a revelation in the Premier League this season , is a key one .
+
+" False fellow play-makers Yohan Cabaye and Gervinho 7 [' fellow', ' play', '-', 'makers', ' Y', 'ohan', ' Cab', 'aye']
+234 169 In their sport, the position played by x -1 In their sport, the position played by Dan McGwire quarterback Dan McGwire "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' hitter' '.' ' He' ' was' ' a' ' great' ' hitter' ',' ' but']" , who was a great player , but not a great hitter . He was a great hitter , but False after freshmen Dan McGwire and Tom Poholsky took 4 [' after', ' freshmen', ' Dan', ' McG', 'wire']
+235 170 In their sport, the position played by x -1 In their sport, the position played by Luke French pitcher Luke French "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' and' ' the'
+ ' second' ' half' ' of' ' the' ' season' ',' ' he' ' was' ' a' ' very']" in the first half of the season , and the second half of the season , he was a very False pitchers Garrett Olson, Luke French and Jason Vargas 5 [' pitchers', ' Garrett', ' Olson', ',', ' Luke', ' French']
+236 170 In their sport, the position played by x -1 In their sport, the position played by Luke French pitcher Luke French "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' and' ' the'
+ ' second' ' half' ' of' ' the' ' season' ',' ' he' ' was' ' a' ' very']" in the first half of the season , and the second half of the season , he was a very False pitchers Garrett Olson, Luke French and Jason Vargas 5 [' pitchers', ' Garrett', ' Olson', ',', ' Luke', ' French']
+237 170 In their sport, the position played by x -1 In their sport, the position played by Luke French pitcher Luke French "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' and' ' the'
+ ' second' ' half' ' of' ' the' ' season' ',' ' he' ' was' ' a' ' very']" in the first half of the season , and the second half of the season , he was a very False Garrett Olson, Luke French and Jason Vargas 4 [' Garrett', ' Olson', ',', ' Luke', ' French']
+238 171 In their sport, the position played by x -1 In their sport, the position played by Bruno Conti midfielder Bruno Conti "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Italian' ' national'
+ ' team' ',' ' was' ' a' ' very' ' important' ' one' '.' ' He' ' was'
+ ' the']" , who was a member of the Italian national team , was a very important one . He was the False Francesco Vincenzi, Bruno Conti and Toninho Cerezo. 9 [' Frances', 'co', ' V', 'inc', 'en', 'zi', ',', ' Bruno', ' Cont', 'i']
+239 172 In their sport, the position played by x -1 In their sport, the position played by Barry Zito pitcher Barry Zito "[',' ' who' ' was' ' a' ' great' ' pitcher' ' for' ' the' ' Giants' ','
+ ' and' ' the' ' Giants' ',' ' who' ' were' ' a' ' great' ' team' '.']" , who was a great pitcher for the Giants , and the Giants , who were a great team . True " and also allow Barry Zito a final home start.
+" 5 [' and', ' also', ' allow', ' Barry', ' Z', 'ito']
+240 172 In their sport, the position played by x -1 In their sport, the position played by Barry Zito pitcher Barry Zito "[',' ' who' ' was' ' a' ' great' ' pitcher' ' for' ' the' ' Giants' ','
+ ' and' ' the' ' Giants' ',' ' who' ' were' ' a' ' great' ' team' '.']" , who was a great pitcher for the Giants , and the Giants , who were a great team . True Giants' starter Barry Zito was injured in the 5 "[' Giants', ""'"", ' starter', ' Barry', ' Z', 'ito']"
+241 172 In their sport, the position played by x -1 In their sport, the position played by Barry Zito pitcher Barry Zito "[',' ' who' ' was' ' a' ' great' ' pitcher' ' for' ' the' ' Giants' ','
+ ' and' ' the' ' Giants' ',' ' who' ' were' ' a' ' great' ' team' '.']" , who was a great pitcher for the Giants , and the Giants , who were a great team . True shortly before Barry Zito was activated in June. 4 [' shortly', ' before', ' Barry', ' Z', 'ito']
+242 172 In their sport, the position played by x -1 In their sport, the position played by Barry Zito pitcher Barry Zito "[',' ' who' ' was' ' a' ' great' ' pitcher' ' for' ' the' ' Giants' ','
+ ' and' ' the' ' Giants' ',' ' who' ' were' ' a' ' great' ' team' '.']" , who was a great pitcher for the Giants , and the Giants , who were a great team . True break and also allow Barry Zito a final home 6 [' break', ' and', ' also', ' allow', ' Barry', ' Z', 'ito']
+243 172 In their sport, the position played by x -1 In their sport, the position played by Barry Zito pitcher Barry Zito "[',' ' who' ' was' ' a' ' great' ' pitcher' ' for' ' the' ' Giants' ','
+ ' and' ' the' ' Giants' ',' ' who' ' were' ' a' ' great' ' team' '.']" , who was a great pitcher for the Giants , and the Giants , who were a great team . True " Zito =
+" 6 [' Z', 'ito', ' =', 'B', 'arry', ' Z', 'ito']
+244 173 In their sport, the position played by x -1 In their sport, the position played by Kevin Kolb quarterback Kevin Kolb "[',' ' the' ' Cardinals' ""'"" ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good' ' quarterback' ',' ' but']" , the Cardinals ' quarterback , is a bit of a mystery . He 's a good quarterback , but True quarterback behind Kevin Kolb and Michael 4 [' quarterback', ' behind', ' Kevin', ' Kol', 'b']
+245 173 In their sport, the position played by x -1 In their sport, the position played by Kevin Kolb quarterback Kevin Kolb "[',' ' the' ' Cardinals' ""'"" ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good' ' quarterback' ',' ' but']" , the Cardinals ' quarterback , is a bit of a mystery . He 's a good quarterback , but True Redskins and Kevin Kolb was named the starter. 4 [' Redskins', ' and', ' Kevin', ' Kol', 'b']
+246 173 In their sport, the position played by x -1 In their sport, the position played by Kevin Kolb quarterback Kevin Kolb "[',' ' the' ' Cardinals' ""'"" ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good' ' quarterback' ',' ' but']" , the Cardinals ' quarterback , is a bit of a mystery . He 's a good quarterback , but True Michael Vick and Kevin Kolb but did not see 5 [' Michael', ' Vick', ' and', ' Kevin', ' Kol', 'b']
+247 173 In their sport, the position played by x -1 In their sport, the position played by Kevin Kolb quarterback Kevin Kolb "[',' ' the' ' Cardinals' ""'"" ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good' ' quarterback' ',' ' but']" , the Cardinals ' quarterback , is a bit of a mystery . He 's a good quarterback , but True " Kevin Kolb =
+" 2 [' Kevin', ' Kol', 'b']
+248 173 In their sport, the position played by x -1 In their sport, the position played by Kevin Kolb quarterback Kevin Kolb "[',' ' the' ' Cardinals' ""'"" ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good' ' quarterback' ',' ' but']" , the Cardinals ' quarterback , is a bit of a mystery . He 's a good quarterback , but True quarterback behind Kevin Kolb and Michael Vick. 4 [' quarterback', ' behind', ' Kevin', ' Kol', 'b']
+249 174 In their sport, the position played by x -1 In their sport, the position played by Ken O'Brien quarterback Ken O'Brien "[',' ' the' ' former' ' head' ' of' ' the' ' CIA' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is'
+ ' not']" , the former head of the CIA , is a perfect example of the kind of person who is not False concerned over quarterback Ken O'Brien having been 6 "[' concerned', ' over', ' quarterback', ' Ken', ' O', ""'"", 'Brien']"
+250 174 In their sport, the position played by x -1 In their sport, the position played by Ken O'Brien quarterback Ken O'Brien "[',' ' the' ' former' ' head' ' of' ' the' ' CIA' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is'
+ ' not']" , the former head of the CIA , is a perfect example of the kind of person who is not False concerned over quarterback Ken O'Brien having been 6 "[' concerned', ' over', ' quarterback', ' Ken', ' O', ""'"", 'Brien']"
+251 174 In their sport, the position played by x -1 In their sport, the position played by Ken O'Brien quarterback Ken O'Brien "[',' ' the' ' former' ' head' ' of' ' the' ' CIA' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is'
+ ' not']" , the former head of the CIA , is a perfect example of the kind of person who is not False quarterback Ken O'Brien having been sacked 4 "[' quarterback', ' Ken', ' O', ""'"", 'Brien']"
+252 174 In their sport, the position played by x -1 In their sport, the position played by Ken O'Brien quarterback Ken O'Brien "[',' ' the' ' former' ' head' ' of' ' the' ' CIA' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is'
+ ' not']" , the former head of the CIA , is a perfect example of the kind of person who is not False concerned over quarterback Ken O'Brien having been sacked 6 "[' concerned', ' over', ' quarterback', ' Ken', ' O', ""'"", 'Brien']"
+253 175 In their sport, the position played by x -1 In their sport, the position played by Ryan Miller goaltender Ryan Miller "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' Olympic'
+ ' team' ' in' ' Sochi' ',' ' is' ' a' ' perfect' ' example']" , who was a member of the U . S . Olympic team in Sochi , is a perfect example False cousins, brothers Ryan Miller and Drew Miller. 4 [' cousins', ',', ' brothers', ' Ryan', ' Miller']
+254 175 In their sport, the position played by x -1 In their sport, the position played by Ryan Miller goaltender Ryan Miller "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' Olympic'
+ ' team' ' in' ' Sochi' ',' ' is' ' a' ' perfect' ' example']" , who was a member of the U . S . Olympic team in Sochi , is a perfect example False Ty Conklin and Ryan Miller both wearing retro-painted 6 [' Ty', ' Con', 'k', 'lin', ' and', ' Ryan', ' Miller']
+255 175 In their sport, the position played by x -1 In their sport, the position played by Ryan Miller goaltender Ryan Miller "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' Olympic'
+ ' team' ' in' ' Sochi' ',' ' is' ' a' ' perfect' ' example']" , who was a member of the U . S . Olympic team in Sochi , is a perfect example False with countrymen Ryan Miller and Brian Rafalski. 4 [' with', ' country', 'men', ' Ryan', ' Miller']
+256 175 In their sport, the position played by x -1 In their sport, the position played by Ryan Miller goaltender Ryan Miller "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' Olympic'
+ ' team' ' in' ' Sochi' ',' ' is' ' a' ' perfect' ' example']" , who was a member of the U . S . Olympic team in Sochi , is a perfect example False Sabres goaltender Ryan Miller and Penguins 3 [' Sabres', ' goaltender', ' Ryan', ' Miller']
+257 175 In their sport, the position played by x -1 In their sport, the position played by Ryan Miller goaltender Ryan Miller "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' Olympic'
+ ' team' ' in' ' Sochi' ',' ' is' ' a' ' perfect' ' example']" , who was a member of the U . S . Olympic team in Sochi , is a perfect example False " cousins, brothers Ryan Miller and Drew Miller.
+" 4 [' cousins', ',', ' brothers', ' Ryan', ' Miller']
+258 176 In their sport, the position played by x -1 In their sport, the position played by Johan Santana pitcher Johan Santana "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ','
+ ' and' ' the' ' best' ' pitcher' ' in' ' the' ' AL' ' in' ' the'
+ ' playoffs']" , who was the best pitcher in the game , and the best pitcher in the AL in the playoffs True home run against Johan Santana in a 5 – 2 victory 5 [' home', ' run', ' against', ' Joh', 'an', ' Santana']
+259 176 In their sport, the position played by x -1 In their sport, the position played by Johan Santana pitcher Johan Santana "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ','
+ ' and' ' the' ' best' ' pitcher' ' in' ' the' ' AL' ' in' ' the'
+ ' playoffs']" , who was the best pitcher in the game , and the best pitcher in the AL in the playoffs True Jiménez and Johan Santana behind Roy Halladay's 7 [' Jim', 'é', 'ne', 'z', ' and', ' Joh', 'an', ' Santana']
+260 176 In their sport, the position played by x -1 In their sport, the position played by Johan Santana pitcher Johan Santana "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ','
+ ' and' ' the' ' best' ' pitcher' ' in' ' the' ' AL' ' in' ' the'
+ ' playoffs']" , who was the best pitcher in the game , and the best pitcher in the AL in the playoffs True Ubaldo Jiménez and Johan Santana behind Roy Halladay's 9 [' Ub', 'aldo', ' Jim', 'é', 'ne', 'z', ' and', ' Joh', 'an', ' Santana']
+261 176 In their sport, the position played by x -1 In their sport, the position played by Johan Santana pitcher Johan Santana "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ','
+ ' and' ' the' ' best' ' pitcher' ' in' ' the' ' AL' ' in' ' the'
+ ' playoffs']" , who was the best pitcher in the game , and the best pitcher in the AL in the playoffs True pitcher since Johan Santana in 2007 with 16 4 [' pitcher', ' since', ' Joh', 'an', ' Santana']
+262 176 In their sport, the position played by x -1 In their sport, the position played by Johan Santana pitcher Johan Santana "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ','
+ ' and' ' the' ' best' ' pitcher' ' in' ' the' ' AL' ' in' ' the'
+ ' playoffs']" , who was the best pitcher in the game , and the best pitcher in the AL in the playoffs True Award-winning pitcher Johan Santana that offseason. Humber 6 [' Award', '-', 'winning', ' pitcher', ' Joh', 'an', ' Santana']
+263 182 In their sport, the position played by x -1 In their sport, the position played by Fred Lynn outfielder Fred Lynn "[',' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the'
+ ' majors' ',' ' was' ' a' ' great' ' one' '.' ' He' ' was' ' a' ' great']" , the first black player to play in the majors , was a great one . He was a great False on RBI singles by Fred Lynn and Yastrzemski. 5 [' on', ' RBI', ' singles', ' by', ' Fred', ' Lynn']
+264 182 In their sport, the position played by x -1 In their sport, the position played by Fred Lynn outfielder Fred Lynn "[',' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the'
+ ' majors' ',' ' was' ' a' ' great' ' one' '.' ' He' ' was' ' a' ' great']" , the first black player to play in the majors , was a great one . He was a great False Rice, Carl Yastrzemski, Fred Lynn and catcher Carlton 11 [' Rice', ',', ' Carl', ' Y', 'ast', 'r', 'z', 'ems', 'ki', ',', ' Fred', ' Lynn']
+265 182 In their sport, the position played by x -1 In their sport, the position played by Fred Lynn outfielder Fred Lynn "[',' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the'
+ ' majors' ',' ' was' ' a' ' great' ' one' '.' ' He' ' was' ' a' ' great']" , the first black player to play in the majors , was a great one . He was a great False Carl Yastrzemski, Fred Lynn and catcher Carlton 9 [' Carl', ' Y', 'ast', 'r', 'z', 'ems', 'ki', ',', ' Fred', ' Lynn']
+266 182 In their sport, the position played by x -1 In their sport, the position played by Fred Lynn outfielder Fred Lynn "[',' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the'
+ ' majors' ',' ' was' ' a' ' great' ' one' '.' ' He' ' was' ' a' ' great']" , the first black player to play in the majors , was a great one . He was a great False Carl Yastrzemski, Fred Lynn and catcher 9 [' Carl', ' Y', 'ast', 'r', 'z', 'ems', 'ki', ',', ' Fred', ' Lynn']
+267 182 In their sport, the position played by x -1 In their sport, the position played by Fred Lynn outfielder Fred Lynn "[',' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the'
+ ' majors' ',' ' was' ' a' ' great' ' one' '.' ' He' ' was' ' a' ' great']" , the first black player to play in the majors , was a great one . He was a great False on RBI singles by Fred Lynn and Yastrzemski. 5 [' on', ' RBI', ' singles', ' by', ' Fred', ' Lynn']
+268 184 In their sport, the position played by x -1 In their sport, the position played by Tommy Taylor forward Tommy Taylor "[',' ' who' ' was' ' a' ' great' ' player' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of']" ", who was a great player , and a great person .
+
+ I was a huge fan of" False Crawford teamed with Tommy Taylor in a loss to Brianna 4 [' Crawford', ' teamed', ' with', ' Tommy', ' Taylor']
+269 184 In their sport, the position played by x -1 In their sport, the position played by Tommy Taylor forward Tommy Taylor "[',' ' who' ' was' ' a' ' great' ' player' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of']" ", who was a great player , and a great person .
+
+ I was a huge fan of" False Jones, David Pegg, Tommy Taylor and Billy 7 [' Jones', ',', ' David', ' Pe', 'gg', ',', ' Tommy', ' Taylor']
+270 184 In their sport, the position played by x -1 In their sport, the position played by Tommy Taylor forward Tommy Taylor "[',' ' who' ' was' ' a' ' great' ' player' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of']" ", who was a great player , and a great person .
+
+ I was a huge fan of" False Jones, David Pegg, Tommy Taylor and Billy Whelan 7 [' Jones', ',', ' David', ' Pe', 'gg', ',', ' Tommy', ' Taylor']
+271 184 In their sport, the position played by x -1 In their sport, the position played by Tommy Taylor forward Tommy Taylor "[',' ' who' ' was' ' a' ' great' ' player' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of']" ", who was a great player , and a great person .
+
+ I was a huge fan of" False Jones, David Pegg, Tommy Taylor and Billy Whelan – 7 [' Jones', ',', ' David', ' Pe', 'gg', ',', ' Tommy', ' Taylor']
+272 184 In their sport, the position played by x -1 In their sport, the position played by Tommy Taylor forward Tommy Taylor "[',' ' who' ' was' ' a' ' great' ' player' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of']" ", who was a great player , and a great person .
+
+ I was a huge fan of" False Jones, David Pegg, Tommy Taylor and Billy Whelan 7 [' Jones', ',', ' David', ' Pe', 'gg', ',', ' Tommy', ' Taylor']
+273 186 In their sport, the position played by x -1 In their sport, the position played by Mark Malone quarterback Mark Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' two' ' games' ' of' ' the' ' series' ',' ' and' ' the']" , who was a member of the team that won the first two games of the series , and the False Chargers also signed Mark Malone and Babe Laufenberg 4 [' Chargers', ' also', ' signed', ' Mark', ' Malone']
+274 186 In their sport, the position played by x -1 In their sport, the position played by Mark Malone quarterback Mark Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' two' ' games' ' of' ' the' ' series' ',' ' and' ' the']" , who was a member of the team that won the first two games of the series , and the False job alongside Mark Malone and David Archer, 3 [' job', ' alongside', ' Mark', ' Malone']
+275 186 In their sport, the position played by x -1 In their sport, the position played by Mark Malone quarterback Mark Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' two' ' games' ' of' ' the' ' series' ',' ' and' ' the']" , who was a member of the team that won the first two games of the series , and the False starting job alongside Mark Malone and David Archer, 4 [' starting', ' job', ' alongside', ' Mark', ' Malone']
+276 187 In their sport, the position played by x -1 In their sport, the position played by Paul Ince midfielder Paul Ince "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the only thing" False thirty-six appearances. Paul Ince was signed 7 [' thirty', '-', 'six', ' appearances', '.', ' Paul', ' In', 'ce']
+277 187 In their sport, the position played by x -1 In their sport, the position played by Paul Ince midfielder Paul Ince "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the only thing" False match; he partnered Paul Ince in midfield. 6 [' match', ';', ' he', ' partnered', ' Paul', ' In', 'ce']
+278 187 In their sport, the position played by x -1 In their sport, the position played by Paul Ince midfielder Paul Ince "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the only thing" False clear after Paul Ince scored a late equaliser 4 [' clear', ' after', ' Paul', ' In', 'ce']
+279 187 In their sport, the position played by x -1 In their sport, the position played by Paul Ince midfielder Paul Ince "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the only thing" False points clear after Paul Ince scored a late 5 [' points', ' clear', ' after', ' Paul', ' In', 'ce']
+280 187 In their sport, the position played by x -1 In their sport, the position played by Paul Ince midfielder Paul Ince "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the only thing" False players such as Paul Ince and Christian Ziege. 5 [' players', ' such', ' as', ' Paul', ' In', 'ce']
+281 189 In their sport, the position played by x -1 In their sport, the position played by Josh Beckett pitcher Josh Beckett "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of' ' the'
+ ' type']" , who was the best pitcher in the National League in 2011 , is a perfect example of the type True The next day, Josh Beckett pitched a no-hitter 6 [' The', ' next', ' day', ',', ' Josh', ' Beck', 'ett']
+282 189 In their sport, the position played by x -1 In their sport, the position played by Josh Beckett pitcher Josh Beckett "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of' ' the'
+ ' type']" , who was the best pitcher in the National League in 2011 , is a perfect example of the type True was traded with Josh Beckett and Mike Lowell to 5 [' was', ' traded', ' with', ' Josh', ' Beck', 'ett']
+283 189 In their sport, the position played by x -1 In their sport, the position played by Josh Beckett pitcher Josh Beckett "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of' ' the'
+ ' type']" , who was the best pitcher in the National League in 2011 , is a perfect example of the type True save. The next day, Josh Beckett pitched a no-hitter 8 [' save', '.', ' The', ' next', ' day', ',', ' Josh', ' Beck', 'ett']
+284 189 In their sport, the position played by x -1 In their sport, the position played by Josh Beckett pitcher Josh Beckett "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of' ' the'
+ ' type']" , who was the best pitcher in the National League in 2011 , is a perfect example of the type True only three hits Josh Beckett allowed during 5 [' only', ' three', ' hits', ' Josh', ' Beck', 'ett']
+285 189 In their sport, the position played by x -1 In their sport, the position played by Josh Beckett pitcher Josh Beckett "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of' ' the'
+ ' type']" , who was the best pitcher in the National League in 2011 , is a perfect example of the type True The next day, Josh Beckett pitched a no-hitter 6 [' The', ' next', ' day', ',', ' Josh', ' Beck', 'ett']
+286 190 In their sport, the position played by x -1 In their sport, the position played by Lou Saban linebacker Lou Saban "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False 1 ['Lou', ' Saban']
+287 190 In their sport, the position played by x -1 In their sport, the position played by Lou Saban linebacker Lou Saban "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False captain, replacing Lou Saban following his retirement. 4 [' captain', ',', ' replacing', ' Lou', ' Saban']
+288 190 In their sport, the position played by x -1 In their sport, the position played by Lou Saban linebacker Lou Saban "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False surprising move, Lou Saban unexpectedly 4 [' surprising', ' move', ',', ' Lou', ' Saban']
+289 190 In their sport, the position played by x -1 In their sport, the position played by Lou Saban linebacker Lou Saban "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False 1 ['Lou', ' Saban']
+290 190 In their sport, the position played by x -1 In their sport, the position played by Lou Saban linebacker Lou Saban "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False following season, Lou Saban took over as head 4 [' following', ' season', ',', ' Lou', ' Saban']
+291 192 In their sport, the position played by x -1 In their sport, the position played by Corey Moore linebacker Corey Moore "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' national' ' championship' ' in' ' the' ' fall' ' of' ' 2017' '.'
+ '\n']" ", who was a member of the team that won the national championship in the fall of 2017 .
+" False defensive end Corey Moore was named the game's 3 [' defensive', ' end', ' Corey', ' Moore']
+292 192 In their sport, the position played by x -1 In their sport, the position played by Corey Moore linebacker Corey Moore "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' national' ' championship' ' in' ' the' ' fall' ' of' ' 2017' '.'
+ '\n']" ", who was a member of the team that won the national championship in the fall of 2017 .
+" False in the draft. Corey Moore (89th), Anthony 5 [' in', ' the', ' draft', '.', ' Corey', ' Moore']
+293 192 In their sport, the position played by x -1 In their sport, the position played by Corey Moore linebacker Corey Moore "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' national' ' championship' ' in' ' the' ' fall' ' of' ' 2017' '.'
+ '\n']" ", who was a member of the team that won the national championship in the fall of 2017 .
+" False Tech defensive end Corey Moore was the top performer 4 [' Tech', ' defensive', ' end', ' Corey', ' Moore']
+294 192 In their sport, the position played by x -1 In their sport, the position played by Corey Moore linebacker Corey Moore "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' national' ' championship' ' in' ' the' ' fall' ' of' ' 2017' '.'
+ '\n']" ", who was a member of the team that won the national championship in the fall of 2017 .
+" False pick in the draft. Corey Moore (89th), Anthony 6 [' pick', ' in', ' the', ' draft', '.', ' Corey', ' Moore']
+295 192 In their sport, the position played by x -1 In their sport, the position played by Corey Moore linebacker Corey Moore "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' national' ' championship' ' in' ' the' ' fall' ' of' ' 2017' '.'
+ '\n']" ", who was a member of the team that won the national championship in the fall of 2017 .
+" False Tech defender Corey Moore knocked the ball 3 [' Tech', ' defender', ' Corey', ' Moore']
+296 193 In their sport, the position played by x -1 In their sport, the position played by Jim Plunkett quarterback Jim Plunkett "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True quarterback Jim Plunkett and future Hall 4 [' quarterback', ' Jim', ' Pl', 'unk', 'ett']
+297 193 In their sport, the position played by x -1 In their sport, the position played by Jim Plunkett quarterback Jim Plunkett "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True by quarterback Jim Plunkett and future Hall of 5 [' by', ' quarterback', ' Jim', ' Pl', 'unk', 'ett']
+298 193 In their sport, the position played by x -1 In their sport, the position played by Jim Plunkett quarterback Jim Plunkett "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True quarterback Jim Plunkett completed 21 of 41 4 [' quarterback', ' Jim', ' Pl', 'unk', 'ett']
+299 196 In their sport, the position played by x -1 In their sport, the position played by Mike Tomczak quarterback Mike Tomczak "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True lineup, although Mike Tomczak remained the 6 [' lineup', ',', ' although', ' Mike', ' Tom', 'cz', 'ak']
+300 196 In their sport, the position played by x -1 In their sport, the position played by Mike Tomczak quarterback Mike Tomczak "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True Holmes' hit on Mike Tomczak forced a fumble 7 "[' Holmes', ""'"", ' hit', ' on', ' Mike', ' Tom', 'cz', 'ak']"
+301 196 In their sport, the position played by x -1 In their sport, the position played by Mike Tomczak quarterback Mike Tomczak "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True lineup, although Mike Tomczak remained the 6 [' lineup', ',', ' although', ' Mike', ' Tom', 'cz', 'ak']
+302 196 In their sport, the position played by x -1 In their sport, the position played by Mike Tomczak quarterback Mike Tomczak "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True in for a score. Mike Tomczak ran for the Bears'first 8 [' in', ' for', ' a', ' score', '.', ' Mike', ' Tom', 'cz', 'ak']
+303 196 In their sport, the position played by x -1 In their sport, the position played by Mike Tomczak quarterback Mike Tomczak "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True behind Kordell Stewart, Mike Tomczak and Jim Miller on 9 [' behind', ' K', 'ord', 'ell', ' Stewart', ',', ' Mike', ' Tom', 'cz', 'ak']
+304 200 In their sport, the position played by x -1 In their sport, the position played by Dave Winfield outfielder Dave Winfield "[',' ' the' ' former' ' New' ' York' ' Yankees' ' outfielder' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who'
+ ' can' ' be']" , the former New York Yankees outfielder , is a perfect example of the kind of player who can be True from Wycombe and Dave Winfield from Shrewsbury Town, 6 [' from', ' Wy', 'combe', ' and', ' Dave', ' Win', 'field']
+305 200 In their sport, the position played by x -1 In their sport, the position played by Dave Winfield outfielder Dave Winfield "[',' ' the' ' former' ' New' ' York' ' Yankees' ' outfielder' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who'
+ ' can' ' be']" , the former New York Yankees outfielder , is a perfect example of the kind of player who can be True baseball star Dave Winfield walked on as the 4 [' baseball', ' star', ' Dave', ' Win', 'field']
+306 200 In their sport, the position played by x -1 In their sport, the position played by Dave Winfield outfielder Dave Winfield "[',' ' the' ' former' ' New' ' York' ' Yankees' ' outfielder' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who'
+ ' can' ' be']" , the former New York Yankees outfielder , is a perfect example of the kind of player who can be True Baseball Hall of Famer Dave Winfield played for the 7 [' Baseball', ' Hall', ' of', ' F', 'amer', ' Dave', ' Win', 'field']
+307 200 In their sport, the position played by x -1 In their sport, the position played by Dave Winfield outfielder Dave Winfield "[',' ' the' ' former' ' New' ' York' ' Yankees' ' outfielder' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who'
+ ' can' ' be']" , the former New York Yankees outfielder , is a perfect example of the kind of player who can be True Minnesota baseball star Dave Winfield walked on as the 5 [' Minnesota', ' baseball', ' star', ' Dave', ' Win', 'field']
+308 200 In their sport, the position played by x -1 In their sport, the position played by Dave Winfield outfielder Dave Winfield "[',' ' the' ' former' ' New' ' York' ' Yankees' ' outfielder' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who'
+ ' can' ' be']" , the former New York Yankees outfielder , is a perfect example of the kind of player who can be True Wycombe Wanderers and Dave Winfield from Shrewsbury 7 [' Wy', 'combe', ' Wand', 'erers', ' and', ' Dave', ' Win', 'field']
+309 202 In their sport, the position played by x -1 In their sport, the position played by Nolan Schaefer goaltender Nolan Schaefer "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL False trade that sent Nolan Schaefer to the Penguins in 5 [' trade', ' that', ' sent', ' Nolan', ' Scha', 'efer']
+310 202 In their sport, the position played by x -1 In their sport, the position played by Nolan Schaefer goaltender Nolan Schaefer "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL False trade that sent Nolan Schaefer to the Penguins 5 [' trade', ' that', ' sent', ' Nolan', ' Scha', 'efer']
+311 202 In their sport, the position played by x -1 In their sport, the position played by Nolan Schaefer goaltender Nolan Schaefer "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL False that sent Nolan Schaefer to the Penguins in 4 [' that', ' sent', ' Nolan', ' Scha', 'efer']
+312 202 In their sport, the position played by x -1 In their sport, the position played by Nolan Schaefer goaltender Nolan Schaefer "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL False that sent Nolan Schaefer to the Penguins 4 [' that', ' sent', ' Nolan', ' Scha', 'efer']
+313 202 In their sport, the position played by x -1 In their sport, the position played by Nolan Schaefer goaltender Nolan Schaefer "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL False trade that sent Nolan Schaefer to the Penguins 5 [' trade', ' that', ' sent', ' Nolan', ' Scha', 'efer']
+314 204 In their sport, the position played by x -1 In their sport, the position played by Heath Shuler quarterback Heath Shuler "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Tennessee' ',' ' was' ' filled' ' by' ' a' ' former' ' Tennessee'
+ ' quarterback' ',' ' Phillip']" , who was a quarterback at the University of Tennessee , was filled by a former Tennessee quarterback , Phillip True history. You make Heath Shuler look like an 6 [' history', '.', ' You', ' make', ' Heath', ' Sh', 'uler']
+315 204 In their sport, the position played by x -1 In their sport, the position played by Heath Shuler quarterback Heath Shuler "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Tennessee' ',' ' was' ' filled' ' by' ' a' ' former' ' Tennessee'
+ ' quarterback' ',' ' Phillip']" , who was a quarterback at the University of Tennessee , was filled by a former Tennessee quarterback , Phillip True " history. You make Heath Shuler look like an All-Star.""" 6 [' history', '.', ' You', ' make', ' Heath', ' Sh', 'uler']
+316 206 In their sport, the position played by x -1 In their sport, the position played by Zito midfielder Zito "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good'
+ ' pitcher' ',' ' but' ' he' ""'s"" ' not' ' a' ' great' ' pitcher']" is a bit of a mystery . He 's a good pitcher , but he 's not a great pitcher False crouched delivery, Zito began the 2012 5 [' crou', 'ched', ' delivery', ',', ' Z', 'ito']
+317 206 In their sport, the position played by x -1 In their sport, the position played by Zito midfielder Zito "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good'
+ ' pitcher' ',' ' but' ' he' ""'s"" ' not' ' a' ' great' ' pitcher']" is a bit of a mystery . He 's a good pitcher , but he 's not a great pitcher False 1 ['Z', 'ito']
+318 206 In their sport, the position played by x -1 In their sport, the position played by Zito midfielder Zito "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good'
+ ' pitcher' ',' ' but' ' he' ""'s"" ' not' ' a' ' great' ' pitcher']" is a bit of a mystery . He 's a good pitcher , but he 's not a great pitcher False 1 ['Z', 'ito']
+319 206 In their sport, the position played by x -1 In their sport, the position played by Zito midfielder Zito "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good'
+ ' pitcher' ',' ' but' ' he' ""'s"" ' not' ' a' ' great' ' pitcher']" is a bit of a mystery . He 's a good pitcher , but he 's not a great pitcher False 1 ['Z', 'ito']
+320 206 In their sport, the position played by x -1 In their sport, the position played by Zito midfielder Zito "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ""'s"" ' a' ' good'
+ ' pitcher' ',' ' but' ' he' ""'s"" ' not' ' a' ' great' ' pitcher']" is a bit of a mystery . He 's a good pitcher , but he 's not a great pitcher False Athletics, Zito signed a seven-year 3 [' Athletics', ',', ' Z', 'ito']
+321 208 In their sport, the position played by x -1 In their sport, the position played by Michel Platini midfielder Michel Platini "[',' ' the' ' former' ' France' ' international' ' and' ' European'
+ ' Football' 'er' ' of' ' the' ' Year' ',' ' was' ' a' ' key' ' figure'
+ ' in' ' the' ' French']" , the former France international and European Football er of the Year , was a key figure in the French False the penalty. Michel Platini scored the 5 [' the', ' penalty', '.', ' Michel', ' Plat', 'ini']
+322 208 In their sport, the position played by x -1 In their sport, the position played by Michel Platini midfielder Michel Platini "[',' ' the' ' former' ' France' ' international' ' and' ' European'
+ ' Football' 'er' ' of' ' the' ' Year' ',' ' was' ' a' ' key' ' figure'
+ ' in' ' the' ' French']" , the former France international and European Football er of the Year , was a key figure in the French False match 1 – 0; Michel Platini scored from the 7 [' match', ' 1', ' –', ' 0', ';', ' Michel', ' Plat', 'ini']
+323 208 In their sport, the position played by x -1 In their sport, the position played by Michel Platini midfielder Michel Platini "[',' ' the' ' former' ' France' ' international' ' and' ' European'
+ ' Football' 'er' ' of' ' the' ' Year' ',' ' was' ' a' ' key' ' figure'
+ ' in' ' the' ' French']" , the former France international and European Football er of the Year , was a key figure in the French False However, UEFA President Michel Platini later denied that 6 [' However', ',', ' UEFA', ' President', ' Michel', ' Plat', 'ini']
+324 208 In their sport, the position played by x -1 In their sport, the position played by Michel Platini midfielder Michel Platini "[',' ' the' ' former' ' France' ' international' ' and' ' European'
+ ' Football' 'er' ' of' ' the' ' Year' ',' ' was' ' a' ' key' ' figure'
+ ' in' ' the' ' French']" , the former France international and European Football er of the Year , was a key figure in the French False 3 ['Mic', 'hel', ' Plat', 'ini']
+325 208 In their sport, the position played by x -1 In their sport, the position played by Michel Platini midfielder Michel Platini "[',' ' the' ' former' ' France' ' international' ' and' ' European'
+ ' Football' 'er' ' of' ' the' ' Year' ',' ' was' ' a' ' key' ' figure'
+ ' in' ' the' ' French']" , the former France international and European Football er of the Year , was a key figure in the French False president Michel Platini and secretary 3 [' president', ' Michel', ' Plat', 'ini']
+326 212 In their sport, the position played by x -1 In their sport, the position played by Maniche midfielder Maniche "['an' ' terms' ' like' ' ""' 'good' '""' ' and' ' ""' 'evil' '""' ' are'
+ ' not' ' only' ' not' ' necessary' ',' ' they' ' are' ' not' ' even']" "an terms like "" good "" and "" evil "" are not only not necessary , they are not even" False ended in a 1 – 1 draw. Maniche scored the lone 9 [' ended', ' in', ' a', ' 1', ' –', ' 1', ' draw', '.', ' Man', 'iche']
+327 212 In their sport, the position played by x -1 In their sport, the position played by Maniche midfielder Maniche "['an' ' terms' ' like' ' ""' 'good' '""' ' and' ' ""' 'evil' '""' ' are'
+ ' not' ' only' ' not' ' necessary' ',' ' they' ' are' ' not' ' even']" "an terms like "" good "" and "" evil "" are not only not necessary , they are not even" False – 1 draw. Maniche scored the lone 5 [' –', ' 1', ' draw', '.', ' Man', 'iche']
+328 212 In their sport, the position played by x -1 In their sport, the position played by Maniche midfielder Maniche "['an' ' terms' ' like' ' ""' 'good' '""' ' and' ' ""' 'evil' '""' ' are'
+ ' not' ' only' ' not' ' necessary' ',' ' they' ' are' ' not' ' even']" "an terms like "" good "" and "" evil "" are not only not necessary , they are not even" False join the team were Maniche and Marco Freitas 5 [' join', ' the', ' team', ' were', ' Man', 'iche']
+329 212 In their sport, the position played by x -1 In their sport, the position played by Maniche midfielder Maniche "['an' ' terms' ' like' ' ""' 'good' '""' ' and' ' ""' 'evil' '""' ' are'
+ ' not' ' only' ' not' ' necessary' ',' ' they' ' are' ' not' ' even']" "an terms like "" good "" and "" evil "" are not only not necessary , they are not even" False – 1 draw. Maniche scored the lone 5 [' –', ' 1', ' draw', '.', ' Man', 'iche']
+330 212 In their sport, the position played by x -1 In their sport, the position played by Maniche midfielder Maniche "['an' ' terms' ' like' ' ""' 'good' '""' ' and' ' ""' 'evil' '""' ' are'
+ ' not' ' only' ' not' ' necessary' ',' ' they' ' are' ' not' ' even']" "an terms like "" good "" and "" evil "" are not only not necessary , they are not even" False the team were Maniche and Marco Freitas 4 [' the', ' team', ' were', ' Man', 'iche']
+331 216 In their sport, the position played by x -1 In their sport, the position played by Keith Peacock midfielder Keith Peacock "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False however, manager Keith Peacock was controversially 6 [' however', ',', ' manager', ' Keith', ' Pe', 'ac', 'ock']
+332 216 In their sport, the position played by x -1 In their sport, the position played by Keith Peacock midfielder Keith Peacock "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False however, manager Keith Peacock was controversially 6 [' however', ',', ' manager', ' Keith', ' Pe', 'ac', 'ock']
+333 216 In their sport, the position played by x -1 In their sport, the position played by Keith Peacock midfielder Keith Peacock "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False however, manager Keith Peacock was controversially 6 [' however', ',', ' manager', ' Keith', ' Pe', 'ac', 'ock']
+334 216 In their sport, the position played by x -1 In their sport, the position played by Keith Peacock midfielder Keith Peacock "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False later, however, manager Keith Peacock was controversially 8 [' later', ',', ' however', ',', ' manager', ' Keith', ' Pe', 'ac', 'ock']
+335 217 In their sport, the position played by x -1 In their sport, the position played by Jim Edmonds outfielder Jim Edmonds "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' hitter' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' hitter']" ", who was a great player , but not a great hitter .
+
+ I think the best hitter" False to strike out Jim Edmonds to end the 5 [' to', ' strike', ' out', ' Jim', ' Ed', 'monds']
+336 217 In their sport, the position played by x -1 In their sport, the position played by Jim Edmonds outfielder Jim Edmonds "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' hitter' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' hitter']" ", who was a great player , but not a great hitter .
+
+ I think the best hitter" False Matheny to score Jim Edmonds and a home run to 6 [' Mat', 'heny', ' to', ' score', ' Jim', ' Ed', 'monds']
+337 217 In their sport, the position played by x -1 In their sport, the position played by Jim Edmonds outfielder Jim Edmonds "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' hitter' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' hitter']" ", who was a great player , but not a great hitter .
+
+ I think the best hitter" False Cardinal since Jim Edmonds to win that 4 [' Cardinal', ' since', ' Jim', ' Ed', 'monds']
+338 217 In their sport, the position played by x -1 In their sport, the position played by Jim Edmonds outfielder Jim Edmonds "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' hitter' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' hitter']" ", who was a great player , but not a great hitter .
+
+ I think the best hitter" False Matheny to score Jim Edmonds and a home 6 [' Mat', 'heny', ' to', ' score', ' Jim', ' Ed', 'monds']
+339 217 In their sport, the position played by x -1 In their sport, the position played by Jim Edmonds outfielder Jim Edmonds "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' hitter' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' hitter']" ", who was a great player , but not a great hitter .
+
+ I think the best hitter" False Matheny to score Jim Edmonds and a home run to 6 [' Mat', 'heny', ' to', ' score', ' Jim', ' Ed', 'monds']
+340 218 In their sport, the position played by x -1 In their sport, the position played by Drew Storen pitcher Drew Storen "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' bullpen'
+ ' this' ' season' '.' ' He' ' has' ' been' ' a' ' key' ' part' ' of'
+ ' the']" , who has been a revelation in the bullpen this season . He has been a key part of the False " exchange for Drew Storen and cash considerations.
+" 4 [' exchange', ' for', ' Drew', ' Store', 'n']
+341 218 In their sport, the position played by x -1 In their sport, the position played by Drew Storen pitcher Drew Storen "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' bullpen'
+ ' this' ' season' '.' ' He' ' has' ' been' ' a' ' key' ' part' ' of'
+ ' the']" , who has been a revelation in the bullpen this season . He has been a key part of the False season, Clippard and Drew Storen served as setup pitchers 8 [' season', ',', ' Cl', 'ipp', 'ard', ' and', ' Drew', ' Store', 'n']
+342 218 In their sport, the position played by x -1 In their sport, the position played by Drew Storen pitcher Drew Storen "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' bullpen'
+ ' this' ' season' '.' ' He' ' has' ' been' ' a' ' key' ' part' ' of'
+ ' the']" , who has been a revelation in the bullpen this season . He has been a key part of the False season, Clippard and Drew Storen served as setup 8 [' season', ',', ' Cl', 'ipp', 'ard', ' and', ' Drew', ' Store', 'n']
+343 223 In their sport, the position played by x -1 In their sport, the position played by Dont'a Hightower linebacker Dont'a Hightower "[' is' ' the' ' equivalent' ' of' ' a' ' linebacker' '.' ' He' ' is'
+ ' the' ' middle' ' linebacker' ',' ' the' ' guy' ' who' ' is'
+ ' responsible' ' for' ' covering']" is the equivalent of a linebacker . He is the middle linebacker , the guy who is responsible for covering True to the 50 after Dont'a Hightower sacked Jordan Jefferson 10 "[' to', ' the', ' 50', ' after', ' D', 'ont', ""'"", 'a', ' H', 'ight', 'ower']"
+344 223 In their sport, the position played by x -1 In their sport, the position played by Dont'a Hightower linebacker Dont'a Hightower "[' is' ' the' ' equivalent' ' of' ' a' ' linebacker' '.' ' He' ' is'
+ ' the' ' middle' ' linebacker' ',' ' the' ' guy' ' who' ' is'
+ ' responsible' ' for' ' covering']" is the equivalent of a linebacker . He is the middle linebacker , the guy who is responsible for covering True the 50 after Dont'a Hightower sacked Jordan 9 "[' the', ' 50', ' after', ' D', 'ont', ""'"", 'a', ' H', 'ight', 'ower']"
+345 223 In their sport, the position played by x -1 In their sport, the position played by Dont'a Hightower linebacker Dont'a Hightower "[' is' ' the' ' equivalent' ' of' ' a' ' linebacker' '.' ' He' ' is'
+ ' the' ' middle' ' linebacker' ',' ' the' ' guy' ' who' ' is'
+ ' responsible' ' for' ' covering']" is the equivalent of a linebacker . He is the middle linebacker , the guy who is responsible for covering True and after a Dont'a Hightower interception, Richardson 9 "[' and', ' after', ' a', ' D', 'ont', ""'"", 'a', ' H', 'ight', 'ower']"
+346 223 In their sport, the position played by x -1 In their sport, the position played by Dont'a Hightower linebacker Dont'a Hightower "[' is' ' the' ' equivalent' ' of' ' a' ' linebacker' '.' ' He' ' is'
+ ' the' ' middle' ' linebacker' ',' ' the' ' guy' ' who' ' is'
+ ' responsible' ' for' ' covering']" is the equivalent of a linebacker . He is the middle linebacker , the guy who is responsible for covering True Maze, linebackers Dont'a Hightower and Courtney Upshaw, 9 "[' Maze', ',', ' linebackers', ' D', 'ont', ""'"", 'a', ' H', 'ight', 'ower']"
+347 223 In their sport, the position played by x -1 In their sport, the position played by Dont'a Hightower linebacker Dont'a Hightower "[' is' ' the' ' equivalent' ' of' ' a' ' linebacker' '.' ' He' ' is'
+ ' the' ' middle' ' linebacker' ',' ' the' ' guy' ' who' ' is'
+ ' responsible' ' for' ' covering']" is the equivalent of a linebacker . He is the middle linebacker , the guy who is responsible for covering True the 50 after Dont'a Hightower sacked Jordan Jefferson 9 "[' the', ' 50', ' after', ' D', 'ont', ""'"", 'a', ' H', 'ight', 'ower']"
+348 225 In their sport, the position played by x -1 In their sport, the position played by Bob Essensa goaltender Bob Essensa "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False string goaltender Bob Essensa who had yet to 5 [' string', ' goaltender', ' Bob', ' Ess', 'ens', 'a']
+349 225 In their sport, the position played by x -1 In their sport, the position played by Bob Essensa goaltender Bob Essensa "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False goal against Bob Essensa of the Winnipeg 5 [' goal', ' against', ' Bob', ' Ess', 'ens', 'a']
+350 225 In their sport, the position played by x -1 In their sport, the position played by Bob Essensa goaltender Bob Essensa "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False the net against Bob Essensa in a 9 – 5 victory 6 [' the', ' net', ' against', ' Bob', ' Ess', 'ens', 'a']
+351 225 In their sport, the position played by x -1 In their sport, the position played by Bob Essensa goaltender Bob Essensa "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False goaltender Bob Essensa who had yet to 4 [' goaltender', ' Bob', ' Ess', 'ens', 'a']
+352 225 In their sport, the position played by x -1 In their sport, the position played by Bob Essensa goaltender Bob Essensa "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False net against Bob Essensa in a 9 – 5 5 [' net', ' against', ' Bob', ' Ess', 'ens', 'a']
+353 226 In their sport, the position played by x -1 In their sport, the position played by Teddy Lehman linebacker Teddy Lehman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False honor, following Teddy Lehman in 2003 and 5 [' honor', ',', ' following', ' Teddy', ' Leh', 'man']
+354 226 In their sport, the position played by x -1 In their sport, the position played by Teddy Lehman linebacker Teddy Lehman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False honor, following Teddy Lehman in 2003 and Roy 5 [' honor', ',', ' following', ' Teddy', ' Leh', 'man']
+355 226 In their sport, the position played by x -1 In their sport, the position played by Teddy Lehman linebacker Teddy Lehman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False this honor, following Teddy Lehman in 2003 and Roy 6 [' this', ' honor', ',', ' following', ' Teddy', ' Leh', 'man']
+356 227 In their sport, the position played by x -1 In their sport, the position played by Graham Taylor defender Graham Taylor "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' England' ' team' ' that'
+ ' won' ' the' ' World' ' Cup' ' in' ' 1966' ',' ' and' ' who' ' was']" , who was a member of the England team that won the World Cup in 1966 , and who was False 1 ['Graham', ' Taylor']
+357 227 In their sport, the position played by x -1 In their sport, the position played by Graham Taylor defender Graham Taylor "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' England' ' team' ' that'
+ ' won' ' the' ' World' ' Cup' ' in' ' 1966' ',' ' and' ' who' ' was']" , who was a member of the England team that won the World Cup in 1966 , and who was False April 1977. When Graham Taylor was named as Keen's 5 [' April', ' 1977', '.', ' When', ' Graham', ' Taylor']
+358 227 In their sport, the position played by x -1 In their sport, the position played by Graham Taylor defender Graham Taylor "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' England' ' team' ' that'
+ ' won' ' the' ' World' ' Cup' ' in' ' 1966' ',' ' and' ' who' ' was']" , who was a member of the England team that won the World Cup in 1966 , and who was False finishing bottom. Graham Taylor retired at 4 [' finishing', ' bottom', '.', ' Graham', ' Taylor']
+359 227 In their sport, the position played by x -1 In their sport, the position played by Graham Taylor defender Graham Taylor "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' England' ' team' ' that'
+ ' won' ' the' ' World' ' Cup' ' in' ' 1966' ',' ' and' ' who' ' was']" , who was a member of the England team that won the World Cup in 1966 , and who was False " Bobby Robson, Graham Taylor and Glenn Hoddle.
+" 5 [' Bobby', ' Rob', 'son', ',', ' Graham', ' Taylor']
+360 227 In their sport, the position played by x -1 In their sport, the position played by Graham Taylor defender Graham Taylor "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' England' ' team' ' that'
+ ' won' ' the' ' World' ' Cup' ' in' ' 1966' ',' ' and' ' who' ' was']" , who was a member of the England team that won the World Cup in 1966 , and who was False once again appointed Graham Taylor as manager. Villa 4 [' once', ' again', ' appointed', ' Graham', ' Taylor']
+361 231 In their sport, the position played by x -1 In their sport, the position played by Curt Flood outfielder Curt Flood "[' in' ' the' ' 1970' 's' ',' ' the' ' players' ' were' ' not' ' only'
+ ' the' ' best' ' in' ' the' ' world' ',' ' but' ' they' ' were' ' also']" in the 1970 s , the players were not only the best in the world , but they were also False same year, Curt Flood of the St. Louis Cardinals 4 [' same', ' year', ',', ' Curt', ' Flood']
+362 231 In their sport, the position played by x -1 In their sport, the position played by Curt Flood outfielder Curt Flood "[' in' ' the' ' 1970' 's' ',' ' the' ' players' ' were' ' not' ' only'
+ ' the' ' best' ' in' ' the' ' world' ',' ' but' ' they' ' were' ' also']" in the 1970 s , the players were not only the best in the world , but they were also False same year, Curt Flood of the St. Louis 4 [' same', ' year', ',', ' Curt', ' Flood']
+363 231 In their sport, the position played by x -1 In their sport, the position played by Curt Flood outfielder Curt Flood "[' in' ' the' ' 1970' 's' ',' ' the' ' players' ' were' ' not' ' only'
+ ' the' ' best' ' in' ' the' ' world' ',' ' but' ' they' ' were' ' also']" in the 1970 s , the players were not only the best in the world , but they were also False Bill Russell and Curt Flood all went to within 4 [' Bill', ' Russell', ' and', ' Curt', ' Flood']
+364 231 In their sport, the position played by x -1 In their sport, the position played by Curt Flood outfielder Curt Flood "[' in' ' the' ' 1970' 's' ',' ' the' ' players' ' were' ' not' ' only'
+ ' the' ' best' ' in' ' the' ' world' ',' ' but' ' they' ' were' ' also']" in the 1970 s , the players were not only the best in the world , but they were also False Russell and Curt Flood all went to within 3 [' Russell', ' and', ' Curt', ' Flood']
+365 231 In their sport, the position played by x -1 In their sport, the position played by Curt Flood outfielder Curt Flood "[' in' ' the' ' 1970' 's' ',' ' the' ' players' ' were' ' not' ' only'
+ ' the' ' best' ' in' ' the' ' world' ',' ' but' ' they' ' were' ' also']" in the 1970 s , the players were not only the best in the world , but they were also False scored teammate Curt Flood in the sixth. 3 [' scored', ' teammate', ' Curt', ' Flood']
+366 234 In their sport, the position played by x -1 In their sport, the position played by Glenn Roeder defender Glenn Roeder "[',' ' who' ' was' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , who was a former NFL quarterback , is a perfect example of the kind of player who can be False " opportunities under Glenn Roeder in 1994 – 95.
+" 4 [' opportunities', ' under', ' Glenn', ' Ro', 'eder']
+367 234 In their sport, the position played by x -1 In their sport, the position played by Glenn Roeder defender Glenn Roeder "[',' ' who' ' was' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , who was a former NFL quarterback , is a perfect example of the kind of player who can be False 3 ['Gl', 'enn', ' Ro', 'eder']
+368 234 In their sport, the position played by x -1 In their sport, the position played by Glenn Roeder defender Glenn Roeder "[',' ' who' ' was' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , who was a former NFL quarterback , is a perfect example of the kind of player who can be False United manager Glenn Roeder was confirmed as 4 [' United', ' manager', ' Glenn', ' Ro', 'eder']
+369 234 In their sport, the position played by x -1 In their sport, the position played by Glenn Roeder defender Glenn Roeder "[',' ' who' ' was' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , who was a former NFL quarterback , is a perfect example of the kind of player who can be False United manager Glenn Roeder was confirmed as 4 [' United', ' manager', ' Glenn', ' Ro', 'eder']
+370 234 In their sport, the position played by x -1 In their sport, the position played by Glenn Roeder defender Glenn Roeder "[',' ' who' ' was' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , who was a former NFL quarterback , is a perfect example of the kind of player who can be False " opportunities under Glenn Roeder in 1994 – 95.
+" 4 [' opportunities', ' under', ' Glenn', ' Ro', 'eder']
+371 236 In their sport, the position played by x -1 In their sport, the position played by Mike Piazza catcher Mike Piazza "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True Boone in 1953, and Mike Piazza in 1998), and 8 [' Boone', ' in', ' 1953', ',', ' and', ' Mike', ' P', 'ia', 'zza']
+372 236 In their sport, the position played by x -1 In their sport, the position played by Mike Piazza catcher Mike Piazza "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True and catcher Mike Piazza as having late 5 [' and', ' catcher', ' Mike', ' P', 'ia', 'zza']
+373 236 In their sport, the position played by x -1 In their sport, the position played by Mike Piazza catcher Mike Piazza "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True recently retired catcher Mike Piazza against claims 6 [' recently', ' retired', ' catcher', ' Mike', ' P', 'ia', 'zza']
+374 236 In their sport, the position played by x -1 In their sport, the position played by Mike Piazza catcher Mike Piazza "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True For example, Mike Piazza grounded into 6 [' For', ' example', ',', ' Mike', ' P', 'ia', 'zza']
+375 236 In their sport, the position played by x -1 In their sport, the position played by Mike Piazza catcher Mike Piazza "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True Johnny Bench and Mike Piazza as the only 6 [' Johnny', ' Bench', ' and', ' Mike', ' P', 'ia', 'zza']
+376 238 In their sport, the position played by x -1 In their sport, the position played by Tommy Harper outfielder Tommy Harper "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' a' ' man']" in the film , the character of the hero is a bit of a mystery . He is a man False a team record until Tommy Harper stole 54 bases 5 [' a', ' team', ' record', ' until', ' Tommy', ' Harper']
+377 238 In their sport, the position played by x -1 In their sport, the position played by Tommy Harper outfielder Tommy Harper "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' a' ' man']" in the film , the character of the hero is a bit of a mystery . He is a man False team record until Tommy Harper stole 54 bases 4 [' team', ' record', ' until', ' Tommy', ' Harper']
+378 239 In their sport, the position played by x -1 In their sport, the position played by Koke midfielder Koke "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' club' ' since' ' he' ' was' ' a' ' teenager' ',' ' is' ' a' ' key'
+ ' part']" , who has been a regular starter for the club since he was a teenager , is a key part False replaced by Koke after 59 minutes, 3 [' replaced', ' by', ' K', 'oke']
+379 239 In their sport, the position played by x -1 In their sport, the position played by Koke midfielder Koke "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' club' ' since' ' he' ' was' ' a' ' teenager' ',' ' is' ' a' ' key'
+ ' part']" , who has been a regular starter for the club since he was a teenager , is a key part False September, also assisting Koke as Atlético won 2 – 5 [' September', ',', ' also', ' assisting', ' K', 'oke']
+380 242 In their sport, the position played by x -1 In their sport, the position played by Jeff Kemp quarterback Jeff Kemp "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' Broadway' ' musical' ',' ' and' ' the' ' role' ' of' ' the'
+ ' lead']" , who was a member of the original cast of the Broadway musical , and the role of the lead False quarterback Jeff Kemp '81, Olympic 2 [' quarterback', ' Jeff', ' Kemp']
+381 242 In their sport, the position played by x -1 In their sport, the position played by Jeff Kemp quarterback Jeff Kemp "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' Broadway' ' musical' ',' ' and' ' the' ' role' ' of' ' the'
+ ' lead']" , who was a member of the original cast of the Broadway musical , and the role of the lead False football quarterbacks: Jeff Kemp (born in 1959) played 4 [' football', ' quarterbacks', ':', ' Jeff', ' Kemp']
+382 242 In their sport, the position played by x -1 In their sport, the position played by Jeff Kemp quarterback Jeff Kemp "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' Broadway' ' musical' ',' ' and' ' the' ' role' ' of' ' the'
+ ' lead']" , who was a member of the original cast of the Broadway musical , and the role of the lead False football quarterbacks: Jeff Kemp (born in 1959) played 4 [' football', ' quarterbacks', ':', ' Jeff', ' Kemp']
+383 243 In their sport, the position played by x -1 In their sport, the position played by John Harkes midfielder John Harkes "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' in' ' the' ' audience' ' at' ' the' '\n']" ", the
+
+ The first time I saw the movie , I was in the audience at the
+" False former D.C. United John Harkes offered color commentary 9 [' former', ' D', '.', 'C', '.', ' United', ' John', ' H', 'ark', 'es']
+384 243 In their sport, the position played by x -1 In their sport, the position played by John Harkes midfielder John Harkes "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' in' ' the' ' audience' ' at' ' the' '\n']" ", the
+
+ The first time I saw the movie , I was in the audience at the
+" False former D.C. United John Harkes offered color 9 [' former', ' D', '.', 'C', '.', ' United', ' John', ' H', 'ark', 'es']
+385 248 In their sport, the position played by x -1 In their sport, the position played by Brian Hayward goaltender Brian Hayward "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False against goaltender Brian Hayward in a 6 – 1 victory 3 [' against', ' goaltender', ' Brian', ' Hayward']
+386 248 In their sport, the position played by x -1 In their sport, the position played by Brian Hayward goaltender Brian Hayward "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False goaltender Brian Hayward in a 6 – 1 victory 2 [' goaltender', ' Brian', ' Hayward']
+387 248 In their sport, the position played by x -1 In their sport, the position played by Brian Hayward goaltender Brian Hayward "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False goaltender Brian Hayward in a 6 – 1 victory 2 [' goaltender', ' Brian', ' Hayward']
+388 248 In their sport, the position played by x -1 In their sport, the position played by Brian Hayward goaltender Brian Hayward "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False against goaltender Brian Hayward in a 6 – 1 victory 3 [' against', ' goaltender', ' Brian', ' Hayward']
+389 251 In their sport, the position played by x -1 In their sport, the position played by Zinedine Zidane midfielder Zinedine Zidane "[',' ' the' ' French' ' midfielder' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' one' ' who' ' decides' ' the' ' rhythm' ' of'
+ ' the']" , the French midfielder , is the most important . He is the one who decides the rhythm of the True football legend Zinedine Zidane personally recommended 7 [' football', ' legend', ' Z', 'ined', 'ine', ' Z', 'id', 'ane']
+390 251 In their sport, the position played by x -1 In their sport, the position played by Zinedine Zidane midfielder Zinedine Zidane "[',' ' the' ' French' ' midfielder' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' one' ' who' ' decides' ' the' ' rhythm' ' of'
+ ' the']" , the French midfielder , is the most important . He is the one who decides the rhythm of the True for an equaliser, Zinedine Zidane scored a stoppage-time 10 [' for', ' an', ' equal', 'iser', ',', ' Z', 'ined', 'ine', ' Z', 'id', 'ane']
+391 251 In their sport, the position played by x -1 In their sport, the position played by Zinedine Zidane midfielder Zinedine Zidane "[',' ' the' ' French' ' midfielder' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' one' ' who' ' decides' ' the' ' rhythm' ' of'
+ ' the']" , the French midfielder , is the most important . He is the one who decides the rhythm of the True football legend Zinedine Zidane personally recommended 7 [' football', ' legend', ' Z', 'ined', 'ine', ' Z', 'id', 'ane']
+392 251 In their sport, the position played by x -1 In their sport, the position played by Zinedine Zidane midfielder Zinedine Zidane "[',' ' the' ' French' ' midfielder' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' one' ' who' ' decides' ' the' ' rhythm' ' of'
+ ' the']" , the French midfielder , is the most important . He is the one who decides the rhythm of the True Leverkusen, with Zinedine Zidane scoring the 10 [' Lever', 'kus', 'en', ',', ' with', ' Z', 'ined', 'ine', ' Z', 'id', 'ane']
+393 251 In their sport, the position played by x -1 In their sport, the position played by Zinedine Zidane midfielder Zinedine Zidane "[',' ' the' ' French' ' midfielder' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' one' ' who' ' decides' ' the' ' rhythm' ' of'
+ ' the']" , the French midfielder , is the most important . He is the one who decides the rhythm of the True the retirement of Zinedine Zidane following the 8 [' the', ' retirement', ' of', ' Z', 'ined', 'ine', ' Z', 'id', 'ane']
+394 254 In their sport, the position played by x -1 In their sport, the position played by Roger Clemens pitcher Roger Clemens "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ' for'
+ ' a' ' long' ' time' ',' ' and' ' who' ' was' ' the' ' best' ' pitcher']" , who was the best pitcher in the game for a long time , and who was the best pitcher True after retirement, Roger Clemens chose to return 5 [' after', ' retirement', ',', ' Roger', ' Cle', 'mens']
+395 254 In their sport, the position played by x -1 In their sport, the position played by Roger Clemens pitcher Roger Clemens "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ' for'
+ ' a' ' long' ' time' ',' ' and' ' who' ' was' ' the' ' best' ' pitcher']" , who was the best pitcher in the game for a long time , and who was the best pitcher True " the state in 1980""; Roger Clemens said he was ""probably" 7 "[' the', ' state', ' in', ' 1980', '"";', ' Roger', ' Cle', 'mens']"
+396 254 In their sport, the position played by x -1 In their sport, the position played by Roger Clemens pitcher Roger Clemens "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ' for'
+ ' a' ' long' ' time' ',' ' and' ' who' ' was' ' the' ' best' ' pitcher']" , who was the best pitcher in the game for a long time , and who was the best pitcher True is tied with Roger Clemens for the most 5 [' is', ' tied', ' with', ' Roger', ' Cle', 'mens']
+397 254 In their sport, the position played by x -1 In their sport, the position played by Roger Clemens pitcher Roger Clemens "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ' for'
+ ' a' ' long' ' time' ',' ' and' ' who' ' was' ' the' ' best' ' pitcher']" , who was the best pitcher in the game for a long time , and who was the best pitcher True two-run home run off Roger Clemens in the top of the 8 [' two', '-', 'run', ' home', ' run', ' off', ' Roger', ' Cle', 'mens']
+398 254 In their sport, the position played by x -1 In their sport, the position played by Roger Clemens pitcher Roger Clemens "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' game' ' for'
+ ' a' ' long' ' time' ',' ' and' ' who' ' was' ' the' ' best' ' pitcher']" , who was the best pitcher in the game for a long time , and who was the best pitcher True relieved an injured Roger Clemens (the oldest player 5 [' relieved', ' an', ' injured', ' Roger', ' Cle', 'mens']
+399 255 In their sport, the position played by x -1 In their sport, the position played by Pep Guardiola midfielder Pep Guardiola "['�' '�' 's' ' men' ' is' ' a' ' far' ' cry' ' from' ' the' ' one' ' they'
+ ' occupied' ' in' ' the' ' first' ' half' ' of' ' the' ' season']" � � s men is a far cry from the one they occupied in the first half of the season False Barcelona coach Pep Guardiola. Under Cruyff's guidance, 3 [' Barcelona', ' coach', ' Pep', ' Guardiola']
+400 255 In their sport, the position played by x -1 In their sport, the position played by Pep Guardiola midfielder Pep Guardiola "['�' '�' 's' ' men' ' is' ' a' ' far' ' cry' ' from' ' the' ' one' ' they'
+ ' occupied' ' in' ' the' ' first' ' half' ' of' ' the' ' season']" � � s men is a far cry from the one they occupied in the first half of the season False Barcelona manager Pep Guardiola came out in support 3 [' Barcelona', ' manager', ' Pep', ' Guardiola']
+401 255 In their sport, the position played by x -1 In their sport, the position played by Pep Guardiola midfielder Pep Guardiola "['�' '�' 's' ' men' ' is' ' a' ' far' ' cry' ' from' ' the' ' one' ' they'
+ ' occupied' ' in' ' the' ' first' ' half' ' of' ' the' ' season']" � � s men is a far cry from the one they occupied in the first half of the season False B youth manager Pep Guardiola took over Frank Rijkaard's 4 [' B', ' youth', ' manager', ' Pep', ' Guardiola']
+402 255 In their sport, the position played by x -1 In their sport, the position played by Pep Guardiola midfielder Pep Guardiola "['�' '�' 's' ' men' ' is' ' a' ' far' ' cry' ' from' ' the' ' one' ' they'
+ ' occupied' ' in' ' the' ' first' ' half' ' of' ' the' ' season']" � � s men is a far cry from the one they occupied in the first half of the season False (1931 – 1933), and Pep Guardiola is the club's 8 [' (', '19', '31', ' –', ' 1933', '),', ' and', ' Pep', ' Guardiola']
+403 255 In their sport, the position played by x -1 In their sport, the position played by Pep Guardiola midfielder Pep Guardiola "['�' '�' 's' ' men' ' is' ' a' ' far' ' cry' ' from' ' the' ' one' ' they'
+ ' occupied' ' in' ' the' ' first' ' half' ' of' ' the' ' season']" � � s men is a far cry from the one they occupied in the first half of the season False achievement as Pep Guardiola won his 12th trophy 3 [' achievement', ' as', ' Pep', ' Guardiola']
+404 258 In their sport, the position played by x -1 In their sport, the position played by Hughie Lehman goaltender Hughie Lehman "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Hugh' 'ie' ' Leh' 'man' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Hugh ie Leh man , the
+
+" False " Hughie Lehman =
+" 3 [' Hugh', 'ie', ' Leh', 'man']
+405 258 In their sport, the position played by x -1 In their sport, the position played by Hughie Lehman goaltender Hughie Lehman "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Hugh' 'ie' ' Leh' 'man' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Hugh ie Leh man , the
+
+" False " Lehman =
+" 7 [' Leh', 'man', ' =', 'H', 'ugh', 'ie', ' Leh', 'man']
+406 258 In their sport, the position played by x -1 In their sport, the position played by Hughie Lehman goaltender Hughie Lehman "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Hugh' 'ie' ' Leh' 'man' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Hugh ie Leh man , the
+
+" False " Lehman =
+" 7 [' Leh', 'man', ' =', 'H', 'ugh', 'ie', ' Leh', 'man']
+407 258 In their sport, the position played by x -1 In their sport, the position played by Hughie Lehman goaltender Hughie Lehman "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Hugh' 'ie' ' Leh' 'man' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Hugh ie Leh man , the
+
+" False " Hughie Lehman =
+" 3 [' Hugh', 'ie', ' Leh', 'man']
+408 258 In their sport, the position played by x -1 In their sport, the position played by Hughie Lehman goaltender Hughie Lehman "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Hugh' 'ie' ' Leh' 'man' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Hugh ie Leh man , the
+
+" False " Hughie Lehman =
+" 3 [' Hugh', 'ie', ' Leh', 'man']
+409 263 In their sport, the position played by x -1 In their sport, the position played by Mark Fitzpatrick goaltender Mark Fitzpatrick "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False goal against Mark Fitzpatrick in a 4 – 1 victory 3 [' goal', ' against', ' Mark', ' Fitzpatrick']
+410 263 In their sport, the position played by x -1 In their sport, the position played by Mark Fitzpatrick goaltender Mark Fitzpatrick "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False Islanders in return for Mark Fitzpatrick during the off-season, 5 [' Islanders', ' in', ' return', ' for', ' Mark', ' Fitzpatrick']
+411 263 In their sport, the position played by x -1 In their sport, the position played by Mark Fitzpatrick goaltender Mark Fitzpatrick "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False in return for Mark Fitzpatrick during the off-season, 4 [' in', ' return', ' for', ' Mark', ' Fitzpatrick']
+412 264 In their sport, the position played by x -1 In their sport, the position played by Drew Henson quarterback Drew Henson "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' is' ' a' ' former' ' quarterback' ' who' ' has']" , a former NFL quarterback , is a bit of a mystery . He is a former quarterback who has True of the NFL and Drew Henson was drafted by the 6 [' of', ' the', ' NFL', ' and', ' Drew', ' H', 'enson']
+413 264 In their sport, the position played by x -1 In their sport, the position played by Drew Henson quarterback Drew Henson "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' is' ' a' ' former' ' quarterback' ' who' ' has']" , a former NFL quarterback , is a bit of a mystery . He is a former quarterback who has True the NFL and Drew Henson was drafted by the 5 [' the', ' NFL', ' and', ' Drew', ' H', 'enson']
+414 264 In their sport, the position played by x -1 In their sport, the position played by Drew Henson quarterback Drew Henson "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' is' ' a' ' former' ' quarterback' ' who' ' has']" , a former NFL quarterback , is a bit of a mystery . He is a former quarterback who has True of the NFL and Drew Henson was drafted by 6 [' of', ' the', ' NFL', ' and', ' Drew', ' H', 'enson']
+415 264 In their sport, the position played by x -1 In their sport, the position played by Drew Henson quarterback Drew Henson "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' is' ' a' ' former' ' quarterback' ' who' ' has']" , a former NFL quarterback , is a bit of a mystery . He is a former quarterback who has True Brady, Driesbach and Drew Henson and wide receiver 8 [' Brady', ',', ' D', 'ries', 'bach', ' and', ' Drew', ' H', 'enson']
+416 264 In their sport, the position played by x -1 In their sport, the position played by Drew Henson quarterback Drew Henson "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' is' ' a' ' former' ' quarterback' ' who' ' has']" , a former NFL quarterback , is a bit of a mystery . He is a former quarterback who has True Driesbach and Drew Henson and wide receiver 6 [' D', 'ries', 'bach', ' and', ' Drew', ' H', 'enson']
+417 266 In their sport, the position played by x -1 In their sport, the position played by Otto Graham quarterback Otto Graham "[',' ' the' ' quarterback' ',' ' is' ' the' ' same' ' as' ' that' ' of'
+ ' the' ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is'
+ ' the' ' leader']" , the quarterback , is the same as that of the quarterback in football . The quarterback is the leader True 2 ['O', 'tto', ' Graham']
+418 266 In their sport, the position played by x -1 In their sport, the position played by Otto Graham quarterback Otto Graham "[',' ' the' ' quarterback' ',' ' is' ' the' ' same' ' as' ' that' ' of'
+ ' the' ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is'
+ ' the' ' leader']" , the quarterback , is the same as that of the quarterback in football . The quarterback is the leader True from quarterback Otto Graham to Dean Sensanbaugher 3 [' from', ' quarterback', ' Otto', ' Graham']
+419 266 In their sport, the position played by x -1 In their sport, the position played by Otto Graham quarterback Otto Graham "[',' ' the' ' quarterback' ',' ' is' ' the' ' same' ' as' ' that' ' of'
+ ' the' ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is'
+ ' the' ' leader']" , the quarterback , is the same as that of the quarterback in football . The quarterback is the leader True protect quarterback Otto Graham from opposing defenders 3 [' protect', ' quarterback', ' Otto', ' Graham']
+420 266 In their sport, the position played by x -1 In their sport, the position played by Otto Graham quarterback Otto Graham "[',' ' the' ' quarterback' ',' ' is' ' the' ' same' ' as' ' that' ' of'
+ ' the' ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is'
+ ' the' ' leader']" , the quarterback , is the same as that of the quarterback in football . The quarterback is the leader True protect quarterback Otto Graham from defenders when 3 [' protect', ' quarterback', ' Otto', ' Graham']
+421 266 In their sport, the position played by x -1 In their sport, the position played by Otto Graham quarterback Otto Graham "[',' ' the' ' quarterback' ',' ' is' ' the' ' same' ' as' ' that' ' of'
+ ' the' ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is'
+ ' the' ' leader']" , the quarterback , is the same as that of the quarterback in football . The quarterback is the leader True department created The Otto Graham Society to honor 4 [' department', ' created', ' The', ' Otto', ' Graham']
+422 269 In their sport, the position played by x -1 In their sport, the position played by Troy Aikman quarterback Troy Aikman "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True broadcasters Troy Aikman and Terry Bradshaw, 4 [' broadcasters', ' Troy', ' A', 'ik', 'man']
+423 269 In their sport, the position played by x -1 In their sport, the position played by Troy Aikman quarterback Troy Aikman "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True quarterback after Troy Aikman was injured and 5 [' quarterback', ' after', ' Troy', ' A', 'ik', 'man']
+424 269 In their sport, the position played by x -1 In their sport, the position played by Troy Aikman quarterback Troy Aikman "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True players Rosey Grier, Troy Aikman and Dan Marino 9 [' players', ' Rose', 'y', ' G', 'rier', ',', ' Troy', ' A', 'ik', 'man']
+425 269 In their sport, the position played by x -1 In their sport, the position played by Troy Aikman quarterback Troy Aikman "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True broadcasters Troy Aikman and Terry Bradshaw, 4 [' broadcasters', ' Troy', ' A', 'ik', 'man']
+426 269 In their sport, the position played by x -1 In their sport, the position played by Troy Aikman quarterback Troy Aikman "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True quarterback after Troy Aikman was injured and 5 [' quarterback', ' after', ' Troy', ' A', 'ik', 'man']
+427 272 In their sport, the position played by x -1 In their sport, the position played by Alexi Ogando pitcher Alexi Ogando "[',' ' the' ' Rangers' ""'"" ' closer' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ""'s"" ' the' ' guy' ' who' ' gets' ' the' ' final' ' out']" , the Rangers ' closer , is the most important . He 's the guy who gets the final out False Rangers reliever Alexi Ogando entered the 5 [' Rangers', ' reliever', ' Alex', 'i', ' Og', 'ando']
+428 274 In their sport, the position played by x -1 In their sport, the position played by Ralph Milne midfielder Ralph Milne "[' Far' 'qu' 'har' ',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the'
+ ' position' ' played' ' by' ' Ralph' ' Mil' 'ne' ' Far' 'qu']" " Far qu har , the
+
+ In their sport , the position played by Ralph Mil ne Far qu" False " Ralph Milne =
+" 2 [' Ralph', ' Mil', 'ne']
+429 274 In their sport, the position played by x -1 In their sport, the position played by Ralph Milne midfielder Ralph Milne "[' Far' 'qu' 'har' ',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the'
+ ' position' ' played' ' by' ' Ralph' ' Mil' 'ne' ' Far' 'qu']" " Far qu har , the
+
+ In their sport , the position played by Ralph Mil ne Far qu" False " Ralph Milne =
+" 2 [' Ralph', ' Mil', 'ne']
+430 274 In their sport, the position played by x -1 In their sport, the position played by Ralph Milne midfielder Ralph Milne "[' Far' 'qu' 'har' ',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the'
+ ' position' ' played' ' by' ' Ralph' ' Mil' 'ne' ' Far' 'qu']" " Far qu har , the
+
+ In their sport , the position played by Ralph Mil ne Far qu" False 3 ['R', 'alph', ' Mil', 'ne']
+431 274 In their sport, the position played by x -1 In their sport, the position played by Ralph Milne midfielder Ralph Milne "[' Far' 'qu' 'har' ',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the'
+ ' position' ' played' ' by' ' Ralph' ' Mil' 'ne' ' Far' 'qu']" " Far qu har , the
+
+ In their sport , the position played by Ralph Mil ne Far qu" False of the Golden Atom, Ralph Milne Farley's The 7 [' of', ' the', ' Golden', ' Atom', ',', ' Ralph', ' Mil', 'ne']
+432 274 In their sport, the position played by x -1 In their sport, the position played by Ralph Milne midfielder Ralph Milne "[' Far' 'qu' 'har' ',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the'
+ ' position' ' played' ' by' ' Ralph' ' Mil' 'ne' ' Far' 'qu']" " Far qu har , the
+
+ In their sport , the position played by Ralph Mil ne Far qu" False the Golden Atom, Ralph Milne Farley's The Radio 6 [' the', ' Golden', ' Atom', ',', ' Ralph', ' Mil', 'ne']
+433 279 In their sport, the position played by x -1 In their sport, the position played by Ricky Stanzi quarterback Ricky Stanzi "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ""'s"" ' a' ' former' ' quarterback' ' who' ' has']" , a former NFL quarterback , is a bit of a mystery . He 's a former quarterback who has True The Jaguars signed Ricky Stanzi on August 27 for 5 [' The', ' Jaguars', ' signed', ' Ricky', ' Stan', 'zi']
+434 281 In their sport, the position played by x -1 In their sport, the position played by Kyle Farnsworth pitcher Kyle Farnsworth "[',' ' the' ' team' ""'s"" ' starting' ' pitcher' ',' ' is' ' to' ' be'
+ ' the' ' ""' 'pit' 'cher' ' of' ' last' ' resort' '.""' '\n' '\n']" ", the team 's starting pitcher , is to be the "" pit cher of last resort .""
+
+" True relief pitcher Kyle Farnsworth after starting 4 [' relief', ' pitcher', ' Kyle', ' Farn', 'sworth']
+435 282 In their sport, the position played by x -1 In their sport, the position played by Joe Ferguson quarterback Joe Ferguson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Joe' ' Ferguson' ',' ' the' ' quarterback' ',' ' is'
+ ' to']" ", the
+
+ In their sport , the position played by Joe Ferguson , the quarterback , is to" True Quarterback Joe Ferguson played 164 3 [' Quarter', 'back', ' Joe', ' Ferguson']
+436 282 In their sport, the position played by x -1 In their sport, the position played by Joe Ferguson quarterback Joe Ferguson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Joe' ' Ferguson' ',' ' the' ' quarterback' ',' ' is'
+ ' to']" ", the
+
+ In their sport , the position played by Joe Ferguson , the quarterback , is to" True hire 45-year-old Joe Ferguson (whom Stephenson 7 [' hire', ' 45', '-', 'year', '-', 'old', ' Joe', ' Ferguson']
+437 282 In their sport, the position played by x -1 In their sport, the position played by Joe Ferguson quarterback Joe Ferguson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Joe' ' Ferguson' ',' ' the' ' quarterback' ',' ' is'
+ ' to']" ", the
+
+ In their sport , the position played by Joe Ferguson , the quarterback , is to" True 45-year-old Joe Ferguson (whom Stephenson had 6 [' 45', '-', 'year', '-', 'old', ' Joe', ' Ferguson']
+438 282 In their sport, the position played by x -1 In their sport, the position played by Joe Ferguson quarterback Joe Ferguson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Joe' ' Ferguson' ',' ' the' ' quarterback' ',' ' is'
+ ' to']" ", the
+
+ In their sport , the position played by Joe Ferguson , the quarterback , is to" True 1 ['Joe', ' Ferguson']
+439 282 In their sport, the position played by x -1 In their sport, the position played by Joe Ferguson quarterback Joe Ferguson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Joe' ' Ferguson' ',' ' the' ' quarterback' ',' ' is'
+ ' to']" ", the
+
+ In their sport , the position played by Joe Ferguson , the quarterback , is to" True Quarterback Joe Ferguson played 164 games 3 [' Quarter', 'back', ' Joe', ' Ferguson']
+440 283 In their sport, the position played by x -1 In their sport, the position played by Randall Cunningham quarterback Randall Cunningham "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True " Cunningham II =
+" 5 [' Cunningham', ' II', ' =', 'Rand', 'all', ' Cunningham']
+441 283 In their sport, the position played by x -1 In their sport, the position played by Randall Cunningham quarterback Randall Cunningham "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True quarterback Randall Cunningham rushed for 66 yards 2 [' quarterback', ' Randall', ' Cunningham']
+442 283 In their sport, the position played by x -1 In their sport, the position played by Randall Cunningham quarterback Randall Cunningham "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True " Cunningham II =
+" 5 [' Cunningham', ' II', ' =', 'Rand', 'all', ' Cunningham']
+443 283 In their sport, the position played by x -1 In their sport, the position played by Randall Cunningham quarterback Randall Cunningham "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Eagles quarterback Randall Cunningham rushed for 66 yards 3 [' Eagles', ' quarterback', ' Randall', ' Cunningham']
+444 283 In their sport, the position played by x -1 In their sport, the position played by Randall Cunningham quarterback Randall Cunningham "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Jaworski and Randall Cunningham three times each. 5 [' Jaw', 'ors', 'ki', ' and', ' Randall', ' Cunningham']
+445 284 In their sport, the position played by x -1 In their sport, the position played by Alex Smith quarterback Alex Smith "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True " Smith – piano
+" 4 [' Smith', ' –', ' piano', 'Alex', ' Smith']
+446 284 In their sport, the position played by x -1 In their sport, the position played by Alex Smith quarterback Alex Smith "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True 1989, pairing Alex Smith and Jocky Scott. 4 [' 1989', ',', ' pairing', ' Alex', ' Smith']
+447 284 In their sport, the position played by x -1 In their sport, the position played by Alex Smith quarterback Alex Smith "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True acquired quarterback Alex Smith from the San 3 [' acquired', ' quarterback', ' Alex', ' Smith']
+448 284 In their sport, the position played by x -1 In their sport, the position played by Alex Smith quarterback Alex Smith "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True co-managers in 1989, pairing Alex Smith and Jocky Scott. 9 [' co', '-', 'man', 'agers', ' in', ' 1989', ',', ' pairing', ' Alex', ' Smith']
+449 284 In their sport, the position played by x -1 In their sport, the position played by Alex Smith quarterback Alex Smith "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True and piano player Alex Smith made the necessary 4 [' and', ' piano', ' player', ' Alex', ' Smith']
+450 285 In their sport, the position played by x -1 In their sport, the position played by Bert Jones quarterback Bert Jones "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Bert' ' Jones' ',' ' the' '\n' '\n' 'In' ' their']" ", the
+
+ In their sport , the position played by Bert Jones , the
+
+ In their" False quarterback Bert Jones and running 2 [' quarterback', ' Bert', ' Jones']
+451 285 In their sport, the position played by x -1 In their sport, the position played by Bert Jones quarterback Bert Jones "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Bert' ' Jones' ',' ' the' '\n' '\n' 'In' ' their']" ", the
+
+ In their sport , the position played by Bert Jones , the
+
+ In their" False of quarterback Bert Jones and running 3 [' of', ' quarterback', ' Bert', ' Jones']
+452 285 In their sport, the position played by x -1 In their sport, the position played by Bert Jones quarterback Bert Jones "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Bert' ' Jones' ',' ' the' '\n' '\n' 'In' ' their']" ", the
+
+ In their sport , the position played by Bert Jones , the
+
+ In their" False of quarterback Bert Jones and running back 3 [' of', ' quarterback', ' Bert', ' Jones']
+453 285 In their sport, the position played by x -1 In their sport, the position played by Bert Jones quarterback Bert Jones "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Bert' ' Jones' ',' ' the' '\n' '\n' 'In' ' their']" ", the
+
+ In their sport , the position played by Bert Jones , the
+
+ In their" False loss of quarterback Bert Jones and running 4 [' loss', ' of', ' quarterback', ' Bert', ' Jones']
+454 287 In their sport, the position played by x -1 In their sport, the position played by Archie Gemmill midfielder Archie Gemmill "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Scottish midfielder , is a key one . He is the ful cr um of the team , True Stadium in 1967. Archie Gemmill scored what 6 [' Stadium', ' in', ' 1967', '.', ' Archie', ' Gem', 'mill']
+455 287 In their sport, the position played by x -1 In their sport, the position played by Archie Gemmill midfielder Archie Gemmill "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Scottish midfielder , is a key one . He is the ful cr um of the team , True and signed Archie Gemmill and Bobby Davison. 4 [' and', ' signed', ' Archie', ' Gem', 'mill']
+456 287 In their sport, the position played by x -1 In their sport, the position played by Archie Gemmill midfielder Archie Gemmill "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Scottish midfielder , is a key one . He is the ful cr um of the team , True David Needham and Archie Gemmill all being non-eligible 6 [' David', ' Need', 'ham', ' and', ' Archie', ' Gem', 'mill']
+457 287 In their sport, the position played by x -1 In their sport, the position played by Archie Gemmill midfielder Archie Gemmill "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Scottish midfielder , is a key one . He is the ful cr um of the team , True signing midfielder Archie Gemmill from Preston North 4 [' signing', ' midfielder', ' Archie', ' Gem', 'mill']
+458 287 In their sport, the position played by x -1 In their sport, the position played by Archie Gemmill midfielder Archie Gemmill "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Scottish midfielder , is a key one . He is the ful cr um of the team , True signing midfielder Archie Gemmill from Preston North 4 [' signing', ' midfielder', ' Archie', ' Gem', 'mill']
+459 288 In their sport, the position played by x -1 In their sport, the position played by Bettina Wiegmann midfielder Bettina Wiegmann "[',' ' the' ' German' ' player' ',' ' is' ' called' ' the' ' ""' 'B' 'ett'
+ 'ina' ' position' '"".' '\n' '\n' 'The' ' Bett' 'ina' ' position']" ", the German player , is called the "" B ett ina position "".
+
+ The Bett ina position" False Summer Olympics and Bettina Wiegmann scored the first 7 [' Summer', ' Olympics', ' and', ' Bett', 'ina', ' W', 'ieg', 'mann']
+460 288 In their sport, the position played by x -1 In their sport, the position played by Bettina Wiegmann midfielder Bettina Wiegmann "[',' ' the' ' German' ' player' ',' ' is' ' called' ' the' ' ""' 'B' 'ett'
+ 'ina' ' position' '"".' '\n' '\n' 'The' ' Bett' 'ina' ' position']" ", the German player , is called the "" B ett ina position "".
+
+ The Bett ina position" False Olympics and Bettina Wiegmann scored the first Olympic 6 [' Olympics', ' and', ' Bett', 'ina', ' W', 'ieg', 'mann']
+461 288 In their sport, the position played by x -1 In their sport, the position played by Bettina Wiegmann midfielder Bettina Wiegmann "[',' ' the' ' German' ' player' ',' ' is' ' called' ' the' ' ""' 'B' 'ett'
+ 'ina' ' position' '"".' '\n' '\n' 'The' ' Bett' 'ina' ' position']" ", the German player , is called the "" B ett ina position "".
+
+ The Bett ina position" False Summer Olympics. Bettina Wiegmann scored the first Olympic 7 [' Summer', ' Olympics', '.', ' Bett', 'ina', ' W', 'ieg', 'mann']
+462 288 In their sport, the position played by x -1 In their sport, the position played by Bettina Wiegmann midfielder Bettina Wiegmann "[',' ' the' ' German' ' player' ',' ' is' ' called' ' the' ' ""' 'B' 'ett'
+ 'ina' ' position' '"".' '\n' '\n' 'The' ' Bett' 'ina' ' position']" ", the German player , is called the "" B ett ina position "".
+
+ The Bett ina position" False Summer Olympics and Bettina Wiegmann scored the first Olympic 7 [' Summer', ' Olympics', ' and', ' Bett', 'ina', ' W', 'ieg', 'mann']
+463 288 In their sport, the position played by x -1 In their sport, the position played by Bettina Wiegmann midfielder Bettina Wiegmann "[',' ' the' ' German' ' player' ',' ' is' ' called' ' the' ' ""' 'B' 'ett'
+ 'ina' ' position' '"".' '\n' '\n' 'The' ' Bett' 'ina' ' position']" ", the German player , is called the "" B ett ina position "".
+
+ The Bett ina position" False conceding a goal). Bettina Wiegmann holds the record 9 [' conced', 'ing', ' a', ' goal', ').', ' Bett', 'ina', ' W', 'ieg', 'mann']
+464 291 In their sport, the position played by x -1 In their sport, the position played by Chico Resch goaltender Chico Resch "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Ch' 'ico' ' Res' 'ch' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Ch ico Res ch , the
+
+" False " fight labelled by Chico Resch as ""like a heavyweight" 6 [' fight', ' labelled', ' by', ' Ch', 'ico', ' Res', 'ch']
+465 292 In their sport, the position played by x -1 In their sport, the position played by Greg Maddux pitcher Greg Maddux "[',' ' the' ' pitcher' ',' ' is' ' to' ' throw' ' the' ' ball' ' to'
+ ' the' ' batter' ',' ' who' ' hits' ' it' ' with' ' a' ' bat' '.']" , the pitcher , is to throw the ball to the batter , who hits it with a bat . True Ferguson Jenkins / Greg Maddux also match this 6 [' Ferguson', ' Jenkins', ' /', ' Greg', ' Mad', 'du', 'x']
+466 292 In their sport, the position played by x -1 In their sport, the position played by Greg Maddux pitcher Greg Maddux "[',' ' the' ' pitcher' ',' ' is' ' to' ' throw' ' the' ' ball' ' to'
+ ' the' ' batter' ',' ' who' ' hits' ' it' ' with' ' a' ' bat' '.']" , the pitcher , is to throw the ball to the batter , who hits it with a bat . True Braves teammate Greg Maddux in 2004. In 5 [' Braves', ' teammate', ' Greg', ' Mad', 'du', 'x']
+467 292 In their sport, the position played by x -1 In their sport, the position played by Greg Maddux pitcher Greg Maddux "[',' ' the' ' pitcher' ',' ' is' ' to' ' throw' ' the' ' ball' ' to'
+ ' the' ' batter' ',' ' who' ' hits' ' it' ' with' ' a' ' bat' '.']" , the pitcher , is to throw the ball to the batter , who hits it with a bat . True dominant pitcher Greg Maddux and then, 5 [' dominant', ' pitcher', ' Greg', ' Mad', 'du', 'x']
+468 292 In their sport, the position played by x -1 In their sport, the position played by Greg Maddux pitcher Greg Maddux "[',' ' the' ' pitcher' ',' ' is' ' to' ' throw' ' the' ' ball' ' to'
+ ' the' ' batter' ',' ' who' ' hits' ' it' ' with' ' a' ' bat' '.']" , the pitcher , is to throw the ball to the batter , who hits it with a bat . True on August 7 against Greg Maddux of the Chicago 7 [' on', ' August', ' 7', ' against', ' Greg', ' Mad', 'du', 'x']
+469 292 In their sport, the position played by x -1 In their sport, the position played by Greg Maddux pitcher Greg Maddux "[',' ' the' ' pitcher' ',' ' is' ' to' ' throw' ' the' ' ball' ' to'
+ ' the' ' batter' ',' ' who' ' hits' ' it' ' with' ' a' ' bat' '.']" , the pitcher , is to throw the ball to the batter , who hits it with a bat . True scoreless innings. Greg Maddux and Kenny Rogers 7 [' score', 'less', ' innings', '.', ' Greg', ' Mad', 'du', 'x']
+470 295 In their sport, the position played by x -1 In their sport, the position played by Luis Tiant pitcher Luis Tiant "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True Bill Lee, and Luis Tiant in their starting 6 [' Bill', ' Lee', ',', ' and', ' Luis', ' T', 'iant']
+471 295 In their sport, the position played by x -1 In their sport, the position played by Luis Tiant pitcher Luis Tiant "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True history. Red Sox pitcher Luis Tiant walked Blomberg in 7 [' history', '.', ' Red', ' Sox', ' pitcher', ' Luis', ' T', 'iant']
+472 295 In their sport, the position played by x -1 In their sport, the position played by Luis Tiant pitcher Luis Tiant "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True Gaylord Perry and Luis Tiant as the only pitchers 6 [' Gay', 'lord', ' Perry', ' and', ' Luis', ' T', 'iant']
+473 295 In their sport, the position played by x -1 In their sport, the position played by Luis Tiant pitcher Luis Tiant "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True Sox pitcher Luis Tiant walked Blomberg 4 [' Sox', ' pitcher', ' Luis', ' T', 'iant']
+474 295 In their sport, the position played by x -1 In their sport, the position played by Luis Tiant pitcher Luis Tiant "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True Torrez, Bill Lee, and Luis Tiant in their starting 9 [' Tor', 'rez', ',', ' Bill', ' Lee', ',', ' and', ' Luis', ' T', 'iant']
+475 299 In their sport, the position played by x -1 In their sport, the position played by Craig Monroe outfielder Craig Monroe "[',' ' the' ' former' ' All' '-' 'Star' ' and' ' former' ' MVP' ' of'
+ ' the' ' NBA' ',' ' is' ' a' ' very' ' important' ' one' '.' ' He']" , the former All - Star and former MVP of the NBA , is a very important one . He False 1 ['Craig', ' Monroe']
+476 299 In their sport, the position played by x -1 In their sport, the position played by Craig Monroe outfielder Craig Monroe "[',' ' the' ' former' ' All' '-' 'Star' ' and' ' former' ' MVP' ' of'
+ ' the' ' NBA' ',' ' is' ' a' ' very' ' important' ' one' '.' ' He']" , the former All - Star and former MVP of the NBA , is a very important one . He False 1 ['Craig', ' Monroe']
+477 306 In their sport, the position played by x -1 In their sport, the position played by Neil Redfearn midfielder Neil Redfearn "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' match' ',' ' is' ' called' ' the' ' ""' 'red']" ", the player who is the first to score a goal in a match , is called the "" red" False youth team coach Neil Redfearn took over 7 [' youth', ' team', ' coach', ' Neil', ' Red', 'f', 'ear', 'n']
+478 306 In their sport, the position played by x -1 In their sport, the position played by Neil Redfearn midfielder Neil Redfearn "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' match' ',' ' is' ' called' ' the' ' ""' 'red']" ", the player who is the first to score a goal in a match , is called the "" red" False players including Neil Redfearn and Dean Saunders, 6 [' players', ' including', ' Neil', ' Red', 'f', 'ear', 'n']
+479 306 In their sport, the position played by x -1 In their sport, the position played by Neil Redfearn midfielder Neil Redfearn "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' match' ',' ' is' ' called' ' the' ' ""' 'red']" ", the player who is the first to score a goal in a match , is called the "" red" False youth team coach Neil Redfearn took over as caretaker 7 [' youth', ' team', ' coach', ' Neil', ' Red', 'f', 'ear', 'n']
+480 308 In their sport, the position played by x -1 In their sport, the position played by Vince Coleman outfielder Vince Coleman "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' ' He' ' was' ' a' ' great' ' player' ',']" , the former NFL running back , is a bit of a mystery . He was a great player , False League rookie since Vince Coleman stole 110 4 [' League', ' rookie', ' since', ' Vince', ' Coleman']
+481 308 In their sport, the position played by x -1 In their sport, the position played by Vince Coleman outfielder Vince Coleman "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' ' He' ' was' ' a' ' great' ' player' ',']" , the former NFL running back , is a bit of a mystery . He was a great player , False League rookie since Vince Coleman stole 110 in 1985. 4 [' League', ' rookie', ' since', ' Vince', ' Coleman']
+482 311 In their sport, the position played by x -1 In their sport, the position played by Jacques Plante goaltender Jacques Plante "[',' ' the' ' Canadiens' ""'"" ' goalie' ',' ' was' ' a' ' position' ' of'
+ ' great' ' importance' '.' ' The' ' Canadiens' ' were' ' a' ' team'
+ ' that' ' relied']" , the Canadiens ' goalie , was a position of great importance . The Canadiens were a team that relied False Goaltender Jacques Plante had been wearing 5 [' Go', 'alt', 'ender', ' Jacques', ' Pl', 'ante']
+483 311 In their sport, the position played by x -1 In their sport, the position played by Jacques Plante goaltender Jacques Plante "[',' ' the' ' Canadiens' ""'"" ' goalie' ',' ' was' ' a' ' position' ' of'
+ ' great' ' importance' '.' ' The' ' Canadiens' ' were' ' a' ' team'
+ ' that' ' relied']" , the Canadiens ' goalie , was a position of great importance . The Canadiens were a team that relied False Canadiens. The Jacques Plante Memorial Trophy 5 [' Canadiens', '.', ' The', ' Jacques', ' Pl', 'ante']
+484 311 In their sport, the position played by x -1 In their sport, the position played by Jacques Plante goaltender Jacques Plante "[',' ' the' ' Canadiens' ""'"" ' goalie' ',' ' was' ' a' ' position' ' of'
+ ' great' ' importance' '.' ' The' ' Canadiens' ' were' ' a' ' team'
+ ' that' ' relied']" , the Canadiens ' goalie , was a position of great importance . The Canadiens were a team that relied False " stop him"". Goaltender Jacques Plante declared it one" 8 "[' stop', ' him', '"".', ' Go', 'alt', 'ender', ' Jacques', ' Pl', 'ante']"
+485 311 In their sport, the position played by x -1 In their sport, the position played by Jacques Plante goaltender Jacques Plante "[',' ' the' ' Canadiens' ""'"" ' goalie' ',' ' was' ' a' ' position' ' of'
+ ' great' ' importance' '.' ' The' ' Canadiens' ' were' ' a' ' team'
+ ' that' ' relied']" , the Canadiens ' goalie , was a position of great importance . The Canadiens were a team that relied False Canadiens. Goaltender Jacques Plante later recalled 7 [' Canadiens', '.', ' Go', 'alt', 'ender', ' Jacques', ' Pl', 'ante']
+486 311 In their sport, the position played by x -1 In their sport, the position played by Jacques Plante goaltender Jacques Plante "[',' ' the' ' Canadiens' ""'"" ' goalie' ',' ' was' ' a' ' position' ' of'
+ ' great' ' importance' '.' ' The' ' Canadiens' ' were' ' a' ' team'
+ ' that' ' relied']" , the Canadiens ' goalie , was a position of great importance . The Canadiens were a team that relied False New York Rangers Jacques Plante made the goaltender 5 [' New', ' York', ' Rangers', ' Jacques', ' Pl', 'ante']
+487 315 In their sport, the position played by x -1 In their sport, the position played by Daniele De Rossi midfielder Daniele De Rossi "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' heart' ' of' ' the' ' team' ',' ' the' ' one']" , the Italian midfielder , is the most important . He is the heart of the team , the one True and footballer Daniele De Rossi believed that Mario 6 [' and', ' footballer', ' Dan', 'ie', 'le', ' De', ' Rossi']
+488 315 In their sport, the position played by x -1 In their sport, the position played by Daniele De Rossi midfielder Daniele De Rossi "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' heart' ' of' ' the' ' team' ',' ' the' ' one']" , the Italian midfielder , is the most important . He is the heart of the team , the one True and footballer Daniele De Rossi believed that Mario 6 [' and', ' footballer', ' Dan', 'ie', 'le', ' De', ' Rossi']
+489 315 In their sport, the position played by x -1 In their sport, the position played by Daniele De Rossi midfielder Daniele De Rossi "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' heart' ' of' ' the' ' team' ',' ' the' ' one']" , the Italian midfielder , is the most important . He is the heart of the team , the one True he and footballer Daniele De Rossi believed that 7 [' he', ' and', ' footballer', ' Dan', 'ie', 'le', ' De', ' Rossi']
+490 319 In their sport, the position played by x -1 In their sport, the position played by Jim Kelly quarterback Jim Kelly "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Miami quarterback Jim Kelly and nose guard 3 [' Miami', ' quarterback', ' Jim', ' Kelly']
+491 319 In their sport, the position played by x -1 In their sport, the position played by Jim Kelly quarterback Jim Kelly "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True such stars as Jim Kelly and Herschel 4 [' such', ' stars', ' as', ' Jim', ' Kelly']
+492 319 In their sport, the position played by x -1 In their sport, the position played by Jim Kelly quarterback Jim Kelly "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Crime writer Jim Kelly and award winning 3 [' Crime', ' writer', ' Jim', ' Kelly']
+493 319 In their sport, the position played by x -1 In their sport, the position played by Jim Kelly quarterback Jim Kelly "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Bowl, Miami quarterback Jim Kelly and nose guard 5 [' Bowl', ',', ' Miami', ' quarterback', ' Jim', ' Kelly']
+494 319 In their sport, the position played by x -1 In their sport, the position played by Jim Kelly quarterback Jim Kelly "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Miami quarterback Jim Kelly and nose guard Jim 3 [' Miami', ' quarterback', ' Jim', ' Kelly']
+495 321 In their sport, the position played by x -1 In their sport, the position played by O'Brien Schofield linebacker O'Brien Schofield "[',' ' is' ' called' ' the' ' ""' 'quarter' 'back' '""' ' and' ' the'
+ ' position' ' played' ' by' ' the' ' other' ' two' ' players' ' is'
+ ' called' ' the']" ", is called the "" quarter back "" and the position played by the other two players is called the" False Wisconsin defensive end O'Brien Schofield and Northwestern 8 "[' Wisconsin', ' defensive', ' end', ' O', ""'"", 'Brien', ' Sch', 'of', 'ield']"
+496 321 In their sport, the position played by x -1 In their sport, the position played by O'Brien Schofield linebacker O'Brien Schofield "[',' ' is' ' called' ' the' ' ""' 'quarter' 'back' '""' ' and' ' the'
+ ' position' ' played' ' by' ' the' ' other' ' two' ' players' ' is'
+ ' called' ' the']" ", is called the "" quarter back "" and the position played by the other two players is called the" False Wisconsin defensive end O'Brien Schofield and Northwestern 8 "[' Wisconsin', ' defensive', ' end', ' O', ""'"", 'Brien', ' Sch', 'of', 'ield']"
+497 321 In their sport, the position played by x -1 In their sport, the position played by O'Brien Schofield linebacker O'Brien Schofield "[',' ' is' ' called' ' the' ' ""' 'quarter' 'back' '""' ' and' ' the'
+ ' position' ' played' ' by' ' the' ' other' ' two' ' players' ' is'
+ ' called' ' the']" ", is called the "" quarter back "" and the position played by the other two players is called the" False Wisconsin defensive end O'Brien Schofield and Northwestern quarterback 8 "[' Wisconsin', ' defensive', ' end', ' O', ""'"", 'Brien', ' Sch', 'of', 'ield']"
+498 323 In their sport, the position played by x -1 In their sport, the position played by Dallas Keuchel pitcher Dallas Keuchel "[' is' ' the' ' most' ' important' '.' ' He' ' is' ' the' ' pitcher'
+ ' who' ' is' ' most' ' likely' ' to' ' get' ' the' ' win' '.' ' He' ' is']" is the most important . He is the pitcher who is most likely to get the win . He is True with Kyle Kendrick and Dallas Keuchel allowing one run 7 [' with', ' Kyle', ' Kendrick', ' and', ' Dallas', ' Ke', 'uc', 'hel']
+499 323 In their sport, the position played by x -1 In their sport, the position played by Dallas Keuchel pitcher Dallas Keuchel "[' is' ' the' ' most' ' important' '.' ' He' ' is' ' the' ' pitcher'
+ ' who' ' is' ' most' ' likely' ' to' ' get' ' the' ' win' '.' ' He' ' is']" is the most important . He is the pitcher who is most likely to get the win . He is True Kendrick and Dallas Keuchel allowing one run apiece 5 [' Kendrick', ' and', ' Dallas', ' Ke', 'uc', 'hel']
+500 327 In their sport, the position played by x -1 In their sport, the position played by Andrew Luck quarterback Andrew Luck "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team' ',']" is similar to that of the quarterback in the NFL . The quarterback is the leader of the team , True Stanford's Andrew Luck and LSU's Tyrann 3 "[' Stanford', ""'s"", ' Andrew', ' Luck']"
+501 327 In their sport, the position played by x -1 In their sport, the position played by Andrew Luck quarterback Andrew Luck "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team' ',']" is similar to that of the quarterback in the NFL . The quarterback is the leader of the team , True Stanford quarterback Andrew Luck was named the 3 [' Stanford', ' quarterback', ' Andrew', ' Luck']
+502 327 In their sport, the position played by x -1 In their sport, the position played by Andrew Luck quarterback Andrew Luck "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team' ',']" is similar to that of the quarterback in the NFL . The quarterback is the leader of the team , True Stanford quarterback Andrew Luck made news for not 3 [' Stanford', ' quarterback', ' Andrew', ' Luck']
+503 327 In their sport, the position played by x -1 In their sport, the position played by Andrew Luck quarterback Andrew Luck "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team' ',']" is similar to that of the quarterback in the NFL . The quarterback is the leader of the team , True Stanford quarterback Andrew Luck was named the game's 3 [' Stanford', ' quarterback', ' Andrew', ' Luck']
+504 327 In their sport, the position played by x -1 In their sport, the position played by Andrew Luck quarterback Andrew Luck "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team' ',']" is similar to that of the quarterback in the NFL . The quarterback is the leader of the team , True season when he caught an Andrew Luck pass in the 6 [' season', ' when', ' he', ' caught', ' an', ' Andrew', ' Luck']
+505 332 In their sport, the position played by x -1 In their sport, the position played by Damon Allen quarterback Damon Allen "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' is' ' a' ' former' ' NFL' ' quarterback'
+ ' who']" , the former NFL quarterback , is a bit of a mystery . He is a former NFL quarterback who True occasions: Doug Flutie, Damon Allen and Sonny Wade. The 8 [' occasions', ':', ' Doug', ' Fl', 'ut', 'ie', ',', ' Damon', ' Allen']
+506 332 In their sport, the position played by x -1 In their sport, the position played by Damon Allen quarterback Damon Allen "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' is' ' a' ' former' ' NFL' ' quarterback'
+ ' who']" , the former NFL quarterback , is a bit of a mystery . He is a former NFL quarterback who True Chad Owens (2012), Damon Allen (2005), Doug 6 [' Chad', ' Owens', ' (', '2012', '),', ' Damon', ' Allen']
+507 332 In their sport, the position played by x -1 In their sport, the position played by Damon Allen quarterback Damon Allen "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' is' ' a' ' former' ' NFL' ' quarterback'
+ ' who']" , the former NFL quarterback , is a bit of a mystery . He is a former NFL quarterback who True occasions: Doug Flutie, Damon Allen and Sonny Wade. The 8 [' occasions', ':', ' Doug', ' Fl', 'ut', 'ie', ',', ' Damon', ' Allen']
+508 332 In their sport, the position played by x -1 In their sport, the position played by Damon Allen quarterback Damon Allen "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' is' ' a' ' former' ' NFL' ' quarterback'
+ ' who']" , the former NFL quarterback , is a bit of a mystery . He is a former NFL quarterback who True in 2004. Veteran Damon Allen led the team to 5 [' in', ' 2004', '.', ' Veteran', ' Damon', ' Allen']
+509 332 In their sport, the position played by x -1 In their sport, the position played by Damon Allen quarterback Damon Allen "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' is' ' a' ' former' ' NFL' ' quarterback'
+ ' who']" , the former NFL quarterback , is a bit of a mystery . He is a former NFL quarterback who True 2004. Veteran Damon Allen led the team to 4 [' 2004', '.', ' Veteran', ' Damon', ' Allen']
+510 337 In their sport, the position played by x -1 In their sport, the position played by Ryan Giggs midfielder Ryan Giggs "[',' ' the' ' Welsh' 'man' ' has' ' been' ' a' ' key' ' figure' ' in'
+ ' the' ' team' ""'s"" ' success' '.' '\n' '\n' '""' 'I' ' think']" ", the Welsh man has been a key figure in the team 's success .
+
+ "" I think" False bench by winger Ryan Giggs and goalkeeper Tim 5 [' bench', ' by', ' winger', ' Ryan', ' G', 'iggs']
+511 337 In their sport, the position played by x -1 In their sport, the position played by Ryan Giggs midfielder Ryan Giggs "[',' ' the' ' Welsh' 'man' ' has' ' been' ' a' ' key' ' figure' ' in'
+ ' the' ' team' ""'s"" ' success' '.' '\n' '\n' '""' 'I' ' think']" ", the Welsh man has been a key figure in the team 's success .
+
+ "" I think" False semi-final replay; Ryan Giggs scored an extra 7 [' semi', '-', 'final', ' replay', ';', ' Ryan', ' G', 'iggs']
+512 337 In their sport, the position played by x -1 In their sport, the position played by Ryan Giggs midfielder Ryan Giggs "[',' ' the' ' Welsh' 'man' ' has' ' been' ' a' ' key' ' figure' ' in'
+ ' the' ' team' ""'s"" ' success' '.' '\n' '\n' '""' 'I' ' think']" ", the Welsh man has been a key figure in the team 's success .
+
+ "" I think" False " equalled only by Ryan Giggs and Wayne Rooney.
+" 6 [' equ', 'alled', ' only', ' by', ' Ryan', ' G', 'iggs']
+513 337 In their sport, the position played by x -1 In their sport, the position played by Ryan Giggs midfielder Ryan Giggs "[',' ' the' ' Welsh' 'man' ' has' ' been' ' a' ' key' ' figure' ' in'
+ ' the' ' team' ""'s"" ' success' '.' '\n' '\n' '""' 'I' ' think']" ", the Welsh man has been a key figure in the team 's success .
+
+ "" I think" False Scholes and Ryan Giggs were all doubts 5 [' Sch', 'oles', ' and', ' Ryan', ' G', 'iggs']
+514 337 In their sport, the position played by x -1 In their sport, the position played by Ryan Giggs midfielder Ryan Giggs "[',' ' the' ' Welsh' 'man' ' has' ' been' ' a' ' key' ' figure' ' in'
+ ' the' ' team' ""'s"" ' success' '.' '\n' '\n' '""' 'I' ' think']" ", the Welsh man has been a key figure in the team 's success .
+
+ "" I think" False United footballer Ryan Giggs is mentioned 4 [' United', ' footballer', ' Ryan', ' G', 'iggs']
+515 341 In their sport, the position played by x -1 In their sport, the position played by Brian Cushing linebacker Brian Cushing "[',' ' the' ' Texans' ""'"" ' star' ' linebacker' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' he' ' played' ' in' ' college' ' at' ' the'
+ ' University']" , the Texans ' star linebacker , is the same as the one he played in college at the University True Ellis and linebacker Brian Cushing had one each; 5 [' Ellis', ' and', ' linebacker', ' Brian', ' C', 'ushing']
+516 341 In their sport, the position played by x -1 In their sport, the position played by Brian Cushing linebacker Brian Cushing "[',' ' the' ' Texans' ""'"" ' star' ' linebacker' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' he' ' played' ' in' ' college' ' at' ' the'
+ ' University']" , the Texans ' star linebacker , is the same as the one he played in college at the University True Ellis and linebacker Brian Cushing had one each; 5 [' Ellis', ' and', ' linebacker', ' Brian', ' C', 'ushing']
+517 341 In their sport, the position played by x -1 In their sport, the position played by Brian Cushing linebacker Brian Cushing "[',' ' the' ' Texans' ""'"" ' star' ' linebacker' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' he' ' played' ' in' ' college' ' at' ' the'
+ ' University']" , the Texans ' star linebacker , is the same as the one he played in college at the University True (fifth, New York Jets), Brian Cushing (15th, Houston Texans) 9 [' (', 'fifth', ',', ' New', ' York', ' Jets', '),', ' Brian', ' C', 'ushing']
+518 341 In their sport, the position played by x -1 In their sport, the position played by Brian Cushing linebacker Brian Cushing "[',' ' the' ' Texans' ""'"" ' star' ' linebacker' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' he' ' played' ' in' ' college' ' at' ' the'
+ ' University']" , the Texans ' star linebacker , is the same as the one he played in college at the University True key linebacker Brian Cushing to an ankle sprain 4 [' key', ' linebacker', ' Brian', ' C', 'ushing']
+519 341 In their sport, the position played by x -1 In their sport, the position played by Brian Cushing linebacker Brian Cushing "[',' ' the' ' Texans' ""'"" ' star' ' linebacker' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' he' ' played' ' in' ' college' ' at' ' the'
+ ' University']" , the Texans ' star linebacker , is the same as the one he played in college at the University True and linebacker Brian Cushing had one each; cornerback 4 [' and', ' linebacker', ' Brian', ' C', 'ushing']
+520 342 In their sport, the position played by x -1 In their sport, the position played by Kiki Cuyler outfielder Kiki Cuyler "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' game' ',' ' is' ' called' ' the' ' ""' 'C']" ", the player who is the first to score a goal in a game , is called the "" C" False World Series record Kiki Cuyler and Hank Greenberg 7 [' World', ' Series', ' record', ' K', 'iki', ' C', 'uy', 'ler']
+521 343 In their sport, the position played by x -1 In their sport, the position played by Schellas Hyndman midfielder Schellas Hyndman "[' is' ' called' ' the' ' ""' 'swe' 'eper' '""' ' position' '.' ' The'
+ ' swe' 'eper' ' is' ' the' ' last' ' defender' ' in' ' the' ' back'
+ ' line']" " is called the "" swe eper "" position . The swe eper is the last defender in the back line" False Dallas coach Schellas Hyndman complained about how 7 [' Dallas', ' coach', ' Sche', 'll', 'as', ' Hy', 'nd', 'man']
+522 343 In their sport, the position played by x -1 In their sport, the position played by Schellas Hyndman midfielder Schellas Hyndman "[' is' ' called' ' the' ' ""' 'swe' 'eper' '""' ' position' '.' ' The'
+ ' swe' 'eper' ' is' ' the' ' last' ' defender' ' in' ' the' ' back'
+ ' line']" " is called the "" swe eper "" position . The swe eper is the last defender in the back line" False match, Dallas coach Schellas Hyndman complained about 9 [' match', ',', ' Dallas', ' coach', ' Sche', 'll', 'as', ' Hy', 'nd', 'man']
+523 344 In their sport, the position played by x -1 In their sport, the position played by Aaron Murray quarterback Aaron Murray "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True touchdown pass from Aaron Murray to Jay Rome was 4 [' touchdown', ' pass', ' from', ' Aaron', ' Murray']
+524 344 In their sport, the position played by x -1 In their sport, the position played by Aaron Murray quarterback Aaron Murray "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True touchdown pass from Aaron Murray to Jay Rome 4 [' touchdown', ' pass', ' from', ' Aaron', ' Murray']
+525 346 In their sport, the position played by x -1 In their sport, the position played by Eli Whiteside catcher Eli Whiteside "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Eli' ' Whites' 'ide' ',' ' the' '\n' '\n' 'In']" ", the
+
+ In their sport , the position played by Eli Whites ide , the
+
+ In" False 3 ['E', 'li', ' Whites', 'ide']
+526 346 In their sport, the position played by x -1 In their sport, the position played by Eli Whiteside catcher Eli Whiteside "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Eli' ' Whites' 'ide' ',' ' the' '\n' '\n' 'In']" ", the
+
+ In their sport , the position played by Eli Whites ide , the
+
+ In" False 3 ['E', 'li', ' Whites', 'ide']
+527 346 In their sport, the position played by x -1 In their sport, the position played by Eli Whiteside catcher Eli Whiteside "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Eli' ' Whites' 'ide' ',' ' the' '\n' '\n' 'In']" ", the
+
+ In their sport , the position played by Eli Whites ide , the
+
+ In" False " Eli Whiteside =
+" 2 [' Eli', ' Whites', 'ide']
+528 346 In their sport, the position played by x -1 In their sport, the position played by Eli Whiteside catcher Eli Whiteside "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Eli' ' Whites' 'ide' ',' ' the' '\n' '\n' 'In']" ", the
+
+ In their sport , the position played by Eli Whites ide , the
+
+ In" False make room for Eli Whiteside when the Giants 5 [' make', ' room', ' for', ' Eli', ' Whites', 'ide']
+529 346 In their sport, the position played by x -1 In their sport, the position played by Eli Whiteside catcher Eli Whiteside "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Eli' ' Whites' 'ide' ',' ' the' '\n' '\n' 'In']" ", the
+
+ In their sport , the position played by Eli Whites ide , the
+
+ In" False 3 ['E', 'li', ' Whites', 'ide']
+530 348 In their sport, the position played by x -1 In their sport, the position played by Abou Diaby midfielder Abou Diaby "[',' ' the' ' French' ' midfielder' ',' ' is' ' a' ' key' ' one' '.' ' He'
+ ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the French midfielder , is a key one . He is the ful cr um of the team , True altercation with Abou Diaby which led 5 [' altercation', ' with', ' Ab', 'ou', ' Di', 'aby']
+531 348 In their sport, the position played by x -1 In their sport, the position played by Abou Diaby midfielder Abou Diaby "[',' ' the' ' French' ' midfielder' ',' ' is' ' a' ' key' ' one' '.' ' He'
+ ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the French midfielder , is a key one . He is the ful cr um of the team , True " Abou Diaby =
+" 3 [' Ab', 'ou', ' Di', 'aby']
+532 348 In their sport, the position played by x -1 In their sport, the position played by Abou Diaby midfielder Abou Diaby "[',' ' the' ' French' ' midfielder' ',' ' is' ' a' ' key' ' one' '.' ' He'
+ ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the French midfielder , is a key one . He is the ful cr um of the team , True formation; Abou Diaby and Theo Walcott 5 [' formation', ';', ' Ab', 'ou', ' Di', 'aby']
+533 348 In their sport, the position played by x -1 In their sport, the position played by Abou Diaby midfielder Abou Diaby "[',' ' the' ' French' ' midfielder' ',' ' is' ' a' ' key' ' one' '.' ' He'
+ ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the French midfielder , is a key one . He is the ful cr um of the team , True " Diaby =
+" 6 [' Di', 'aby', ' =', 'Ab', 'ou', ' Di', 'aby']
+534 348 In their sport, the position played by x -1 In their sport, the position played by Abou Diaby midfielder Abou Diaby "[',' ' the' ' French' ' midfielder' ',' ' is' ' a' ' key' ' one' '.' ' He'
+ ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the French midfielder , is a key one . He is the ful cr um of the team , True altercation with Abou Diaby which led to 5 [' altercation', ' with', ' Ab', 'ou', ' Di', 'aby']
+535 349 In their sport, the position played by x -1 In their sport, the position played by John Smoltz pitcher John Smoltz "[',' ' the' ' pitcher' ',' ' is' ' the' ' most' ' important' '.' ' He'
+ ' is' ' the' ' one' ' who' ' has' ' to' ' make' ' the' ' final'
+ ' decision']" , the pitcher , is the most important . He is the one who has to make the final decision True game-winning RBI against John Smoltz in June, and 8 [' game', '-', 'winning', ' RBI', ' against', ' John', ' Sm', 'olt', 'z']
+536 349 In their sport, the position played by x -1 In their sport, the position played by John Smoltz pitcher John Smoltz "[',' ' the' ' pitcher' ',' ' is' ' the' ' most' ' important' '.' ' He'
+ ' is' ' the' ' one' ' who' ' has' ' to' ' make' ' the' ' final'
+ ' decision']" , the pitcher , is the most important . He is the one who has to make the final decision True game-winning RBI against John Smoltz in June, and stealing 8 [' game', '-', 'winning', ' RBI', ' against', ' John', ' Sm', 'olt', 'z']
+537 349 In their sport, the position played by x -1 In their sport, the position played by John Smoltz pitcher John Smoltz "[',' ' the' ' pitcher' ',' ' is' ' the' ' most' ' important' '.' ' He'
+ ' is' ' the' ' one' ' who' ' has' ' to' ' make' ' the' ' final'
+ ' decision']" , the pitcher , is the most important . He is the one who has to make the final decision True Vasgersian and John Smoltz with play-by-play 7 [' Vas', 'gers', 'ian', ' and', ' John', ' Sm', 'olt', 'z']
+538 350 In their sport, the position played by x -1 In their sport, the position played by Jacques Cloutier goaltender Jacques Cloutier "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' game' ',' ' is' ' called' ' the' ' ""' 'C']" ", the player who is the first to score a goal in a game , is called the "" C" False goaltender Jacques Cloutier that year), Hašek 4 [' goaltender', ' Jacques', ' Cl', 'out', 'ier']
+539 353 In their sport, the position played by x -1 In their sport, the position played by Bob Lemon pitcher Bob Lemon "[',' ' the' ' former' ' Chicago' ' White' ' Sox' ' pitcher' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , the former Chicago White Sox pitcher , is a perfect example of the kind of player who can be True 1 ['Bob', ' Lemon']
+540 353 In their sport, the position played by x -1 In their sport, the position played by Bob Lemon pitcher Bob Lemon "[',' ' the' ' former' ' Chicago' ' White' ' Sox' ' pitcher' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , the former Chicago White Sox pitcher , is a perfect example of the kind of player who can be True " heavy ball."" Bob Lemon describes his pitches" 4 "[' heavy', ' ball', '.""', ' Bob', ' Lemon']"
+541 353 In their sport, the position played by x -1 In their sport, the position played by Bob Lemon pitcher Bob Lemon "[',' ' the' ' former' ' Chicago' ' White' ' Sox' ' pitcher' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , the former Chicago White Sox pitcher , is a perfect example of the kind of player who can be True Bob Feller, Bob Lemon and Early 5 [' Bob', ' F', 'eller', ',', ' Bob', ' Lemon']
+542 353 In their sport, the position played by x -1 In their sport, the position played by Bob Lemon pitcher Bob Lemon "[',' ' the' ' former' ' Chicago' ' White' ' Sox' ' pitcher' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , the former Chicago White Sox pitcher , is a perfect example of the kind of player who can be True Garcia, Bob Feller, Bob Lemon and Early Wynn. 7 [' Garcia', ',', ' Bob', ' F', 'eller', ',', ' Bob', ' Lemon']
+543 353 In their sport, the position played by x -1 In their sport, the position played by Bob Lemon pitcher Bob Lemon "[',' ' the' ' former' ' Chicago' ' White' ' Sox' ' pitcher' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be']" , the former Chicago White Sox pitcher , is a perfect example of the kind of player who can be True resigned and Bob Lemon was hired 3 [' resigned', ' and', ' Bob', ' Lemon']
+544 354 In their sport, the position played by x -1 In their sport, the position played by Zach Mettenberger quarterback Zach Mettenberger "[' is' ' the' ' quarterback' '.' ' The' ' position' ' is' ' the'
+ ' quarterback' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of'
+ ' the' ' team' '.' ' The']" is the quarterback . The position is the quarterback . The quarterback is the leader of the team . The True touchdown pass from Zach Mettenberger early in the 6 [' touchdown', ' pass', ' from', ' Zach', ' Met', 'ten', 'berger']
+545 354 In their sport, the position played by x -1 In their sport, the position played by Zach Mettenberger quarterback Zach Mettenberger "[' is' ' the' ' quarterback' '.' ' The' ' position' ' is' ' the'
+ ' quarterback' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of'
+ ' the' ' team' '.' ' The']" is the quarterback . The position is the quarterback . The quarterback is the leader of the team . The True touchdown pass from Zach Mettenberger early in the fourth 6 [' touchdown', ' pass', ' from', ' Zach', ' Met', 'ten', 'berger']
+546 354 In their sport, the position played by x -1 In their sport, the position played by Zach Mettenberger quarterback Zach Mettenberger "[' is' ' the' ' quarterback' '.' ' The' ' position' ' is' ' the'
+ ' quarterback' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of'
+ ' the' ' team' '.' ' The']" is the quarterback . The position is the quarterback . The quarterback is the leader of the team . The True pass from Zach Mettenberger early in the 5 [' pass', ' from', ' Zach', ' Met', 'ten', 'berger']
+547 354 In their sport, the position played by x -1 In their sport, the position played by Zach Mettenberger quarterback Zach Mettenberger "[' is' ' the' ' quarterback' '.' ' The' ' position' ' is' ' the'
+ ' quarterback' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of'
+ ' the' ' team' '.' ' The']" is the quarterback . The position is the quarterback . The quarterback is the leader of the team . The True first sack against Zach Mettenberger of the TennesseeTitans. 6 [' first', ' sack', ' against', ' Zach', ' Met', 'ten', 'berger']
+548 354 In their sport, the position played by x -1 In their sport, the position played by Zach Mettenberger quarterback Zach Mettenberger "[' is' ' the' ' quarterback' '.' ' The' ' position' ' is' ' the'
+ ' quarterback' '.' ' The' ' quarterback' ' is' ' the' ' leader' ' of'
+ ' the' ' team' '.' ' The']" is the quarterback . The position is the quarterback . The quarterback is the leader of the team . The True the game when a Zach Mettenberger fumble was 7 [' the', ' game', ' when', ' a', ' Zach', ' Met', 'ten', 'berger']
+549 357 In their sport, the position played by x -1 In their sport, the position played by Clint Benedict goaltender Clint Benedict "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False which was shared with Clint Benedict of the Ottawa Senators, 5 [' which', ' was', ' shared', ' with', ' Clint', ' Benedict']
+550 357 In their sport, the position played by x -1 In their sport, the position played by Clint Benedict goaltender Clint Benedict "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False the O 'Brien Cup. Clint Benedict again led league 7 "[' the', ' O', "" '"", 'Brien', ' Cup', '.', ' Clint', ' Benedict']"
+551 357 In their sport, the position played by x -1 In their sport, the position played by Clint Benedict goaltender Clint Benedict "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False and the O 'Brien Cup. Clint Benedict again led league in 8 "[' and', ' the', ' O', "" '"", 'Brien', ' Cup', '.', ' Clint', ' Benedict']"
+552 357 In their sport, the position played by x -1 In their sport, the position played by Clint Benedict goaltender Clint Benedict "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False series, goaltender Clint Benedict became embroiled 4 [' series', ',', ' goaltender', ' Clint', ' Benedict']
+553 357 In their sport, the position played by x -1 In their sport, the position played by Clint Benedict goaltender Clint Benedict "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False Hall of Famer Clint Benedict became the Senators' 5 [' Hall', ' of', ' F', 'amer', ' Clint', ' Benedict']
+554 359 In their sport, the position played by x -1 In their sport, the position played by Elvis Grbac quarterback Elvis Grbac "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Chiefs quarterback Elvis Grbac was lost for 5 [' Chiefs', ' quarterback', ' Elvis', ' Gr', 'b', 'ac']
+555 359 In their sport, the position played by x -1 In their sport, the position played by Elvis Grbac quarterback Elvis Grbac "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Chiefs quarterback Elvis Grbac was lost for the 5 [' Chiefs', ' quarterback', ' Elvis', ' Gr', 'b', 'ac']
+556 359 In their sport, the position played by x -1 In their sport, the position played by Elvis Grbac quarterback Elvis Grbac "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True protected Elvis Grbac while he was connecting 4 [' protected', ' Elvis', ' Gr', 'b', 'ac']
+557 359 In their sport, the position played by x -1 In their sport, the position played by Elvis Grbac quarterback Elvis Grbac "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True line that protected Elvis Grbac while he was 6 [' line', ' that', ' protected', ' Elvis', ' Gr', 'b', 'ac']
+558 359 In their sport, the position played by x -1 In their sport, the position played by Elvis Grbac quarterback Elvis Grbac "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True line that protected Elvis Grbac while he was connecting 6 [' line', ' that', ' protected', ' Elvis', ' Gr', 'b', 'ac']
+559 360 In their sport, the position played by x -1 In their sport, the position played by Pelle Lindbergh goaltender Pelle Lindbergh "[',' ' the' ' Swedish' ' player' ',' ' is' ' called' ' the' ' ""' 'back'
+ 'hand' '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the Swedish player , is called the "" back hand "" and the position played by the American player" False centre earned him the Pelle Lindbergh Memorial from his 8 [' centre', ' earned', ' him', ' the', ' Pel', 'le', ' Lind', 'ber', 'gh']
+560 360 In their sport, the position played by x -1 In their sport, the position played by Pelle Lindbergh goaltender Pelle Lindbergh "[',' ' the' ' Swedish' ' player' ',' ' is' ' called' ' the' ' ""' 'back'
+ 'hand' '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the Swedish player , is called the "" back hand "" and the position played by the American player" False he was awarded the Pelle Lindbergh Memorial Trophy, 8 [' he', ' was', ' awarded', ' the', ' Pel', 'le', ' Lind', 'ber', 'gh']
+561 360 In their sport, the position played by x -1 In their sport, the position played by Pelle Lindbergh goaltender Pelle Lindbergh "[',' ' the' ' Swedish' ' player' ',' ' is' ' called' ' the' ' ""' 'back'
+ 'hand' '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the Swedish player , is called the "" back hand "" and the position played by the American player" False awarded the Pelle Lindbergh Memorial Trophy, 6 [' awarded', ' the', ' Pel', 'le', ' Lind', 'ber', 'gh']
+562 360 In their sport, the position played by x -1 In their sport, the position played by Pelle Lindbergh goaltender Pelle Lindbergh "[',' ' the' ' Swedish' ' player' ',' ' is' ' called' ' the' ' ""' 'back'
+ 'hand' '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the Swedish player , is called the "" back hand "" and the position played by the American player" False awarded the Pelle Lindbergh Memorial Trophy, 6 [' awarded', ' the', ' Pel', 'le', ' Lind', 'ber', 'gh']
+563 360 In their sport, the position played by x -1 In their sport, the position played by Pelle Lindbergh goaltender Pelle Lindbergh "[',' ' the' ' Swedish' ' player' ',' ' is' ' called' ' the' ' ""' 'back'
+ 'hand' '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the Swedish player , is called the "" back hand "" and the position played by the American player" False awarded the Pelle Lindbergh Memorial Trophy, 6 [' awarded', ' the', ' Pel', 'le', ' Lind', 'ber', 'gh']
+564 371 In their sport, the position played by x -1 In their sport, the position played by Kevin Thomson midfielder Kevin Thomson "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True Whittaker, Scott Brown, Kevin Thomson and Steven Fletcher, 8 [' Wh', 'itt', 'aker', ',', ' Scott', ' Brown', ',', ' Kevin', ' Thomson']
+565 371 In their sport, the position played by x -1 In their sport, the position played by Kevin Thomson midfielder Kevin Thomson "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True Derek Riordan, Kevin Thomson and Scott Brown. 5 [' Derek', ' Ri', 'ordan', ',', ' Kevin', ' Thomson']
+566 371 In their sport, the position played by x -1 In their sport, the position played by Kevin Thomson midfielder Kevin Thomson "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True Whittaker, Scott Brown, Kevin Thomson and Steven Fletcher, 8 [' Wh', 'itt', 'aker', ',', ' Scott', ' Brown', ',', ' Kevin', ' Thomson']
+567 371 In their sport, the position played by x -1 In their sport, the position played by Kevin Thomson midfielder Kevin Thomson "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True Derek Riordan, Kevin Thomson and Scott Brown. 5 [' Derek', ' Ri', 'ordan', ',', ' Kevin', ' Thomson']
+568 371 In their sport, the position played by x -1 In their sport, the position played by Kevin Thomson midfielder Kevin Thomson "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True Derek Riordan, Kevin Thomson and Scott Brown. 5 [' Derek', ' Ri', 'ordan', ',', ' Kevin', ' Thomson']
+569 372 In their sport, the position played by x -1 In their sport, the position played by Dan Bouchard goaltender Dan Bouchard "[',' ' the' ' player' ' who' ' is' ' the' ' most' ' likely' ' to' ' be'
+ ' traded' ',' ' is' ' the' ' one' ' that' ' is' ' most' ' likely' ' to']" , the player who is the most likely to be traded , is the one that is most likely to False selection and rookie Dan Bouchard with his second. 6 [' selection', ' and', ' rookie', ' Dan', ' B', 'ouch', 'ard']
+570 372 In their sport, the position played by x -1 In their sport, the position played by Dan Bouchard goaltender Dan Bouchard "[',' ' the' ' player' ' who' ' is' ' the' ' most' ' likely' ' to' ' be'
+ ' traded' ',' ' is' ' the' ' one' ' that' ' is' ' most' ' likely' ' to']" , the player who is the most likely to be traded , is the one that is most likely to False Goaltender Dan Bouchard led the team in 6 [' Go', 'alt', 'ender', ' Dan', ' B', 'ouch', 'ard']
+571 372 In their sport, the position played by x -1 In their sport, the position played by Dan Bouchard goaltender Dan Bouchard "[',' ' the' ' player' ' who' ' is' ' the' ' most' ' likely' ' to' ' be'
+ ' traded' ',' ' is' ' the' ' one' ' that' ' is' ' most' ' likely' ' to']" , the player who is the most likely to be traded , is the one that is most likely to False selection and rookie Dan Bouchard with his second. 6 [' selection', ' and', ' rookie', ' Dan', ' B', 'ouch', 'ard']
+572 373 In their sport, the position played by x -1 In their sport, the position played by Jim Druckenmiller quarterback Jim Druckenmiller "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Jim' ' Dru' 'ck' 'en' 'm' 'iller' ',' ' the']" ", the
+
+ In their sport , the position played by Jim Dru ck en m iller , the" False quarterback Jim Druckenmiller threw a touchdown 6 [' quarterback', ' Jim', ' Dru', 'ck', 'en', 'm', 'iller']
+573 373 In their sport, the position played by x -1 In their sport, the position played by Jim Druckenmiller quarterback Jim Druckenmiller "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Jim' ' Dru' 'ck' 'en' 'm' 'iller' ',' ' the']" ", the
+
+ In their sport , the position played by Jim Dru ck en m iller , the" False quarterback Jim Druckenmiller entered the 6 [' quarterback', ' Jim', ' Dru', 'ck', 'en', 'm', 'iller']
+574 373 In their sport, the position played by x -1 In their sport, the position played by Jim Druckenmiller quarterback Jim Druckenmiller "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Jim' ' Dru' 'ck' 'en' 'm' 'iller' ',' ' the']" ", the
+
+ In their sport , the position played by Jim Dru ck en m iller , the" False quarterback Jim Druckenmiller completed 16 of 6 [' quarterback', ' Jim', ' Dru', 'ck', 'en', 'm', 'iller']
+575 373 In their sport, the position played by x -1 In their sport, the position played by Jim Druckenmiller quarterback Jim Druckenmiller "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Jim' ' Dru' 'ck' 'en' 'm' 'iller' ',' ' the']" ", the
+
+ In their sport , the position played by Jim Dru ck en m iller , the" False Tech quarterback Jim Druckenmiller threw a touchdown pass 7 [' Tech', ' quarterback', ' Jim', ' Dru', 'ck', 'en', 'm', 'iller']
+576 373 In their sport, the position played by x -1 In their sport, the position played by Jim Druckenmiller quarterback Jim Druckenmiller "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Jim' ' Dru' 'ck' 'en' 'm' 'iller' ',' ' the']" ", the
+
+ In their sport , the position played by Jim Dru ck en m iller , the" False quarterback Jim Druckenmiller came into the 6 [' quarterback', ' Jim', ' Dru', 'ck', 'en', 'm', 'iller']
+577 378 In their sport, the position played by x -1 In their sport, the position played by David Batty midfielder David Batty "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' key' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' in' ' the' ' squad' ' who']" , the former England captain , is a key one . He is the only player in the squad who False header. Beardsley and David Batty were booked and 8 [' header', '.', ' Be', 'ards', 'ley', ' and', ' David', ' Bat', 'ty']
+578 378 In their sport, the position played by x -1 In their sport, the position played by David Batty midfielder David Batty "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' key' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' in' ' the' ' squad' ' who']" , the former England captain , is a key one . He is the only player in the squad who False Beardsley and David Batty were booked and 6 [' Be', 'ards', 'ley', ' and', ' David', ' Bat', 'ty']
+579 378 In their sport, the position played by x -1 In their sport, the position played by David Batty midfielder David Batty "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' key' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' in' ' the' ' squad' ' who']" , the former England captain , is a key one . He is the only player in the squad who False Beardsley and David Batty were booked and 6 [' Be', 'ards', 'ley', ' and', ' David', ' Bat', 'ty']
+580 378 In their sport, the position played by x -1 In their sport, the position played by David Batty midfielder David Batty "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' key' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' in' ' the' ' squad' ' who']" , the former England captain , is a key one . He is the only player in the squad who False header. Beardsley and David Batty were booked 8 [' header', '.', ' Be', 'ards', 'ley', ' and', ' David', ' Bat', 'ty']
+581 378 In their sport, the position played by x -1 In their sport, the position played by David Batty midfielder David Batty "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' key' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' in' ' the' ' squad' ' who']" , the former England captain , is a key one . He is the only player in the squad who False Beardsley and David Batty were booked 6 [' Be', 'ards', 'ley', ' and', ' David', ' Bat', 'ty']
+582 382 In their sport, the position played by x -1 In their sport, the position played by Dewey Selmon linebacker Dewey Selmon "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False fourth-quarter field goal, Dewey Selmon forced a Detroit fumble 9 [' fourth', '-', 'quarter', ' field', ' goal', ',', ' Dew', 'ey', ' Sel', 'mon']
+583 382 In their sport, the position played by x -1 In their sport, the position played by Dewey Selmon linebacker Dewey Selmon "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False fourth-quarter field goal, Dewey Selmon forced a Detroit 9 [' fourth', '-', 'quarter', ' field', ' goal', ',', ' Dew', 'ey', ' Sel', 'mon']
+584 382 In their sport, the position played by x -1 In their sport, the position played by Dewey Selmon linebacker Dewey Selmon "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False Morris Owens. Dewey Selmon suffered a serious 6 [' Morris', ' Owens', '.', ' Dew', 'ey', ' Sel', 'mon']
+585 382 In their sport, the position played by x -1 In their sport, the position played by Dewey Selmon linebacker Dewey Selmon "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False to Morris Owens. Dewey Selmon suffered a serious 7 [' to', ' Morris', ' Owens', '.', ' Dew', 'ey', ' Sel', 'mon']
+586 383 In their sport, the position played by x -1 In their sport, the position played by Kelly Hrudey goaltender Kelly Hrudey "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' game' ',' ' is' ' called' ' the' ' ""' 'goal']" ", the player who is the first to score a goal in a game , is called the "" goal" False Hockey analyst Kelly Hrudey compared Crosby 6 [' Hockey', ' analyst', ' Kelly', ' H', 'r', 'ude', 'y']
+587 383 In their sport, the position played by x -1 In their sport, the position played by Kelly Hrudey goaltender Kelly Hrudey "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' game' ',' ' is' ' called' ' the' ' ""' 'goal']" ", the player who is the first to score a goal in a game , is called the "" goal" False Hockey analyst Kelly Hrudey compared Crosby 6 [' Hockey', ' analyst', ' Kelly', ' H', 'r', 'ude', 'y']
+588 383 In their sport, the position played by x -1 In their sport, the position played by Kelly Hrudey goaltender Kelly Hrudey "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' game' ',' ' is' ' called' ' the' ' ""' 'goal']" ", the player who is the first to score a goal in a game , is called the "" goal" False officials. Hockey analyst Kelly Hrudey compared Crosby to 8 [' officials', '.', ' Hockey', ' analyst', ' Kelly', ' H', 'r', 'ude', 'y']
+589 383 In their sport, the position played by x -1 In their sport, the position played by Kelly Hrudey goaltender Kelly Hrudey "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' game' ',' ' is' ' called' ' the' ' ""' 'goal']" ", the player who is the first to score a goal in a game , is called the "" goal" False Hockey analyst Kelly Hrudey compared Crosby to 6 [' Hockey', ' analyst', ' Kelly', ' H', 'r', 'ude', 'y']
+590 383 In their sport, the position played by x -1 In their sport, the position played by Kelly Hrudey goaltender Kelly Hrudey "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' score' ' a'
+ ' goal' ' in' ' a' ' game' ',' ' is' ' called' ' the' ' ""' 'goal']" ", the player who is the first to score a goal in a game , is called the "" goal" False officials. Hockey analyst Kelly Hrudey compared Crosby to 8 [' officials', '.', ' Hockey', ' analyst', ' Kelly', ' H', 'r', 'ude', 'y']
+591 384 In their sport, the position played by x -1 In their sport, the position played by Fernando Torres forward Fernando Torres "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who has been a revelation in the Premier League this season , is a perfect example of this . False substituted for Fernando Torres for the second 3 [' substituted', ' for', ' Fernando', ' Torres']
+592 384 In their sport, the position played by x -1 In their sport, the position played by Fernando Torres forward Fernando Torres "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who has been a revelation in the Premier League this season , is a perfect example of this . False header before Fernando Torres received a long pass 3 [' header', ' before', ' Fernando', ' Torres']
+593 384 In their sport, the position played by x -1 In their sport, the position played by Fernando Torres forward Fernando Torres "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who has been a revelation in the Premier League this season , is a perfect example of this . False Yossi Benayoun and Fernando Torres for a club 8 [' Y', 'oss', 'i', ' Ben', 'ay', 'oun', ' and', ' Fernando', ' Torres']
+594 384 In their sport, the position played by x -1 In their sport, the position played by Fernando Torres forward Fernando Torres "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who has been a revelation in the Premier League this season , is a perfect example of this . False the second half, Fernando Torres ended a goal drought 5 [' the', ' second', ' half', ',', ' Fernando', ' Torres']
+595 384 In their sport, the position played by x -1 In their sport, the position played by Fernando Torres forward Fernando Torres "[',' ' who' ' has' ' been' ' a' ' revelation' ' in' ' the' ' Premier'
+ ' League' ' this' ' season' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who has been a revelation in the Premier League this season , is a perfect example of this . False " Fernando Torres =
+" 1 [' Fernando', ' Torres']
+596 387 In their sport, the position played by x -1 In their sport, the position played by Mike Pawlawski quarterback Mike Pawlawski "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' II' ' national' ' championship' ' in'
+ ' 2011' ',' ' is']" , who was a member of the team that won the NCAA Division II national championship in 2011 , is False quarterback job behind Mike Pawlawski while DeBerg and Craig 6 [' quarterback', ' job', ' behind', ' Mike', ' Paw', 'law', 'ski']
+597 390 In their sport, the position played by x -1 In their sport, the position played by Sean Burke goaltender Sean Burke "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' he' ' was' ' a'
+ ' great' ' player' '.' ' He' ' was' ' a' ' great' ' player' '.']" , who was a great player , but he was a great player . He was a great player . False on goaltender Sean Burke of the Phoenix Coyotes. 3 [' on', ' goaltender', ' Sean', ' Burke']
+598 390 In their sport, the position played by x -1 In their sport, the position played by Sean Burke goaltender Sean Burke "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' he' ' was' ' a'
+ ' great' ' player' '.' ' He' ' was' ' a' ' great' ' player' '.']" , who was a great player , but he was a great player . He was a great player . False against goaltender Sean Burke of the New 3 [' against', ' goaltender', ' Sean', ' Burke']
+599 390 In their sport, the position played by x -1 In their sport, the position played by Sean Burke goaltender Sean Burke "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' he' ' was' ' a'
+ ' great' ' player' '.' ' He' ' was' ' a' ' great' ' player' '.']" , who was a great player , but he was a great player . He was a great player . False goaltender Sean Burke of the New Jersey 2 [' goaltender', ' Sean', ' Burke']
+600 390 In their sport, the position played by x -1 In their sport, the position played by Sean Burke goaltender Sean Burke "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' he' ' was' ' a'
+ ' great' ' player' '.' ' He' ' was' ' a' ' great' ' player' '.']" , who was a great player , but he was a great player . He was a great player . False 2000, on goaltender Sean Burke of the Phoenix 5 [' 2000', ',', ' on', ' goaltender', ' Sean', ' Burke']
+601 390 In their sport, the position played by x -1 In their sport, the position played by Sean Burke goaltender Sean Burke "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' he' ' was' ' a'
+ ' great' ' player' '.' ' He' ' was' ' a' ' great' ' player' '.']" , who was a great player , but he was a great player . He was a great player . False 2000, on goaltender Sean Burke of the Phoenix Coyotes. 5 [' 2000', ',', ' on', ' goaltender', ' Sean', ' Burke']
+602 391 In their sport, the position played by x -1 In their sport, the position played by Logan Thomas quarterback Logan Thomas "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL True intercepted a Logan Thomas pass and returned it 3 [' intercepted', ' a', ' Logan', ' Thomas']
+603 391 In their sport, the position played by x -1 In their sport, the position played by Logan Thomas quarterback Logan Thomas "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL True Sunseri intercepted a Logan Thomas pass and returned 6 [' Sun', 'ser', 'i', ' intercepted', ' a', ' Logan', ' Thomas']
+604 391 In their sport, the position played by x -1 In their sport, the position played by Logan Thomas quarterback Logan Thomas "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL True Sunseri intercepted a Logan Thomas pass and returned 6 [' Sun', 'ser', 'i', ' intercepted', ' a', ' Logan', ' Thomas']
+605 393 In their sport, the position played by x -1 In their sport, the position played by Jason Campbell quarterback Jason Campbell "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Michigan' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who was a quarterback at the University of Michigan , is a perfect example of the type of player True Rogers (ninth), and Jason Campbell (25th). Virginia 7 [' Rogers', ' (', 'n', 'inth', '),', ' and', ' Jason', ' Campbell']
+606 393 In their sport, the position played by x -1 In their sport, the position played by Jason Campbell quarterback Jason Campbell "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Michigan' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who was a quarterback at the University of Michigan , is a perfect example of the type of player True play, quarterback Jason Campbell threw a long pass 4 [' play', ',', ' quarterback', ' Jason', ' Campbell']
+607 393 In their sport, the position played by x -1 In their sport, the position played by Jason Campbell quarterback Jason Campbell "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Michigan' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who was a quarterback at the University of Michigan , is a perfect example of the type of player True three yards, and Jason Campbell threw a seven-yard 5 [' three', ' yards', ',', ' and', ' Jason', ' Campbell']
+608 393 In their sport, the position played by x -1 In their sport, the position played by Jason Campbell quarterback Jason Campbell "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Michigan' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who was a quarterback at the University of Michigan , is a perfect example of the type of player True days later, Jason Campbell was signed by the Cleveland 4 [' days', ' later', ',', ' Jason', ' Campbell']
+609 393 In their sport, the position played by x -1 In their sport, the position played by Jason Campbell quarterback Jason Campbell "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Michigan' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who was a quarterback at the University of Michigan , is a perfect example of the type of player True three yards, and Jason Campbell threw a seven-yard 5 [' three', ' yards', ',', ' and', ' Jason', ' Campbell']
+610 399 In their sport, the position played by x -1 In their sport, the position played by Bryan Robson midfielder Bryan Robson "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' '�' '�' 'I' '�' '�' 'm' ' not']" ", who was a great player and a great person .
+
+ � � I � � m not" False Waddle, Barnes, Bryan Robson and Paul Gascoigne 7 [' W', 'addle', ',', ' Barnes', ',', ' Bryan', ' Rob', 'son']
+611 399 In their sport, the position played by x -1 In their sport, the position played by Bryan Robson midfielder Bryan Robson "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' '�' '�' 'I' '�' '�' 'm' ' not']" ", who was a great player and a great person .
+
+ � � I � � m not" False and captain Bryan Robson was injured 4 [' and', ' captain', ' Bryan', ' Rob', 'son']
+612 399 In their sport, the position played by x -1 In their sport, the position played by Bryan Robson midfielder Bryan Robson "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' '�' '�' 'I' '�' '�' 'm' ' not']" ", who was a great player and a great person .
+
+ � � I � � m not" False 3 ['B', 'ryan', ' Rob', 'son']
+613 399 In their sport, the position played by x -1 In their sport, the position played by Bryan Robson midfielder Bryan Robson "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' '�' '�' 'I' '�' '�' 'm' ' not']" ", who was a great player and a great person .
+
+ � � I � � m not" False Albion manager Bryan Robson became manager at Sheffield 4 [' Albion', ' manager', ' Bryan', ' Rob', 'son']
+614 399 In their sport, the position played by x -1 In their sport, the position played by Bryan Robson midfielder Bryan Robson "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' '�' '�' 'I' '�' '�' 'm' ' not']" ", who was a great player and a great person .
+
+ � � I � � m not" False transfer fee to sign Bryan Robson from West Bromwich 6 [' transfer', ' fee', ' to', ' sign', ' Bryan', ' Rob', 'son']
+615 401 In their sport, the position played by x -1 In their sport, the position played by Todd Helton quarterback Todd Helton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great person .
+
+ I think the only thing" False only teammate Todd Helton in batting 4 [' only', ' teammate', ' Todd', ' Hel', 'ton']
+616 401 In their sport, the position played by x -1 In their sport, the position played by Todd Helton quarterback Todd Helton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great person .
+
+ I think the only thing" False of this article, Todd Helton had met all five 6 [' of', ' this', ' article', ',', ' Todd', ' Hel', 'ton']
+617 401 In their sport, the position played by x -1 In their sport, the position played by Todd Helton quarterback Todd Helton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great person .
+
+ I think the only thing" False quarterback Todd Helton was hurt in the 3 [' quarterback', ' Todd', ' Hel', 'ton']
+618 401 In their sport, the position played by x -1 In their sport, the position played by Todd Helton quarterback Todd Helton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great person .
+
+ I think the only thing" False State, starter Todd Helton got injured and 5 [' State', ',', ' starter', ' Todd', ' Hel', 'ton']
+619 401 In their sport, the position played by x -1 In their sport, the position played by Todd Helton quarterback Todd Helton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great person .
+
+ I think the only thing" False of this article, Todd Helton had met all five 6 [' of', ' this', ' article', ',', ' Todd', ' Hel', 'ton']
+620 404 In their sport, the position played by x -1 In their sport, the position played by Dennis Shaw quarterback Dennis Shaw "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False starting quarterback Dennis Shaw found himself 3 [' starting', ' quarterback', ' Dennis', ' Shaw']
+621 404 In their sport, the position played by x -1 In their sport, the position played by Dennis Shaw quarterback Dennis Shaw "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False games at quarterback. Dennis Shaw who had been 5 [' games', ' at', ' quarterback', '.', ' Dennis', ' Shaw']
+622 404 In their sport, the position played by x -1 In their sport, the position played by Dennis Shaw quarterback Dennis Shaw "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False starting quarterback Dennis Shaw found himself 3 [' starting', ' quarterback', ' Dennis', ' Shaw']
+623 411 In their sport, the position played by x -1 In their sport, the position played by Dean Furman midfielder Dean Furman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2011' ',' ' is'
+ ' a']" , who was a member of the team that won the NCAA Division I Championship in 2011 , is a False Joe Colbeck, Dean Furman and Nicky Law. 6 [' Joe', ' Col', 'beck', ',', ' Dean', ' Fur', 'man']
+624 411 In their sport, the position played by x -1 In their sport, the position played by Dean Furman midfielder Dean Furman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2011' ',' ' is'
+ ' a']" , who was a member of the team that won the NCAA Division I Championship in 2011 , is a False Joe Colbeck, Dean Furman and Nicky Law. 6 [' Joe', ' Col', 'beck', ',', ' Dean', ' Fur', 'man']
+625 411 In their sport, the position played by x -1 In their sport, the position played by Dean Furman midfielder Dean Furman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2011' ',' ' is'
+ ' a']" , who was a member of the team that won the NCAA Division I Championship in 2011 , is a False Joe Colbeck, Dean Furman and Nicky Law. 6 [' Joe', ' Col', 'beck', ',', ' Dean', ' Fur', 'man']
+626 411 In their sport, the position played by x -1 In their sport, the position played by Dean Furman midfielder Dean Furman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2011' ',' ' is'
+ ' a']" , who was a member of the team that won the NCAA Division I Championship in 2011 , is a False colleagues, Joe Colbeck, Dean Furman and Nicky Law. Although 8 [' colleagues', ',', ' Joe', ' Col', 'beck', ',', ' Dean', ' Fur', 'man']
+627 411 In their sport, the position played by x -1 In their sport, the position played by Dean Furman midfielder Dean Furman "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2011' ',' ' is'
+ ' a']" , who was a member of the team that won the NCAA Division I Championship in 2011 , is a False Joe Colbeck, Dean Furman and Nicky Law. 6 [' Joe', ' Col', 'beck', ',', ' Dean', ' Fur', 'man']
+628 415 In their sport, the position played by x -1 In their sport, the position played by Tim Raines outfielder Tim Raines "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", who was a great player , but not a great person .
+
+ I 'm not sure if" False to leadoff hitter Tim Raines while batting in the 6 [' to', ' lead', 'off', ' hitter', ' Tim', ' Rain', 'es']
+629 415 In their sport, the position played by x -1 In their sport, the position played by Tim Raines outfielder Tim Raines "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", who was a great player , but not a great person .
+
+ I 'm not sure if" False to leadoff hitter Tim Raines while batting 6 [' to', ' lead', 'off', ' hitter', ' Tim', ' Rain', 'es']
+630 415 In their sport, the position played by x -1 In their sport, the position played by Tim Raines outfielder Tim Raines "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", who was a great player , but not a great person .
+
+ I 'm not sure if" False leadoff hitter Tim Raines while batting in the 5 [' lead', 'off', ' hitter', ' Tim', ' Rain', 'es']
+631 416 In their sport, the position played by x -1 In their sport, the position played by Colin Kaepernick quarterback Colin Kaepernick "[',' ' who' ' has' ' been' ' a' ' great' ' leader' ' for' ' the' ' 49'
+ 'ers' ',' ' and' ' the' ' team' ',' ' and' ' the' ' city' ' of']" , who has been a great leader for the 49 ers , and the team , and the city of False Player of the Year Colin Kaepernick ran for more than 1,100 5 [' Player', ' of', ' the', ' Year', ' Colin', ' Kaepernick']
+632 416 In their sport, the position played by x -1 In their sport, the position played by Colin Kaepernick quarterback Colin Kaepernick "[',' ' who' ' has' ' been' ' a' ' great' ' leader' ' for' ' the' ' 49'
+ 'ers' ',' ' and' ' the' ' team' ',' ' and' ' the' ' city' ' of']" , who has been a great leader for the 49 ers , and the team , and the city of False dual-threat quarterback Colin Kaepernick and running back 5 [' dual', '-', 'threat', ' quarterback', ' Colin', ' Kaepernick']
+633 416 In their sport, the position played by x -1 In their sport, the position played by Colin Kaepernick quarterback Colin Kaepernick "[',' ' who' ' has' ' been' ' a' ' great' ' leader' ' for' ' the' ' 49'
+ 'ers' ',' ' and' ' the' ' team' ',' ' and' ' the' ' city' ' of']" , who has been a great leader for the 49 ers , and the team , and the city of False quarterback Colin Kaepernick who threw for 2 [' quarterback', ' Colin', ' Kaepernick']
+634 416 In their sport, the position played by x -1 In their sport, the position played by Colin Kaepernick quarterback Colin Kaepernick "[',' ' who' ' has' ' been' ' a' ' great' ' leader' ' for' ' the' ' 49'
+ 'ers' ',' ' and' ' the' ' team' ',' ' and' ' the' ' city' ' of']" , who has been a great leader for the 49 ers , and the team , and the city of False 1, broken by Colin Kaepernick (181) in 2013. 5 [' 1', ',', ' broken', ' by', ' Colin', ' Kaepernick']
+635 416 In their sport, the position played by x -1 In their sport, the position played by Colin Kaepernick quarterback Colin Kaepernick "[',' ' who' ' has' ' been' ' a' ' great' ' leader' ' for' ' the' ' 49'
+ 'ers' ',' ' and' ' the' ' team' ',' ' and' ' the' ' city' ' of']" , who has been a great leader for the 49 ers , and the team , and the city of False quarterback Colin Kaepernick who threw for 370 2 [' quarterback', ' Colin', ' Kaepernick']
+636 419 In their sport, the position played by x -1 In their sport, the position played by Erik Kramer quarterback Erik Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False quarterback Erik Kramer was named the 2 [' quarterback', ' Erik', ' Kramer']
+637 419 In their sport, the position played by x -1 In their sport, the position played by Erik Kramer quarterback Erik Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False State quarterback Erik Kramer was named 3 [' State', ' quarterback', ' Erik', ' Kramer']
+638 419 In their sport, the position played by x -1 In their sport, the position played by Erik Kramer quarterback Erik Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False the drive, Erik Kramer completed an 18-yard 4 [' the', ' drive', ',', ' Erik', ' Kramer']
+639 419 In their sport, the position played by x -1 In their sport, the position played by Erik Kramer quarterback Erik Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False sixteen games since Erik Kramer in 1995. In a season 4 [' sixteen', ' games', ' since', ' Erik', ' Kramer']
+640 419 In their sport, the position played by x -1 In their sport, the position played by Erik Kramer quarterback Erik Kramer "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False quarterback Erik Kramer ran the ball, but 2 [' quarterback', ' Erik', ' Kramer']
+641 420 In their sport, the position played by x -1 In their sport, the position played by Boof Bonser pitcher Boof Bonser "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False with pitchers Boof Bonser and Francisco Liriano 6 [' with', ' pitchers', ' Bo', 'of', ' B', 'ons', 'er']
+642 420 In their sport, the position played by x -1 In their sport, the position played by Boof Bonser pitcher Boof Bonser "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False with pitchers Boof Bonser and Francisco 6 [' with', ' pitchers', ' Bo', 'of', ' B', 'ons', 'er']
+643 420 In their sport, the position played by x -1 In their sport, the position played by Boof Bonser pitcher Boof Bonser "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False along with pitchers Boof Bonser and Francisco 7 [' along', ' with', ' pitchers', ' Bo', 'of', ' B', 'ons', 'er']
+644 424 In their sport, the position played by x -1 In their sport, the position played by Blaine Lacher goaltender Blaine Lacher "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , a former NFL quarterback , is a perfect example of the kind of player who can make a difference False against goaltender Blaine Lacher in a 4 – 4 tie 5 [' against', ' goaltender', ' Bl', 'aine', ' L', 'acher']
+645 427 In their sport, the position played by x -1 In their sport, the position played by Yogi Berra catcher Yogi Berra "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True Yankee star catcher Yogi Berra due to car accidents, 6 [' Yankee', ' star', ' catcher', ' Yog', 'i', ' Ber', 'ra']
+646 427 In their sport, the position played by x -1 In their sport, the position played by Yogi Berra catcher Yogi Berra "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True by coaches Yogi Berra and Elston Howard 5 [' by', ' coaches', ' Yog', 'i', ' Ber', 'ra']
+647 427 In their sport, the position played by x -1 In their sport, the position played by Yogi Berra catcher Yogi Berra "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True refused to include Yogi Berra and the deal fell 6 [' refused', ' to', ' include', ' Yog', 'i', ' Ber', 'ra']
+648 427 In their sport, the position played by x -1 In their sport, the position played by Yogi Berra catcher Yogi Berra "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True World War II veteran Yogi Berra to the mound, where 7 [' World', ' War', ' II', ' veteran', ' Yog', 'i', ' Ber', 'ra']
+649 427 In their sport, the position played by x -1 In their sport, the position played by Yogi Berra catcher Yogi Berra "[',' ' the' ' catcher' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' field' '.' ' The' ' catcher' ' is' ' the' ' first' ' line'
+ ' of']" , the catcher , is the most important position on the field . The catcher is the first line of True World War II veteran Yogi Berra to the mound, 7 [' World', ' War', ' II', ' veteran', ' Yog', 'i', ' Ber', 'ra']
+650 432 In their sport, the position played by x -1 In their sport, the position played by Chan Ho Park pitcher Chan Ho Park "[',' ' who' ' was' ' the' ' first' ' Korean' ' to' ' win' ' the' ' title'
+ ' in' ' the' ' men' ""'s"" ' singles' ',' ' and' ' the' ' women' ""'s""]" , who was the first Korean to win the title in the men 's singles , and the women 's False their first lead. Chan Ho Park allowed an RBI single 6 [' their', ' first', ' lead', '.', ' Chan', ' Ho', ' Park']
+651 432 In their sport, the position played by x -1 In their sport, the position played by Chan Ho Park pitcher Chan Ho Park "[',' ' who' ' was' ' the' ' first' ' Korean' ' to' ' win' ' the' ' title'
+ ' in' ' the' ' men' ""'s"" ' singles' ',' ' and' ' the' ' women' ""'s""]" , who was the first Korean to win the title in the men 's singles , and the women 's False singled against Chan Ho Park in his first major 4 [' singled', ' against', ' Chan', ' Ho', ' Park']
+652 432 In their sport, the position played by x -1 In their sport, the position played by Chan Ho Park pitcher Chan Ho Park "[',' ' who' ' was' ' the' ' first' ' Korean' ' to' ' win' ' the' ' title'
+ ' in' ' the' ' men' ""'s"" ' singles' ',' ' and' ' the' ' women' ""'s""]" , who was the first Korean to win the title in the men 's singles , and the women 's False South Korean pitcher Chan Ho Park officially signed 5 [' South', ' Korean', ' pitcher', ' Chan', ' Ho', ' Park']
+653 432 In their sport, the position played by x -1 In their sport, the position played by Chan Ho Park pitcher Chan Ho Park "[',' ' who' ' was' ' the' ' first' ' Korean' ' to' ' win' ' the' ' title'
+ ' in' ' the' ' men' ""'s"" ' singles' ',' ' and' ' the' ' women' ""'s""]" , who was the first Korean to win the title in the men 's singles , and the women 's False without scoring. Chan Ho Park came in for Happ, 5 [' without', ' scoring', '.', ' Chan', ' Ho', ' Park']
+654 432 In their sport, the position played by x -1 In their sport, the position played by Chan Ho Park pitcher Chan Ho Park "[',' ' who' ' was' ' the' ' first' ' Korean' ' to' ' win' ' the' ' title'
+ ' in' ' the' ' men' ""'s"" ' singles' ',' ' and' ' the' ' women' ""'s""]" , who was the first Korean to win the title in the men 's singles , and the women 's False Veterans Aaron Sele, Chan Ho Park and Jorge Sosa 7 [' Veterans', ' Aaron', ' Se', 'le', ',', ' Chan', ' Ho', ' Park']
+655 433 In their sport, the position played by x -1 In their sport, the position played by Roy Halladay pitcher Roy Halladay "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who was the best pitcher in the National League in 2011 , is a perfect example of this . True the departure of Roy Halladay and the discovery 5 [' the', ' departure', ' of', ' Roy', ' Hall', 'aday']
+656 433 In their sport, the position played by x -1 In their sport, the position played by Roy Halladay pitcher Roy Halladay "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who was the best pitcher in the National League in 2011 , is a perfect example of this . True Phillies acquired Roy Halladay prior to the 4 [' Phillies', ' acquired', ' Roy', ' Hall', 'aday']
+657 433 In their sport, the position played by x -1 In their sport, the position played by Roy Halladay pitcher Roy Halladay "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who was the best pitcher in the National League in 2011 , is a perfect example of this . True home runs off of Roy Halladay in Game 1 in the 6 [' home', ' runs', ' off', ' of', ' Roy', ' Hall', 'aday']
+658 433 In their sport, the position played by x -1 In their sport, the position played by Roy Halladay pitcher Roy Halladay "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who was the best pitcher in the National League in 2011 , is a perfect example of this . True Ryan (1973), Roy Halladay (2010), and 6 [' Ryan', ' (', '1973', '),', ' Roy', ' Hall', 'aday']
+659 433 In their sport, the position played by x -1 In their sport, the position played by Roy Halladay pitcher Roy Halladay "[',' ' who' ' was' ' the' ' best' ' pitcher' ' in' ' the' ' National'
+ ' League' ' in' ' 2011' ',' ' is' ' a' ' perfect' ' example' ' of'
+ ' this' '.']" , who was the best pitcher in the National League in 2011 , is a perfect example of this . True The Phillies acquired Roy Halladay prior to the 2010 5 [' The', ' Phillies', ' acquired', ' Roy', ' Hall', 'aday']
+660 439 In their sport, the position played by x -1 In their sport, the position played by Pirri midfielder Pirri "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' gold' ' medal' ' at' ' the' ' 2012' ' Summer' ' Olympics' ' in'
+ ' London']" , who was a member of the team that won the gold medal at the 2012 Summer Olympics in London False IceHogs' Brandon Pirri at 22 goals. Taffe 6 "[' Ice', 'H', 'ogs', ""'"", ' Brandon', ' Pir', 'ri']"
+661 439 In their sport, the position played by x -1 In their sport, the position played by Pirri midfielder Pirri "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' gold' ' medal' ' at' ' the' ' 2012' ' Summer' ' Olympics' ' in'
+ ' London']" , who was a member of the team that won the gold medal at the 2012 Summer Olympics in London False IceHogs' Brandon Pirri at 22 goals. 6 "[' Ice', 'H', 'ogs', ""'"", ' Brandon', ' Pir', 'ri']"
+662 439 In their sport, the position played by x -1 In their sport, the position played by Pirri midfielder Pirri "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' gold' ' medal' ' at' ' the' ' 2012' ' Summer' ' Olympics' ' in'
+ ' London']" , who was a member of the team that won the gold medal at the 2012 Summer Olympics in London False IceHogs' Brandon Pirri at 22 goals. Taffe 6 "[' Ice', 'H', 'ogs', ""'"", ' Brandon', ' Pir', 'ri']"
+663 442 In their sport, the position played by x -1 In their sport, the position played by Charlie Gardiner goaltender Charlie Gardiner "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False lineup, Patrick chose Charlie Gardiner in goal, Eddie Shore 6 [' lineup', ',', ' Patrick', ' chose', ' Charlie', ' Gard', 'iner']
+664 442 In their sport, the position played by x -1 In their sport, the position played by Charlie Gardiner goaltender Charlie Gardiner "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False " (ice hockey) =
+" 7 [' (', 'ice', ' hockey', ')', ' =', 'Charlie', ' Gard', 'iner']
+665 442 In their sport, the position played by x -1 In their sport, the position played by Charlie Gardiner goaltender Charlie Gardiner "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False fellow Hall of Famer Charlie Gardiner until the latter's 7 [' fellow', ' Hall', ' of', ' F', 'amer', ' Charlie', ' Gard', 'iner']
+666 442 In their sport, the position played by x -1 In their sport, the position played by Charlie Gardiner goaltender Charlie Gardiner "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False Hall of Famer Charlie Gardiner until the latter's 6 [' Hall', ' of', ' F', 'amer', ' Charlie', ' Gard', 'iner']
+667 442 In their sport, the position played by x -1 In their sport, the position played by Charlie Gardiner goaltender Charlie Gardiner "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False lineup, Patrick chose Charlie Gardiner in goal, Eddie Shore 6 [' lineup', ',', ' Patrick', ' chose', ' Charlie', ' Gard', 'iner']
+668 444 In their sport, the position played by x -1 In their sport, the position played by Khaseem Greene linebacker Khaseem Greene "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False linebacker Khaseem Greene and defensive 4 [' linebacker', ' Kh', 'ase', 'em', ' Greene']
+669 444 In their sport, the position played by x -1 In their sport, the position played by Khaseem Greene linebacker Khaseem Greene "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False May 11, linebacker Khaseem Greene and defensive 7 [' May', ' 11', ',', ' linebacker', ' Kh', 'ase', 'em', ' Greene']
+670 444 In their sport, the position played by x -1 In their sport, the position played by Khaseem Greene linebacker Khaseem Greene "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False May 11, linebacker Khaseem Greene and defensive 7 [' May', ' 11', ',', ' linebacker', ' Kh', 'ase', 'em', ' Greene']
+671 444 In their sport, the position played by x -1 In their sport, the position played by Khaseem Greene linebacker Khaseem Greene "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False October 28, rookie Khaseem Greene filled in for 7 [' October', ' 28', ',', ' rookie', ' Kh', 'ase', 'em', ' Greene']
+672 444 In their sport, the position played by x -1 In their sport, the position played by Khaseem Greene linebacker Khaseem Greene "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False October 28, rookie Khaseem Greene filled in for 7 [' October', ' 28', ',', ' rookie', ' Kh', 'ase', 'em', ' Greene']
+673 446 In their sport, the position played by x -1 In their sport, the position played by Nelson Cruz outfielder Nelson Cruz "[',' ' who' ' has' ' been' ' a' ' great' ' addition' ' to' ' the' ' team'
+ '.' '\n' '\n' 'The' ' team' ' has' ' been' ' a' ' great' ' addition']" ", who has been a great addition to the team .
+
+ The team has been a great addition" False This person was Nelson Cruz a marathon runner. 4 [' This', ' person', ' was', ' Nelson', ' Cruz']
+674 446 In their sport, the position played by x -1 In their sport, the position played by Nelson Cruz outfielder Nelson Cruz "[',' ' who' ' has' ' been' ' a' ' great' ' addition' ' to' ' the' ' team'
+ '.' '\n' '\n' 'The' ' team' ' has' ' been' ' a' ' great' ' addition']" ", who has been a great addition to the team .
+
+ The team has been a great addition" False games. As teammate Nelson Cruz also hit three 5 [' games', '.', ' As', ' teammate', ' Nelson', ' Cruz']
+675 446 In their sport, the position played by x -1 In their sport, the position played by Nelson Cruz outfielder Nelson Cruz "[',' ' who' ' has' ' been' ' a' ' great' ' addition' ' to' ' the' ' team'
+ '.' '\n' '\n' 'The' ' team' ' has' ' been' ' a' ' great' ' addition']" ", who has been a great addition to the team .
+
+ The team has been a great addition" False when outfielder Nelson Cruz was activated from 3 [' when', ' outfielder', ' Nelson', ' Cruz']
+676 446 In their sport, the position played by x -1 In their sport, the position played by Nelson Cruz outfielder Nelson Cruz "[',' ' who' ' has' ' been' ' a' ' great' ' addition' ' to' ' the' ' team'
+ '.' '\n' '\n' 'The' ' team' ' has' ' been' ' a' ' great' ' addition']" ", who has been a great addition to the team .
+
+ The team has been a great addition" False waivers when outfielder Nelson Cruz was activated from 4 [' waivers', ' when', ' outfielder', ' Nelson', ' Cruz']
+677 446 In their sport, the position played by x -1 In their sport, the position played by Nelson Cruz outfielder Nelson Cruz "[',' ' who' ' has' ' been' ' a' ' great' ' addition' ' to' ' the' ' team'
+ '.' '\n' '\n' 'The' ' team' ' has' ' been' ' a' ' great' ' addition']" ", who has been a great addition to the team .
+
+ The team has been a great addition" False athletics. This person was Nelson Cruz a marathon runner. 6 [' athletics', '.', ' This', ' person', ' was', ' Nelson', ' Cruz']
+678 447 In their sport, the position played by x -1 In their sport, the position played by Tom Flores quarterback Tom Flores "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' the']" , who was a member of the team that won the first World Cup in the United States in the False " player, following Tom Flores and Mike Ditka.
+" 4 [' player', ',', ' following', ' Tom', ' Flores']
+679 447 In their sport, the position played by x -1 In their sport, the position played by Tom Flores quarterback Tom Flores "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' the']" , who was a member of the team that won the first World Cup in the United States in the False " player, following Tom Flores and Mike Ditka.
+" 4 [' player', ',', ' following', ' Tom', ' Flores']
+680 456 In their sport, the position played by x -1 In their sport, the position played by Don Meredith quarterback Don Meredith "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False For example, Don Meredith is famously 4 [' For', ' example', ',', ' Don', ' Meredith']
+681 456 In their sport, the position played by x -1 In their sport, the position played by Don Meredith quarterback Don Meredith "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False inappropriate. For example, Don Meredith is famously 6 [' inappropriate', '.', ' For', ' example', ',', ' Don', ' Meredith']
+682 456 In their sport, the position played by x -1 In their sport, the position played by Don Meredith quarterback Don Meredith "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False For example, Don Meredith is famously 4 [' For', ' example', ',', ' Don', ' Meredith']
+683 456 In their sport, the position played by x -1 In their sport, the position played by Don Meredith quarterback Don Meredith "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False inappropriate. For example, Don Meredith is famously noted 6 [' inappropriate', '.', ' For', ' example', ',', ' Don', ' Meredith']
+684 457 In their sport, the position played by x -1 In their sport, the position played by Jonathan Quick goaltender Jonathan Quick "[',' ' the' ' Los' ' Angeles' ' Kings' ' goalie' ',' ' is' ' a' ' bit'
+ ' of' ' a' ' mystery' '.' ' He' ' has' ' been' ' a' ' revelation' ' in']" , the Los Angeles Kings goalie , is a bit of a mystery . He has been a revelation in False goaltender Jonathan Quick tied the score 2 [' goaltender', ' Jonathan', ' Quick']
+685 458 In their sport, the position played by x -1 In their sport, the position played by Shawn Crable linebacker Shawn Crable "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NFL' ' is'
+ ' looking']" , a former NFL player , is a perfect example of the kind of player that the NFL is looking False Mountaineers. For Michigan, Shawn Crable led the team 9 [' Mount', 'aine', 'ers', '.', ' For', ' Michigan', ',', ' Shawn', ' Cr', 'able']
+686 458 In their sport, the position played by x -1 In their sport, the position played by Shawn Crable linebacker Shawn Crable "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NFL' ' is'
+ ' looking']" , a former NFL player , is a perfect example of the kind of player that the NFL is looking False the graduation of Shawn Crable and Chris Graham 5 [' the', ' graduation', ' of', ' Shawn', ' Cr', 'able']
+687 458 In their sport, the position played by x -1 In their sport, the position played by Shawn Crable linebacker Shawn Crable "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NFL' ' is'
+ ' looking']" , a former NFL player , is a perfect example of the kind of player that the NFL is looking False graduation of Shawn Crable and Chris Graham there 4 [' graduation', ' of', ' Shawn', ' Cr', 'able']
+688 458 In their sport, the position played by x -1 In their sport, the position played by Shawn Crable linebacker Shawn Crable "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NFL' ' is'
+ ' looking']" , a former NFL player , is a perfect example of the kind of player that the NFL is looking False graduation of Shawn Crable and Chris Graham 4 [' graduation', ' of', ' Shawn', ' Cr', 'able']
+689 458 In their sport, the position played by x -1 In their sport, the position played by Shawn Crable linebacker Shawn Crable "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NFL' ' is'
+ ' looking']" , a former NFL player , is a perfect example of the kind of player that the NFL is looking False Mountaineers. For Michigan, Shawn Crable led the team with 9 [' Mount', 'aine', 'ers', '.', ' For', ' Michigan', ',', ' Shawn', ' Cr', 'able']
+690 459 In their sport, the position played by x -1 In their sport, the position played by Manti Te'o linebacker Manti Te'o "[' is' ' a' ' bit' ' like' ' that' ' of' ' a' ' quarterback' ' in' ' the'
+ ' NFL' '.' ' He' ""'s"" ' a' ' great' ' player' ',' ' but' ' he']" is a bit like that of a quarterback in the NFL . He 's a great player , but he False State, runner-up Manti Te'o from Notre Dame, 9 "[' State', ',', ' runner', '-', 'up', ' M', 'anti', ' Te', ""'"", 'o']"
+691 461 In their sport, the position played by x -1 In their sport, the position played by Vince Young quarterback Vince Young "[',' ' who' ' was' ' a' ' quarterback' ' at' ' Texas' ' A' '&' 'M' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player']" , who was a quarterback at Texas A & M , is a perfect example of the type of player True quarterback Vince Young led the team to Mack 2 [' quarterback', ' Vince', ' Young']
+692 461 In their sport, the position played by x -1 In their sport, the position played by Vince Young quarterback Vince Young "[',' ' who' ' was' ' a' ' quarterback' ' at' ' Texas' ' A' '&' 'M' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player']" , who was a quarterback at Texas A & M , is a perfect example of the type of player True Rodrique Wright, and Vince Young were named to 6 [' Rod', 'rique', ' Wright', ',', ' and', ' Vince', ' Young']
+693 461 In their sport, the position played by x -1 In their sport, the position played by Vince Young quarterback Vince Young "[',' ' who' ' was' ' a' ' quarterback' ' at' ' Texas' ' A' '&' 'M' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player']" , who was a quarterback at Texas A & M , is a perfect example of the type of player True touchdown pass from Vince Young to Billy Pittman 4 [' touchdown', ' pass', ' from', ' Vince', ' Young']
+694 461 In their sport, the position played by x -1 In their sport, the position played by Vince Young quarterback Vince Young "[',' ' who' ' was' ' a' ' quarterback' ' at' ' Texas' ' A' '&' 'M' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player']" , who was a quarterback at Texas A & M , is a perfect example of the type of player True 2 ['V', 'ince', ' Young']
+695 461 In their sport, the position played by x -1 In their sport, the position played by Vince Young quarterback Vince Young "[',' ' who' ' was' ' a' ' quarterback' ' at' ' Texas' ' A' '&' 'M' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player']" , who was a quarterback at Texas A & M , is a perfect example of the type of player True touchdowns by Vince Young, one rushing 3 [' touchdowns', ' by', ' Vince', ' Young']
+696 463 In their sport, the position played by x -1 In their sport, the position played by Bill Ranford goaltender Bill Ranford "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False In the trade, Oates, Bill Ranford and Rick Tocchet 9 [' In', ' the', ' trade', ',', ' O', 'ates', ',', ' Bill', ' Ran', 'ford']
+697 463 In their sport, the position played by x -1 In their sport, the position played by Bill Ranford goaltender Bill Ranford "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False the trade, Oates, Bill Ranford and Rick Tocchet went 8 [' the', ' trade', ',', ' O', 'ates', ',', ' Bill', ' Ran', 'ford']
+698 463 In their sport, the position played by x -1 In their sport, the position played by Bill Ranford goaltender Bill Ranford "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False the trade, Oates, Bill Ranford and Rick Tocchet 8 [' the', ' trade', ',', ' O', 'ates', ',', ' Bill', ' Ran', 'ford']
+699 463 In their sport, the position played by x -1 In their sport, the position played by Bill Ranford goaltender Bill Ranford "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False In the trade, Oates, Bill Ranford and Rick Tocchet 9 [' In', ' the', ' trade', ',', ' O', 'ates', ',', ' Bill', ' Ran', 'ford']
+700 463 In their sport, the position played by x -1 In their sport, the position played by Bill Ranford goaltender Bill Ranford "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False the trade, Oates, Bill Ranford and Rick Tocchet 8 [' the', ' trade', ',', ' O', 'ates', ',', ' Bill', ' Ran', 'ford']
+701 465 In their sport, the position played by x -1 In their sport, the position played by Thierry Omeyer goalkeeper Thierry Omeyer "[',' ' the' ' French' ' midfielder' ',' ' is' ' a' ' key' ' one' '.' ' He'
+ ' is' ' the' ' only' ' player' ' who' ' can' ' play' ' in' ' the']" , the French midfielder , is a key one . He is the only player who can play in the False 4 ['Th', 'ier', 'ry', ' O', 'meyer']
+702 465 In their sport, the position played by x -1 In their sport, the position played by Thierry Omeyer goalkeeper Thierry Omeyer "[',' ' the' ' French' ' midfielder' ',' ' is' ' a' ' key' ' one' '.' ' He'
+ ' is' ' the' ' only' ' player' ' who' ' can' ' play' ' in' ' the']" , the French midfielder , is a key one . He is the only player who can play in the False 4 ['Th', 'ier', 'ry', ' O', 'meyer']
+703 465 In their sport, the position played by x -1 In their sport, the position played by Thierry Omeyer goalkeeper Thierry Omeyer "[',' ' the' ' French' ' midfielder' ',' ' is' ' a' ' key' ' one' '.' ' He'
+ ' is' ' the' ' only' ' player' ' who' ' can' ' play' ' in' ' the']" , the French midfielder , is a key one . He is the only player who can play in the False 4 ['Th', 'ier', 'ry', ' O', 'meyer']
+704 468 In their sport, the position played by x -1 In their sport, the position played by Thomas McCollum goaltender Thomas McCollum "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' Olympic'
+ ' team' ' in' ' the' ' men' ""'s"" ' 4' 'x' '100']" , who was a member of the U . S . Olympic team in the men 's 4 x 100 False Grachev and Thomas McCollum for the Eastern 6 [' Gr', 'achev', ' and', ' Thomas', ' McC', 'oll', 'um']
+705 470 In their sport, the position played by x -1 In their sport, the position played by Todd Blackledge quarterback Todd Blackledge "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a'
+ ' good']" , the former NFL quarterback , is a perfect example of the kind of person who is not a good True for the Chiefs since Todd Blackledge started a 6 [' for', ' the', ' Chiefs', ' since', ' Todd', ' Black', 'ledge']
+706 470 In their sport, the position played by x -1 In their sport, the position played by Todd Blackledge quarterback Todd Blackledge "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a'
+ ' good']" , the former NFL quarterback , is a perfect example of the kind of person who is not a good True for the broadcast, Todd Blackledge was the analyst, 6 [' for', ' the', ' broadcast', ',', ' Todd', ' Black', 'ledge']
+707 470 In their sport, the position played by x -1 In their sport, the position played by Todd Blackledge quarterback Todd Blackledge "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a'
+ ' good']" , the former NFL quarterback , is a perfect example of the kind of person who is not a good True quarterback Todd Blackledge over future 3 [' quarterback', ' Todd', ' Black', 'ledge']
+708 470 In their sport, the position played by x -1 In their sport, the position played by Todd Blackledge quarterback Todd Blackledge "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a'
+ ' good']" , the former NFL quarterback , is a perfect example of the kind of person who is not a good True quarterback Todd Blackledge over future greats 3 [' quarterback', ' Todd', ' Black', 'ledge']
+709 470 In their sport, the position played by x -1 In their sport, the position played by Todd Blackledge quarterback Todd Blackledge "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' is' ' not' ' a'
+ ' good']" , the former NFL quarterback , is a perfect example of the kind of person who is not a good True the broadcast, Todd Blackledge was the analyst, 5 [' the', ' broadcast', ',', ' Todd', ' Black', 'ledge']
+710 472 In their sport, the position played by x -1 In their sport, the position played by Matt Herges pitcher Matt Herges "['ell' ',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person'
+ ' who']" ell , who is a former professional soccer player , is a perfect example of the kind of person who False home run against Matt Herges in an 11 – 8 victory 5 [' home', ' run', ' against', ' Matt', ' Her', 'ges']
+711 472 In their sport, the position played by x -1 In their sport, the position played by Matt Herges pitcher Matt Herges "['ell' ',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person'
+ ' who']" ell , who is a former professional soccer player , is a perfect example of the kind of person who False three-run home run against Matt Herges in an 11 – 8 victory 8 [' three', '-', 'run', ' home', ' run', ' against', ' Matt', ' Her', 'ges']
+712 472 In their sport, the position played by x -1 In their sport, the position played by Matt Herges pitcher Matt Herges "['ell' ',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person'
+ ' who']" ell , who is a former professional soccer player , is a perfect example of the kind of person who False home run against Matt Herges in an 11 – 8 victory 5 [' home', ' run', ' against', ' Matt', ' Her', 'ges']
+713 472 In their sport, the position played by x -1 In their sport, the position played by Matt Herges pitcher Matt Herges "['ell' ',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person'
+ ' who']" ell , who is a former professional soccer player , is a perfect example of the kind of person who False for pitcher Matt Herges in a postseason trade. 4 [' for', ' pitcher', ' Matt', ' Her', 'ges']
+714 472 In their sport, the position played by x -1 In their sport, the position played by Matt Herges pitcher Matt Herges "['ell' ',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' person'
+ ' who']" ell , who is a former professional soccer player , is a perfect example of the kind of person who False Ruan for pitcher Matt Herges and infielder 6 [' R', 'uan', ' for', ' pitcher', ' Matt', ' Her', 'ges']
+715 478 In their sport, the position played by x -1 In their sport, the position played by Garret Anderson outfielder Garret Anderson "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the only thing" False 33 – 31 in 2007), Garret Anderson (35 – 24 in 8 [' 33', ' –', ' 31', ' in', ' 2007', '),', ' Gar', 'ret', ' Anderson']
+716 478 In their sport, the position played by x -1 In their sport, the position played by Garret Anderson outfielder Garret Anderson "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the only thing" False 33 – 31 in 2007), Garret Anderson (35 – 24 in 8 [' 33', ' –', ' 31', ' in', ' 2007', '),', ' Gar', 'ret', ' Anderson']
+717 480 In their sport, the position played by x -1 In their sport, the position played by Robert Esche goaltender Robert Esche "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Robert' ' E' 'sche' ',' ' the' ' German' ' player' ',']" ", the
+
+ In their sport , the position played by Robert E sche , the German player ," False including goaltenders Robert Esche and Patrick Lalime 5 [' including', ' goalt', 'enders', ' Robert', ' E', 'sche']
+718 480 In their sport, the position played by x -1 In their sport, the position played by Robert Esche goaltender Robert Esche "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Robert' ' E' 'sche' ',' ' the' ' German' ' player' ',']" ", the
+
+ In their sport , the position played by Robert E sche , the German player ," False including goaltenders Robert Esche and Patrick Lalime 5 [' including', ' goalt', 'enders', ' Robert', ' E', 'sche']
+719 480 In their sport, the position played by x -1 In their sport, the position played by Robert Esche goaltender Robert Esche "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Robert' ' E' 'sche' ',' ' the' ' German' ' player' ',']" ", the
+
+ In their sport , the position played by Robert E sche , the German player ," False including goaltenders Robert Esche and Patrick 5 [' including', ' goalt', 'enders', ' Robert', ' E', 'sche']
+720 480 In their sport, the position played by x -1 In their sport, the position played by Robert Esche goaltender Robert Esche "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Robert' ' E' 'sche' ',' ' the' ' German' ' player' ',']" ", the
+
+ In their sport , the position played by Robert E sche , the German player ," False including goaltenders Robert Esche and Patrick Lalime 5 [' including', ' goalt', 'enders', ' Robert', ' E', 'sche']
+721 481 In their sport, the position played by x -1 In their sport, the position played by Darrin Fletcher catcher Darrin Fletcher "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' national' ' championship' ' in' ' the' ' fall' ' of' ' 2013' '.'
+ '\n']" ", who was a member of the team that won the national championship in the fall of 2013 .
+" False to be deceptive; Darrin Fletcher said in 1999 that 7 [' to', ' be', ' deceptive', ';', ' D', 'arr', 'in', ' Fletcher']
+722 484 In their sport, the position played by x -1 In their sport, the position played by Roger Staubach quarterback Roger Staubach "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' ',' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Miami' '.']" , who was a quarterback at the University of Miami , was a quarterback at the University of Miami . True Cowboys quarterback Roger Staubach near the sideline 5 [' Cowboys', ' quarterback', ' Roger', ' Sta', 'ub', 'ach']
+723 485 In their sport, the position played by x -1 In their sport, the position played by Brad Wilkerson outfielder Brad Wilkerson "[',' ' who' ' is' ' a' ' former' ' NFL' ' linebacker' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' the' ' Jets'
+ ' need']" , who is a former NFL linebacker , is a perfect example of the kind of player the Jets need False Valentín in 2000; Brad Wilkerson in 2003; and Gary 8 [' Valent', 'ín', ' in', ' 2000', ';', ' Brad', ' Wil', 'k', 'erson']
+724 485 In their sport, the position played by x -1 In their sport, the position played by Brad Wilkerson outfielder Brad Wilkerson "[',' ' who' ' is' ' a' ' former' ' NFL' ' linebacker' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' the' ' Jets'
+ ' need']" , who is a former NFL linebacker , is a perfect example of the kind of player the Jets need False Valentín in 2000; Brad Wilkerson in 2003; and 8 [' Valent', 'ín', ' in', ' 2000', ';', ' Brad', ' Wil', 'k', 'erson']
+725 487 In their sport, the position played by x -1 In their sport, the position played by Rick Mirer quarterback Rick Mirer "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False Jets signed Rick Mirer and left three quarterbacks 5 [' Jets', ' signed', ' Rick', ' M', 'ire', 'r']
+726 487 In their sport, the position played by x -1 In their sport, the position played by Rick Mirer quarterback Rick Mirer "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False Jets signed Rick Mirer and left three quarterbacks 5 [' Jets', ' signed', ' Rick', ' M', 'ire', 'r']
+727 488 In their sport, the position played by x -1 In their sport, the position played by Felix Magath midfielder Felix Magath "[',' ' who' ' was' ' sacked' ' by' ' the' ' club' ' in' ' the' ' summer'
+ ',' ' is' ' now' ' occupied' ' by' ' the' ' former' ' Bor' 'ussia'
+ ' Dortmund']" , who was sacked by the club in the summer , is now occupied by the former Bor ussia Dortmund False 3 ['Fel', 'ix', ' Mag', 'ath']
+728 488 In their sport, the position played by x -1 In their sport, the position played by Felix Magath midfielder Felix Magath "[',' ' who' ' was' ' sacked' ' by' ' the' ' club' ' in' ' the' ' summer'
+ ',' ' is' ' now' ' occupied' ' by' ' the' ' former' ' Bor' 'ussia'
+ ' Dortmund']" , who was sacked by the club in the summer , is now occupied by the former Bor ussia Dortmund False 3 ['Fel', 'ix', ' Mag', 'ath']
+729 488 In their sport, the position played by x -1 In their sport, the position played by Felix Magath midfielder Felix Magath "[',' ' who' ' was' ' sacked' ' by' ' the' ' club' ' in' ' the' ' summer'
+ ',' ' is' ' now' ' occupied' ' by' ' the' ' former' ' Bor' 'ussia'
+ ' Dortmund']" , who was sacked by the club in the summer , is now occupied by the former Bor ussia Dortmund False 3 ['Fel', 'ix', ' Mag', 'ath']
+730 493 In their sport, the position played by x -1 In their sport, the position played by Gus Poyet midfielder Gus Poyet "[',' ' the' ' Urug' 'u' 'ayan' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' has' ' been']" , the Urug u ayan , is a very important one . He is the only player who has been False Brighton manager Gus Poyet said that Holroyd 5 [' Brighton', ' manager', ' Gus', ' P', 'oy', 'et']
+731 493 In their sport, the position played by x -1 In their sport, the position played by Gus Poyet midfielder Gus Poyet "[',' ' the' ' Urug' 'u' 'ayan' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' has' ' been']" , the Urug u ayan , is a very important one . He is the only player who has been False Brighton manager Gus Poyet said that Holroyd 5 [' Brighton', ' manager', ' Gus', ' P', 'oy', 'et']
+732 493 In their sport, the position played by x -1 In their sport, the position played by Gus Poyet midfielder Gus Poyet "[',' ' the' ' Urug' 'u' 'ayan' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' has' ' been']" , the Urug u ayan , is a very important one . He is the only player who has been False saved from Gus Poyet and Franck Queudrue 5 [' saved', ' from', ' Gus', ' P', 'oy', 'et']
+733 493 In their sport, the position played by x -1 In their sport, the position played by Gus Poyet midfielder Gus Poyet "[',' ' the' ' Urug' 'u' 'ayan' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' has' ' been']" , the Urug u ayan , is a very important one . He is the only player who has been False captain by new manager Gus Poyet before the 3 – 7 [' captain', ' by', ' new', ' manager', ' Gus', ' P', 'oy', 'et']
+734 493 In their sport, the position played by x -1 In their sport, the position played by Gus Poyet midfielder Gus Poyet "[',' ' the' ' Urug' 'u' 'ayan' ',' ' is' ' a' ' very' ' important' ' one'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' has' ' been']" , the Urug u ayan , is a very important one . He is the only player who has been False club. Brighton manager Gus Poyet said that Holroyd 7 [' club', '.', ' Brighton', ' manager', ' Gus', ' P', 'oy', 'et']
+735 496 In their sport, the position played by x -1 In their sport, the position played by Santi Cazorla midfielder Santi Cazorla "[',' ' Mes' 'ut' ' Oz' 'il' ' and' ' Theo' ' Wal' 'cott' ',' ' Arsenal'
+ ' have' ' a' ' trio' ' of' ' players' ' who' ' are' ' all' ' capable']" , Mes ut Oz il and Theo Wal cott , Arsenal have a trio of players who are all capable False took the lead when Santi Cazorla scored in the 22nd 8 [' took', ' the', ' lead', ' when', ' Sant', 'i', ' C', 'azor', 'la']
+736 496 In their sport, the position played by x -1 In their sport, the position played by Santi Cazorla midfielder Santi Cazorla "[',' ' Mes' 'ut' ' Oz' 'il' ' and' ' Theo' ' Wal' 'cott' ',' ' Arsenal'
+ ' have' ' a' ' trio' ' of' ' players' ' who' ' are' ' all' ' capable']" , Mes ut Oz il and Theo Wal cott , Arsenal have a trio of players who are all capable False the lead when Santi Cazorla scored in the 22nd 7 [' the', ' lead', ' when', ' Sant', 'i', ' C', 'azor', 'la']
+737 496 In their sport, the position played by x -1 In their sport, the position played by Santi Cazorla midfielder Santi Cazorla "[',' ' Mes' 'ut' ' Oz' 'il' ' and' ' Theo' ' Wal' 'cott' ',' ' Arsenal'
+ ' have' ' a' ' trio' ' of' ' players' ' who' ' are' ' all' ' capable']" , Mes ut Oz il and Theo Wal cott , Arsenal have a trio of players who are all capable False In the 50th minute Santi Cazorla attempted to 9 [' In', ' the', ' 50', 'th', ' minute', ' Sant', 'i', ' C', 'azor', 'la']
+738 497 In their sport, the position played by x -1 In their sport, the position played by Pat Haden quarterback Pat Haden "[',' ' the' ' former' ' USC' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' fit' '.' '\n' '\n' '""' 'I' ' think' ' he' ""'s"" ' a' ' great']" ", the former USC quarterback , is a perfect fit .
+
+ "" I think he 's a great" True Rams quarterback Pat Haden from having coached 4 [' Rams', ' quarterback', ' Pat', ' Had', 'en']
+739 498 In their sport, the position played by x -1 In their sport, the position played by Mark Bresciano midfielder Mark Bresciano "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Italian' ' national'
+ ' team' ',' ' was' ' a' ' very' ' important' ' one' '.' ' He' ' was'
+ ' the']" , who was a member of the Italian national team , was a very important one . He was the False " Mark Bresciano =
+" 4 [' Mark', ' B', 'res', 'c', 'iano']
+740 498 In their sport, the position played by x -1 In their sport, the position played by Mark Bresciano midfielder Mark Bresciano "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Italian' ' national'
+ ' team' ',' ' was' ' a' ' very' ' important' ' one' '.' ' He' ' was'
+ ' the']" , who was a member of the Italian national team , was a very important one . He was the False " Mark Bresciano =
+" 4 [' Mark', ' B', 'res', 'c', 'iano']
+741 499 In their sport, the position played by x -1 In their sport, the position played by Duncan Edwards midfielder Duncan Edwards "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' has' ' been']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder has been" True " buried was named ""Duncan Edwards Close"". The" 6 "[' buried', ' was', ' named', ' ""', 'Dun', 'can', ' Edwards']"
+742 499 In their sport, the position played by x -1 In their sport, the position played by Duncan Edwards midfielder Duncan Edwards "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' has' ' been']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder has been" True 2 ['Dun', 'can', ' Edwards']
+743 499 In their sport, the position played by x -1 In their sport, the position played by Duncan Edwards midfielder Duncan Edwards "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' has' ' been']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder has been" True " was named ""Duncan Edwards Close"". The Wren's" 5 "[' was', ' named', ' ""', 'Dun', 'can', ' Edwards']"
+744 499 In their sport, the position played by x -1 In their sport, the position played by Duncan Edwards midfielder Duncan Edwards "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' has' ' been']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder has been" True player since Duncan Edwards in 1953, when he 3 [' player', ' since', ' Duncan', ' Edwards']
+745 499 In their sport, the position played by x -1 In their sport, the position played by Duncan Edwards midfielder Duncan Edwards "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' has' ' been']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder has been" True " was named ""Duncan Edwards Close"". The Wren's" 5 "[' was', ' named', ' ""', 'Dun', 'can', ' Edwards']"
+746 500 In their sport, the position played by x -1 In their sport, the position played by Chip Banks linebacker Chip Banks "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False Browns traded Chip Banks along with their 3 [' Browns', ' traded', ' Chip', ' Banks']
+747 500 In their sport, the position played by x -1 In their sport, the position played by Chip Banks linebacker Chip Banks "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False Browns traded Chip Banks along with their 3 [' Browns', ' traded', ' Chip', ' Banks']
+748 506 In their sport, the position played by x -1 In their sport, the position played by Hugh Millen quarterback Hugh Millen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False quarterback, behind Hugh Millen and Tom Hodson, and 5 [' quarterback', ',', ' behind', ' Hugh', ' Mill', 'en']
+749 506 In their sport, the position played by x -1 In their sport, the position played by Hugh Millen quarterback Hugh Millen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False quarterback, behind Hugh Millen and Tom Hodson, 5 [' quarterback', ',', ' behind', ' Hugh', ' Mill', 'en']
+750 506 In their sport, the position played by x -1 In their sport, the position played by Hugh Millen quarterback Hugh Millen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False quarterback, behind Hugh Millen and Tom Hodson, 5 [' quarterback', ',', ' behind', ' Hugh', ' Mill', 'en']
+751 506 In their sport, the position played by x -1 In their sport, the position played by Hugh Millen quarterback Hugh Millen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False quarterback, behind Hugh Millen and Tom Hodson, 5 [' quarterback', ',', ' behind', ' Hugh', ' Mill', 'en']
+752 506 In their sport, the position played by x -1 In their sport, the position played by Hugh Millen quarterback Hugh Millen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False They also gave Hugh Millen permission to 5 [' They', ' also', ' gave', ' Hugh', ' Mill', 'en']
+753 517 In their sport, the position played by x -1 In their sport, the position played by Milt Plum quarterback Milt Plum "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' M' 'ilt' ' Plum' ',' ' the' ' quarterback' ',' ' is']" ", the
+
+ In their sport , the position played by M ilt Plum , the quarterback , is" True Illinois, and Milt Plum was named as his 5 [' Illinois', ',', ' and', ' M', 'ilt', ' Plum']
+754 517 In their sport, the position played by x -1 In their sport, the position played by Milt Plum quarterback Milt Plum "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' M' 'ilt' ' Plum' ',' ' the' ' quarterback' ',' ' is']" ", the
+
+ In their sport , the position played by M ilt Plum , the quarterback , is" True strategy. Milt Plum spoke out against Brown 4 [' strategy', '.', ' M', 'ilt', ' Plum']
+755 517 In their sport, the position played by x -1 In their sport, the position played by Milt Plum quarterback Milt Plum "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' M' 'ilt' ' Plum' ',' ' the' ' quarterback' ',' ' is']" ", the
+
+ In their sport , the position played by M ilt Plum , the quarterback , is" True in Illinois, and Milt Plum was named 6 [' in', ' Illinois', ',', ' and', ' M', 'ilt', ' Plum']
+756 517 In their sport, the position played by x -1 In their sport, the position played by Milt Plum quarterback Milt Plum "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' M' 'ilt' ' Plum' ',' ' the' ' quarterback' ',' ' is']" ", the
+
+ In their sport , the position played by M ilt Plum , the quarterback , is" True in Illinois, and Milt Plum was named as his replacement. 6 [' in', ' Illinois', ',', ' and', ' M', 'ilt', ' Plum']
+757 517 In their sport, the position played by x -1 In their sport, the position played by Milt Plum quarterback Milt Plum "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' M' 'ilt' ' Plum' ',' ' the' ' quarterback' ',' ' is']" ", the
+
+ In their sport , the position played by M ilt Plum , the quarterback , is" True in Illinois, and Milt Plum was named as 6 [' in', ' Illinois', ',', ' and', ' M', 'ilt', ' Plum']
+758 521 In their sport, the position played by x -1 In their sport, the position played by David Aebischer goaltender David Aebischer "[',' ' the' ' German' ',' ' is' ' that' ' of' ' a' '\n' '\n' 'The'
+ ' German' ',' ' who' ' is' ' a' ' very' ' good' ' player' ',']" ", the German , is that of a
+
+ The German , who is a very good player ," False doubts if goalie David Aebischer could perform 6 [' doubts', ' if', ' goalie', ' David', ' A', 'eb', 'ischer']
+759 521 In their sport, the position played by x -1 In their sport, the position played by David Aebischer goaltender David Aebischer "[',' ' the' ' German' ',' ' is' ' that' ' of' ' a' '\n' '\n' 'The'
+ ' German' ',' ' who' ' is' ' a' ' very' ' good' ' player' ',']" ", the German , is that of a
+
+ The German , who is a very good player ," False Montreal against David Aebischer of the Montreal 5 [' Montreal', ' against', ' David', ' A', 'eb', 'ischer']
+760 521 In their sport, the position played by x -1 In their sport, the position played by David Aebischer goaltender David Aebischer "[',' ' the' ' German' ',' ' is' ' that' ' of' ' a' '\n' '\n' 'The'
+ ' German' ',' ' who' ' is' ' a' ' very' ' good' ' player' ',']" ", the German , is that of a
+
+ The German , who is a very good player ," False Montreal against David Aebischer of the Montreal 5 [' Montreal', ' against', ' David', ' A', 'eb', 'ischer']
+761 521 In their sport, the position played by x -1 In their sport, the position played by David Aebischer goaltender David Aebischer "[',' ' the' ' German' ',' ' is' ' that' ' of' ' a' '\n' '\n' 'The'
+ ' German' ',' ' who' ' is' ' a' ' very' ' good' ' player' ',']" ", the German , is that of a
+
+ The German , who is a very good player ," False European leagues. David Aebischer returned home 6 [' European', ' leagues', '.', ' David', ' A', 'eb', 'ischer']
+762 521 In their sport, the position played by x -1 In their sport, the position played by David Aebischer goaltender David Aebischer "[',' ' the' ' German' ',' ' is' ' that' ' of' ' a' '\n' '\n' 'The'
+ ' German' ',' ' who' ' is' ' a' ' very' ' good' ' player' ',']" ", the German , is that of a
+
+ The German , who is a very good player ," False Avalanche traded goalie David Aebischer to the Montreal 6 [' Avalanche', ' traded', ' goalie', ' David', ' A', 'eb', 'ischer']
+763 522 In their sport, the position played by x -1 In their sport, the position played by Charlie Robertson pitcher Charlie Robertson "[',' ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about' ' the' ' new' ' �' '�' 'The' '\n' '\n' 'The' ' first']" ", the
+
+ The first thing that strikes you about the new � � The
+
+ The first" False after those thrown by Charlie Robertson and Mark Buehrle, 5 [' after', ' those', ' thrown', ' by', ' Charlie', ' Robertson']
+764 522 In their sport, the position played by x -1 In their sport, the position played by Charlie Robertson pitcher Charlie Robertson "[',' ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about' ' the' ' new' ' �' '�' 'The' '\n' '\n' 'The' ' first']" ", the
+
+ The first thing that strikes you about the new � � The
+
+ The first" False those thrown by Charlie Robertson and Mark Buehrle, 4 [' those', ' thrown', ' by', ' Charlie', ' Robertson']
+765 522 In their sport, the position played by x -1 In their sport, the position played by Charlie Robertson pitcher Charlie Robertson "[',' ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about' ' the' ' new' ' �' '�' 'The' '\n' '\n' 'The' ' first']" ", the
+
+ The first thing that strikes you about the new � � The
+
+ The first" False game since Charlie Robertson in 1922 (Don 3 [' game', ' since', ' Charlie', ' Robertson']
+766 522 In their sport, the position played by x -1 In their sport, the position played by Charlie Robertson pitcher Charlie Robertson "[',' ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about' ' the' ' new' ' �' '�' 'The' '\n' '\n' 'The' ' first']" ", the
+
+ The first thing that strikes you about the new � � The
+
+ The first" False perfect game since Charlie Robertson in 1922 (Don Larsen's 4 [' perfect', ' game', ' since', ' Charlie', ' Robertson']
+767 525 In their sport, the position played by x -1 In their sport, the position played by Brad Van Pelt linebacker Brad Van Pelt "[',' ' the' ' team' ""'s"" ' quarterback' ',' ' is' ' to' ' throw' ' the'
+ ' ball' ' to' ' the' ' receiver' ',' ' who' ' is' ' running' ' a'
+ ' route']" , the team 's quarterback , is to throw the ball to the receiver , who is running a route False Carson and Pro Bowler Brad Van Pelt — into one of 8 [' Carson', ' and', ' Pro', ' Bow', 'ler', ' Brad', ' Van', ' P', 'elt']
+768 525 In their sport, the position played by x -1 In their sport, the position played by Brad Van Pelt linebacker Brad Van Pelt "[',' ' the' ' team' ""'s"" ' quarterback' ',' ' is' ' to' ' throw' ' the'
+ ' ball' ' to' ' the' ' receiver' ',' ' who' ' is' ' running' ' a'
+ ' route']" , the team 's quarterback , is to throw the ball to the receiver , who is running a route False Pro Bowler Brad Van Pelt — into one of 6 [' Pro', ' Bow', 'ler', ' Brad', ' Van', ' P', 'elt']
+769 525 In their sport, the position played by x -1 In their sport, the position played by Brad Van Pelt linebacker Brad Van Pelt "[',' ' the' ' team' ""'s"" ' quarterback' ',' ' is' ' to' ' throw' ' the'
+ ' ball' ' to' ' the' ' receiver' ',' ' who' ' is' ' running' ' a'
+ ' route']" , the team 's quarterback , is to throw the ball to the receiver , who is running a route False and Pro Bowler Brad Van Pelt — into one of 7 [' and', ' Pro', ' Bow', 'ler', ' Brad', ' Van', ' P', 'elt']
+770 527 In their sport, the position played by x -1 In their sport, the position played by Darren Fletcher midfielder Darren Fletcher "[',' ' who' ' was' ' a' ' key' ' player' ' in' ' the' ' team' ""'s""
+ ' success' ' in' ' the' ' Champions' ' League' ',' ' has' ' been'
+ ' taken' ' by']" , who was a key player in the team 's success in the Champions League , has been taken by False one-two with Darren Fletcher before making his 5 [' one', '-', 'two', ' with', ' Darren', ' Fletcher']
+771 527 In their sport, the position played by x -1 In their sport, the position played by Darren Fletcher midfielder Darren Fletcher "[',' ' who' ' was' ' a' ' key' ' player' ' in' ' the' ' team' ""'s""
+ ' success' ' in' ' the' ' Champions' ' League' ',' ' has' ' been'
+ ' taken' ' by']" , who was a key player in the team 's success in the Champions League , has been taken by False it meant that Darren Fletcher started on the 4 [' it', ' meant', ' that', ' Darren', ' Fletcher']
+772 527 In their sport, the position played by x -1 In their sport, the position played by Darren Fletcher midfielder Darren Fletcher "[',' ' who' ' was' ' a' ' key' ' player' ' in' ' the' ' team' ""'s""
+ ' success' ' in' ' the' ' Champions' ' League' ',' ' has' ' been'
+ ' taken' ' by']" , who was a key player in the team 's success in the Champions League , has been taken by False Meanwhile, midfielder Darren Fletcher was forced to miss 4 [' Meanwhile', ',', ' midfielder', ' Darren', ' Fletcher']
+773 527 In their sport, the position played by x -1 In their sport, the position played by Darren Fletcher midfielder Darren Fletcher "[',' ' who' ' was' ' a' ' key' ' player' ' in' ' the' ' team' ""'s""
+ ' success' ' in' ' the' ' Champions' ' League' ',' ' has' ' been'
+ ' taken' ' by']" , who was a key player in the team 's success in the Champions League , has been taken by False Meanwhile, midfielder Darren Fletcher was forced to miss 4 [' Meanwhile', ',', ' midfielder', ' Darren', ' Fletcher']
+774 527 In their sport, the position played by x -1 In their sport, the position played by Darren Fletcher midfielder Darren Fletcher "[',' ' who' ' was' ' a' ' key' ' player' ' in' ' the' ' team' ""'s""
+ ' success' ' in' ' the' ' Champions' ' League' ',' ' has' ' been'
+ ' taken' ' by']" , who was a key player in the team 's success in the Champions League , has been taken by False Meanwhile, midfielder Darren Fletcher was forced to 4 [' Meanwhile', ',', ' midfielder', ' Darren', ' Fletcher']
+775 529 In their sport, the position played by x -1 In their sport, the position played by Tedy Bruschi linebacker Tedy Bruschi "[',' ' the' ' Patriots' ""'"" ' middle' ' linebacker' ',' ' is' ' the'
+ ' most' ' important' ' position' ' on' ' the' ' field' '.' ' Brus' 'chi'
+ ' is' ' the']" , the Patriots ' middle linebacker , is the most important position on the field . Brus chi is the True Patriots' Tedy Bruschi and Richard Seymour, 5 "[' Patriots', ""'"", ' T', 'edy', ' Brus', 'chi']"
+776 529 In their sport, the position played by x -1 In their sport, the position played by Tedy Bruschi linebacker Tedy Bruschi "[',' ' the' ' Patriots' ""'"" ' middle' ' linebacker' ',' ' is' ' the'
+ ' most' ' important' ' position' ' on' ' the' ' field' '.' ' Brus' 'chi'
+ ' is' ' the']" , the Patriots ' middle linebacker , is the most important position on the field . Brus chi is the True and the Patriots' Tedy Bruschi and Richard 7 "[' and', ' the', ' Patriots', ""'"", ' T', 'edy', ' Brus', 'chi']"
+777 531 In their sport, the position played by x -1 In their sport, the position played by Curtis McElhinney goaltender Curtis McElhinney "[',' ' the' ' goalie' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' ice' '.' ' The' ' goalie' ' is' ' responsible' ' for'
+ ' stopping' ' the']" , the goalie , is the most important position on the ice . The goalie is responsible for stopping the False broke. Backup Curtis McElhinney was a surprise 7 [' broke', '.', ' Backup', ' Curtis', ' Mc', 'El', 'hin', 'ney']
+778 531 In their sport, the position played by x -1 In their sport, the position played by Curtis McElhinney goaltender Curtis McElhinney "[',' ' the' ' goalie' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' ice' '.' ' The' ' goalie' ' is' ' responsible' ' for'
+ ' stopping' ' the']" , the goalie , is the most important position on the ice . The goalie is responsible for stopping the False controversy broke. Backup Curtis McElhinney was a surprise starter 8 [' controversy', ' broke', '.', ' Backup', ' Curtis', ' Mc', 'El', 'hin', 'ney']
+779 531 In their sport, the position played by x -1 In their sport, the position played by Curtis McElhinney goaltender Curtis McElhinney "[',' ' the' ' goalie' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' ice' '.' ' The' ' goalie' ' is' ' responsible' ' for'
+ ' stopping' ' the']" , the goalie , is the most important position on the ice . The goalie is responsible for stopping the False goaltender Curtis McElhinney on waivers, and traded 5 [' goaltender', ' Curtis', ' Mc', 'El', 'hin', 'ney']
+780 531 In their sport, the position played by x -1 In their sport, the position played by Curtis McElhinney goaltender Curtis McElhinney "[',' ' the' ' goalie' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' ice' '.' ' The' ' goalie' ' is' ' responsible' ' for'
+ ' stopping' ' the']" , the goalie , is the most important position on the ice . The goalie is responsible for stopping the False picked up goaltender Curtis McElhinney on waivers 7 [' picked', ' up', ' goaltender', ' Curtis', ' Mc', 'El', 'hin', 'ney']
+781 531 In their sport, the position played by x -1 In their sport, the position played by Curtis McElhinney goaltender Curtis McElhinney "[',' ' the' ' goalie' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' ice' '.' ' The' ' goalie' ' is' ' responsible' ' for'
+ ' stopping' ' the']" , the goalie , is the most important position on the ice . The goalie is responsible for stopping the False up goaltender Curtis McElhinney on waivers and traded 6 [' up', ' goaltender', ' Curtis', ' Mc', 'El', 'hin', 'ney']
+782 539 In their sport, the position played by x -1 In their sport, the position played by Jim Hart quarterback Jim Hart "[',' ' the' ' former' ' head' ' of' ' the' ' CIA' '�' '�' 's' ' National'
+ ' Cl' 'andestine' ' Service' ',' ' is' ' a' ' perfect' ' example' ' of']" , the former head of the CIA � � s National Cl andestine Service , is a perfect example of False League Club Jim Hart said the state 3 [' League', ' Club', ' Jim', ' Hart']
+783 539 In their sport, the position played by x -1 In their sport, the position played by Jim Hart quarterback Jim Hart "[',' ' the' ' former' ' head' ' of' ' the' ' CIA' '�' '�' 's' ' National'
+ ' Cl' 'andestine' ' Service' ',' ' is' ' a' ' perfect' ' example' ' of']" , the former head of the CIA � � s National Cl andestine Service , is a perfect example of False National League Club Jim Hart said the state 4 [' National', ' League', ' Club', ' Jim', ' Hart']
+784 540 In their sport, the position played by x -1 In their sport, the position played by Junior Stanislas midfielder Junior Stanislas "[',' ' the' ' French' ' player' ',' ' is' ' that' ' of' ' a' '\n' '\n'
+ 'The' ' French' ' player' ',' ' Junior' ' Stan' 'isl' 'as' ',']" ", the French player , is that of a
+
+ The French player , Junior Stan isl as ," False but a goal from Junior Stanislas three minutes from 7 [' but', ' a', ' goal', ' from', ' Junior', ' Stan', 'isl', 'as']
+785 540 In their sport, the position played by x -1 In their sport, the position played by Junior Stanislas midfielder Junior Stanislas "[',' ' the' ' French' ' player' ',' ' is' ' that' ' of' ' a' '\n' '\n'
+ 'The' ' French' ' player' ',' ' Junior' ' Stan' 'isl' 'as' ',']" ", the French player , is that of a
+
+ The French player , Junior Stan isl as ," False lead, but a goal from Junior Stanislas three minutes 9 [' lead', ',', ' but', ' a', ' goal', ' from', ' Junior', ' Stan', 'isl', 'as']
+786 540 In their sport, the position played by x -1 In their sport, the position played by Junior Stanislas midfielder Junior Stanislas "[',' ' the' ' French' ' player' ',' ' is' ' that' ' of' ' a' '\n' '\n'
+ 'The' ' French' ' player' ',' ' Junior' ' Stan' 'isl' 'as' ',']" ", the French player , is that of a
+
+ The French player , Junior Stan isl as ," False but a goal from Junior Stanislas three minutes from 7 [' but', ' a', ' goal', ' from', ' Junior', ' Stan', 'isl', 'as']
+787 540 In their sport, the position played by x -1 In their sport, the position played by Junior Stanislas midfielder Junior Stanislas "[',' ' the' ' French' ' player' ',' ' is' ' that' ' of' ' a' '\n' '\n'
+ 'The' ' French' ' player' ',' ' Junior' ' Stan' 'isl' 'as' ',']" ", the French player , is that of a
+
+ The French player , Junior Stan isl as ," False but a goal from Junior Stanislas three minutes 7 [' but', ' a', ' goal', ' from', ' Junior', ' Stan', 'isl', 'as']
+788 541 In their sport, the position played by x -1 In their sport, the position played by Wil Nieves catcher Wil Nieves "[',' ' who' ' was' ' the' ' first' ' Puerto' ' Rican' ' to' ' play' ' in'
+ ' the' ' major' ' leagues' ',' ' is' ' a' ' position' ' that' ' is'
+ ' not']" , who was the first Puerto Rican to play in the major leagues , is a position that is not False backup catcher Wil Nieves and a home run 4 [' backup', ' catcher', ' Wil', ' N', 'ieves']
+789 541 In their sport, the position played by x -1 In their sport, the position played by Wil Nieves catcher Wil Nieves "[',' ' who' ' was' ' the' ' first' ' Puerto' ' Rican' ' to' ' play' ' in'
+ ' the' ' major' ' leagues' ',' ' is' ' a' ' position' ' that' ' is'
+ ' not']" , who was the first Puerto Rican to play in the major leagues , is a position that is not False from backup catcher Wil Nieves and a home run 5 [' from', ' backup', ' catcher', ' Wil', ' N', 'ieves']
+790 541 In their sport, the position played by x -1 In their sport, the position played by Wil Nieves catcher Wil Nieves "[',' ' who' ' was' ' the' ' first' ' Puerto' ' Rican' ' to' ' play' ' in'
+ ' the' ' major' ' leagues' ',' ' is' ' a' ' position' ' that' ' is'
+ ' not']" , who was the first Puerto Rican to play in the major leagues , is a position that is not False Jr., Reid Brignac, Wil Nieves and Hernández likely 9 [' Jr', '.,', ' Reid', ' Br', 'ign', 'ac', ',', ' Wil', ' N', 'ieves']
+791 541 In their sport, the position played by x -1 In their sport, the position played by Wil Nieves catcher Wil Nieves "[',' ' who' ' was' ' the' ' first' ' Puerto' ' Rican' ' to' ' play' ' in'
+ ' the' ' major' ' leagues' ',' ' is' ' a' ' position' ' that' ' is'
+ ' not']" , who was the first Puerto Rican to play in the major leagues , is a position that is not False with help from Wil Nieves (three hits) and 5 [' with', ' help', ' from', ' Wil', ' N', 'ieves']
+792 541 In their sport, the position played by x -1 In their sport, the position played by Wil Nieves catcher Wil Nieves "[',' ' who' ' was' ' the' ' first' ' Puerto' ' Rican' ' to' ' play' ' in'
+ ' the' ' major' ' leagues' ',' ' is' ' a' ' position' ' that' ' is'
+ ' not']" , who was the first Puerto Rican to play in the major leagues , is a position that is not False backup catcher Wil Nieves and a home 4 [' backup', ' catcher', ' Wil', ' N', 'ieves']
+793 544 In their sport, the position played by x -1 In their sport, the position played by Ty Conklin goaltender Ty Conklin "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False help of goaltender Ty Conklin and center Evgeni 6 [' help', ' of', ' goaltender', ' Ty', ' Con', 'k', 'lin']
+794 544 In their sport, the position played by x -1 In their sport, the position played by Ty Conklin goaltender Ty Conklin "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False goal against Ty Conklin of the Detroit Red 5 [' goal', ' against', ' Ty', ' Con', 'k', 'lin']
+795 544 In their sport, the position played by x -1 In their sport, the position played by Ty Conklin goaltender Ty Conklin "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False injured. Goaltender Ty Conklin replaced Marc-Andre 8 [' injured', '.', ' Go', 'alt', 'ender', ' Ty', ' Con', 'k', 'lin']
+796 544 In their sport, the position played by x -1 In their sport, the position played by Ty Conklin goaltender Ty Conklin "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False style, with Ty Conklin and Ryan Miller 6 [' style', ',', ' with', ' Ty', ' Con', 'k', 'lin']
+797 544 In their sport, the position played by x -1 In their sport, the position played by Ty Conklin goaltender Ty Conklin "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the False first NHL goal against Ty Conklin of the Detroit 7 [' first', ' NHL', ' goal', ' against', ' Ty', ' Con', 'k', 'lin']
+798 545 In their sport, the position played by x -1 In their sport, the position played by George Best midfielder George Best "[' in' ' the' ' 1966' ' World' ' Cup' ' final' ',' ' he' ' was' ' the'
+ ' best' ' player' ' on' ' the' ' pitch' ',' ' but' ' he' ' was' ' also']" in the 1966 World Cup final , he was the best player on the pitch , but he was also False notorious incident with George Best who, while playing 4 [' notorious', ' incident', ' with', ' George', ' Best']
+799 545 In their sport, the position played by x -1 In their sport, the position played by George Best midfielder George Best "[' in' ' the' ' 1966' ' World' ' Cup' ' final' ',' ' he' ' was' ' the'
+ ' best' ' player' ' on' ' the' ' pitch' ',' ' but' ' he' ' was' ' also']" in the 1966 World Cup final , he was the best player on the pitch , but he was also False Sydenham links into George Best Belfast City Airport 6 [' Sy', 'den', 'ham', ' links', ' into', ' George', ' Best']
+800 545 In their sport, the position played by x -1 In their sport, the position played by George Best midfielder George Best "[' in' ' the' ' 1966' ' World' ' Cup' ' final' ',' ' he' ' was' ' the'
+ ' best' ' player' ' on' ' the' ' pitch' ',' ' but' ' he' ' was' ' also']" in the 1966 World Cup final , he was the best player on the pitch , but he was also False – including George Best – to win the FA Cup 3 [' –', ' including', ' George', ' Best']
+801 545 In their sport, the position played by x -1 In their sport, the position played by George Best midfielder George Best "[' in' ' the' ' 1966' ' World' ' Cup' ' final' ',' ' he' ' was' ' the'
+ ' best' ' player' ' on' ' the' ' pitch' ',' ' but' ' he' ' was' ' also']" in the 1966 World Cup final , he was the best player on the pitch , but he was also False best in the world. George Best was something 6 [' best', ' in', ' the', ' world', '.', ' George', ' Best']
+802 545 In their sport, the position played by x -1 In their sport, the position played by George Best midfielder George Best "[' in' ' the' ' 1966' ' World' ' Cup' ' final' ',' ' he' ' was' ' the'
+ ' best' ' player' ' on' ' the' ' pitch' ',' ' but' ' he' ' was' ' also']" in the 1966 World Cup final , he was the best player on the pitch , but he was also False notorious incident with George Best who, while playing 4 [' notorious', ' incident', ' with', ' George', ' Best']
+803 546 In their sport, the position played by x -1 In their sport, the position played by Kevin Weekes goaltender Kevin Weekes "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Conn'
+ ' Smy' 'the' ' Trophy' ' as' ' the' ' playoff' ' MVP' '.' '\n' '\n']" ", who was the first goalie to win the Conn Smy the Trophy as the playoff MVP .
+
+" False playing backup Kevin Weekes in favour of 4 [' playing', ' backup', ' Kevin', ' Week', 'es']
+804 546 In their sport, the position played by x -1 In their sport, the position played by Kevin Weekes goaltender Kevin Weekes "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Conn'
+ ' Smy' 'the' ' Trophy' ' as' ' the' ' playoff' ' MVP' '.' '\n' '\n']" ", who was the first goalie to win the Conn Smy the Trophy as the playoff MVP .
+
+" False playing backup Kevin Weekes in favour of Cloutier. 4 [' playing', ' backup', ' Kevin', ' Week', 'es']
+805 546 In their sport, the position played by x -1 In their sport, the position played by Kevin Weekes goaltender Kevin Weekes "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Conn'
+ ' Smy' 'the' ' Trophy' ' as' ' the' ' playoff' ' MVP' '.' '\n' '\n']" ", who was the first goalie to win the Conn Smy the Trophy as the playoff MVP .
+
+" False began playing backup Kevin Weekes in favour of 5 [' began', ' playing', ' backup', ' Kevin', ' Week', 'es']
+806 546 In their sport, the position played by x -1 In their sport, the position played by Kevin Weekes goaltender Kevin Weekes "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Conn'
+ ' Smy' 'the' ' Trophy' ' as' ' the' ' playoff' ' MVP' '.' '\n' '\n']" ", who was the first goalie to win the Conn Smy the Trophy as the playoff MVP .
+
+" False Mike Brown, Kevin Weekes and a first-round 5 [' Mike', ' Brown', ',', ' Kevin', ' Week', 'es']
+807 546 In their sport, the position played by x -1 In their sport, the position played by Kevin Weekes goaltender Kevin Weekes "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Conn'
+ ' Smy' 'the' ' Trophy' ' as' ' the' ' playoff' ' MVP' '.' '\n' '\n']" ", who was the first goalie to win the Conn Smy the Trophy as the playoff MVP .
+
+" False Rangers goalie Kevin Weekes during an 4 [' Rangers', ' goalie', ' Kevin', ' Week', 'es']
+808 551 In their sport, the position played by x -1 In their sport, the position played by Mike Scioscia catcher Mike Scioscia "[',' ' the' ' manager' ' of' ' the' ' Los' ' Angeles' ' Angels' ',' ' is'
+ ' the' ' same' ' as' ' the' ' position' ' played' ' by' ' the' ' manager'
+ ' of']" , the manager of the Los Angeles Angels , is the same as the position played by the manager of False positive, with manager Mike Scioscia saying that 7 [' positive', ',', ' with', ' manager', ' Mike', ' Sci', 'osc', 'ia']
+809 551 In their sport, the position played by x -1 In their sport, the position played by Mike Scioscia catcher Mike Scioscia "[',' ' the' ' manager' ' of' ' the' ' Los' ' Angeles' ' Angels' ',' ' is'
+ ' the' ' same' ' as' ' the' ' position' ' played' ' by' ' the' ' manager'
+ ' of']" , the manager of the Los Angeles Angels , is the same as the position played by the manager of False Angels manager Mike Scioscia dedicated his 2009 5 [' Angels', ' manager', ' Mike', ' Sci', 'osc', 'ia']
+810 551 In their sport, the position played by x -1 In their sport, the position played by Mike Scioscia catcher Mike Scioscia "[',' ' the' ' manager' ' of' ' the' ' Los' ' Angeles' ' Angels' ',' ' is'
+ ' the' ' same' ' as' ' the' ' position' ' played' ' by' ' the' ' manager'
+ ' of']" , the manager of the Los Angeles Angels , is the same as the position played by the manager of False other players. Mike Scioscia accepted his guest 6 [' other', ' players', '.', ' Mike', ' Sci', 'osc', 'ia']
+811 551 In their sport, the position played by x -1 In their sport, the position played by Mike Scioscia catcher Mike Scioscia "[',' ' the' ' manager' ' of' ' the' ' Los' ' Angeles' ' Angels' ',' ' is'
+ ' the' ' same' ' as' ' the' ' position' ' played' ' by' ' the' ' manager'
+ ' of']" , the manager of the Los Angeles Angels , is the same as the position played by the manager of False Angel manager Mike Scioscia and acting manager 5 [' Angel', ' manager', ' Mike', ' Sci', 'osc', 'ia']
+812 551 In their sport, the position played by x -1 In their sport, the position played by Mike Scioscia catcher Mike Scioscia "[',' ' the' ' manager' ' of' ' the' ' Los' ' Angeles' ' Angels' ',' ' is'
+ ' the' ' same' ' as' ' the' ' position' ' played' ' by' ' the' ' manager'
+ ' of']" , the manager of the Los Angeles Angels , is the same as the position played by the manager of False 138,038.57. Angels manager Mike Scioscia dedicated his 2009 12 [' 138', ',', '0', '38', '.', '57', '.', ' Angels', ' manager', ' Mike', ' Sci', 'osc', 'ia']
+813 552 In their sport, the position played by x -1 In their sport, the position played by Carlos Lee outfielder Carlos Lee "[',' ' the' ' former' ' Texas' ' Rangers' ' outfielder' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' who' ' can'
+ ' be' ' a']" , the former Texas Rangers outfielder , is a perfect example of the type of player who can be a True runs (32, tied with Carlos Lee and Chris Young) 7 [' runs', ' (', '32', ',', ' tied', ' with', ' Carlos', ' Lee']
+814 552 In their sport, the position played by x -1 In their sport, the position played by Carlos Lee outfielder Carlos Lee "[',' ' the' ' former' ' Texas' ' Rangers' ' outfielder' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' who' ' can'
+ ' be' ' a']" , the former Texas Rangers outfielder , is a perfect example of the type of player who can be a True (32, tied with Carlos Lee and Chris Young) for 6 [' (', '32', ',', ' tied', ' with', ' Carlos', ' Lee']
+815 554 In their sport, the position played by x -1 In their sport, the position played by Bobby Hebert quarterback Bobby Hebert "[',' ' the' ' former' ' LSU' ' quarterback' ',' ' is' ' a' ' position'
+ ' that' ' is' ' not' ' easy' ' to' ' fill' '.' '\n' '\n' '""' 'I']" ", the former LSU quarterback , is a position that is not easy to fill .
+
+ "" I" True time sitting behind Bobby Hebert and Browning 5 [' time', ' sitting', ' behind', ' Bobby', ' He', 'bert']
+816 554 In their sport, the position played by x -1 In their sport, the position played by Bobby Hebert quarterback Bobby Hebert "[',' ' the' ' former' ' LSU' ' quarterback' ',' ' is' ' a' ' position'
+ ' that' ' is' ' not' ' easy' ' to' ' fill' '.' '\n' '\n' '""' 'I']" ", the former LSU quarterback , is a position that is not easy to fill .
+
+ "" I" True time sitting behind Bobby Hebert and Browning 5 [' time', ' sitting', ' behind', ' Bobby', ' He', 'bert']
+817 554 In their sport, the position played by x -1 In their sport, the position played by Bobby Hebert quarterback Bobby Hebert "[',' ' the' ' former' ' LSU' ' quarterback' ',' ' is' ' a' ' position'
+ ' that' ' is' ' not' ' easy' ' to' ' fill' '.' '\n' '\n' '""' 'I']" ", the former LSU quarterback , is a position that is not easy to fill .
+
+ "" I" True sitting behind Bobby Hebert and Browning Nagle. 4 [' sitting', ' behind', ' Bobby', ' He', 'bert']
+818 554 In their sport, the position played by x -1 In their sport, the position played by Bobby Hebert quarterback Bobby Hebert "[',' ' the' ' former' ' LSU' ' quarterback' ',' ' is' ' a' ' position'
+ ' that' ' is' ' not' ' easy' ' to' ' fill' '.' '\n' '\n' '""' 'I']" ", the former LSU quarterback , is a position that is not easy to fill .
+
+ "" I" True sitting behind Bobby Hebert and Browning Nagle. 4 [' sitting', ' behind', ' Bobby', ' He', 'bert']
+819 557 In their sport, the position played by x -1 In their sport, the position played by Vicente Engonga midfielder Vicente Engonga "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Vic' 'ente' ' Eng' 'onga' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Vic ente Eng onga , the
+
+" False the penalty box. Vicente Engonga scored from the resulting 7 [' the', ' penalty', ' box', '.', ' Vic', 'ente', ' Eng', 'onga']
+820 560 In their sport, the position played by x -1 In their sport, the position played by Kurt Warner quarterback Kurt Warner "[',' ' the' ' quarterback' ' for' ' the' ' St' '.' ' Louis' ' Rams' ','
+ ' is' ' a' ' position' ' that' ' is' ' often' ' overlooked' '.' '\n' '\n']" ", the quarterback for the St . Louis Rams , is a position that is often overlooked .
+
+" True Rams'quarterback Kurt Warner that evening. Smith 5 "[' Rams', ""'"", 'quarter', 'back', ' Kurt', ' Warner']"
+821 560 In their sport, the position played by x -1 In their sport, the position played by Kurt Warner quarterback Kurt Warner "[',' ' the' ' quarterback' ' for' ' the' ' St' '.' ' Louis' ' Rams' ','
+ ' is' ' a' ' position' ' that' ' is' ' often' ' overlooked' '.' '\n' '\n']" ", the quarterback for the St . Louis Rams , is a position that is often overlooked .
+
+" True Rams'quarterback Kurt Warner that evening. 5 "[' Rams', ""'"", 'quarter', 'back', ' Kurt', ' Warner']"
+822 560 In their sport, the position played by x -1 In their sport, the position played by Kurt Warner quarterback Kurt Warner "[',' ' the' ' quarterback' ' for' ' the' ' St' '.' ' Louis' ' Rams' ','
+ ' is' ' a' ' position' ' that' ' is' ' often' ' overlooked' '.' '\n' '\n']" ", the quarterback for the St . Louis Rams , is a position that is often overlooked .
+
+" True Rams'quarterback Kurt Warner that evening. Smith 5 "[' Rams', ""'"", 'quarter', 'back', ' Kurt', ' Warner']"
+823 560 In their sport, the position played by x -1 In their sport, the position played by Kurt Warner quarterback Kurt Warner "[',' ' the' ' quarterback' ' for' ' the' ' St' '.' ' Louis' ' Rams' ','
+ ' is' ' a' ' position' ' that' ' is' ' often' ' overlooked' '.' '\n' '\n']" ", the quarterback for the St . Louis Rams , is a position that is often overlooked .
+
+" True intercepted a pass from Kurt Warner and returned 5 [' intercepted', ' a', ' pass', ' from', ' Kurt', ' Warner']
+824 560 In their sport, the position played by x -1 In their sport, the position played by Kurt Warner quarterback Kurt Warner "[',' ' the' ' quarterback' ' for' ' the' ' St' '.' ' Louis' ' Rams' ','
+ ' is' ' a' ' position' ' that' ' is' ' often' ' overlooked' '.' '\n' '\n']" ", the quarterback for the St . Louis Rams , is a position that is often overlooked .
+
+" True Rams'quarterback Kurt Warner that evening. Smith 5 "[' Rams', ""'"", 'quarter', 'back', ' Kurt', ' Warner']"
+825 561 In their sport, the position played by x -1 In their sport, the position played by Mike Vernon goaltender Mike Vernon "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' great' ' quarterback' ',' ' but']" , the former NFL quarterback , is a bit of a mystery . He was a great quarterback , but False Game in Edmonton. Mike Vernon was the winning goaltender 5 [' Game', ' in', ' Edmonton', '.', ' Mike', ' Vernon']
+826 561 In their sport, the position played by x -1 In their sport, the position played by Mike Vernon goaltender Mike Vernon "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' great' ' quarterback' ',' ' but']" , the former NFL quarterback , is a bit of a mystery . He was a great quarterback , but False " hockey) =
+" 4 [' hockey', ')', ' =', 'Mike', ' Vernon']
+827 561 In their sport, the position played by x -1 In their sport, the position played by Mike Vernon goaltender Mike Vernon "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' great' ' quarterback' ',' ' but']" , the former NFL quarterback , is a bit of a mystery . He was a great quarterback , but False Vernon (ice hockey) 5 [' Vernon', ' (', 'ice', ' hockey', 'Mike', ' Vernon']
+828 561 In their sport, the position played by x -1 In their sport, the position played by Mike Vernon goaltender Mike Vernon "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' great' ' quarterback' ',' ' but']" , the former NFL quarterback , is a bit of a mystery . He was a great quarterback , but False on Flames goalie Mike Vernon in the second overtime 4 [' on', ' Flames', ' goalie', ' Mike', ' Vernon']
+829 561 In their sport, the position played by x -1 In their sport, the position played by Mike Vernon goaltender Mike Vernon "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' was' ' a' ' great' ' quarterback' ',' ' but']" , the former NFL quarterback , is a bit of a mystery . He was a great quarterback , but False centre by the league. Mike Vernon was named to 6 [' centre', ' by', ' the', ' league', '.', ' Mike', ' Vernon']
+830 566 In their sport, the position played by x -1 In their sport, the position played by Kenny Demens linebacker Kenny Demens "[',' ' who' ' was' ' a' ' defensive' ' end' ' for' ' the' ' Packers' ' in'
+ ' the' ' NFL' ',' ' is' ' the' ' same' ' as' ' the' ' position' ' played']" , who was a defensive end for the Packers in the NFL , is the same as the position played False yards per punt return. Kenny Demens led the team 7 [' yards', ' per', ' punt', ' return', '.', ' Kenny', ' Dem', 'ens']
+831 566 In their sport, the position played by x -1 In their sport, the position played by Kenny Demens linebacker Kenny Demens "[',' ' who' ' was' ' a' ' defensive' ' end' ' for' ' the' ' Packers' ' in'
+ ' the' ' NFL' ',' ' is' ' the' ' same' ' as' ' the' ' position' ' played']" , who was a defensive end for the Packers in the NFL , is the same as the position played False per punt return. Kenny Demens led the team 6 [' per', ' punt', ' return', '.', ' Kenny', ' Dem', 'ens']
+832 566 In their sport, the position played by x -1 In their sport, the position played by Kenny Demens linebacker Kenny Demens "[',' ' who' ' was' ' a' ' defensive' ' end' ' for' ' the' ' Packers' ' in'
+ ' the' ' NFL' ',' ' is' ' the' ' same' ' as' ' the' ' position' ' played']" , who was a defensive end for the Packers in the NFL , is the same as the position played False per punt return. Kenny Demens led the team in 6 [' per', ' punt', ' return', '.', ' Kenny', ' Dem', 'ens']
+833 571 In their sport, the position played by x -1 In their sport, the position played by Johnny Vander Meer pitcher Johnny Vander Meer "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Johnny' ' Vander' ' Me' 'er' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Johnny Vander Me er , the
+
+" False history, after Johnny Vander Meer threw consecutive 6 [' history', ',', ' after', ' Johnny', ' Vander', ' Me', 'er']
+834 572 In their sport, the position played by x -1 In their sport, the position played by Steve Grogan quarterback Steve Grogan "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" False quarterback Steve Grogan set an NFL 3 [' quarterback', ' Steve', ' Gro', 'gan']
+835 572 In their sport, the position played by x -1 In their sport, the position played by Steve Grogan quarterback Steve Grogan "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" False Patriots quarterback Steve Grogan set an NFL 4 [' Patriots', ' quarterback', ' Steve', ' Gro', 'gan']
+836 572 In their sport, the position played by x -1 In their sport, the position played by Steve Grogan quarterback Steve Grogan "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" False quarterback Steve Grogan was released. Zolak 3 [' quarterback', ' Steve', ' Gro', 'gan']
+837 572 In their sport, the position played by x -1 In their sport, the position played by Steve Grogan quarterback Steve Grogan "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" False starting quarterback Steve Grogan was released. Zolak 4 [' starting', ' quarterback', ' Steve', ' Gro', 'gan']
+838 572 In their sport, the position played by x -1 In their sport, the position played by Steve Grogan quarterback Steve Grogan "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" False Patriots game to enable Steve Grogan to break the 6 [' Patriots', ' game', ' to', ' enable', ' Steve', ' Gro', 'gan']
+839 574 In their sport, the position played by x -1 In their sport, the position played by Angelo Bertelli quarterback Angelo Bertelli "[',' ' the' ' Italian' ' goalkeeper' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' can' ' save' ' the'
+ ' game']" , the Italian goalkeeper , is the most important . He is the only player who can save the game False Hirsch from quarterback Angelo Bertelli in the second 7 [' H', 'irsch', ' from', ' quarterback', ' Angelo', ' Ber', 'te', 'lli']
+840 574 In their sport, the position played by x -1 In their sport, the position played by Angelo Bertelli quarterback Angelo Bertelli "[',' ' the' ' Italian' ' goalkeeper' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' can' ' save' ' the'
+ ' game']" , the Italian goalkeeper , is the most important . He is the only player who can save the game False quarterback Angelo Bertelli in the second quarter. 4 [' quarterback', ' Angelo', ' Ber', 'te', 'lli']
+841 574 In their sport, the position played by x -1 In their sport, the position played by Angelo Bertelli quarterback Angelo Bertelli "[',' ' the' ' Italian' ' goalkeeper' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' can' ' save' ' the'
+ ' game']" , the Italian goalkeeper , is the most important . He is the only player who can save the game False quarterback Angelo Bertelli in the second quarter. 4 [' quarterback', ' Angelo', ' Ber', 'te', 'lli']
+842 574 In their sport, the position played by x -1 In their sport, the position played by Angelo Bertelli quarterback Angelo Bertelli "[',' ' the' ' Italian' ' goalkeeper' ',' ' is' ' the' ' most' ' important'
+ '.' ' He' ' is' ' the' ' only' ' player' ' who' ' can' ' save' ' the'
+ ' game']" , the Italian goalkeeper , is the most important . He is the only player who can save the game False quarterback Angelo Bertelli in the second 4 [' quarterback', ' Angelo', ' Ber', 'te', 'lli']
+843 576 In their sport, the position played by x -1 In their sport, the position played by Roy Sproson defender Roy Sproson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Roy' ' Sp' 'ro' 'son' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Roy Sp ro son , the
+
+" False Cheadle, Albert Leake, Roy Sproson (half-backs); Colin 11 [' C', 'head', 'le', ',', ' Albert', ' Le', 'ake', ',', ' Roy', ' Sp', 'ro', 'son']
+844 576 In their sport, the position played by x -1 In their sport, the position played by Roy Sproson defender Roy Sproson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Roy' ' Sp' 'ro' 'son' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Roy Sp ro son , the
+
+" False written off. Player Roy Sproson later said that 7 [' written', ' off', '.', ' Player', ' Roy', ' Sp', 'ro', 'son']
+845 576 In their sport, the position played by x -1 In their sport, the position played by Roy Sproson defender Roy Sproson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Roy' ' Sp' 'ro' 'son' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Roy Sp ro son , the
+
+" False Stan Turner and Roy Sproson began to be known 6 [' Stan', ' Turner', ' and', ' Roy', ' Sp', 'ro', 'son']
+846 576 In their sport, the position played by x -1 In their sport, the position played by Roy Sproson defender Roy Sproson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Roy' ' Sp' 'ro' 'son' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Roy Sp ro son , the
+
+" False 52 minutes, before Roy Sproson scored the 7 [' 52', ' minutes', ',', ' before', ' Roy', ' Sp', 'ro', 'son']
+847 576 In their sport, the position played by x -1 In their sport, the position played by Roy Sproson defender Roy Sproson "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Roy' ' Sp' 'ro' 'son' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Roy Sp ro son , the
+
+" False off. Player Roy Sproson later said that 6 [' off', '.', ' Player', ' Roy', ' Sp', 'ro', 'son']
+848 579 In their sport, the position played by x -1 In their sport, the position played by Marty Schottenheimer linebacker Marty Schottenheimer "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' former' ' NFL' ' coach' ','
+ ' Bill']" , the former NFL coach , is the same as the one played by the former NFL coach , Bill False season. Head coach Marty Schottenheimer chose Grbac to 7 [' season', '.', ' Head', ' coach', ' Marty', ' Sch', 'otten', 'heimer']
+849 579 In their sport, the position played by x -1 In their sport, the position played by Marty Schottenheimer linebacker Marty Schottenheimer "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' former' ' NFL' ' coach' ','
+ ' Bill']" , the former NFL coach , is the same as the one played by the former NFL coach , Bill False from 1960 to 1974. Marty Schottenheimer was hired in 1989 8 [' from', ' 1960', ' to', ' 1974', '.', ' Marty', ' Sch', 'otten', 'heimer']
+850 579 In their sport, the position played by x -1 In their sport, the position played by Marty Schottenheimer linebacker Marty Schottenheimer "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' former' ' NFL' ' coach' ','
+ ' Bill']" , the former NFL coach , is the same as the one played by the former NFL coach , Bill False over and hired Marty Schottenheimer as the club's 6 [' over', ' and', ' hired', ' Marty', ' Sch', 'otten', 'heimer']
+851 579 In their sport, the position played by x -1 In their sport, the position played by Marty Schottenheimer linebacker Marty Schottenheimer "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' former' ' NFL' ' coach' ','
+ ' Bill']" , the former NFL coach , is the same as the one played by the former NFL coach , Bill False 1960 to 1974. Marty Schottenheimer was hired in 7 [' 1960', ' to', ' 1974', '.', ' Marty', ' Sch', 'otten', 'heimer']
+852 579 In their sport, the position played by x -1 In their sport, the position played by Marty Schottenheimer linebacker Marty Schottenheimer "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' the' ' same' ' as'
+ ' the' ' one' ' played' ' by' ' the' ' former' ' NFL' ' coach' ','
+ ' Bill']" , the former NFL coach , is the same as the one played by the former NFL coach , Bill False season. Head coach Marty Schottenheimer chose Grbac to 7 [' season', '.', ' Head', ' coach', ' Marty', ' Sch', 'otten', 'heimer']
+853 580 In their sport, the position played by x -1 In their sport, the position played by Melanie Behringer midfielder Melanie Behringer "[',' ' the' ' German' '-' 'born' ' American' ',' ' is' ' that' ' of' ' a'
+ ' �' '�' 's' 'port' 'sw' 'oman' '�' '�' ',']" , the German - born American , is that of a � � s port sw oman � � , False Ariane Hingst and Melanie Behringer have stated that men 9 [' Ari', 'ane', ' H', 'ing', 'st', ' and', ' Melanie', ' Beh', 'ring', 'er']
+854 580 In their sport, the position played by x -1 In their sport, the position played by Melanie Behringer midfielder Melanie Behringer "[',' ' the' ' German' '-' 'born' ' American' ',' ' is' ' that' ' of' ' a'
+ ' �' '�' 's' 'port' 'sw' 'oman' '�' '�' ',']" , the German - born American , is that of a � � s port sw oman � � , False twice, with Melanie Behringer and Kim Kulig 6 [' twice', ',', ' with', ' Melanie', ' Beh', 'ring', 'er']
+855 580 In their sport, the position played by x -1 In their sport, the position played by Melanie Behringer midfielder Melanie Behringer "[',' ' the' ' German' '-' 'born' ' American' ',' ' is' ' that' ' of' ' a'
+ ' �' '�' 's' 'port' 'sw' 'oman' '�' '�' ',']" , the German - born American , is that of a � � s port sw oman � � , False scored twice, with Melanie Behringer and Kim Kulig also 7 [' scored', ' twice', ',', ' with', ' Melanie', ' Beh', 'ring', 'er']
+856 582 In their sport, the position played by x -1 In their sport, the position played by Arthur Milton midfielder Arthur Milton "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Arthur' ' Milton' ',' ' the' ' two' ' teams' ' are'
+ ' the']" ", the
+
+ In their sport , the position played by Arthur Milton , the two teams are the" False Richardson (twice), Arthur Milton and Raman Subba 6 [' Richardson', ' (', 'tw', 'ice', '),', ' Arthur', ' Milton']
+857 582 In their sport, the position played by x -1 In their sport, the position played by Arthur Milton midfielder Arthur Milton "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Arthur' ' Milton' ',' ' the' ' two' ' teams' ' are'
+ ' the']" ", the
+
+ In their sport , the position played by Arthur Milton , the two teams are the" False Richardson (twice), Arthur Milton and Raman Subba 6 [' Richardson', ' (', 'tw', 'ice', '),', ' Arthur', ' Milton']
+858 582 In their sport, the position played by x -1 In their sport, the position played by Arthur Milton midfielder Arthur Milton "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Arthur' ' Milton' ',' ' the' ' two' ' teams' ' are'
+ ' the']" ", the
+
+ In their sport , the position played by Arthur Milton , the two teams are the" False Richardson (twice), Arthur Milton and Raman Subba 6 [' Richardson', ' (', 'tw', 'ice', '),', ' Arthur', ' Milton']
+859 582 In their sport, the position played by x -1 In their sport, the position played by Arthur Milton midfielder Arthur Milton "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Arthur' ' Milton' ',' ' the' ' two' ' teams' ' are'
+ ' the']" ", the
+
+ In their sport , the position played by Arthur Milton , the two teams are the" False " Richardson (twice), Arthur Milton and Raman Subba Row.
+" 6 [' Richardson', ' (', 'tw', 'ice', '),', ' Arthur', ' Milton']
+860 586 In their sport, the position played by x -1 In their sport, the position played by Mark Sanchez quarterback Mark Sanchez "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' position' ' of' ' quarterback' ' is' ' a' ' very'
+ ' important']" ", the quarterback , is to be the quarterback .
+
+ The position of quarterback is a very important" True elected to start Mark Sanchez for the second straight 4 [' elected', ' to', ' start', ' Mark', ' Sanchez']
+861 586 In their sport, the position played by x -1 In their sport, the position played by Mark Sanchez quarterback Mark Sanchez "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' position' ' of' ' quarterback' ' is' ' a' ' very'
+ ' important']" ", the quarterback , is to be the quarterback .
+
+ The position of quarterback is a very important" True round selections were Mark Sanchez (fifth, New 4 [' round', ' selections', ' were', ' Mark', ' Sanchez']
+862 586 In their sport, the position played by x -1 In their sport, the position played by Mark Sanchez quarterback Mark Sanchez "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' position' ' of' ' quarterback' ' is' ' a' ' very'
+ ' important']" ", the quarterback , is to be the quarterback .
+
+ The position of quarterback is a very important" True " Mark Sanchez =
+" 1 [' Mark', ' Sanchez']
+863 586 In their sport, the position played by x -1 In their sport, the position played by Mark Sanchez quarterback Mark Sanchez "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' position' ' of' ' quarterback' ' is' ' a' ' very'
+ ' important']" ", the quarterback , is to be the quarterback .
+
+ The position of quarterback is a very important" True selections were Mark Sanchez (fifth, New York 3 [' selections', ' were', ' Mark', ' Sanchez']
+864 586 In their sport, the position played by x -1 In their sport, the position played by Mark Sanchez quarterback Mark Sanchez "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' position' ' of' ' quarterback' ' is' ' a' ' very'
+ ' important']" ", the quarterback , is to be the quarterback .
+
+ The position of quarterback is a very important" True Trojans quarterback Mark Sanchez ' father, a fire 5 [' Tro', 'j', 'ans', ' quarterback', ' Mark', ' Sanchez']
+865 587 In their sport, the position played by x -1 In their sport, the position played by Zac Robinson quarterback Zac Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' title' ' in' ' 2014' ',' ' is' ' now' ' occupied' ' by' ' the']" , who was a member of the team that won the title in 2014 , is now occupied by the False rushed for 105 yards. Zac Robinson was 30 of 42 for 6 [' rushed', ' for', ' 105', ' yards', '.', ' Zac', ' Robinson']
+866 587 In their sport, the position played by x -1 In their sport, the position played by Zac Robinson quarterback Zac Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' title' ' in' ' 2014' ',' ' is' ' now' ' occupied' ' by' ' the']" , who was a member of the team that won the title in 2014 , is now occupied by the False active, with both Zac Robinson and Bobby Reid 5 [' active', ',', ' with', ' both', ' Zac', ' Robinson']
+867 587 In their sport, the position played by x -1 In their sport, the position played by Zac Robinson quarterback Zac Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' title' ' in' ' 2014' ',' ' is' ' now' ' occupied' ' by' ' the']" , who was a member of the team that won the title in 2014 , is now occupied by the False active, with both Zac Robinson and Bobby Reid completing 5 [' active', ',', ' with', ' both', ' Zac', ' Robinson']
+868 587 In their sport, the position played by x -1 In their sport, the position played by Zac Robinson quarterback Zac Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' title' ' in' ' 2014' ',' ' is' ' now' ' occupied' ' by' ' the']" , who was a member of the team that won the title in 2014 , is now occupied by the False active, with both Zac Robinson and Bobby 5 [' active', ',', ' with', ' both', ' Zac', ' Robinson']
+869 587 In their sport, the position played by x -1 In their sport, the position played by Zac Robinson quarterback Zac Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' title' ' in' ' 2014' ',' ' is' ' now' ' occupied' ' by' ' the']" , who was a member of the team that won the title in 2014 , is now occupied by the False rushed for 105 yards. Zac Robinson was 30 of 42 for 427 6 [' rushed', ' for', ' 105', ' yards', '.', ' Zac', ' Robinson']
+870 590 In their sport, the position played by x -1 In their sport, the position played by Don McPherson quarterback Don McPherson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True three years. Don McPherson was elected 7 [' three', ' years', '.', ' Don', ' Mc', 'P', 'her', 'son']
+871 590 In their sport, the position played by x -1 In their sport, the position played by Don McPherson quarterback Don McPherson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True every three years. Don McPherson was elected mayor 8 [' every', ' three', ' years', '.', ' Don', ' Mc', 'P', 'her', 'son']
+872 590 In their sport, the position played by x -1 In their sport, the position played by Don McPherson quarterback Don McPherson "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True three years. Don McPherson was elected mayor 7 [' three', ' years', '.', ' Don', ' Mc', 'P', 'her', 'son']
+873 593 In their sport, the position played by x -1 In their sport, the position played by Tony Gwynn outfielder Tony Gwynn "[',' ' the' ' San' ' Diego' ' Padres' ""'"" ' shortstop' ',' ' is' ' the'
+ ' most' ' important' ' position' ' on' ' the' ' field' '.' ' The'
+ ' Padres' ""'""]" , the San Diego Padres ' shortstop , is the most important position on the field . The Padres ' False season, he worked with Tony Gwynn on skills at 8 [' season', ',', ' he', ' worked', ' with', ' Tony', ' G', 'wyn', 'n']
+874 593 In their sport, the position played by x -1 In their sport, the position played by Tony Gwynn outfielder Tony Gwynn "[',' ' the' ' San' ' Diego' ' Padres' ""'"" ' shortstop' ',' ' is' ' the'
+ ' most' ' important' ' position' ' on' ' the' ' field' '.' ' The'
+ ' Padres' ""'""]" , the San Diego Padres ' shortstop , is the most important position on the field . The Padres ' False reactions from them. Tony Gwynn suggested that Wiggins 7 [' reactions', ' from', ' them', '.', ' Tony', ' G', 'wyn', 'n']
+875 593 In their sport, the position played by x -1 In their sport, the position played by Tony Gwynn outfielder Tony Gwynn "[',' ' the' ' San' ' Diego' ' Padres' ""'"" ' shortstop' ',' ' is' ' the'
+ ' most' ' important' ' position' ' on' ' the' ' field' '.' ' The'
+ ' Padres' ""'""]" , the San Diego Padres ' shortstop , is the most important position on the field . The Padres ' False the source), Tony Gwynn (8), Honus Wagner 6 [' the', ' source', '),', ' Tony', ' G', 'wyn', 'n']
+876 593 In their sport, the position played by x -1 In their sport, the position played by Tony Gwynn outfielder Tony Gwynn "[',' ' the' ' San' ' Diego' ' Padres' ""'"" ' shortstop' ',' ' is' ' the'
+ ' most' ' important' ' position' ' on' ' the' ' field' '.' ' The'
+ ' Padres' ""'""]" , the San Diego Padres ' shortstop , is the most important position on the field . The Padres ' False of those guys like Tony Gwynn — they never feel like 7 [' of', ' those', ' guys', ' like', ' Tony', ' G', 'wyn', 'n']
+877 593 In their sport, the position played by x -1 In their sport, the position played by Tony Gwynn outfielder Tony Gwynn "[',' ' the' ' San' ' Diego' ' Padres' ""'"" ' shortstop' ',' ' is' ' the'
+ ' most' ' important' ' position' ' on' ' the' ' field' '.' ' The'
+ ' Padres' ""'""]" , the San Diego Padres ' shortstop , is the most important position on the field . The Padres ' False " Brett (.469, 1980) and Tony Gwynn (.459, 1997).
+" 10 [' Brett', ' (.', '469', ',', ' 1980', ')', ' and', ' Tony', ' G', 'wyn', 'n']
+878 596 In their sport, the position played by x -1 In their sport, the position played by Garrett Gilbert quarterback Garrett Gilbert "[',' ' who' ' was' ' a' ' first' '-' 'round' ' pick' ' in' ' the' ' NFL'
+ ' draft' ',' ' is' ' a' ' good' ' example' '.' ' He' ' was']" , who was a first - round pick in the NFL draft , is a good example . He was False out, freshman Garrett Gilbert replaced him 4 [' out', ',', ' freshman', ' Garrett', ' Gilbert']
+879 596 In their sport, the position played by x -1 In their sport, the position played by Garrett Gilbert quarterback Garrett Gilbert "[',' ' who' ' was' ' a' ' first' '-' 'round' ' pick' ' in' ' the' ' NFL'
+ ' draft' ',' ' is' ' a' ' good' ' example' '.' ' He' ' was']" , who was a first - round pick in the NFL draft , is a good example . He was False " Buckeyes), surpassed by Garrett Gilbert in 2010
+" 6 [' Buc', 'keyes', '),', ' surpassed', ' by', ' Garrett', ' Gilbert']
+880 596 In their sport, the position played by x -1 In their sport, the position played by Garrett Gilbert quarterback Garrett Gilbert "[',' ' who' ' was' ' a' ' first' '-' 'round' ' pick' ' in' ' the' ' NFL'
+ ' draft' ',' ' is' ' a' ' good' ' example' '.' ' He' ' was']" , who was a first - round pick in the NFL draft , is a good example . He was False McCoy out, freshman Garrett Gilbert replaced him at 5 [' McCoy', ' out', ',', ' freshman', ' Garrett', ' Gilbert']
+881 596 In their sport, the position played by x -1 In their sport, the position played by Garrett Gilbert quarterback Garrett Gilbert "[',' ' who' ' was' ' a' ' first' '-' 'round' ' pick' ' in' ' the' ' NFL'
+ ' draft' ',' ' is' ' a' ' good' ' example' '.' ' He' ' was']" , who was a first - round pick in the NFL draft , is a good example . He was False " surpassed by Garrett Gilbert in 2010
+" 3 [' surpassed', ' by', ' Garrett', ' Gilbert']
+882 596 In their sport, the position played by x -1 In their sport, the position played by Garrett Gilbert quarterback Garrett Gilbert "[',' ' who' ' was' ' a' ' first' '-' 'round' ' pick' ' in' ' the' ' NFL'
+ ' draft' ',' ' is' ' a' ' good' ' example' '.' ' He' ' was']" , who was a first - round pick in the NFL draft , is a good example . He was False by quarterback Garrett Gilbert to set up a 41-yard 3 [' by', ' quarterback', ' Garrett', ' Gilbert']
+883 598 In their sport, the position played by x -1 In their sport, the position played by Jonathan Greening midfielder Jonathan Greening "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' side' ' since'
+ ' the' ' start' ' of' ' the' ' season' ',' ' is' ' a' ' key' ' one']" , who has been a regular in the side since the start of the season , is a key one False On 24 March, Jonathan Greening joined from 6 [' On', ' 24', ' March', ',', ' Jonathan', ' Green', 'ing']
+884 598 In their sport, the position played by x -1 In their sport, the position played by Jonathan Greening midfielder Jonathan Greening "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' side' ' since'
+ ' the' ' start' ' of' ' the' ' season' ',' ' is' ' a' ' key' ' one']" , who has been a regular in the side since the start of the season , is a key one False team, moving Jonathan Greening from a wide position 5 [' team', ',', ' moving', ' Jonathan', ' Green', 'ing']
+885 598 In their sport, the position played by x -1 In their sport, the position played by Jonathan Greening midfielder Jonathan Greening "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' side' ' since'
+ ' the' ' start' ' of' ' the' ' season' ',' ' is' ' a' ' key' ' one']" , who has been a regular in the side since the start of the season , is a key one False departure of captain Jonathan Greening he was given the 5 [' departure', ' of', ' captain', ' Jonathan', ' Green', 'ing']
+886 598 In their sport, the position played by x -1 In their sport, the position played by Jonathan Greening midfielder Jonathan Greening "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' side' ' since'
+ ' the' ' start' ' of' ' the' ' season' ',' ' is' ' a' ' key' ' one']" , who has been a regular in the side since the start of the season , is a key one False On 24 March, Jonathan Greening joined from 6 [' On', ' 24', ' March', ',', ' Jonathan', ' Green', 'ing']
+887 598 In their sport, the position played by x -1 In their sport, the position played by Jonathan Greening midfielder Jonathan Greening "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' side' ' since'
+ ' the' ' start' ' of' ' the' ' season' ',' ' is' ' a' ' key' ' one']" , who has been a regular in the side since the start of the season , is a key one False On 24 March, Jonathan Greening joined from York City 6 [' On', ' 24', ' March', ',', ' Jonathan', ' Green', 'ing']
+888 600 In their sport, the position played by x -1 In their sport, the position played by Dontrelle Willis pitcher Dontrelle Willis "[',' ' the' ' quarterback' ',' ' is' ' the' ' equivalent' ' of' ' the'
+ ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the equivalent of the quarterback in football . The quarterback is the leader of the False " percentage was 6.2 % by Dontrelle Willis in 2005.
+" 11 [' percentage', ' was', ' 6', '.', '2', ' %', ' by', ' D', 'ont', 'rel', 'le', ' Willis']
+889 600 In their sport, the position played by x -1 In their sport, the position played by Dontrelle Willis pitcher Dontrelle Willis "[',' ' the' ' quarterback' ',' ' is' ' the' ' equivalent' ' of' ' the'
+ ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the equivalent of the quarterback in football . The quarterback is the leader of the False " percentage was 6.2 % by Dontrelle Willis in 2005.
+" 11 [' percentage', ' was', ' 6', '.', '2', ' %', ' by', ' D', 'ont', 'rel', 'le', ' Willis']
+890 600 In their sport, the position played by x -1 In their sport, the position played by Dontrelle Willis pitcher Dontrelle Willis "[',' ' the' ' quarterback' ',' ' is' ' the' ' equivalent' ' of' ' the'
+ ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the equivalent of the quarterback in football . The quarterback is the leader of the False " percentage was 6.2 % by Dontrelle Willis in 2005.
+" 11 [' percentage', ' was', ' 6', '.', '2', ' %', ' by', ' D', 'ont', 'rel', 'le', ' Willis']
+891 600 In their sport, the position played by x -1 In their sport, the position played by Dontrelle Willis pitcher Dontrelle Willis "[',' ' the' ' quarterback' ',' ' is' ' the' ' equivalent' ' of' ' the'
+ ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the equivalent of the quarterback in football . The quarterback is the leader of the False " percentage was 6.2 % by Dontrelle Willis in 2005.
+" 11 [' percentage', ' was', ' 6', '.', '2', ' %', ' by', ' D', 'ont', 'rel', 'le', ' Willis']
+892 602 In their sport, the position played by x -1 In their sport, the position played by Rick Ankiel pitcher Rick Ankiel "[' is' ' that' ' of' ' a' ' catcher' '.' ' He' ' is' ' the' ' catcher'
+ ' of' ' the' ' team' '.' ' He' ' is' ' the' ' catcher' ' of' ' the']" is that of a catcher . He is the catcher of the team . He is the catcher of the False go-ahead run against Rick Ankiel of the St. Louis 8 [' go', '-', 'ahead', ' run', ' against', ' Rick', ' An', 'ki', 'el']
+893 602 In their sport, the position played by x -1 In their sport, the position played by Rick Ankiel pitcher Rick Ankiel "[' is' ' that' ' of' ' a' ' catcher' '.' ' He' ' is' ' the' ' catcher'
+ ' of' ' the' ' team' '.' ' He' ' is' ' the' ' catcher' ' of' ' the']" is that of a catcher . He is the catcher of the team . He is the catcher of the False Edmonds, and Rick Ankiel each drove 7 [' Ed', 'monds', ',', ' and', ' Rick', ' An', 'ki', 'el']
+894 602 In their sport, the position played by x -1 In their sport, the position played by Rick Ankiel pitcher Rick Ankiel "[' is' ' that' ' of' ' a' ' catcher' '.' ' He' ' is' ' the' ' catcher'
+ ' of' ' the' ' team' '.' ' He' ' is' ' the' ' catcher' ' of' ' the']" is that of a catcher . He is the catcher of the team . He is the catcher of the False run against Rick Ankiel of the St. Louis Cardinals 5 [' run', ' against', ' Rick', ' An', 'ki', 'el']
+895 602 In their sport, the position played by x -1 In their sport, the position played by Rick Ankiel pitcher Rick Ankiel "[' is' ' that' ' of' ' a' ' catcher' '.' ' He' ' is' ' the' ' catcher'
+ ' of' ' the' ' team' '.' ' He' ' is' ' the' ' catcher' ' of' ' the']" is that of a catcher . He is the catcher of the team . He is the catcher of the False a home run to Rick Ankiel in the 11th inning 7 [' a', ' home', ' run', ' to', ' Rick', ' An', 'ki', 'el']
+896 602 In their sport, the position played by x -1 In their sport, the position played by Rick Ankiel pitcher Rick Ankiel "[' is' ' that' ' of' ' a' ' catcher' '.' ' He' ' is' ' the' ' catcher'
+ ' of' ' the' ' team' '.' ' He' ' is' ' the' ' catcher' ' of' ' the']" is that of a catcher . He is the catcher of the team . He is the catcher of the False run against Rick Ankiel of the St. Louis 5 [' run', ' against', ' Rick', ' An', 'ki', 'el']
+897 617 In their sport, the position played by x -1 In their sport, the position played by Gary Cuozzo quarterback Gary Cuozzo "[',' ' the' ' former' ' New' ' York' ' City' ' police' ' detective' ','
+ ' is' ' a' ' familiar' ' one' '.' ' He' ' is' ' the' ' man' ' who' ',']" , the former New York City police detective , is a familiar one . He is the man who , False After starting Gary Cuozzo at quarterback in 5 [' After', ' starting', ' Gary', ' Cu', 'o', 'zzo']
+898 621 In their sport, the position played by x -1 In their sport, the position played by Rufus Alexander linebacker Rufus Alexander "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' R' 'uf' 'us' ' Alexander' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by R uf us Alexander , the
+
+" False Linebacker Rufus Alexander pulled down his first 5 [' Line', 'backer', ' R', 'uf', 'us', ' Alexander']
+899 621 In their sport, the position played by x -1 In their sport, the position played by Rufus Alexander linebacker Rufus Alexander "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' R' 'uf' 'us' ' Alexander' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by R uf us Alexander , the
+
+" False senior linebacker Rufus Alexander was named the AT 5 [' senior', ' linebacker', ' R', 'uf', 'us', ' Alexander']
+900 621 In their sport, the position played by x -1 In their sport, the position played by Rufus Alexander linebacker Rufus Alexander "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' R' 'uf' 'us' ' Alexander' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by R uf us Alexander , the
+
+" False Linebacker Rufus Alexander pulled down his 5 [' Line', 'backer', ' R', 'uf', 'us', ' Alexander']
+901 621 In their sport, the position played by x -1 In their sport, the position played by Rufus Alexander linebacker Rufus Alexander "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' R' 'uf' 'us' ' Alexander' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by R uf us Alexander , the
+
+" False 2007 NFL Draft, Rufus Alexander was chosen in 7 [' 2007', ' NFL', ' Draft', ',', ' R', 'uf', 'us', ' Alexander']
+902 621 In their sport, the position played by x -1 In their sport, the position played by Rufus Alexander linebacker Rufus Alexander "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' R' 'uf' 'us' ' Alexander' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by R uf us Alexander , the
+
+" False 2007 NFL Draft, Rufus Alexander was chosen in 7 [' 2007', ' NFL', ' Draft', ',', ' R', 'uf', 'us', ' Alexander']
+903 623 In their sport, the position played by x -1 In their sport, the position played by Juan Cuadrado midfielder Juan Cuadrado "[' is' ' that' ' of' ' a' ' �' '�' 'false' ' 9' '�' '�' ',' ' a'
+ ' position' ' that' ' is' ' not' ' as' ' well' ' known' ' as']" is that of a � � false 9 � � , a position that is not as well known as False expense of Juan Cuadrado and Oscar. Chelsea 5 [' expense', ' of', ' Juan', ' Cu', 'adr', 'ado']
+904 623 In their sport, the position played by x -1 In their sport, the position played by Juan Cuadrado midfielder Juan Cuadrado "[' is' ' that' ' of' ' a' ' �' '�' 'false' ' 9' '�' '�' ',' ' a'
+ ' position' ' that' ' is' ' not' ' as' ' well' ' known' ' as']" is that of a � � false 9 � � , a position that is not as well known as False at the expense of Juan Cuadrado and Oscar. Chelsea 7 [' at', ' the', ' expense', ' of', ' Juan', ' Cu', 'adr', 'ado']
+905 623 In their sport, the position played by x -1 In their sport, the position played by Juan Cuadrado midfielder Juan Cuadrado "[' is' ' that' ' of' ' a' ' �' '�' 'false' ' 9' '�' '�' ',' ' a'
+ ' position' ' that' ' is' ' not' ' as' ' well' ' known' ' as']" is that of a � � false 9 � � , a position that is not as well known as False expense of Juan Cuadrado and Oscar. Chelsea 5 [' expense', ' of', ' Juan', ' Cu', 'adr', 'ado']
+906 623 In their sport, the position played by x -1 In their sport, the position played by Juan Cuadrado midfielder Juan Cuadrado "[' is' ' that' ' of' ' a' ' �' '�' 'false' ' 9' '�' '�' ',' ' a'
+ ' position' ' that' ' is' ' not' ' as' ' well' ' known' ' as']" is that of a � � false 9 � � , a position that is not as well known as False visitors, and Juan Cuadrado for the hosts. In 6 [' visitors', ',', ' and', ' Juan', ' Cu', 'adr', 'ado']
+907 623 In their sport, the position played by x -1 In their sport, the position played by Juan Cuadrado midfielder Juan Cuadrado "[' is' ' that' ' of' ' a' ' �' '�' 'false' ' 9' '�' '�' ',' ' a'
+ ' position' ' that' ' is' ' not' ' as' ' well' ' known' ' as']" is that of a � � false 9 � � , a position that is not as well known as False at the expense of Juan Cuadrado and Oscar. Chelsea 7 [' at', ' the', ' expense', ' of', ' Juan', ' Cu', 'adr', 'ado']
+908 627 In their sport, the position played by x -1 In their sport, the position played by Marco Tardelli midfielder Marco Tardelli "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' only' ' player' ' in' ' the' ' squad' ' who' ' has']" , the Italian midfielder , is a key one . He is the only player in the squad who has True Bettega tackled him. Marco Tardelli was later sent 9 [' Bet', 'te', 'ga', ' tackled', ' him', '.', ' Marco', ' T', 'ard', 'elli']
+909 627 In their sport, the position played by x -1 In their sport, the position played by Marco Tardelli midfielder Marco Tardelli "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' only' ' player' ' in' ' the' ' squad' ' who' ' has']" , the Italian midfielder , is a key one . He is the only player in the squad who has True tackled him. Marco Tardelli was later 6 [' tackled', ' him', '.', ' Marco', ' T', 'ard', 'elli']
+910 627 In their sport, the position played by x -1 In their sport, the position played by Marco Tardelli midfielder Marco Tardelli "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' only' ' player' ' in' ' the' ' squad' ' who' ' has']" , the Italian midfielder , is a key one . He is the only player in the squad who has True 3 ['Marco', ' T', 'ard', 'elli']
+911 628 In their sport, the position played by x -1 In their sport, the position played by Brian Boucher goaltender Brian Boucher "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' Stanley' ' Cup' ' in' ' 2006' ',' ' is' ' a' ' bit' ' of']" , who was a member of the team that won the Stanley Cup in 2006 , is a bit of False As a result, rookie Brian Boucher began playing more. 7 [' As', ' a', ' result', ',', ' rookie', ' Brian', ' Bou', 'cher']
+912 628 In their sport, the position played by x -1 In their sport, the position played by Brian Boucher goaltender Brian Boucher "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' Stanley' ' Cup' ' in' ' 2006' ',' ' is' ' a' ' bit' ' of']" , who was a member of the team that won the Stanley Cup in 2006 , is a bit of False Philadelphia's Brian Boucher in 2000 to 4 "[' Philadelphia', ""'s"", ' Brian', ' Bou', 'cher']"
+913 628 In their sport, the position played by x -1 In their sport, the position played by Brian Boucher goaltender Brian Boucher "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' Stanley' ' Cup' ' in' ' 2006' ',' ' is' ' a' ' bit' ' of']" , who was a member of the team that won the Stanley Cup in 2006 , is a bit of False result, rookie Brian Boucher began playing more. 5 [' result', ',', ' rookie', ' Brian', ' Bou', 'cher']
+914 628 In their sport, the position played by x -1 In their sport, the position played by Brian Boucher goaltender Brian Boucher "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' Stanley' ' Cup' ' in' ' 2006' ',' ' is' ' a' ' bit' ' of']" , who was a member of the team that won the Stanley Cup in 2006 , is a bit of False since Philadelphia's Brian Boucher in 2000 to win a 5 "[' since', ' Philadelphia', ""'s"", ' Brian', ' Bou', 'cher']"
+915 628 In their sport, the position played by x -1 In their sport, the position played by Brian Boucher goaltender Brian Boucher "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' Stanley' ' Cup' ' in' ' 2006' ',' ' is' ' a' ' bit' ' of']" , who was a member of the team that won the Stanley Cup in 2006 , is a bit of False result, rookie Brian Boucher began playing more. 5 [' result', ',', ' rookie', ' Brian', ' Bou', 'cher']
+916 629 In their sport, the position played by x -1 In their sport, the position played by David Whitehurst quarterback David Whitehurst "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' bit' ' of'
+ ' a' ' mystery' '.' ' He' ' is' ' a' ' former' ' NFL' ' quarterback'
+ ' who']" , the former NFL quarterback , is a bit of a mystery . He is a former NFL quarterback who True Packers quarterback David Whitehurst threw for a career-best 4 [' Packers', ' quarterback', ' David', ' White', 'hurst']
+917 630 In their sport, the position played by x -1 In their sport, the position played by Sam Crawford outfielder Sam Crawford "[',' ' the' ' former' ' head' ' of' ' the' ' CIA' '�' '�' 's' ' National'
+ ' Cl' 'andestine' ' Service' ',' ' is' ' a' ' perfect' ' example' ' of']" , the former head of the CIA � � s National Cl andestine Service , is a perfect example of False featuring Ty Cobb, Sam Crawford and Jim Delahanty, 5 [' featuring', ' Ty', ' Cobb', ',', ' Sam', ' Crawford']
+918 633 In their sport, the position played by x -1 In their sport, the position played by Joe Montana quarterback Joe Montana "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True games based on Joe Montana (Joe Montana Wide 4 [' games', ' based', ' on', ' Joe', ' Montana']
+919 633 In their sport, the position played by x -1 In their sport, the position played by Joe Montana quarterback Joe Montana "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True games based on Joe Montana (Joe Montana Wide 4 [' games', ' based', ' on', ' Joe', ' Montana']
+920 633 In their sport, the position played by x -1 In their sport, the position played by Joe Montana quarterback Joe Montana "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True football record until Joe Montana surpassed it 4 [' football', ' record', ' until', ' Joe', ' Montana']
+921 633 In their sport, the position played by x -1 In their sport, the position played by Joe Montana quarterback Joe Montana "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True quarterback Joe Montana and running back 2 [' quarterback', ' Joe', ' Montana']
+922 633 In their sport, the position played by x -1 In their sport, the position played by Joe Montana quarterback Joe Montana "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True third quarter 49ers Joe Montana threw a 61-yard 5 [' third', ' quarter', ' 49', 'ers', ' Joe', ' Montana']
+923 634 In their sport, the position played by x -1 In their sport, the position played by Joel Zumaya pitcher Joel Zumaya "[',' ' the' ' Tigers' ' have' ' a' ' lot' ' of' ' young' ' players' ' who'
+ ' are' ' going' ' to' ' be' ' in' ' the' ' mix' ' for' ' playing' ' time']" , the Tigers have a lot of young players who are going to be in the mix for playing time False Detroit Tigers' pitcher Joel Zumaya injured himself during 7 "[' Detroit', ' Tigers', ""'"", ' pitcher', ' Joel', ' Z', 'um', 'aya']"
+924 634 In their sport, the position played by x -1 In their sport, the position played by Joel Zumaya pitcher Joel Zumaya "[',' ' the' ' Tigers' ' have' ' a' ' lot' ' of' ' young' ' players' ' who'
+ ' are' ' going' ' to' ' be' ' in' ' the' ' mix' ' for' ' playing' ' time']" , the Tigers have a lot of young players who are going to be in the mix for playing time False Detroit Tigers' pitcher Joel Zumaya injured himself 7 "[' Detroit', ' Tigers', ""'"", ' pitcher', ' Joel', ' Z', 'um', 'aya']"
+925 634 In their sport, the position played by x -1 In their sport, the position played by Joel Zumaya pitcher Joel Zumaya "[',' ' the' ' Tigers' ' have' ' a' ' lot' ' of' ' young' ' players' ' who'
+ ' are' ' going' ' to' ' be' ' in' ' the' ' mix' ' for' ' playing' ' time']" , the Tigers have a lot of young players who are going to be in the mix for playing time False Detroit Tigers' pitcher Joel Zumaya injured himself 7 "[' Detroit', ' Tigers', ""'"", ' pitcher', ' Joel', ' Z', 'um', 'aya']"
+926 634 In their sport, the position played by x -1 In their sport, the position played by Joel Zumaya pitcher Joel Zumaya "[',' ' the' ' Tigers' ' have' ' a' ' lot' ' of' ' young' ' players' ' who'
+ ' are' ' going' ' to' ' be' ' in' ' the' ' mix' ' for' ' playing' ' time']" , the Tigers have a lot of young players who are going to be in the mix for playing time False Tigers' pitcher Joel Zumaya injured himself 6 "[' Tigers', ""'"", ' pitcher', ' Joel', ' Z', 'um', 'aya']"
+927 634 In their sport, the position played by x -1 In their sport, the position played by Joel Zumaya pitcher Joel Zumaya "[',' ' the' ' Tigers' ' have' ' a' ' lot' ' of' ' young' ' players' ' who'
+ ' are' ' going' ' to' ' be' ' in' ' the' ' mix' ' for' ' playing' ' time']" , the Tigers have a lot of young players who are going to be in the mix for playing time False the return of Joel Zumaya from the disabled 6 [' the', ' return', ' of', ' Joel', ' Z', 'um', 'aya']
+928 637 In their sport, the position played by x -1 In their sport, the position played by Shane Ray linebacker Shane Ray "[',' ' who' ' was' ' a' ' defensive' ' end' ' at' ' the' ' University'
+ ' of' ' Alabama' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' well']" , who was a defensive end at the University of Alabama , is a position that is not as well False Fowler (Florida), Shane Ray (Missouri) and 5 [' Fowler', ' (', 'Florida', '),', ' Shane', ' Ray']
+929 637 In their sport, the position played by x -1 In their sport, the position played by Shane Ray linebacker Shane Ray "[',' ' who' ' was' ' a' ' defensive' ' end' ' at' ' the' ' University'
+ ' of' ' Alabama' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' well']" , who was a defensive end at the University of Alabama , is a position that is not as well False (Florida), Shane Ray (Missouri) 4 [' (', 'Florida', '),', ' Shane', ' Ray']
+930 637 In their sport, the position played by x -1 In their sport, the position played by Shane Ray linebacker Shane Ray "[',' ' who' ' was' ' a' ' defensive' ' end' ' at' ' the' ' University'
+ ' of' ' Alabama' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' well']" , who was a defensive end at the University of Alabama , is a position that is not as well False Fowler (Florida), Shane Ray (Missouri) and Randy 5 [' Fowler', ' (', 'Florida', '),', ' Shane', ' Ray']
+931 637 In their sport, the position played by x -1 In their sport, the position played by Shane Ray linebacker Shane Ray "[',' ' who' ' was' ' a' ' defensive' ' end' ' at' ' the' ' University'
+ ' of' ' Alabama' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' well']" , who was a defensive end at the University of Alabama , is a position that is not as well False Fowler (Florida), Shane Ray (Missouri) 5 [' Fowler', ' (', 'Florida', '),', ' Shane', ' Ray']
+932 638 In their sport, the position played by x -1 In their sport, the position played by LaMarr Woodley linebacker LaMarr Woodley "[' is' ' the' ' same' ' as' ' that' ' of' ' a' ' defensive' ' end' ' in'
+ ' football' '.' ' He' ' is' ' a' ' pass' ' rusher' ',' ' and' ' he']" is the same as that of a defensive end in football . He is a pass rusher , and he False quarter. After a LaMarr Woodley interception the 8 [' quarter', '.', ' After', ' a', ' La', 'M', 'arr', ' Wood', 'ley']
+933 638 In their sport, the position played by x -1 In their sport, the position played by LaMarr Woodley linebacker LaMarr Woodley "[' is' ' the' ' same' ' as' ' that' ' of' ' a' ' defensive' ' end' ' in'
+ ' football' '.' ' He' ' is' ' a' ' pass' ' rusher' ',' ' and' ' he']" is the same as that of a defensive end in football . He is a pass rusher , and he False " teammate and co-captain LaMarr Woodley was number one).
+" 10 [' teammate', ' and', ' co', '-', 'capt', 'ain', ' La', 'M', 'arr', ' Wood', 'ley']
+934 638 In their sport, the position played by x -1 In their sport, the position played by LaMarr Woodley linebacker LaMarr Woodley "[' is' ' the' ' same' ' as' ' that' ' of' ' a' ' defensive' ' end' ' in'
+ ' football' '.' ' He' ' is' ' a' ' pass' ' rusher' ',' ' and' ' he']" is the same as that of a defensive end in football . He is a pass rusher , and he False and co-captain LaMarr Woodley was number 9 [' and', ' co', '-', 'capt', 'ain', ' La', 'M', 'arr', ' Wood', 'ley']
+935 638 In their sport, the position played by x -1 In their sport, the position played by LaMarr Woodley linebacker LaMarr Woodley "[' is' ' the' ' same' ' as' ' that' ' of' ' a' ' defensive' ' end' ' in'
+ ' football' '.' ' He' ' is' ' a' ' pass' ' rusher' ',' ' and' ' he']" is the same as that of a defensive end in football . He is a pass rusher , and he False fumble was picked up by LaMarr Woodley and returned 7 yards 9 [' fumble', ' was', ' picked', ' up', ' by', ' La', 'M', 'arr', ' Wood', 'ley']
+936 638 In their sport, the position played by x -1 In their sport, the position played by LaMarr Woodley linebacker LaMarr Woodley "[' is' ' the' ' same' ' as' ' that' ' of' ' a' ' defensive' ' end' ' in'
+ ' football' '.' ' He' ' is' ' a' ' pass' ' rusher' ',' ' and' ' he']" is the same as that of a defensive end in football . He is a pass rusher , and he False quarter. After a LaMarr Woodley interception 8 [' quarter', '.', ' After', ' a', ' La', 'M', 'arr', ' Wood', 'ley']
+937 639 In their sport, the position played by x -1 In their sport, the position played by Henrik Lundqvist goaltender Henrik Lundqvist "[' is' ' the' ' equivalent' ' of' ' a' ' quarterback' ' in' ' football'
+ '.' ' He' ' is' ' the' ' quarterback' ' of' ' the' ' team' ',' ' and'
+ ' he' ' is']" is the equivalent of a quarterback in football . He is the quarterback of the team , and he is False April 7, 2010, on Henrik Lundqvist of the New 10 [' April', ' 7', ',', ' 2010', ',', ' on', ' Hen', 'rik', ' Lund', 'qv', 'ist']
+938 645 In their sport, the position played by x -1 In their sport, the position played by Jocelyn Thibault goaltender Jocelyn Thibault "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' reach' ' the'
+ ' end' ' of' ' the' ' course' ',' ' wins' '.' '\n' '\n' 'The']" ", the player who is the first to reach the end of the course , wins .
+
+ The" False goaltender Jocelyn Thibault leave as free 5 [' goaltender', ' Jo', 'ce', 'lyn', ' Thib', 'ault']
+939 645 In their sport, the position played by x -1 In their sport, the position played by Jocelyn Thibault goaltender Jocelyn Thibault "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' reach' ' the'
+ ' end' ' of' ' the' ' course' ',' ' wins' '.' '\n' '\n' 'The']" ", the player who is the first to reach the end of the course , wins .
+
+ The" False 4 ['Jo', 'ce', 'lyn', ' Thib', 'ault']
+940 645 In their sport, the position played by x -1 In their sport, the position played by Jocelyn Thibault goaltender Jocelyn Thibault "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' reach' ' the'
+ ' end' ' of' ' the' ' course' ',' ' wins' '.' '\n' '\n' 'The']" ", the player who is the first to reach the end of the course , wins .
+
+ The" False Ouellet and goaltender Jocelyn Thibault leave as free 9 [' O', 'uel', 'let', ' and', ' goaltender', ' Jo', 'ce', 'lyn', ' Thib', 'ault']
+941 645 In their sport, the position played by x -1 In their sport, the position played by Jocelyn Thibault goaltender Jocelyn Thibault "[',' ' the' ' player' ' who' ' is' ' the' ' first' ' to' ' reach' ' the'
+ ' end' ' of' ' the' ' course' ',' ' wins' '.' '\n' '\n' 'The']" ", the player who is the first to reach the end of the course , wins .
+
+ The" False 4 ['Jo', 'ce', 'lyn', ' Thib', 'ault']
+942 646 In their sport, the position played by x -1 In their sport, the position played by Les Horvath quarterback Les Horvath "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Les' ' Hor' 'v' 'ath' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Les Hor v ath , the
+
+" False roomed with Les Horvath and Don McCafferty 6 [' room', 'ed', ' with', ' Les', ' Hor', 'v', 'ath']
+943 646 In their sport, the position played by x -1 In their sport, the position played by Les Horvath quarterback Les Horvath "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Les' ' Hor' 'v' 'ath' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Les Hor v ath , the
+
+" False roomed with Les Horvath and Don McCafferty 6 [' room', 'ed', ' with', ' Les', ' Hor', 'v', 'ath']
+944 646 In their sport, the position played by x -1 In their sport, the position played by Les Horvath quarterback Les Horvath "[',' ' the' '\n' '\n' 'In' ' their' ' sport' ',' ' the' ' position'
+ ' played' ' by' ' Les' ' Hor' 'v' 'ath' ',' ' the' '\n' '\n']" ", the
+
+ In their sport , the position played by Les Hor v ath , the
+
+" False Lavelli roomed with Les Horvath and Don McCafferty 9 [' L', 'ave', 'lli', ' room', 'ed', ' with', ' Les', ' Hor', 'v', 'ath']
+945 653 In their sport, the position played by x -1 In their sport, the position played by Craig Erickson quarterback Craig Erickson "[',' ' who' ' is' ' a' ' former' ' All' '-' 'American' ' at' ' the'
+ ' University' ' of' ' Minnesota' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the']" , who is a former All - American at the University of Minnesota , is a perfect example of the False DeBerg and Craig Erickson competed for the 6 [' De', 'B', 'erg', ' and', ' Craig', ' Eric', 'kson']
+946 653 In their sport, the position played by x -1 In their sport, the position played by Craig Erickson quarterback Craig Erickson "[',' ' who' ' is' ' a' ' former' ' All' '-' 'American' ' at' ' the'
+ ' University' ' of' ' Minnesota' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the']" , who is a former All - American at the University of Minnesota , is a perfect example of the False while DeBerg and Craig Erickson competed for the starting 7 [' while', ' De', 'B', 'erg', ' and', ' Craig', ' Eric', 'kson']
+947 655 In their sport, the position played by x -1 In their sport, the position played by Alfie Moore goaltender Alfie Moore "[',' ' the' ' man' ' who' ' was' ' the' ' first' ' to' ' be' ' sent'
+ ' off' ' in' ' the' ' first' ' game' ' of' ' the' ' tournament' ','
+ ' was']" , the man who was the first to be sent off in the first game of the tournament , was False minor-league goaltender Alfie Moore after Karakas 6 [' minor', '-', 'league', ' goaltender', ' Alf', 'ie', ' Moore']
+948 657 In their sport, the position played by x -1 In their sport, the position played by Marc Overmars midfielder Marc Overmars "[',' ' the' ' Dutch' 'man' ',' ' is' ' that' ' of' ' a' ' midfielder' ','
+ ' but' ' he' ' is' ' a' ' midfielder' ' who' ' is' ' also' ' a']" , the Dutch man , is that of a midfielder , but he is a midfielder who is also a True Midfielders Marc Overmars and Emmanuel Petit 6 [' Mid', 'field', 'ers', ' Marc', ' Over', 'm', 'ars']
+949 657 In their sport, the position played by x -1 In their sport, the position played by Marc Overmars midfielder Marc Overmars "[',' ' the' ' Dutch' 'man' ',' ' is' ' that' ' of' ' a' ' midfielder' ','
+ ' but' ' he' ' is' ' a' ' midfielder' ' who' ' is' ' also' ' a']" , the Dutch man , is that of a midfielder , but he is a midfielder who is also a True minutes after Marc Overmars took advantage of 5 [' minutes', ' after', ' Marc', ' Over', 'm', 'ars']
+950 657 In their sport, the position played by x -1 In their sport, the position played by Marc Overmars midfielder Marc Overmars "[',' ' the' ' Dutch' 'man' ',' ' is' ' that' ' of' ' a' ' midfielder' ','
+ ' but' ' he' ' is' ' a' ' midfielder' ' who' ' is' ' also' ' a']" , the Dutch man , is that of a midfielder , but he is a midfielder who is also a True second-half goal by Marc Overmars gave Arsenal 8 [' second', '-', 'half', ' goal', ' by', ' Marc', ' Over', 'm', 'ars']
+951 657 In their sport, the position played by x -1 In their sport, the position played by Marc Overmars midfielder Marc Overmars "[',' ' the' ' Dutch' 'man' ',' ' is' ' that' ' of' ' a' ' midfielder' ','
+ ' but' ' he' ' is' ' a' ' midfielder' ' who' ' is' ' also' ' a']" , the Dutch man , is that of a midfielder , but he is a midfielder who is also a True including midfielders Marc Overmars and Emmanuel Petit 6 [' including', ' midfield', 'ers', ' Marc', ' Over', 'm', 'ars']
+952 657 In their sport, the position played by x -1 In their sport, the position played by Marc Overmars midfielder Marc Overmars "[',' ' the' ' Dutch' 'man' ',' ' is' ' that' ' of' ' a' ' midfielder' ','
+ ' but' ' he' ' is' ' a' ' midfielder' ' who' ' is' ' also' ' a']" , the Dutch man , is that of a midfielder , but he is a midfielder who is also a True took the lead when Marc Overmars scored 11 minutes 7 [' took', ' the', ' lead', ' when', ' Marc', ' Over', 'm', 'ars']
+953 658 In their sport, the position played by x -1 In their sport, the position played by Joel Stave quarterback Joel Stave "[',' ' a' ' 6' '-' 'foot' '-' '6' ',' ' 230' '-' 'pound' ' senior' ','
+ ' is' ' a' ' big' ' one' '.' ' He' ' is']" , a 6 - foot - 6 , 230 - pound senior , is a big one . He is False a 4-yard pass from Joel Stave to tight end 8 [' a', ' 4', '-', 'yard', ' pass', ' from', ' Joel', ' St', 'ave']
+954 658 In their sport, the position played by x -1 In their sport, the position played by Joel Stave quarterback Joel Stave "[',' ' a' ' 6' '-' 'foot' '-' '6' ',' ' 230' '-' 'pound' ' senior' ','
+ ' is' ' a' ' big' ' one' '.' ' He' ' is']" , a 6 - foot - 6 , 230 - pound senior , is a big one . He is False quarterback Joel Stave and favorite 3 [' quarterback', ' Joel', ' St', 'ave']
+955 659 In their sport, the position played by x -1 In their sport, the position played by Rivaldo midfielder Rivaldo "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the best thing" False Luis Enrique and Rivaldo, and the team won 4 [' Luis', ' Enrique', ' and', ' Riv', 'aldo']
+956 659 In their sport, the position played by x -1 In their sport, the position played by Rivaldo midfielder Rivaldo "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the best thing" False slip, not a Rivaldo shot, for a 5 [' slip', ',', ' not', ' a', ' Riv', 'aldo']
+957 659 In their sport, the position played by x -1 In their sport, the position played by Rivaldo midfielder Rivaldo "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the best thing" False runner up to Rivaldo for 1999's European 4 [' runner', ' up', ' to', ' Riv', 'aldo']
+958 659 In their sport, the position played by x -1 In their sport, the position played by Rivaldo midfielder Rivaldo "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the best thing" False Luis Enrique and Rivaldo, and the team 4 [' Luis', ' Enrique', ' and', ' Riv', 'aldo']
+959 659 In their sport, the position played by x -1 In their sport, the position played by Rivaldo midfielder Rivaldo "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' manager' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great manager .
+
+ I think the best thing" False was runner up to Rivaldo for 1999's European 5 [' was', ' runner', ' up', ' to', ' Riv', 'aldo']
+960 663 In their sport, the position played by x -1 In their sport, the position played by Mike Trout outfielder Mike Trout "[',' ' the' ' Angels' ""'"" ' center' ' fielder' ',' ' is' ' the' ' most'
+ ' important' ' position' ' in' ' baseball' '.' ' He' ' is' ' the' ' best'
+ ' player']" , the Angels ' center fielder , is the most important position in baseball . He is the best player False top prospect Mike Trout and fired hitting 3 [' top', ' prospect', ' Mike', ' Trout']
+961 663 In their sport, the position played by x -1 In their sport, the position played by Mike Trout outfielder Mike Trout "[',' ' the' ' Angels' ""'"" ' center' ' fielder' ',' ' is' ' the' ' most'
+ ' important' ' position' ' in' ' baseball' '.' ' He' ' is' ' the' ' best'
+ ' player']" , the Angels ' center fielder , is the most important position in baseball . He is the best player False top prospect Mike Trout and fired 3 [' top', ' prospect', ' Mike', ' Trout']
+962 663 In their sport, the position played by x -1 In their sport, the position played by Mike Trout outfielder Mike Trout "[',' ' the' ' Angels' ""'"" ' center' ' fielder' ',' ' is' ' the' ' most'
+ ' important' ' position' ' in' ' baseball' '.' ' He' ' is' ' the' ' best'
+ ' player']" , the Angels ' center fielder , is the most important position in baseball . He is the best player False off the bat of Mike Trout at 106 miles per hour 5 [' off', ' the', ' bat', ' of', ' Mike', ' Trout']
+963 663 In their sport, the position played by x -1 In their sport, the position played by Mike Trout outfielder Mike Trout "[',' ' the' ' Angels' ""'"" ' center' ' fielder' ',' ' is' ' the' ' most'
+ ' important' ' position' ' in' ' baseball' '.' ' He' ' is' ' the' ' best'
+ ' player']" , the Angels ' center fielder , is the most important position in baseball . He is the best player False motion off the bat of Mike Trout – one tumbling on 6 [' motion', ' off', ' the', ' bat', ' of', ' Mike', ' Trout']
+964 663 In their sport, the position played by x -1 In their sport, the position played by Mike Trout outfielder Mike Trout "[',' ' the' ' Angels' ""'"" ' center' ' fielder' ',' ' is' ' the' ' most'
+ ' important' ' position' ' in' ' baseball' '.' ' He' ' is' ' the' ' best'
+ ' player']" , the Angels ' center fielder , is the most important position in baseball . He is the best player False motion off the bat of Mike Trout – one tumbling on 6 [' motion', ' off', ' the', ' bat', ' of', ' Mike', ' Trout']
+965 665 In their sport, the position played by x -1 In their sport, the position played by Kevin Lalande goaltender Kevin Lalande "[',' ' the' ' French' '-' 'Canadian' ' actor' ' who' ' plays' ' the'
+ ' role' ' of' ' the' ' French' '-' 'Canadian' ' detective' ',' ' is' ' a'
+ ' bit']" , the French - Canadian actor who plays the role of the French - Canadian detective , is a bit False minor-league goaltender Kevin Lalande to Columbus for a 6 [' minor', '-', 'league', ' goaltender', ' Kevin', ' Lal', 'ande']
+966 665 In their sport, the position played by x -1 In their sport, the position played by Kevin Lalande goaltender Kevin Lalande "[',' ' the' ' French' '-' 'Canadian' ' actor' ' who' ' plays' ' the'
+ ' role' ' of' ' the' ' French' '-' 'Canadian' ' detective' ',' ' is' ' a'
+ ' bit']" , the French - Canadian actor who plays the role of the French - Canadian detective , is a bit False puck past Kevin Lalande for the shorthanded 4 [' puck', ' past', ' Kevin', ' Lal', 'ande']
+967 665 In their sport, the position played by x -1 In their sport, the position played by Kevin Lalande goaltender Kevin Lalande "[',' ' the' ' French' '-' 'Canadian' ' actor' ' who' ' plays' ' the'
+ ' role' ' of' ' the' ' French' '-' 'Canadian' ' detective' ',' ' is' ' a'
+ ' bit']" , the French - Canadian actor who plays the role of the French - Canadian detective , is a bit False minor-league goaltender Kevin Lalande to Columbus for a 6 [' minor', '-', 'league', ' goaltender', ' Kevin', ' Lal', 'ande']
+968 665 In their sport, the position played by x -1 In their sport, the position played by Kevin Lalande goaltender Kevin Lalande "[',' ' the' ' French' '-' 'Canadian' ' actor' ' who' ' plays' ' the'
+ ' role' ' of' ' the' ' French' '-' 'Canadian' ' detective' ',' ' is' ' a'
+ ' bit']" , the French - Canadian actor who plays the role of the French - Canadian detective , is a bit False Hugo Carpentier, Kevin Lalande and Dan Spang. 7 [' Hugo', ' Car', 'pent', 'ier', ',', ' Kevin', ' Lal', 'ande']
+969 665 In their sport, the position played by x -1 In their sport, the position played by Kevin Lalande goaltender Kevin Lalande "[',' ' the' ' French' '-' 'Canadian' ' actor' ' who' ' plays' ' the'
+ ' role' ' of' ' the' ' French' '-' 'Canadian' ' detective' ',' ' is' ' a'
+ ' bit']" , the French - Canadian actor who plays the role of the French - Canadian detective , is a bit False the puck past Kevin Lalande for the shorthanded 5 [' the', ' puck', ' past', ' Kevin', ' Lal', 'ande']
+970 670 In their sport, the position played by x -1 In their sport, the position played by Garrincha midfielder Garrincha "[',' ' the' ' Brazilian' ',' ' was' ' that' ' of' ' a' ' ""' 'liber' 'o'
+ '""' ' (' 'a' ' swe' 'eper' '),' ' and' ' he' ' was']" ", the Brazilian , was that of a "" liber o "" ( a swe eper ), and he was" False Brazilian player Garrincha thought the incident 4 [' Brazilian', ' player', ' Garr', 'inch', 'a']
+971 671 In their sport, the position played by x -1 In their sport, the position played by Curtis Sanford goaltender Curtis Sanford "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False a veteran Curtis Sanford to be their third 3 [' a', ' veteran', ' Curtis', ' Sanford']
+972 671 In their sport, the position played by x -1 In their sport, the position played by Curtis Sanford goaltender Curtis Sanford "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False did sign a veteran Curtis Sanford to be their third 5 [' did', ' sign', ' a', ' veteran', ' Curtis', ' Sanford']
+973 671 In their sport, the position played by x -1 In their sport, the position played by Curtis Sanford goaltender Curtis Sanford "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False rumors, goaltender Curtis Sanford returned from his 4 [' rumors', ',', ' goaltender', ' Curtis', ' Sanford']
+974 673 In their sport, the position played by x -1 In their sport, the position played by Michel Larocque goaltender Michel Larocque "[',' ' the' ' French' '-' 'Canadian' ',' ' is' ' that' ' of' ' a' ' ""' 'p'
+ 'oker' ' player' '""' ' who' ' is' ' not' ' a' ' ""']" ", the French - Canadian , is that of a "" p oker player "" who is not a """ False October 17 against Michel Larocque of the Montreal 6 [' October', ' 17', ' against', ' Michel', ' Lar', 'oc', 'que']
+975 673 In their sport, the position played by x -1 In their sport, the position played by Michel Larocque goaltender Michel Larocque "[',' ' the' ' French' '-' 'Canadian' ',' ' is' ' that' ' of' ' a' ' ""' 'p'
+ 'oker' ' player' '""' ' who' ' is' ' not' ' a' ' ""']" ", the French - Canadian , is that of a "" p oker player "" who is not a """ False October 17 against Michel Larocque of the Montreal Canadiens, 6 [' October', ' 17', ' against', ' Michel', ' Lar', 'oc', 'que']
+976 673 In their sport, the position played by x -1 In their sport, the position played by Michel Larocque goaltender Michel Larocque "[',' ' the' ' French' '-' 'Canadian' ',' ' is' ' that' ' of' ' a' ' ""' 'p'
+ 'oker' ' player' '""' ' who' ' is' ' not' ' a' ' ""']" ", the French - Canadian , is that of a "" p oker player "" who is not a """ False October 17 against Michel Larocque of the Montreal Canadiens, 6 [' October', ' 17', ' against', ' Michel', ' Lar', 'oc', 'que']
+977 673 In their sport, the position played by x -1 In their sport, the position played by Michel Larocque goaltender Michel Larocque "[',' ' the' ' French' '-' 'Canadian' ',' ' is' ' that' ' of' ' a' ' ""' 'p'
+ 'oker' ' player' '""' ' who' ' is' ' not' ' a' ' ""']" ", the French - Canadian , is that of a "" p oker player "" who is not a """ False October 17 against Michel Larocque of the Montreal Canadiens, 6 [' October', ' 17', ' against', ' Michel', ' Lar', 'oc', 'que']
+978 674 In their sport, the position played by x -1 In their sport, the position played by Kelly Stouffer quarterback Kelly Stouffer "[',' ' a' ' former' ' NFL' ' linebacker' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' was' ' a' ' solid' ' linebacker' ' in' ' the']" , a former NFL linebacker , is a bit of a mystery . He was a solid linebacker in the False Cardinals selection of Kelly Stouffer and the Buffalo Bills 6 [' Cardinals', ' selection', ' of', ' Kelly', ' St', 'ou', 'ffer']
+979 674 In their sport, the position played by x -1 In their sport, the position played by Kelly Stouffer quarterback Kelly Stouffer "[',' ' a' ' former' ' NFL' ' linebacker' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' was' ' a' ' solid' ' linebacker' ' in' ' the']" , a former NFL linebacker , is a bit of a mystery . He was a solid linebacker in the False Cardinals selection of Kelly Stouffer and the Buffalo Bills 6 [' Cardinals', ' selection', ' of', ' Kelly', ' St', 'ou', 'ffer']
+980 674 In their sport, the position played by x -1 In their sport, the position played by Kelly Stouffer quarterback Kelly Stouffer "[',' ' a' ' former' ' NFL' ' linebacker' ',' ' is' ' a' ' bit' ' of' ' a'
+ ' mystery' '.' ' He' ' was' ' a' ' solid' ' linebacker' ' in' ' the']" , a former NFL linebacker , is a bit of a mystery . He was a solid linebacker in the False selection of Kelly Stouffer and the Buffalo Bills 5 [' selection', ' of', ' Kelly', ' St', 'ou', 'ffer']
+981 676 In their sport, the position played by x -1 In their sport, the position played by Mason Foster linebacker Mason Foster "[',' ' who' ' is' ' a' ' former' ' NFL' ' linebacker' ',' ' is' ' a'
+ ' bit' ' of' ' a' ' mystery' '.' ' He' ' is' ' a' ' solid' ' player']" , who is a former NFL linebacker , is a bit of a mystery . He is a solid player True Buccaneers linebacker Mason Foster on March 25, 3 [' Buccaneers', ' linebacker', ' Mason', ' Foster']
+982 684 In their sport, the position played by x -1 In their sport, the position played by Mathieu Flamini midfielder Mathieu Flamini "[',' ' who' ' is' ' a' ' midfielder' ',' ' is' ' that' ' of' ' a' ' deep'
+ '-' 'lying' ' play' 'maker' '.' ' He' ' is' ' a' ' player']" , who is a midfielder , is that of a deep - lying play maker . He is a player True Emmanuel Adebayor, Mathieu Flamini and Fàbregas playing 9 [' Emmanuel', ' Ad', 'eb', 'ay', 'or', ',', ' Math', 'ieu', ' Flam', 'ini']
+983 684 In their sport, the position played by x -1 In their sport, the position played by Mathieu Flamini midfielder Mathieu Flamini "[',' ' who' ' is' ' a' ' midfielder' ',' ' is' ' that' ' of' ' a' ' deep'
+ '-' 'lying' ' play' 'maker' '.' ' He' ' is' ' a' ' player']" , who is a midfielder , is that of a deep - lying play maker . He is a player True the bench and Mathieu Flamini partnered Cesc 6 [' the', ' bench', ' and', ' Math', 'ieu', ' Flam', 'ini']
+984 684 In their sport, the position played by x -1 In their sport, the position played by Mathieu Flamini midfielder Mathieu Flamini "[',' ' who' ' is' ' a' ' midfielder' ',' ' is' ' that' ' of' ' a' ' deep'
+ '-' 'lying' ' play' 'maker' '.' ' He' ' is' ' a' ' player']" , who is a midfielder , is that of a deep - lying play maker . He is a player True Emmanuel Adebayor, Mathieu Flamini and Fàbregas playing 9 [' Emmanuel', ' Ad', 'eb', 'ay', 'or', ',', ' Math', 'ieu', ' Flam', 'ini']
+985 684 In their sport, the position played by x -1 In their sport, the position played by Mathieu Flamini midfielder Mathieu Flamini "[',' ' who' ' is' ' a' ' midfielder' ',' ' is' ' that' ' of' ' a' ' deep'
+ '-' 'lying' ' play' 'maker' '.' ' He' ' is' ' a' ' player']" , who is a midfielder , is that of a deep - lying play maker . He is a player True was on the bench and Mathieu Flamini partnered Cesc 8 [' was', ' on', ' the', ' bench', ' and', ' Math', 'ieu', ' Flam', 'ini']
+986 684 In their sport, the position played by x -1 In their sport, the position played by Mathieu Flamini midfielder Mathieu Flamini "[',' ' who' ' is' ' a' ' midfielder' ',' ' is' ' that' ' of' ' a' ' deep'
+ '-' 'lying' ' play' 'maker' '.' ' He' ' is' ' a' ' player']" , who is a midfielder , is that of a deep - lying play maker . He is a player True the bench and Mathieu Flamini partnered Cesc 6 [' the', ' bench', ' and', ' Math', 'ieu', ' Flam', 'ini']
+987 692 In their sport, the position played by x -1 In their sport, the position played by Patrick Roy goaltender Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' who']" , who was a member of the Montreal Canadiens , is a perfect example of the type of player who False NHL All-Star Game. Patrick Roy, Ray Bourque 7 [' NHL', ' All', '-', 'Star', ' Game', '.', ' Patrick', ' Roy']
+988 692 In their sport, the position played by x -1 In their sport, the position played by Patrick Roy goaltender Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' who']" , who was a member of the Montreal Canadiens , is a perfect example of the type of player who False " Lemieux and goalies Patrick Roy and Mike Vernon.
+" 7 [' Lem', 'ie', 'ux', ' and', ' goal', 'ies', ' Patrick', ' Roy']
+989 692 In their sport, the position played by x -1 In their sport, the position played by Patrick Roy goaltender Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' who']" , who was a member of the Montreal Canadiens , is a perfect example of the type of player who False Canadiens goalie Patrick Roy joined the 3 [' Canadiens', ' goalie', ' Patrick', ' Roy']
+990 692 In their sport, the position played by x -1 In their sport, the position played by Patrick Roy goaltender Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' who']" , who was a member of the Montreal Canadiens , is a perfect example of the type of player who False against goaltender Patrick Roy on the powerplay. 3 [' against', ' goaltender', ' Patrick', ' Roy']
+991 692 In their sport, the position played by x -1 In their sport, the position played by Patrick Roy goaltender Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' who']" , who was a member of the Montreal Canadiens , is a perfect example of the type of player who False champion goalies Patrick Roy and Mike Vernon. 4 [' champion', ' goal', 'ies', ' Patrick', ' Roy']
+992 696 In their sport, the position played by x -1 In their sport, the position played by Carlos Manuel midfielder Carlos Manuel "[' Santana' ',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is'
+ ' a' ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player']" Santana , who is a former world champion , is a very important one . He is the only player False taken to Professor Carlos Manuel Clinical Surgery 4 [' taken', ' to', ' Professor', ' Carlos', ' Manuel']
+993 696 In their sport, the position played by x -1 In their sport, the position played by Carlos Manuel midfielder Carlos Manuel "[' Santana' ',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is'
+ ' a' ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player']" Santana , who is a former world champion , is a very important one . He is the only player False The Cuban scholar Carlos Manuel Trelles later 4 [' The', ' Cuban', ' scholar', ' Carlos', ' Manuel']
+994 696 In their sport, the position played by x -1 In their sport, the position played by Carlos Manuel midfielder Carlos Manuel "[' Santana' ',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is'
+ ' a' ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player']" Santana , who is a former world champion , is a very important one . He is the only player False Cuban scholar Carlos Manuel Trelles later wrote 3 [' Cuban', ' scholar', ' Carlos', ' Manuel']
+995 696 In their sport, the position played by x -1 In their sport, the position played by Carlos Manuel midfielder Carlos Manuel "[' Santana' ',' ' who' ' is' ' a' ' former' ' world' ' champion' ',' ' is'
+ ' a' ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only' ' player']" Santana , who is a former world champion , is a very important one . He is the only player False taken to Professor Carlos Manuel Clinical Surgery Hospital 4 [' taken', ' to', ' Professor', ' Carlos', ' Manuel']
+996 697 In their sport, the position played by x -1 In their sport, the position played by Bob John midfielder Bob John "['stone' ',' ' the' ' former' ' England' ' international' ',' ' is' ' a'
+ ' key' ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' and' ' England']" "stone , the former England international , is a key one .
+
+ The former Manchester United and England" False had led 1 – 0 with a Bob John goal, but Newcastle's 8 [' had', ' led', ' 1', ' –', ' 0', ' with', ' a', ' Bob', ' John']
+997 699 In their sport, the position played by x -1 In their sport, the position played by Ichiro Suzuki outfielder Ichiro Suzuki "[',' ' the' ' Japanese' ' star' ' who' ' has' ' been' ' with' ' the'
+ ' Mariners' ' since' ' 2001' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.']" , the Japanese star who has been with the Mariners since 2001 , is a bit of a mystery . False chief engineer Ichiro Suzuki sought to develop 4 [' chief', ' engineer', ' Ich', 'iro', ' Suzuki']
+998 699 In their sport, the position played by x -1 In their sport, the position played by Ichiro Suzuki outfielder Ichiro Suzuki "[',' ' the' ' Japanese' ' star' ' who' ' has' ' been' ' with' ' the'
+ ' Mariners' ' since' ' 2001' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.']" , the Japanese star who has been with the Mariners since 2001 , is a bit of a mystery . False score of 5 – 4. Ichiro Suzuki won the MVP award 8 [' score', ' of', ' 5', ' –', ' 4', '.', ' Ich', 'iro', ' Suzuki']
+999 699 In their sport, the position played by x -1 In their sport, the position played by Ichiro Suzuki outfielder Ichiro Suzuki "[',' ' the' ' Japanese' ' star' ' who' ' has' ' been' ' with' ' the'
+ ' Mariners' ' since' ' 2001' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.']" , the Japanese star who has been with the Mariners since 2001 , is a bit of a mystery . False surpassed by Ichiro Suzuki of the Seattle 4 [' surpassed', ' by', ' Ich', 'iro', ' Suzuki']
+1000 699 In their sport, the position played by x -1 In their sport, the position played by Ichiro Suzuki outfielder Ichiro Suzuki "[',' ' the' ' Japanese' ' star' ' who' ' has' ' been' ' with' ' the'
+ ' Mariners' ' since' ' 2001' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.']" , the Japanese star who has been with the Mariners since 2001 , is a bit of a mystery . False previous record bid that Ichiro Suzuki had garnered in 6 [' previous', ' record', ' bid', ' that', ' Ich', 'iro', ' Suzuki']
+1001 699 In their sport, the position played by x -1 In their sport, the position played by Ichiro Suzuki outfielder Ichiro Suzuki "[',' ' the' ' Japanese' ' star' ' who' ' has' ' been' ' with' ' the'
+ ' Mariners' ' since' ' 2001' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery'
+ '.']" , the Japanese star who has been with the Mariners since 2001 , is a bit of a mystery . False to Brian Roberts, Ichiro Suzuki hit a long fly ball 6 [' to', ' Brian', ' Roberts', ',', ' Ich', 'iro', ' Suzuki']
+1002 700 In their sport, the position played by x -1 In their sport, the position played by Tommy John pitcher Tommy John "[',' ' the' ' player' ' who' ' is' ' the' ' best' ' at' ' the' ' position'
+ ',' ' is' ' the' ' best' ' player' ' at' ' the' ' position' '.' '\n']" ", the player who is the best at the position , is the best player at the position .
+" False rehabilitation from his Tommy John surgery, and as 4 [' rehabilitation', ' from', ' his', ' Tommy', ' John']
+1003 700 In their sport, the position played by x -1 In their sport, the position played by Tommy John pitcher Tommy John "[',' ' the' ' player' ' who' ' is' ' the' ' best' ' at' ' the' ' position'
+ ',' ' is' ' the' ' best' ' player' ' at' ' the' ' position' '.' '\n']" ", the player who is the best at the position , is the best player at the position .
+" False Pujols did not need Tommy John surgery at that 7 [' Pu', 'j', 'ols', ' did', ' not', ' need', ' Tommy', ' John']
+1004 700 In their sport, the position played by x -1 In their sport, the position played by Tommy John pitcher Tommy John "[',' ' the' ' player' ' who' ' is' ' the' ' best' ' at' ' the' ' position'
+ ',' ' is' ' the' ' best' ' player' ' at' ' the' ' position' '.' '\n']" ", the player who is the best at the position , is the best player at the position .
+" False that required Tommy John surgery. Though he 3 [' that', ' required', ' Tommy', ' John']
+1005 700 In their sport, the position played by x -1 In their sport, the position played by Tommy John pitcher Tommy John "[',' ' the' ' player' ' who' ' is' ' the' ' best' ' at' ' the' ' position'
+ ',' ' is' ' the' ' best' ' player' ' at' ' the' ' position' '.' '\n']" ", the player who is the best at the position , is the best player at the position .
+" False affected by his Tommy John surgery in 2010. He 4 [' affected', ' by', ' his', ' Tommy', ' John']
+1006 700 In their sport, the position played by x -1 In their sport, the position played by Tommy John pitcher Tommy John "[',' ' the' ' player' ' who' ' is' ' the' ' best' ' at' ' the' ' position'
+ ',' ' is' ' the' ' best' ' player' ' at' ' the' ' position' '.' '\n']" ", the player who is the best at the position , is the best player at the position .
+" False Lieber underwent Tommy John surgery in 2002 4 [' Lie', 'ber', ' underwent', ' Tommy', ' John']
+1007 713 In their sport, the position played by x -1 In their sport, the position played by Paul Harsley midfielder Paul Harsley "[',' ' the' ' former' ' England' ' rugby' ' player' ',' ' is' ' a' ' key'
+ ' role' '.' ' He' ' is' ' the' ' man' ' who' ',' ' as' ' captain']" , the former England rugby player , is a key role . He is the man who , as captain False Chesterfield midfielder Paul Harsley joined on loan 6 [' Chester', 'field', ' midfielder', ' Paul', ' H', 'ars', 'ley']
+1008 713 In their sport, the position played by x -1 In their sport, the position played by Paul Harsley midfielder Paul Harsley "[',' ' the' ' former' ' England' ' rugby' ' player' ',' ' is' ' a' ' key'
+ ' role' '.' ' He' ' is' ' the' ' man' ' who' ',' ' as' ' captain']" , the former England rugby player , is a key role . He is the man who , as captain False midfielder Paul Harsley joined on loan 4 [' midfielder', ' Paul', ' H', 'ars', 'ley']
+1009 714 In their sport, the position played by x -1 In their sport, the position played by Michael Conforto outfielder Michael Conforto "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' good' ' hitter' ',' ' but' ' he' '�' '�' 's']" is a bit of a mystery . He � � s a good hitter , but he � � s False the sixth when Michael Conforto drove in Yoenis Céspedes 6 [' the', ' sixth', ' when', ' Michael', ' Con', 'fort', 'o']
+1010 714 In their sport, the position played by x -1 In their sport, the position played by Michael Conforto outfielder Michael Conforto "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' good' ' hitter' ',' ' but' ' he' '�' '�' 's']" is a bit of a mystery . He � � s a good hitter , but he � � s False of the sixth when Michael Conforto drove in Yoenis 7 [' of', ' the', ' sixth', ' when', ' Michael', ' Con', 'fort', 'o']
+1011 714 In their sport, the position played by x -1 In their sport, the position played by Michael Conforto outfielder Michael Conforto "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' good' ' hitter' ',' ' but' ' he' '�' '�' 's']" is a bit of a mystery . He � � s a good hitter , but he � � s False the sixth when Michael Conforto drove in Yoenis Céspedes 6 [' the', ' sixth', ' when', ' Michael', ' Con', 'fort', 'o']
+1012 714 In their sport, the position played by x -1 In their sport, the position played by Michael Conforto outfielder Michael Conforto "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' good' ' hitter' ',' ' but' ' he' '�' '�' 's']" is a bit of a mystery . He � � s a good hitter , but he � � s False of the sixth when Michael Conforto drove in Yoenis 7 [' of', ' the', ' sixth', ' when', ' Michael', ' Con', 'fort', 'o']
+1013 714 In their sport, the position played by x -1 In their sport, the position played by Michael Conforto outfielder Michael Conforto "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' '�' '�' 's' ' a'
+ ' good' ' hitter' ',' ' but' ' he' '�' '�' 's']" is a bit of a mystery . He � � s a good hitter , but he � � s False sixth when Michael Conforto drove in Yoenis 5 [' sixth', ' when', ' Michael', ' Con', 'fort', 'o']
+1014 715 In their sport, the position played by x -1 In their sport, the position played by CC Sabathia pitcher CC Sabathia "[' is' ' that' ' of' ' a' ' pitcher' ' who' ' is' ' not' ' only' ' a'
+ ' great' ' pitcher' ',' ' but' ' also' ' a' ' great' ' teammate' '.'
+ ' He']" is that of a pitcher who is not only a great pitcher , but also a great teammate . He True agents: pitchers CC Sabathia and A. J. Burnett, 6 [' agents', ':', ' pitchers', ' CC', ' Sab', 'ath', 'ia']
+1015 715 In their sport, the position played by x -1 In their sport, the position played by CC Sabathia pitcher CC Sabathia "[' is' ' that' ' of' ' a' ' pitcher' ' who' ' is' ' not' ' only' ' a'
+ ' great' ' pitcher' ',' ' but' ' also' ' a' ' great' ' teammate' '.'
+ ' He']" is that of a pitcher who is not only a great pitcher , but also a great teammate . He True beating Yankees ace CC Sabathia four times during 6 [' beating', ' Yankees', ' ace', ' CC', ' Sab', 'ath', 'ia']
+1016 715 In their sport, the position played by x -1 In their sport, the position played by CC Sabathia pitcher CC Sabathia "[' is' ' that' ' of' ' a' ' pitcher' ' who' ' is' ' not' ' only' ' a'
+ ' great' ' pitcher' ',' ' but' ' also' ' a' ' great' ' teammate' '.'
+ ' He']" is that of a pitcher who is not only a great pitcher , but also a great teammate . He True pitchers for Game 1, CC Sabathia and Cliff Lee, 8 [' pitchers', ' for', ' Game', ' 1', ',', ' CC', ' Sab', 'ath', 'ia']
+1017 715 In their sport, the position played by x -1 In their sport, the position played by CC Sabathia pitcher CC Sabathia "[' is' ' that' ' of' ' a' ' pitcher' ' who' ' is' ' not' ' only' ' a'
+ ' great' ' pitcher' ',' ' but' ' also' ' a' ' great' ' teammate' '.'
+ ' He']" is that of a pitcher who is not only a great pitcher , but also a great teammate . He True He relieved CC Sabathia in the eighth inning 5 [' He', ' relieved', ' CC', ' Sab', 'ath', 'ia']
+1018 715 In their sport, the position played by x -1 In their sport, the position played by CC Sabathia pitcher CC Sabathia "[' is' ' that' ' of' ' a' ' pitcher' ' who' ' is' ' not' ' only' ' a'
+ ' great' ' pitcher' ',' ' but' ' also' ' a' ' great' ' teammate' '.'
+ ' He']" is that of a pitcher who is not only a great pitcher , but also a great teammate . He True free agents: pitchers CC Sabathia and A. J. Burnett, 7 [' free', ' agents', ':', ' pitchers', ' CC', ' Sab', 'ath', 'ia']
+1019 718 In their sport, the position played by x -1 In their sport, the position played by Vesa Toskala goaltender Vesa Toskala "[' is' ' that' ' of' ' a' ' goaltender' '.' ' He' ' is' ' the' ' last'
+ ' line' ' of' ' defense' ' for' ' the' ' team' '.' ' He' ' is' ' the']" is that of a goaltender . He is the last line of defense for the team . He is the True starting goaltender Vesa Toskala was injured. Serving 6 [' starting', ' goaltender', ' Ves', 'a', ' Tos', 'k', 'ala']
+1020 718 In their sport, the position played by x -1 In their sport, the position played by Vesa Toskala goaltender Vesa Toskala "[' is' ' that' ' of' ' a' ' goaltender' '.' ' He' ' is' ' the' ' last'
+ ' line' ' of' ' defense' ' for' ' the' ' team' '.' ' He' ' is' ' the']" is that of a goaltender . He is the last line of defense for the team . He is the True 2006, against Vesa Toskala in a loss to the 7 [' 2006', ',', ' against', ' Ves', 'a', ' Tos', 'k', 'ala']
+1021 718 In their sport, the position played by x -1 In their sport, the position played by Vesa Toskala goaltender Vesa Toskala "[' is' ' that' ' of' ' a' ' goaltender' '.' ' He' ' is' ' the' ' last'
+ ' line' ' of' ' defense' ' for' ' the' ' team' '.' ' He' ' is' ' the']" is that of a goaltender . He is the last line of defense for the team . He is the True competing with Vesa Toskala for the backup 6 [' competing', ' with', ' Ves', 'a', ' Tos', 'k', 'ala']
+1022 718 In their sport, the position played by x -1 In their sport, the position played by Vesa Toskala goaltender Vesa Toskala "[' is' ' that' ' of' ' a' ' goaltender' '.' ' He' ' is' ' the' ' last'
+ ' line' ' of' ' defense' ' for' ' the' ' team' '.' ' He' ' is' ' the']" is that of a goaltender . He is the last line of defense for the team . He is the True 2006, against Vesa Toskala in a loss to 7 [' 2006', ',', ' against', ' Ves', 'a', ' Tos', 'k', 'ala']
+1023 718 In their sport, the position played by x -1 In their sport, the position played by Vesa Toskala goaltender Vesa Toskala "[' is' ' that' ' of' ' a' ' goaltender' '.' ' He' ' is' ' the' ' last'
+ ' line' ' of' ' defense' ' for' ' the' ' team' '.' ' He' ' is' ' the']" is that of a goaltender . He is the last line of defense for the team . He is the True 2006, against Vesa Toskala in a loss to the 7 [' 2006', ',', ' against', ' Ves', 'a', ' Tos', 'k', 'ala']
+1024 721 In their sport, the position played by x -1 In their sport, the position played by Ben Roethlisberger quarterback Ben Roethlisberger "[',' ' the' ' Steelers' ' quarterback' ',' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader' ' of']" , the Steelers quarterback , is the most important position on the field . The quarterback is the leader of True King ('99 Bucs), Ben Roethlisberger (' 04 Steelers), 9 "[' King', "" ('"", '99', ' Bucs', '),', ' Ben', ' Ro', 'eth', 'lis', 'berger']"
+1025 721 In their sport, the position played by x -1 In their sport, the position played by Ben Roethlisberger quarterback Ben Roethlisberger "[',' ' the' ' Steelers' ' quarterback' ',' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader' ' of']" , the Steelers quarterback , is the most important position on the field . The quarterback is the leader of True cross-town team. Ben Roethlisberger attended Stanley Cup 9 [' cross', '-', 'town', ' team', '.', ' Ben', ' Ro', 'eth', 'lis', 'berger']
+1026 721 In their sport, the position played by x -1 In their sport, the position played by Ben Roethlisberger quarterback Ben Roethlisberger "[',' ' the' ' Steelers' ' quarterback' ',' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader' ' of']" , the Steelers quarterback , is the most important position on the field . The quarterback is the leader of True the Rams when Ben Roethlisberger sustained a knee 7 [' the', ' Rams', ' when', ' Ben', ' Ro', 'eth', 'lis', 'berger']
+1027 721 In their sport, the position played by x -1 In their sport, the position played by Ben Roethlisberger quarterback Ben Roethlisberger "[',' ' the' ' Steelers' ' quarterback' ',' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader' ' of']" , the Steelers quarterback , is the most important position on the field . The quarterback is the leader of True an interception of Ben Roethlisberger was driven to the 7 [' an', ' interception', ' of', ' Ben', ' Ro', 'eth', 'lis', 'berger']
+1028 721 In their sport, the position played by x -1 In their sport, the position played by Ben Roethlisberger quarterback Ben Roethlisberger "[',' ' the' ' Steelers' ' quarterback' ',' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader' ' of']" , the Steelers quarterback , is the most important position on the field . The quarterback is the leader of True the cross-town team. Ben Roethlisberger attended Stanley Cup 10 [' the', ' cross', '-', 'town', ' team', '.', ' Ben', ' Ro', 'eth', 'lis', 'berger']
+1029 723 In their sport, the position played by x -1 In their sport, the position played by Cornell Brown linebacker Cornell Brown "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' championship' ' in' ' 2011' ',' ' is' ' a' ' position'
+ ' that']" , who was a member of the team that won the NCAA championship in 2011 , is a position that False Hokie All-American Cornell Brown was injured. 6 [' Hok', 'ie', ' All', '-', 'American', ' Cornell', ' Brown']
+1030 723 In their sport, the position played by x -1 In their sport, the position played by Cornell Brown linebacker Cornell Brown "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' championship' ' in' ' 2011' ',' ' is' ' a' ' position'
+ ' that']" , who was a member of the team that won the NCAA championship in 2011 , is a position that False Virginia Tech defender Cornell Brown was not one of 4 [' Virginia', ' Tech', ' defender', ' Cornell', ' Brown']
+1031 723 In their sport, the position played by x -1 In their sport, the position played by Cornell Brown linebacker Cornell Brown "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' championship' ' in' ' 2011' ',' ' is' ' a' ' position'
+ ' that']" , who was a member of the team that won the NCAA championship in 2011 , is a position that False defensive end Cornell Brown was named the best 3 [' defensive', ' end', ' Cornell', ' Brown']
+1032 723 In their sport, the position played by x -1 In their sport, the position played by Cornell Brown linebacker Cornell Brown "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' championship' ' in' ' 2011' ',' ' is' ' a' ' position'
+ ' that']" , who was a member of the team that won the NCAA championship in 2011 , is a position that False defensive end Cornell Brown was named the best 3 [' defensive', ' end', ' Cornell', ' Brown']
+1033 723 In their sport, the position played by x -1 In their sport, the position played by Cornell Brown linebacker Cornell Brown "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' championship' ' in' ' 2011' ',' ' is' ' a' ' position'
+ ' that']" , who was a member of the team that won the NCAA championship in 2011 , is a position that False Virginia Tech defender Cornell Brown was not one of 4 [' Virginia', ' Tech', ' defender', ' Cornell', ' Brown']
+1034 724 In their sport, the position played by x -1 In their sport, the position played by Jesper Olsen midfielder Jesper Olsen "[' and' ' his' ' team' ' is' ' to' ' be' ' comm' 'ended' '.' ' They'
+ ' have' ' been' ' very' ' open' ' about' ' their' ' intentions' ' and'
+ ' have' ' been']" and his team is to be comm ended . They have been very open about their intentions and have been False Johan Cruyff and Jesper Olsen for Ajax. Having 8 [' Joh', 'an', ' Cru', 'y', 'ff', ' and', ' Jes', 'per', ' Olsen']
+1035 724 In their sport, the position played by x -1 In their sport, the position played by Jesper Olsen midfielder Jesper Olsen "[' and' ' his' ' team' ' is' ' to' ' be' ' comm' 'ended' '.' ' They'
+ ' have' ' been' ' very' ' open' ' about' ' their' ' intentions' ' and'
+ ' have' ' been']" and his team is to be comm ended . They have been very open about their intentions and have been False Johan Cruyff and Jesper Olsen for Ajax. 8 [' Joh', 'an', ' Cru', 'y', 'ff', ' and', ' Jes', 'per', ' Olsen']
+1036 724 In their sport, the position played by x -1 In their sport, the position played by Jesper Olsen midfielder Jesper Olsen "[' and' ' his' ' team' ' is' ' to' ' be' ' comm' 'ended' '.' ' They'
+ ' have' ' been' ' very' ' open' ' about' ' their' ' intentions' ' and'
+ ' have' ' been']" and his team is to be comm ended . They have been very open about their intentions and have been False by Johan Cruyff and Jesper Olsen for Ajax. Having 9 [' by', ' Joh', 'an', ' Cru', 'y', 'ff', ' and', ' Jes', 'per', ' Olsen']
+1037 724 In their sport, the position played by x -1 In their sport, the position played by Jesper Olsen midfielder Jesper Olsen "[' and' ' his' ' team' ' is' ' to' ' be' ' comm' 'ended' '.' ' They'
+ ' have' ' been' ' very' ' open' ' about' ' their' ' intentions' ' and'
+ ' have' ' been']" and his team is to be comm ended . They have been very open about their intentions and have been False Johan Cruyff and Jesper Olsen for Ajax. Having 8 [' Joh', 'an', ' Cru', 'y', 'ff', ' and', ' Jes', 'per', ' Olsen']
+1038 724 In their sport, the position played by x -1 In their sport, the position played by Jesper Olsen midfielder Jesper Olsen "[' and' ' his' ' team' ' is' ' to' ' be' ' comm' 'ended' '.' ' They'
+ ' have' ' been' ' very' ' open' ' about' ' their' ' intentions' ' and'
+ ' have' ' been']" and his team is to be comm ended . They have been very open about their intentions and have been False Johan Cruyff and Jesper Olsen for Ajax. Having 8 [' Joh', 'an', ' Cru', 'y', 'ff', ' and', ' Jes', 'per', ' Olsen']
+1039 726 In their sport, the position played by x -1 In their sport, the position played by Scott Dreisbach quarterback Scott Dreisbach "[',' ' the' ' team' ""'s"" ' quarterback' ',' ' is' ' to' ' throw' ' the'
+ ' ball' ' to' ' the' ' receiver' ',' ' who' ' is' ' running' ' a'
+ ' route']" , the team 's quarterback , is to throw the ball to the receiver , who is running a route True August 26, 1995 from Scott Dreisbach to seal an 18 – 8 [' August', ' 26', ',', ' 1995', ' from', ' Scott', ' Dre', 'is', 'bach']
+1040 726 In their sport, the position played by x -1 In their sport, the position played by Scott Dreisbach quarterback Scott Dreisbach "[',' ' the' ' team' ""'s"" ' quarterback' ',' ' is' ' to' ' throw' ' the'
+ ' ball' ' to' ' the' ' receiver' ',' ' who' ' is' ' running' ' a'
+ ' route']" , the team 's quarterback , is to throw the ball to the receiver , who is running a route True position. Junior Scott Dreisbach had started 11 6 [' position', '.', ' Junior', ' Scott', ' Dre', 'is', 'bach']
+1041 726 In their sport, the position played by x -1 In their sport, the position played by Scott Dreisbach quarterback Scott Dreisbach "[',' ' the' ' team' ""'s"" ' quarterback' ',' ' is' ' to' ' throw' ' the'
+ ' ball' ' to' ' the' ' receiver' ',' ' who' ' is' ' running' ' a'
+ ' route']" , the team 's quarterback , is to throw the ball to the receiver , who is running a route True position. Junior Scott Dreisbach had started 11 6 [' position', '.', ' Junior', ' Scott', ' Dre', 'is', 'bach']
+1042 728 In their sport, the position played by x -1 In their sport, the position played by Alex Auld goaltender Alex Auld "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'A' 'uld' ',' ' who' ' is' ' now' ' a']" ", the former England international , is a key one .
+
+ A uld , who is now a" False replaced Cloutier, Alex Auld filled in for 7 [' replaced', ' Cl', 'out', 'ier', ',', ' Alex', ' A', 'uld']
+1043 728 In their sport, the position played by x -1 In their sport, the position played by Alex Auld goaltender Alex Auld "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'A' 'uld' ',' ' who' ' is' ' now' ' a']" ", the former England international , is a key one .
+
+ A uld , who is now a" False goaltender Alex Auld and defenceman Bryan 3 [' goaltender', ' Alex', ' A', 'uld']
+1044 728 In their sport, the position played by x -1 In their sport, the position played by Alex Auld goaltender Alex Auld "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'A' 'uld' ',' ' who' ' is' ' now' ' a']" ", the former England international , is a key one .
+
+ A uld , who is now a" False favour of rookie Alex Auld after allowing 5 [' favour', ' of', ' rookie', ' Alex', ' A', 'uld']
+1045 728 In their sport, the position played by x -1 In their sport, the position played by Alex Auld goaltender Alex Auld "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'A' 'uld' ',' ' who' ' is' ' now' ' a']" ", the former England international , is a key one .
+
+ A uld , who is now a" False Martin Gerber and Alex Auld meant the team 6 [' Martin', ' Ger', 'ber', ' and', ' Alex', ' A', 'uld']
+1046 728 In their sport, the position played by x -1 In their sport, the position played by Alex Auld goaltender Alex Auld "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'A' 'uld' ',' ' who' ' is' ' now' ' a']" ", the former England international , is a key one .
+
+ A uld , who is now a" False favour of rookie Alex Auld after allowing six 5 [' favour', ' of', ' rookie', ' Alex', ' A', 'uld']
+1047 730 In their sport, the position played by x -1 In their sport, the position played by Ali Karimi midfielder Ali Karimi "[',' ' who' ' was' ' the' ' first' ' Iranian' ' to' ' win' ' the' ' title'
+ ',' ' was' ' a' ' surprise' '.' '\n' '\n' 'The' ' Iranian' ',']" ", who was the first Iranian to win the title , was a surprise .
+
+ The Iranian ," False brought in, to replace Ali Karimi in the 88th 7 [' brought', ' in', ',', ' to', ' replace', ' Ali', ' Kar', 'imi']
+1048 730 In their sport, the position played by x -1 In their sport, the position played by Ali Karimi midfielder Ali Karimi "[',' ' who' ' was' ' the' ' first' ' Iranian' ' to' ' win' ' the' ' title'
+ ',' ' was' ' a' ' surprise' '.' '\n' '\n' 'The' ' Iranian' ',']" ", who was the first Iranian to win the title , was a surprise .
+
+ The Iranian ," False in, to replace Ali Karimi in the 88th minute, 6 [' in', ',', ' to', ' replace', ' Ali', ' Kar', 'imi']
+1049 731 In their sport, the position played by x -1 In their sport, the position played by Jack Balmer forward Jack Balmer "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False Albert Stubbins and Jack Balmer with numerous assists. 7 [' Albert', ' St', 'ubb', 'ins', ' and', ' Jack', ' Bal', 'mer']
+1050 731 In their sport, the position played by x -1 In their sport, the position played by Jack Balmer forward Jack Balmer "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False Stubbins and Jack Balmer with numerous 6 [' St', 'ubb', 'ins', ' and', ' Jack', ' Bal', 'mer']
+1051 731 In their sport, the position played by x -1 In their sport, the position played by Jack Balmer forward Jack Balmer "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False attack alongside Jack Balmer and Billy Liddell. 4 [' attack', ' alongside', ' Jack', ' Bal', 'mer']
+1052 731 In their sport, the position played by x -1 In their sport, the position played by Jack Balmer forward Jack Balmer "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False attack alongside Jack Balmer and Billy Liddell. 4 [' attack', ' alongside', ' Jack', ' Bal', 'mer']
+1053 732 In their sport, the position played by x -1 In their sport, the position played by Moise Fokou linebacker Moise Fokou "[',' ' the' ' French' ' player' ',' ' is' ' called' ' the' ' ""' 'p' 'ivot'
+ '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the French player , is called the "" p ivot "" and the position played by the American player" False Game. Linebacker Moise Fokou played in the 8 [' Game', '.', ' Line', 'backer', ' Mo', 'ise', ' F', 'ok', 'ou']
+1054 732 In their sport, the position played by x -1 In their sport, the position played by Moise Fokou linebacker Moise Fokou "[',' ' the' ' French' ' player' ',' ' is' ' called' ' the' ' ""' 'p' 'ivot'
+ '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the French player , is called the "" p ivot "" and the position played by the American player" False Game. Linebacker Moise Fokou played in the Under 8 [' Game', '.', ' Line', 'backer', ' Mo', 'ise', ' F', 'ok', 'ou']
+1055 732 In their sport, the position played by x -1 In their sport, the position played by Moise Fokou linebacker Moise Fokou "[',' ' the' ' French' ' player' ',' ' is' ' called' ' the' ' ""' 'p' 'ivot'
+ '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the French player , is called the "" p ivot "" and the position played by the American player" False Linebacker Moise Fokou played in the Under 6 [' Line', 'backer', ' Mo', 'ise', ' F', 'ok', 'ou']
+1056 732 In their sport, the position played by x -1 In their sport, the position played by Moise Fokou linebacker Moise Fokou "[',' ' the' ' French' ' player' ',' ' is' ' called' ' the' ' ""' 'p' 'ivot'
+ '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the French player , is called the "" p ivot "" and the position played by the American player" False Game. Linebacker Moise Fokou played in the 8 [' Game', '.', ' Line', 'backer', ' Mo', 'ise', ' F', 'ok', 'ou']
+1057 732 In their sport, the position played by x -1 In their sport, the position played by Moise Fokou linebacker Moise Fokou "[',' ' the' ' French' ' player' ',' ' is' ' called' ' the' ' ""' 'p' 'ivot'
+ '""' ' and' ' the' ' position' ' played' ' by' ' the' ' American'
+ ' player']" ", the French player , is called the "" p ivot "" and the position played by the American player" False Game. Linebacker Moise Fokou played in the Under 8 [' Game', '.', ' Line', 'backer', ' Mo', 'ise', ' F', 'ok', 'ou']
+1058 733 In their sport, the position played by x -1 In their sport, the position played by Vladislav Tretiak goaltender Vladislav Tretiak "[',' ' the' ' goalie' ',' ' is' ' called' ' the' ' ""' 'sp' 'ider' '""'
+ ' and' ' the' ' position' ' played' ' by' ' the' ' defenseman' ' is'
+ ' called']" ", the goalie , is called the "" sp ider "" and the position played by the defenseman is called" False national goaltender Vladislav Tretiak was named Bure's 7 [' national', ' goaltender', ' Vlad', 'isl', 'av', ' T', 'ret', 'iak']
+1059 733 In their sport, the position played by x -1 In their sport, the position played by Vladislav Tretiak goaltender Vladislav Tretiak "[',' ' the' ' goalie' ',' ' is' ' called' ' the' ' ""' 'sp' 'ider' '""'
+ ' and' ' the' ' position' ' played' ' by' ' the' ' defenseman' ' is'
+ ' called']" ", the goalie , is called the "" sp ider "" and the position played by the defenseman is called" False by goaltender Vladislav Tretiak and forwards Valeri 7 [' by', ' goaltender', ' Vlad', 'isl', 'av', ' T', 'ret', 'iak']
+1060 733 In their sport, the position played by x -1 In their sport, the position played by Vladislav Tretiak goaltender Vladislav Tretiak "[',' ' the' ' goalie' ',' ' is' ' called' ' the' ' ""' 'sp' 'ider' '""'
+ ' and' ' the' ' position' ' played' ' by' ' the' ' defenseman' ' is'
+ ' called']" ", the goalie , is called the "" sp ider "" and the position played by the defenseman is called" False the venerable Vladislav Tretiak in goal. With a strong 7 [' the', ' venerable', ' Vlad', 'isl', 'av', ' T', 'ret', 'iak']
+1061 733 In their sport, the position played by x -1 In their sport, the position played by Vladislav Tretiak goaltender Vladislav Tretiak "[',' ' the' ' goalie' ',' ' is' ' called' ' the' ' ""' 'sp' 'ider' '""'
+ ' and' ' the' ' position' ' played' ' by' ' the' ' defenseman' ' is'
+ ' called']" ", the goalie , is called the "" sp ider "" and the position played by the defenseman is called" False national goaltender Vladislav Tretiak was named Bure's 7 [' national', ' goaltender', ' Vlad', 'isl', 'av', ' T', 'ret', 'iak']
+1062 733 In their sport, the position played by x -1 In their sport, the position played by Vladislav Tretiak goaltender Vladislav Tretiak "[',' ' the' ' goalie' ',' ' is' ' called' ' the' ' ""' 'sp' 'ider' '""'
+ ' and' ' the' ' position' ' played' ' by' ' the' ' defenseman' ' is'
+ ' called']" ", the goalie , is called the "" sp ider "" and the position played by the defenseman is called" False goaltender Vladislav Tretiak was named Bure's 6 [' goaltender', ' Vlad', 'isl', 'av', ' T', 'ret', 'iak']
+1063 738 In their sport, the position played by x -1 In their sport, the position played by Aaron Rodgers quarterback Aaron Rodgers "[',' ' the' ' Packers' ' quarterback' ',' ' is' ' to' ' be' ' the'
+ ' quarterback' ' of' ' the' ' team' '.' ' He' ' is' ' the' ' leader'
+ ' of' ' the']" , the Packers quarterback , is to be the quarterback of the team . He is the leader of the True that pressured Aaron Rodgers 19 times in 3 [' that', ' pressured', ' Aaron', ' Rodgers']
+1064 738 In their sport, the position played by x -1 In their sport, the position played by Aaron Rodgers quarterback Aaron Rodgers "[',' ' the' ' Packers' ' quarterback' ',' ' is' ' to' ' be' ' the'
+ ' quarterback' ' of' ' the' ' team' '.' ' He' ' is' ' the' ' leader'
+ ' of' ' the']" , the Packers quarterback , is to be the quarterback of the team . He is the leader of the True commitment to Aaron Rodgers as its new 3 [' commitment', ' to', ' Aaron', ' Rodgers']
+1065 738 In their sport, the position played by x -1 In their sport, the position played by Aaron Rodgers quarterback Aaron Rodgers "[',' ' the' ' Packers' ' quarterback' ',' ' is' ' to' ' be' ' the'
+ ' quarterback' ' of' ' the' ' team' '.' ' He' ' is' ' the' ' leader'
+ ' of' ' the']" , the Packers quarterback , is to be the quarterback of the team . He is the leader of the True California and quarterback Aaron Rodgers scored 21 unanswered 4 [' California', ' and', ' quarterback', ' Aaron', ' Rodgers']
+1066 738 In their sport, the position played by x -1 In their sport, the position played by Aaron Rodgers quarterback Aaron Rodgers "[',' ' the' ' Packers' ' quarterback' ',' ' is' ' to' ' be' ' the'
+ ' quarterback' ' of' ' the' ' team' '.' ' He' ' is' ' the' ' leader'
+ ' of' ' the']" , the Packers quarterback , is to be the quarterback of the team . He is the leader of the True the situation Aaron Rodgers endured backing up 3 [' the', ' situation', ' Aaron', ' Rodgers']
+1067 738 In their sport, the position played by x -1 In their sport, the position played by Aaron Rodgers quarterback Aaron Rodgers "[',' ' the' ' Packers' ' quarterback' ',' ' is' ' to' ' be' ' the'
+ ' quarterback' ' of' ' the' ' team' '.' ' He' ' is' ' the' ' leader'
+ ' of' ' the']" , the Packers quarterback , is to be the quarterback of the team . He is the leader of the True California a first down. Aaron Rodgers then completed 6 [' California', ' a', ' first', ' down', '.', ' Aaron', ' Rodgers']
+1068 739 In their sport, the position played by x -1 In their sport, the position played by Scott Milanovich quarterback Scott Milanovich "[',' ' who' ' was' ' a' ' quarterback' ' at' ' the' ' University' ' of'
+ ' Michigan' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous' ' as']" , who was a quarterback at the University of Michigan , is a position that is not as glamorous as True quarterbacks Scott Milanovich and John Kaleo 3 [' quarterbacks', ' Scott', ' Milan', 'ovich']
+1069 740 In their sport, the position played by x -1 In their sport, the position played by Christy Mathewson pitcher Christy Mathewson "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True first base, moving Christy Mathewson to pitcher. McGann 7 [' first', ' base', ',', ' moving', ' Christy', ' Mat', 'hew', 'son']
+1070 740 In their sport, the position played by x -1 In their sport, the position played by Christy Mathewson pitcher Christy Mathewson "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True runners-up, Christy Mathewson and Warren Spahn. 7 [' runners', '-', 'up', ',', ' Christy', ' Mat', 'hew', 'son']
+1071 740 In their sport, the position played by x -1 In their sport, the position played by Christy Mathewson pitcher Christy Mathewson "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True first base, moving Christy Mathewson to pitcher. McGann 7 [' first', ' base', ',', ' moving', ' Christy', ' Mat', 'hew', 'son']
+1072 740 In their sport, the position played by x -1 In their sport, the position played by Christy Mathewson pitcher Christy Mathewson "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True first base, moving Christy Mathewson to pitcher. McGann 7 [' first', ' base', ',', ' moving', ' Christy', ' Mat', 'hew', 'son']
+1073 740 In their sport, the position played by x -1 In their sport, the position played by Christy Mathewson pitcher Christy Mathewson "[',' ' the' ' pitcher' ',' ' is' ' the' ' same' ' as' ' that' ' of' ' the'
+ ' pitcher' ' in' ' baseball' '.' ' The' ' pitcher' ' is' ' the' ' one']" , the pitcher , is the same as that of the pitcher in baseball . The pitcher is the one True Hall of Famer Christy Mathewson in 1905 was the last 7 [' Hall', ' of', ' F', 'amer', ' Christy', ' Mat', 'hew', 'son']
+1074 741 In their sport, the position played by x -1 In their sport, the position played by Geoffrey Kondogbia midfielder Geoffrey Kondogbia "[' is' ' that' ' of' ' a' ' midfielder' ',' ' but' ' he' ' is' ' a'
+ ' midfielder' ' who' ' is' ' also' ' a' ' defender' '.' ' He' ' is' ' a']" is that of a midfielder , but he is a midfielder who is also a defender . He is a True Gary Medel and Geoffrey Kondogbia – being sent 8 [' Gary', ' Med', 'el', ' and', ' Geoffrey', ' K', 'ond', 'og', 'bia']
+1075 741 In their sport, the position played by x -1 In their sport, the position played by Geoffrey Kondogbia midfielder Geoffrey Kondogbia "[' is' ' that' ' of' ' a' ' midfielder' ',' ' but' ' he' ' is' ' a'
+ ' midfielder' ' who' ' is' ' also' ' a' ' defender' '.' ' He' ' is' ' a']" is that of a midfielder , but he is a midfielder who is also a defender . He is a True Gary Medel and Geoffrey Kondogbia – being sent off in 8 [' Gary', ' Med', 'el', ' and', ' Geoffrey', ' K', 'ond', 'og', 'bia']
+1076 741 In their sport, the position played by x -1 In their sport, the position played by Geoffrey Kondogbia midfielder Geoffrey Kondogbia "[' is' ' that' ' of' ' a' ' midfielder' ',' ' but' ' he' ' is' ' a'
+ ' midfielder' ' who' ' is' ' also' ' a' ' defender' '.' ' He' ' is' ' a']" is that of a midfielder , but he is a midfielder who is also a defender . He is a True Gary Medel and Geoffrey Kondogbia – being sent 8 [' Gary', ' Med', 'el', ' and', ' Geoffrey', ' K', 'ond', 'og', 'bia']
+1077 741 In their sport, the position played by x -1 In their sport, the position played by Geoffrey Kondogbia midfielder Geoffrey Kondogbia "[' is' ' that' ' of' ' a' ' midfielder' ',' ' but' ' he' ' is' ' a'
+ ' midfielder' ' who' ' is' ' also' ' a' ' defender' '.' ' He' ' is' ' a']" is that of a midfielder , but he is a midfielder who is also a defender . He is a True Gary Medel and Geoffrey Kondogbia – being sent off 8 [' Gary', ' Med', 'el', ' and', ' Geoffrey', ' K', 'ond', 'og', 'bia']
+1078 741 In their sport, the position played by x -1 In their sport, the position played by Geoffrey Kondogbia midfielder Geoffrey Kondogbia "[' is' ' that' ' of' ' a' ' midfielder' ',' ' but' ' he' ' is' ' a'
+ ' midfielder' ' who' ' is' ' also' ' a' ' defender' '.' ' He' ' is' ' a']" is that of a midfielder , but he is a midfielder who is also a defender . He is a True Gary Medel and Geoffrey Kondogbia – being sent off 8 [' Gary', ' Med', 'el', ' and', ' Geoffrey', ' K', 'ond', 'og', 'bia']
+1079 742 In their sport, the position played by x -1 In their sport, the position played by Fabio Capello midfielder Fabio Capello "[',' ' the' ' Italian' ' coach' ',' ' is' ' a' ' bit' ' like' ' that'
+ ' of' ' a' ' general' ' manager' '.' ' He' ' is' ' the' ' manager' ' of']" , the Italian coach , is a bit like that of a general manager . He is the manager of False appointed manager Fabio Capello left Carrick 5 [' appointed', ' manager', ' Fab', 'io', ' Cape', 'llo']
+1080 742 In their sport, the position played by x -1 In their sport, the position played by Fabio Capello midfielder Fabio Capello "[',' ' the' ' Italian' ' coach' ',' ' is' ' a' ' bit' ' like' ' that'
+ ' of' ' a' ' general' ' manager' '.' ' He' ' is' ' the' ' manager' ' of']" , the Italian coach , is a bit like that of a general manager . He is the manager of False Steve Bruce urged Fabio Capello to consider 6 [' Steve', ' Bruce', ' urged', ' Fab', 'io', ' Cape', 'llo']
+1081 742 In their sport, the position played by x -1 In their sport, the position played by Fabio Capello midfielder Fabio Capello "[',' ' the' ' Italian' ' coach' ',' ' is' ' a' ' bit' ' like' ' that'
+ ' of' ' a' ' general' ' manager' '.' ' He' ' is' ' the' ' manager' ' of']" , the Italian coach , is a bit like that of a general manager . He is the manager of False subsequently appointed Fabio Capello as the new coach and 5 [' subsequently', ' appointed', ' Fab', 'io', ' Cape', 'llo']
+1082 742 In their sport, the position played by x -1 In their sport, the position played by Fabio Capello midfielder Fabio Capello "[',' ' the' ' Italian' ' coach' ',' ' is' ' a' ' bit' ' like' ' that'
+ ' of' ' a' ' general' ' manager' '.' ' He' ' is' ' the' ' manager' ' of']" , the Italian coach , is a bit like that of a general manager . He is the manager of False 3 ['Fab', 'io', ' Cape', 'llo']
+1083 742 In their sport, the position played by x -1 In their sport, the position played by Fabio Capello midfielder Fabio Capello "[',' ' the' ' Italian' ' coach' ',' ' is' ' a' ' bit' ' like' ' that'
+ ' of' ' a' ' general' ' manager' '.' ' He' ' is' ' the' ' manager' ' of']" , the Italian coach , is a bit like that of a general manager . He is the manager of False appointed manager Fabio Capello left Carrick out 5 [' appointed', ' manager', ' Fab', 'io', ' Cape', 'llo']
+1084 743 In their sport, the position played by x -1 In their sport, the position played by Nate Robertson pitcher Nate Robertson "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False the third inning, Nate Robertson hit Denard Span. 5 [' the', ' third', ' inning', ',', ' Nate', ' Robertson']
+1085 743 In their sport, the position played by x -1 In their sport, the position played by Nate Robertson pitcher Nate Robertson "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False third inning, Nate Robertson hit Denard Span. Minnesota 4 [' third', ' inning', ',', ' Nate', ' Robertson']
+1086 743 In their sport, the position played by x -1 In their sport, the position played by Nate Robertson pitcher Nate Robertson "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False third inning, Nate Robertson hit Denard Span. 4 [' third', ' inning', ',', ' Nate', ' Robertson']
+1087 743 In their sport, the position played by x -1 In their sport, the position played by Nate Robertson pitcher Nate Robertson "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False the third inning, Nate Robertson hit Denard Span. 5 [' the', ' third', ' inning', ',', ' Nate', ' Robertson']
+1088 744 In their sport, the position played by x -1 In their sport, the position played by John Reaves quarterback John Reaves "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False Gators standout John Reaves started the 4 [' Gators', ' standout', ' John', ' Re', 'aves']
+1089 745 In their sport, the position played by x -1 In their sport, the position played by Reinhard Divis goaltender Reinhard Divis "[' is' ' a' ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only'
+ ' player' ' who' ' can' ' be' ' considered' ' a' ' true' ' ""' 'super'
+ 'star']" " is a very important one . He is the only player who can be considered a true "" super star" False goal against Reinhard Divis of the St. Louis 5 [' goal', ' against', ' Rein', 'hard', ' Div', 'is']
+1090 745 In their sport, the position played by x -1 In their sport, the position played by Reinhard Divis goaltender Reinhard Divis "[' is' ' a' ' very' ' important' ' one' '.' ' He' ' is' ' the' ' only'
+ ' player' ' who' ' can' ' be' ' considered' ' a' ' true' ' ""' 'super'
+ 'star']" " is a very important one . He is the only player who can be considered a true "" super star" False NHL goal against Reinhard Divis of the St. Louis 6 [' NHL', ' goal', ' against', ' Rein', 'hard', ' Div', 'is']
+1091 746 In their sport, the position played by x -1 In their sport, the position played by John Parker Wilson quarterback John Parker Wilson "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False offensive line; John Parker Wilson at quarterback; 5 [' offensive', ' line', ';', ' John', ' Parker', ' Wilson']
+1092 748 In their sport, the position played by x -1 In their sport, the position played by Joe Kapp quarterback Joe Kapp "[',' ' the' ' quarterback' ',' ' is' ' the' ' equivalent' ' of' ' the'
+ ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the equivalent of the quarterback in football . The quarterback is the leader of the True and the scientist Joe Kapp establish that 5 [' and', ' the', ' scientist', ' Joe', ' K', 'app']
+1093 748 In their sport, the position played by x -1 In their sport, the position played by Joe Kapp quarterback Joe Kapp "[',' ' the' ' quarterback' ',' ' is' ' the' ' equivalent' ' of' ' the'
+ ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the equivalent of the quarterback in football . The quarterback is the leader of the True which brought Joe Kapp from the British 4 [' which', ' brought', ' Joe', ' K', 'app']
+1094 748 In their sport, the position played by x -1 In their sport, the position played by Joe Kapp quarterback Joe Kapp "[',' ' the' ' quarterback' ',' ' is' ' the' ' equivalent' ' of' ' the'
+ ' quarterback' ' in' ' football' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the equivalent of the quarterback in football . The quarterback is the leader of the True leagues, which brought Joe Kapp from the British 6 [' leagues', ',', ' which', ' brought', ' Joe', ' K', 'app']
+1095 750 In their sport, the position played by x -1 In their sport, the position played by Graeme Souness midfielder Graeme Souness "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' the' ' equivalent' ' of'
+ ' the' ' English' ' midfielder' '.' '\n' '\n' 'The' ' Scottish'
+ ' midfielder' ' is' ' the']" ", the Scottish midfielder , is the equivalent of the English midfielder .
+
+ The Scottish midfielder is the" True sacking of Graeme Souness as Newcastle manager 7 [' s', 'acking', ' of', ' Gra', 'eme', ' S', 'oun', 'ess']
+1096 750 In their sport, the position played by x -1 In their sport, the position played by Graeme Souness midfielder Graeme Souness "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' the' ' equivalent' ' of'
+ ' the' ' English' ' midfielder' '.' '\n' '\n' 'The' ' Scottish'
+ ' midfielder' ' is' ' the']" ", the Scottish midfielder , is the equivalent of the English midfielder .
+
+ The Scottish midfielder is the" True players, including Graeme Souness (his captain), 7 [' players', ',', ' including', ' Gra', 'eme', ' S', 'oun', 'ess']
+1097 750 In their sport, the position played by x -1 In their sport, the position played by Graeme Souness midfielder Graeme Souness "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' the' ' equivalent' ' of'
+ ' the' ' English' ' midfielder' '.' '\n' '\n' 'The' ' Scottish'
+ ' midfielder' ' is' ' the']" ", the Scottish midfielder , is the equivalent of the English midfielder .
+
+ The Scottish midfielder is the" True captain as both Graeme Souness and Phil Thompson 7 [' captain', ' as', ' both', ' Gra', 'eme', ' S', 'oun', 'ess']
+1098 750 In their sport, the position played by x -1 In their sport, the position played by Graeme Souness midfielder Graeme Souness "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' the' ' equivalent' ' of'
+ ' the' ' English' ' midfielder' '.' '\n' '\n' 'The' ' Scottish'
+ ' midfielder' ' is' ' the']" ", the Scottish midfielder , is the equivalent of the English midfielder .
+
+ The Scottish midfielder is the" True David O 'Leary, Graeme Souness and Rafael Benítez 9 "[' David', ' O', "" '"", 'Leary', ',', ' Gra', 'eme', ' S', 'oun', 'ess']"
+1099 750 In their sport, the position played by x -1 In their sport, the position played by Graeme Souness midfielder Graeme Souness "[',' ' the' ' Scottish' ' midfielder' ',' ' is' ' the' ' equivalent' ' of'
+ ' the' ' English' ' midfielder' '.' '\n' '\n' 'The' ' Scottish'
+ ' midfielder' ' is' ' the']" ", the Scottish midfielder , is the equivalent of the English midfielder .
+
+ The Scottish midfielder is the" True Liverpool captain Graeme Souness break the jaw 6 [' Liverpool', ' captain', ' Gra', 'eme', ' S', 'oun', 'ess']
+1100 753 In their sport, the position played by x -1 In their sport, the position played by Archie Manning quarterback Archie Manning "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Mississippi quarterback Archie Manning used the good 3 [' Mississippi', ' quarterback', ' Archie', ' Manning']
+1101 753 In their sport, the position played by x -1 In their sport, the position played by Archie Manning quarterback Archie Manning "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Miss quarterback Archie Manning finished the game 3 [' Miss', ' quarterback', ' Archie', ' Manning']
+1102 753 In their sport, the position played by x -1 In their sport, the position played by Archie Manning quarterback Archie Manning "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Miss quarterback Archie Manning finished the game 3 [' Miss', ' quarterback', ' Archie', ' Manning']
+1103 753 In their sport, the position played by x -1 In their sport, the position played by Archie Manning quarterback Archie Manning "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True former NFL quarterback Archie Manning and older brother 4 [' former', ' NFL', ' quarterback', ' Archie', ' Manning']
+1104 753 In their sport, the position played by x -1 In their sport, the position played by Archie Manning quarterback Archie Manning "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Ole Miss quarterback Archie Manning finished the game 4 [' Ole', ' Miss', ' quarterback', ' Archie', ' Manning']
+1105 754 In their sport, the position played by x -1 In their sport, the position played by Paulo Bento midfielder Paulo Bento "[',' ' the' ' Portuguese' ' coach' ',' ' is' ' a' ' very' ' important'
+ ' one' '.' ' He' ' is' ' the' ' one' ' who' ' will' ' decide' ' the'
+ ' future']" , the Portuguese coach , is a very important one . He is the one who will decide the future False internationals like Paulo Bento and Amaral, but 5 [' international', 's', ' like', ' Paulo', ' Bent', 'o']
+1106 754 In their sport, the position played by x -1 In their sport, the position played by Paulo Bento midfielder Paulo Bento "[',' ' the' ' Portuguese' ' coach' ',' ' is' ' a' ' very' ' important'
+ ' one' '.' ' He' ' is' ' the' ' one' ' who' ' will' ' decide' ' the'
+ ' future']" , the Portuguese coach , is a very important one . He is the one who will decide the future False contend with Paulo Bento and Rui Bento, 4 [' contend', ' with', ' Paulo', ' Bent', 'o']
+1107 754 In their sport, the position played by x -1 In their sport, the position played by Paulo Bento midfielder Paulo Bento "[',' ' the' ' Portuguese' ' coach' ',' ' is' ' a' ' very' ' important'
+ ' one' '.' ' He' ' is' ' the' ' one' ' who' ' will' ' decide' ' the'
+ ' future']" , the Portuguese coach , is a very important one . He is the one who will decide the future False direct contend with Paulo Bento and Rui Bento, 5 [' direct', ' contend', ' with', ' Paulo', ' Bent', 'o']
+1108 754 In their sport, the position played by x -1 In their sport, the position played by Paulo Bento midfielder Paulo Bento "[',' ' the' ' Portuguese' ' coach' ',' ' is' ' a' ' very' ' important'
+ ' one' '.' ' He' ' is' ' the' ' one' ' who' ' will' ' decide' ' the'
+ ' future']" , the Portuguese coach , is a very important one . He is the one who will decide the future False direct contend with Paulo Bento and Rui Bento, 5 [' direct', ' contend', ' with', ' Paulo', ' Bent', 'o']
+1109 754 In their sport, the position played by x -1 In their sport, the position played by Paulo Bento midfielder Paulo Bento "[',' ' the' ' Portuguese' ' coach' ',' ' is' ' a' ' very' ' important'
+ ' one' '.' ' He' ' is' ' the' ' one' ' who' ' will' ' decide' ' the'
+ ' future']" , the Portuguese coach , is a very important one . He is the one who will decide the future False internationals like Paulo Bento and Amaral, but 5 [' international', 's', ' like', ' Paulo', ' Bent', 'o']
+1110 755 In their sport, the position played by x -1 In their sport, the position played by Bob Griese quarterback Bob Griese "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Fame quarterback Bob Griese said Kafka 5 [' Fame', ' quarterback', ' Bob', ' G', 'ries', 'e']
+1111 755 In their sport, the position played by x -1 In their sport, the position played by Bob Griese quarterback Bob Griese "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True quarterback Bob Griese combined with 4 [' quarterback', ' Bob', ' G', 'ries', 'e']
+1112 755 In their sport, the position played by x -1 In their sport, the position played by Bob Griese quarterback Bob Griese "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True touchdown passes from Bob Griese to Jim Mandich. 6 [' touchdown', ' passes', ' from', ' Bob', ' G', 'ries', 'e']
+1113 755 In their sport, the position played by x -1 In their sport, the position played by Bob Griese quarterback Bob Griese "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True rookie quarterback Bob Griese combined with 5 [' rookie', ' quarterback', ' Bob', ' G', 'ries', 'e']
+1114 755 In their sport, the position played by x -1 In their sport, the position played by Bob Griese quarterback Bob Griese "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True In 1967, quarterback Bob Griese and Clancy were 7 [' In', ' 1967', ',', ' quarterback', ' Bob', ' G', 'ries', 'e']
+1115 756 In their sport, the position played by x -1 In their sport, the position played by Frank Lampard midfielder Frank Lampard "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' Chelsea'
+ ' since' ' he' ' was' ' a' ' teenager' ',' ' is' ' a' ' key' ' one' '.']" , who has been a regular starter for Chelsea since he was a teenager , is a key one . False penalty, which Frank Lampard converted, in 5 [' penalty', ',', ' which', ' Frank', ' Lamp', 'ard']
+1116 756 In their sport, the position played by x -1 In their sport, the position played by Frank Lampard midfielder Frank Lampard "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' Chelsea'
+ ' since' ' he' ' was' ' a' ' teenager' ',' ' is' ' a' ' key' ' one' '.']" , who has been a regular starter for Chelsea since he was a teenager , is a key one . False Thierry Henry. Frank Lampard scored a late equaliser 7 [' Th', 'ier', 'ry', ' Henry', '.', ' Frank', ' Lamp', 'ard']
+1117 756 In their sport, the position played by x -1 In their sport, the position played by Frank Lampard midfielder Frank Lampard "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' Chelsea'
+ ' since' ' he' ' was' ' a' ' teenager' ',' ' is' ' a' ' key' ' one' '.']" , who has been a regular starter for Chelsea since he was a teenager , is a key one . False taken midfielder Frank Lampard from their own franchise 4 [' taken', ' midfielder', ' Frank', ' Lamp', 'ard']
+1118 756 In their sport, the position played by x -1 In their sport, the position played by Frank Lampard midfielder Frank Lampard "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' Chelsea'
+ ' since' ' he' ' was' ' a' ' teenager' ',' ' is' ' a' ' key' ' one' '.']" , who has been a regular starter for Chelsea since he was a teenager , is a key one . False his opponent Frank Lampard and in the 17th 4 [' his', ' opponent', ' Frank', ' Lamp', 'ard']
+1119 756 In their sport, the position played by x -1 In their sport, the position played by Frank Lampard midfielder Frank Lampard "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' Chelsea'
+ ' since' ' he' ' was' ' a' ' teenager' ',' ' is' ' a' ' key' ' one' '.']" , who has been a regular starter for Chelsea since he was a teenager , is a key one . False " difference between Frank Lampard and Sissoko.""" 4 [' difference', ' between', ' Frank', ' Lamp', 'ard']
+1120 758 In their sport, the position played by x -1 In their sport, the position played by Mike Holmgren quarterback Mike Holmgren "[',' ' the' ' Seahawks' ""'"" ' head' ' coach' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' played' ' by' ' the' ' Seahawks' ""'"" ' quarterback'
+ ',']" , the Seahawks ' head coach , is the same as the one played by the Seahawks ' quarterback , True quarterbacks, Browns president Mike Holmgren said McCoy 7 [' quarterbacks', ',', ' Browns', ' president', ' Mike', ' Hol', 'mg', 'ren']
+1121 758 In their sport, the position played by x -1 In their sport, the position played by Mike Holmgren quarterback Mike Holmgren "[',' ' the' ' Seahawks' ""'"" ' head' ' coach' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' played' ' by' ' the' ' Seahawks' ""'"" ' quarterback'
+ ',']" , the Seahawks ' head coach , is the same as the one played by the Seahawks ' quarterback , True and general manager Mike Holmgren told the media 6 [' and', ' general', ' manager', ' Mike', ' Hol', 'mg', 'ren']
+1122 758 In their sport, the position played by x -1 In their sport, the position played by Mike Holmgren quarterback Mike Holmgren "[',' ' the' ' Seahawks' ""'"" ' head' ' coach' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' played' ' by' ' the' ' Seahawks' ""'"" ' quarterback'
+ ',']" , the Seahawks ' head coach , is the same as the one played by the Seahawks ' quarterback , True Seahawks head coach Mike Holmgren attributed the penalties 6 [' Seahawks', ' head', ' coach', ' Mike', ' Hol', 'mg', 'ren']
+1123 758 In their sport, the position played by x -1 In their sport, the position played by Mike Holmgren quarterback Mike Holmgren "[',' ' the' ' Seahawks' ""'"" ' head' ' coach' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' played' ' by' ' the' ' Seahawks' ""'"" ' quarterback'
+ ',']" , the Seahawks ' head coach , is the same as the one played by the Seahawks ' quarterback , True former Packers coach Mike Holmgren as team president, 6 [' former', ' Packers', ' coach', ' Mike', ' Hol', 'mg', 'ren']
+1124 758 In their sport, the position played by x -1 In their sport, the position played by Mike Holmgren quarterback Mike Holmgren "[',' ' the' ' Seahawks' ""'"" ' head' ' coach' ',' ' is' ' the' ' same'
+ ' as' ' the' ' one' ' played' ' by' ' the' ' Seahawks' ""'"" ' quarterback'
+ ',']" , the Seahawks ' head coach , is the same as the one played by the Seahawks ' quarterback , True head coach Mike Holmgren benched starting 5 [' head', ' coach', ' Mike', ' Hol', 'mg', 'ren']
+1125 759 In their sport, the position played by x -1 In their sport, the position played by Tiago Mendes midfielder Tiago Mendes "[',' ' who' ' is' ' now' ' at' ' the' ' helm' ' of' ' the' ' Portuguese'
+ ' national' ' team' ',' ' is' ' a' ' very' ' important' ' one' '.' '\n']" ", who is now at the helm of the Portuguese national team , is a very important one .
+" False Jérémy Toulalan and Tiago Mendes as his starting 11 [' J', 'é', 'ré', 'my', ' T', 'oul', 'alan', ' and', ' Ti', 'ago', ' Mend', 'es']
+1126 760 In their sport, the position played by x -1 In their sport, the position played by Kerry Collins quarterback Kerry Collins "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Giants'
+ ',' ' is' ' the' ' same' ' position' ' that' ' the' ' Giants' ' are'
+ ' in' ' now']" , who was a great quarterback for the Giants , is the same position that the Giants are in now True time in four seasons. Kerry Collins continued to have 6 [' time', ' in', ' four', ' seasons', '.', ' Kerry', ' Collins']
+1127 760 In their sport, the position played by x -1 In their sport, the position played by Kerry Collins quarterback Kerry Collins "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Giants'
+ ',' ' is' ' the' ' same' ' position' ' that' ' the' ' Giants' ' are'
+ ' in' ' now']" , who was a great quarterback for the Giants , is the same position that the Giants are in now True Ravens'defense harassed Kerry Collins all game long, and 6 "[' Ravens', ""'d"", 'ef', 'ense', ' harassed', ' Kerry', ' Collins']"
+1128 760 In their sport, the position played by x -1 In their sport, the position played by Kerry Collins quarterback Kerry Collins "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Giants'
+ ',' ' is' ' the' ' same' ' position' ' that' ' the' ' Giants' ' are'
+ ' in' ' now']" , who was a great quarterback for the Giants , is the same position that the Giants are in now True team signed Kerry Collins out of retirement 3 [' team', ' signed', ' Kerry', ' Collins']
+1129 760 In their sport, the position played by x -1 In their sport, the position played by Kerry Collins quarterback Kerry Collins "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Giants'
+ ',' ' is' ' the' ' same' ' position' ' that' ' the' ' Giants' ' are'
+ ' in' ' now']" , who was a great quarterback for the Giants , is the same position that the Giants are in now True the team signed Kerry Collins out of retirement 4 [' the', ' team', ' signed', ' Kerry', ' Collins']
+1130 760 In their sport, the position played by x -1 In their sport, the position played by Kerry Collins quarterback Kerry Collins "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Giants'
+ ',' ' is' ' the' ' same' ' position' ' that' ' the' ' Giants' ' are'
+ ' in' ' now']" , who was a great quarterback for the Giants , is the same position that the Giants are in now True who provided Kerry Collins with yet another 3 [' who', ' provided', ' Kerry', ' Collins']
+1131 764 In their sport, the position played by x -1 In their sport, the position played by Beniamino Vignola midfielder Beniamino Vignola "[',' ' the' ' Italian' ' painter' ',' ' is' ' that' ' of' ' the' '\n' '\n'
+ 'The' ' Italian' ' painter' ',' ' Ben' 'iam' 'ino' ' V' 'ign']" ", the Italian painter , is that of the
+
+ The Italian painter , Ben iam ino V ign" False the lead through Beniamino Vignola in the 13th minute. 8 [' the', ' lead', ' through', ' Ben', 'iam', 'ino', ' V', 'ign', 'ola']
+1132 764 In their sport, the position played by x -1 In their sport, the position played by Beniamino Vignola midfielder Beniamino Vignola "[',' ' the' ' Italian' ' painter' ',' ' is' ' that' ' of' ' the' '\n' '\n'
+ 'The' ' Italian' ' painter' ',' ' Ben' 'iam' 'ino' ' V' 'ign']" ", the Italian painter , is that of the
+
+ The Italian painter , Ben iam ino V ign" False lead through Beniamino Vignola in the 13th minute. 7 [' lead', ' through', ' Ben', 'iam', 'ino', ' V', 'ign', 'ola']
+1133 764 In their sport, the position played by x -1 In their sport, the position played by Beniamino Vignola midfielder Beniamino Vignola "[',' ' the' ' Italian' ' painter' ',' ' is' ' that' ' of' ' the' '\n' '\n'
+ 'The' ' Italian' ' painter' ',' ' Ben' 'iam' 'ino' ' V' 'ign']" ", the Italian painter , is that of the
+
+ The Italian painter , Ben iam ino V ign" False lead through Beniamino Vignola in the 13th minute. 7 [' lead', ' through', ' Ben', 'iam', 'ino', ' V', 'ign', 'ola']
+1134 764 In their sport, the position played by x -1 In their sport, the position played by Beniamino Vignola midfielder Beniamino Vignola "[',' ' the' ' Italian' ' painter' ',' ' is' ' that' ' of' ' the' '\n' '\n'
+ 'The' ' Italian' ' painter' ',' ' Ben' 'iam' 'ino' ' V' 'ign']" ", the Italian painter , is that of the
+
+ The Italian painter , Ben iam ino V ign" False the lead through Beniamino Vignola in the 13th minute. 8 [' the', ' lead', ' through', ' Ben', 'iam', 'ino', ' V', 'ign', 'ola']
+1135 765 In their sport, the position played by x -1 In their sport, the position played by Matt Flynn quarterback Matt Flynn "[',' ' who' ' was' ' a' ' backup' ' for' ' the' ' Packers' ' in' ' 2011'
+ ',' ' is' ' now' ' the' ' starter' '.' '\n' '\n' 'The' ' Packers']" ", who was a backup for the Packers in 2011 , is now the starter .
+
+ The Packers" False 1 ['Matt', ' Flynn']
+1136 765 In their sport, the position played by x -1 In their sport, the position played by Matt Flynn quarterback Matt Flynn "[',' ' who' ' was' ' a' ' backup' ' for' ' the' ' Packers' ' in' ' 2011'
+ ',' ' is' ' now' ' the' ' starter' '.' '\n' '\n' 'The' ' Packers']" ", who was a backup for the Packers in 2011 , is now the starter .
+
+ The Packers" False 1 ['Matt', ' Flynn']
+1137 767 In their sport, the position played by x -1 In their sport, the position played by Von Miller linebacker Von Miller "[' is' ' a' ' bit' ' of' ' a' ' mystery' '.' ' He' ""'s"" ' a' ' great'
+ ' player' ',' ' but' ' he' ""'s"" ' not' ' a' ' great' ' player']" is a bit of a mystery . He 's a great player , but he 's not a great player False Super Bowl MVP Von Miller of the Denver Broncos. 4 [' Super', ' Bowl', ' MVP', ' Von', ' Miller']
+1138 768 In their sport, the position played by x -1 In their sport, the position played by Rasmus Lindgren midfielder Rasmus Lindgren "[',' ' who' ' is' ' a' ' former' ' NHL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the'
+ ' Oilers']" , who is a former NHL player , is a perfect example of the type of player that the Oilers False vice-captain, with Rasmus Lindgren as captain. He began 9 [' vice', '-', 'capt', 'ain', ',', ' with', ' Rasm', 'us', ' Lind', 'gren']
+1139 768 In their sport, the position played by x -1 In their sport, the position played by Rasmus Lindgren midfielder Rasmus Lindgren "[',' ' who' ' is' ' a' ' former' ' NHL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the'
+ ' Oilers']" , who is a former NHL player , is a perfect example of the type of player that the Oilers False vice-captain, with Rasmus Lindgren as captain. He 9 [' vice', '-', 'capt', 'ain', ',', ' with', ' Rasm', 'us', ' Lind', 'gren']
+1140 768 In their sport, the position played by x -1 In their sport, the position played by Rasmus Lindgren midfielder Rasmus Lindgren "[',' ' who' ' is' ' a' ' former' ' NHL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the'
+ ' Oilers']" , who is a former NHL player , is a perfect example of the type of player that the Oilers False vice-captain, with Rasmus Lindgren as captain. He 9 [' vice', '-', 'capt', 'ain', ',', ' with', ' Rasm', 'us', ' Lind', 'gren']
+1141 769 In their sport, the position played by x -1 In their sport, the position played by Trent Green quarterback Trent Green "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Packers'
+ ' in' ' the' ' 1990' 's' ',' ' and' ' the' ' Packers' ' have' ' a' ' lot']" , who was a great quarterback for the Packers in the 1990 s , and the Packers have a lot True Chiefs quarterback Trent Green serves as the color 3 [' Chiefs', ' quarterback', ' Trent', ' Green']
+1142 769 In their sport, the position played by x -1 In their sport, the position played by Trent Green quarterback Trent Green "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Packers'
+ ' in' ' the' ' 1990' 's' ',' ' and' ' the' ' Packers' ' have' ' a' ' lot']" , who was a great quarterback for the Packers in the 1990 s , and the Packers have a lot True for quarterback Trent Green and signed free 3 [' for', ' quarterback', ' Trent', ' Green']
+1143 769 In their sport, the position played by x -1 In their sport, the position played by Trent Green quarterback Trent Green "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Packers'
+ ' in' ' the' ' 1990' 's' ',' ' and' ' the' ' Packers' ' have' ' a' ' lot']" , who was a great quarterback for the Packers in the 1990 s , and the Packers have a lot True quarterback Trent Green serves as the color 2 [' quarterback', ' Trent', ' Green']
+1144 769 In their sport, the position played by x -1 In their sport, the position played by Trent Green quarterback Trent Green "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Packers'
+ ' in' ' the' ' 1990' 's' ',' ' and' ' the' ' Packers' ' have' ' a' ' lot']" , who was a great quarterback for the Packers in the 1990 s , and the Packers have a lot True 2 ['T', 'rent', ' Green']
+1145 769 In their sport, the position played by x -1 In their sport, the position played by Trent Green quarterback Trent Green "[',' ' who' ' was' ' a' ' great' ' quarterback' ' for' ' the' ' Packers'
+ ' in' ' the' ' 1990' 's' ',' ' and' ' the' ' Packers' ' have' ' a' ' lot']" , who was a great quarterback for the Packers in the 1990 s , and the Packers have a lot True 2 ['T', 'rent', ' Green']
+1146 770 In their sport, the position played by x -1 In their sport, the position played by Alessio Cerci midfielder Alessio Cerci "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Italian'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 16' ',' ' is' ' a'
+ ' key']" , who has been a regular in the Italian national team since the age of 16 , is a key False was agreed after Alessio Cerci went the opposite 6 [' was', ' agreed', ' after', ' Aless', 'io', ' Cer', 'ci']
+1147 771 In their sport, the position played by x -1 In their sport, the position played by Glenn Hall goaltender Glenn Hall "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False Chicago as punishment; Glenn Hall was included as he 5 [' Chicago', ' as', ' punishment', ';', ' Glenn', ' Hall']
+1148 771 In their sport, the position played by x -1 In their sport, the position played by Glenn Hall goaltender Glenn Hall "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False popularized by Glenn Hall and was improved 4 [' popular', 'ized', ' by', ' Glenn', ' Hall']
+1149 771 In their sport, the position played by x -1 In their sport, the position played by Glenn Hall goaltender Glenn Hall "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False Chicago as punishment; Glenn Hall was included as he 5 [' Chicago', ' as', ' punishment', ';', ' Glenn', ' Hall']
+1150 771 In their sport, the position played by x -1 In their sport, the position played by Glenn Hall goaltender Glenn Hall "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False of Ted Lindsay and Glenn Hall from Detroit. Making 5 [' of', ' Ted', ' Lindsay', ' and', ' Glenn', ' Hall']
+1151 771 In their sport, the position played by x -1 In their sport, the position played by Glenn Hall goaltender Glenn Hall "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' man' ' who' ' is' ' a' ' bit' ' of' ' a' ' l' 'oner']" in the film , the character of the hero is a man who is a bit of a l oner False moving him past Glenn Hall for 3rd place 4 [' moving', ' him', ' past', ' Glenn', ' Hall']
+1152 772 In their sport, the position played by x -1 In their sport, the position played by Stephone Anthony linebacker Stephone Anthony "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' women'
+ ""'s"" ' national' ' team' ' that' ' won' ' the' ' gold' ' medal']" , who was a member of the U . S . women 's national team that won the gold medal False (Mississippi State), Stephone Anthony (Clemson) 8 [' (', 'Miss', 'iss', 'ippi', ' State', '),', ' Steph', 'one', ' Anthony']
+1153 772 In their sport, the position played by x -1 In their sport, the position played by Stephone Anthony linebacker Stephone Anthony "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' women'
+ ""'s"" ' national' ' team' ' that' ' won' ' the' ' gold' ' medal']" , who was a member of the U . S . women 's national team that won the gold medal False (Mississippi State), Stephone Anthony (Clemson) and 8 [' (', 'Miss', 'iss', 'ippi', ' State', '),', ' Steph', 'one', ' Anthony']
+1154 772 In their sport, the position played by x -1 In their sport, the position played by Stephone Anthony linebacker Stephone Anthony "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' women'
+ ""'s"" ' national' ' team' ' that' ' won' ' the' ' gold' ' medal']" , who was a member of the U . S . women 's national team that won the gold medal False (Mississippi State), Stephone Anthony (Clemson) and Paul 8 [' (', 'Miss', 'iss', 'ippi', ' State', '),', ' Steph', 'one', ' Anthony']
+1155 772 In their sport, the position played by x -1 In their sport, the position played by Stephone Anthony linebacker Stephone Anthony "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' women'
+ ""'s"" ' national' ' team' ' that' ' won' ' the' ' gold' ' medal']" , who was a member of the U . S . women 's national team that won the gold medal False (Mississippi State), Stephone Anthony (Clemson) and Paul 8 [' (', 'Miss', 'iss', 'ippi', ' State', '),', ' Steph', 'one', ' Anthony']
+1156 775 In their sport, the position played by x -1 In their sport, the position played by Joe Paterno quarterback Joe Paterno "[' and' ' the' ' Penn' ' State' ' football' ' team' ' is' ' that' ' of'
+ ' a' ' victim' '.' ' They' ' are' ' the' ' ones' ' who' ' are' ' being'
+ ' victimized']" and the Penn State football team is that of a victim . They are the ones who are being victimized False " halftime deficit in the Joe Paterno era.
+" 7 [' halftime', ' deficit', ' in', ' the', ' Joe', ' P', 'ater', 'no']
+1157 775 In their sport, the position played by x -1 In their sport, the position played by Joe Paterno quarterback Joe Paterno "[' and' ' the' ' Penn' ' State' ' football' ' team' ' is' ' that' ' of'
+ ' a' ' victim' '.' ' They' ' are' ' the' ' ones' ' who' ' are' ' being'
+ ' victimized']" and the Penn State football team is that of a victim . They are the ones who are being victimized False final career loss for Joe Paterno as the Nittany 7 [' final', ' career', ' loss', ' for', ' Joe', ' P', 'ater', 'no']
+1158 775 In their sport, the position played by x -1 In their sport, the position played by Joe Paterno quarterback Joe Paterno "[' and' ' the' ' Penn' ' State' ' football' ' team' ' is' ' that' ' of'
+ ' a' ' victim' '.' ' They' ' are' ' the' ' ones' ' who' ' are' ' being'
+ ' victimized']" and the Penn State football team is that of a victim . They are the ones who are being victimized False " deficit in the Joe Paterno era.
+" 6 [' deficit', ' in', ' the', ' Joe', ' P', 'ater', 'no']
+1159 775 In their sport, the position played by x -1 In their sport, the position played by Joe Paterno quarterback Joe Paterno "[' and' ' the' ' Penn' ' State' ' football' ' team' ' is' ' that' ' of'
+ ' a' ' victim' '.' ' They' ' are' ' the' ' ones' ' who' ' are' ' being'
+ ' victimized']" and the Penn State football team is that of a victim . They are the ones who are being victimized False holdover from the Joe Paterno era, departed, 7 [' hold', 'over', ' from', ' the', ' Joe', ' P', 'ater', 'no']
+1160 775 In their sport, the position played by x -1 In their sport, the position played by Joe Paterno quarterback Joe Paterno "[' and' ' the' ' Penn' ' State' ' football' ' team' ' is' ' that' ' of'
+ ' a' ' victim' '.' ' They' ' are' ' the' ' ones' ' who' ' are' ' being'
+ ' victimized']" and the Penn State football team is that of a victim . They are the ones who are being victimized False " Penn State coach Joe Paterno once said ""While" 6 [' Penn', ' State', ' coach', ' Joe', ' P', 'ater', 'no']
+1161 776 In their sport, the position played by x -1 In their sport, the position played by Ubaldo Righetti defender Ubaldo Righetti "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' hat' '-' 'tr'
+ 'ick' ' in' ' the' ' tournament' ',' ' was' ' a' ' crucial' ' one']" , who was the first to score a hat - tr ick in the tournament , was a crucial one False crossbar. Souness and Ubaldo Righetti converted 10 [' cross', 'bar', '.', ' S', 'oun', 'ess', ' and', ' Ub', 'aldo', ' Rig', 'hetti']
+1162 776 In their sport, the position played by x -1 In their sport, the position played by Ubaldo Righetti defender Ubaldo Righetti "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' hat' '-' 'tr'
+ 'ick' ' in' ' the' ' tournament' ',' ' was' ' a' ' crucial' ' one']" , who was the first to score a hat - tr ick in the tournament , was a crucial one False crossbar. Souness and Ubaldo Righetti converted their 10 [' cross', 'bar', '.', ' S', 'oun', 'ess', ' and', ' Ub', 'aldo', ' Rig', 'hetti']
+1163 776 In their sport, the position played by x -1 In their sport, the position played by Ubaldo Righetti defender Ubaldo Righetti "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' hat' '-' 'tr'
+ 'ick' ' in' ' the' ' tournament' ',' ' was' ' a' ' crucial' ' one']" , who was the first to score a hat - tr ick in the tournament , was a crucial one False crossbar. Souness and Ubaldo Righetti converted their respective 10 [' cross', 'bar', '.', ' S', 'oun', 'ess', ' and', ' Ub', 'aldo', ' Rig', 'hetti']
+1164 777 In their sport, the position played by x -1 In their sport, the position played by Blaine Gabbert quarterback Blaine Gabbert "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' '\n' '\n' 'The' ' Jaguars' ' are' ' a' ' team' ' that' ' is']" " is similar to that of the quarterback in the NFL .
+
+ The Jaguars are a team that is" True 49ers quarterback Blaine Gabbert had been improving, 6 [' 49', 'ers', ' quarterback', ' Bl', 'aine', ' Gab', 'bert']
+1165 777 In their sport, the position played by x -1 In their sport, the position played by Blaine Gabbert quarterback Blaine Gabbert "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' '\n' '\n' 'The' ' Jaguars' ' are' ' a' ' team' ' that' ' is']" " is similar to that of the quarterback in the NFL .
+
+ The Jaguars are a team that is" True Hyde. 49ers quarterback Blaine Gabbert had been improving, 8 [' Hyde', '.', ' 49', 'ers', ' quarterback', ' Bl', 'aine', ' Gab', 'bert']
+1166 777 In their sport, the position played by x -1 In their sport, the position played by Blaine Gabbert quarterback Blaine Gabbert "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' '\n' '\n' 'The' ' Jaguars' ' are' ' a' ' team' ' that' ' is']" " is similar to that of the quarterback in the NFL .
+
+ The Jaguars are a team that is" True quarterback Blaine Gabbert had been improving, 4 [' quarterback', ' Bl', 'aine', ' Gab', 'bert']
+1167 777 In their sport, the position played by x -1 In their sport, the position played by Blaine Gabbert quarterback Blaine Gabbert "[' is' ' similar' ' to' ' that' ' of' ' the' ' quarterback' ' in' ' the'
+ ' NFL' '.' '\n' '\n' 'The' ' Jaguars' ' are' ' a' ' team' ' that' ' is']" " is similar to that of the quarterback in the NFL .
+
+ The Jaguars are a team that is" True 49ers quarterback Blaine Gabbert had been improving, 6 [' 49', 'ers', ' quarterback', ' Bl', 'aine', ' Gab', 'bert']
+1168 778 In their sport, the position played by x -1 In their sport, the position played by Walter Johnson pitcher Walter Johnson "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' '�']" in the first half of the season , the team has been a bit of a mess . They � False (37 in 1908), Walter Johnson (36 in 1913) 6 [' (', '37', ' in', ' 1908', '),', ' Walter', ' Johnson']
+1169 778 In their sport, the position played by x -1 In their sport, the position played by Walter Johnson pitcher Walter Johnson "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' '�']" in the first half of the season , the team has been a bit of a mess . They � False despite dueling Walter Johnson up to the 16th inning. 4 [' despite', ' duel', 'ing', ' Walter', ' Johnson']
+1170 778 In their sport, the position played by x -1 In their sport, the position played by Walter Johnson pitcher Walter Johnson "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' '�']" in the first half of the season , the team has been a bit of a mess . They � False pitchers, the likes of Walter Johnson and Christy 6 [' pitchers', ',', ' the', ' likes', ' of', ' Walter', ' Johnson']
+1171 778 In their sport, the position played by x -1 In their sport, the position played by Walter Johnson pitcher Walter Johnson "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' '�']" in the first half of the season , the team has been a bit of a mess . They � False American League pitchers Walter Johnson from the Washington 4 [' American', ' League', ' pitchers', ' Walter', ' Johnson']
+1172 778 In their sport, the position played by x -1 In their sport, the position played by Walter Johnson pitcher Walter Johnson "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' the' ' team'
+ ' has' ' been' ' a' ' bit' ' of' ' a' ' mess' '.' ' They' '�']" in the first half of the season , the team has been a bit of a mess . They � False 2 ['Wal', 'ter', ' Johnson']
+1173 785 In their sport, the position played by x -1 In their sport, the position played by Pedro Beato pitcher Pedro Beato "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' '.' '\n' '\n' 'The' ' first' ' half' ' of' ' the']" ", who was the first to score a goal in the game .
+
+ The first half of the" False solo home run against Pedro Beato in a 7 – 6 6 [' solo', ' home', ' run', ' against', ' Pedro', ' Be', 'ato']
+1174 785 In their sport, the position played by x -1 In their sport, the position played by Pedro Beato pitcher Pedro Beato "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' '.' '\n' '\n' 'The' ' first' ' half' ' of' ' the']" ", who was the first to score a goal in the game .
+
+ The first half of the" False home run against Pedro Beato in a 7 – 6 victory 5 [' home', ' run', ' against', ' Pedro', ' Be', 'ato']
+1175 785 In their sport, the position played by x -1 In their sport, the position played by Pedro Beato pitcher Pedro Beato "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' '.' '\n' '\n' 'The' ' first' ' half' ' of' ' the']" ", who was the first to score a goal in the game .
+
+ The first half of the" False home run against Pedro Beato in a 7 – 6 victory 5 [' home', ' run', ' against', ' Pedro', ' Be', 'ato']
+1176 786 In their sport, the position played by x -1 In their sport, the position played by Jeff Francis pitcher Jeff Francis "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great person .
+
+ I think the only thing" False home run against Jeff Francis in the first game 4 [' home', ' run', ' against', ' Jeff', ' Francis']
+1177 786 In their sport, the position played by x -1 In their sport, the position played by Jeff Francis pitcher Jeff Francis "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' think' ' the' ' only' ' thing']" ", who was a great player , but not a great person .
+
+ I think the only thing" False home run against Jeff Francis in the first game 4 [' home', ' run', ' against', ' Jeff', ' Francis']
+1178 790 In their sport, the position played by x -1 In their sport, the position played by Tommy Docherty midfielder Tommy Docherty "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' manager'
+ ',' ' was' ' a' ' great' ' manager' '.' ' He' ' was' ' a' ' great']" , who was a great player and a great manager , was a great manager . He was a great False 1960s Chelsea manager Tommy Docherty changed the kit 6 [' 1960', 's', ' Chelsea', ' manager', ' Tommy', ' Doc', 'herty']
+1179 790 In their sport, the position played by x -1 In their sport, the position played by Tommy Docherty midfielder Tommy Docherty "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' manager'
+ ',' ' was' ' a' ' great' ' manager' '.' ' He' ' was' ' a' ' great']" , who was a great player and a great manager , was a great manager . He was a great False replaced by Tommy Docherty in December 4 [' replaced', ' by', ' Tommy', ' Doc', 'herty']
+1180 790 In their sport, the position played by x -1 In their sport, the position played by Tommy Docherty midfielder Tommy Docherty "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' manager'
+ ',' ' was' ' a' ' great' ' manager' '.' ' He' ' was' ' a' ' great']" , who was a great player and a great manager , was a great manager . He was a great False Two days later Tommy Docherty was appointed 5 [' Two', ' days', ' later', ' Tommy', ' Doc', 'herty']
+1181 790 In their sport, the position played by x -1 In their sport, the position played by Tommy Docherty midfielder Tommy Docherty "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' manager'
+ ',' ' was' ' a' ' great' ' manager' '.' ' He' ' was' ' a' ' great']" , who was a great player and a great manager , was a great manager . He was a great False 3 ['Tom', 'my', ' Doc', 'herty']
+1182 790 In their sport, the position played by x -1 In their sport, the position played by Tommy Docherty midfielder Tommy Docherty "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' manager'
+ ',' ' was' ' a' ' great' ' manager' '.' ' He' ' was' ' a' ' great']" , who was a great player and a great manager , was a great manager . He was a great False 1968. Two days later Tommy Docherty was appointed as 7 [' 1968', '.', ' Two', ' days', ' later', ' Tommy', ' Doc', 'herty']
+1183 791 In their sport, the position played by x -1 In their sport, the position played by Ray Houghton midfielder Ray Houghton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False Kevin Richardson, Ray Houghton and Shaun Teale. 6 [' Kevin', ' Richardson', ',', ' Ray', ' H', 'ough', 'ton']
+1184 791 In their sport, the position played by x -1 In their sport, the position played by Ray Houghton midfielder Ray Houghton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False Richardson, Ray Houghton and Shaun Teale. 5 [' Richardson', ',', ' Ray', ' H', 'ough', 'ton']
+1185 791 In their sport, the position played by x -1 In their sport, the position played by Ray Houghton midfielder Ray Houghton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False 0 win after Ray Houghton secured an early 6 [' 0', ' win', ' after', ' Ray', ' H', 'ough', 'ton']
+1186 791 In their sport, the position played by x -1 In their sport, the position played by Ray Houghton midfielder Ray Houghton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False Trevor Hebberd, Ray Houghton and Jeremy Charles. 8 [' Trevor', ' Heb', 'ber', 'd', ',', ' Ray', ' H', 'ough', 'ton']
+1187 791 In their sport, the position played by x -1 In their sport, the position played by Ray Houghton midfielder Ray Houghton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' who' ' was']" , who was a member of the team that won the first World Cup in 1930 , and who was False Kevin Richardson, Ray Houghton and Shaun Teale. They 6 [' Kevin', ' Richardson', ',', ' Ray', ' H', 'ough', 'ton']
+1188 793 In their sport, the position played by x -1 In their sport, the position played by John Brantley quarterback John Brantley "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Series' ' in' ' 1903' ',' ' was' ' filled'
+ ' by']" , who was a member of the team that won the first World Series in 1903 , was filled by False intercepted a John Brantley pass. The game marked 5 [' intercepted', ' a', ' John', ' Br', 'ant', 'ley']
+1189 793 In their sport, the position played by x -1 In their sport, the position played by John Brantley quarterback John Brantley "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Series' ' in' ' 1903' ',' ' was' ' filled'
+ ' by']" , who was a member of the team that won the first World Series in 1903 , was filled by False early 7 – 0 lead when John Brantley threw a 65-yard touchdown 9 [' early', ' 7', ' –', ' 0', ' lead', ' when', ' John', ' Br', 'ant', 'ley']
+1190 793 In their sport, the position played by x -1 In their sport, the position played by John Brantley quarterback John Brantley "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Series' ' in' ' 1903' ',' ' was' ' filled'
+ ' by']" , who was a member of the team that won the first World Series in 1903 , was filled by False Kirkpatrick intercepted a John Brantley pass. The 7 [' Kirk', 'patrick', ' intercepted', ' a', ' John', ' Br', 'ant', 'ley']
+1191 795 In their sport, the position played by x -1 In their sport, the position played by Tim Tebow quarterback Tim Tebow "[' in' ' the' ' NFL' ',' ' the' ' quarterback' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader']" in the NFL , the quarterback is the most important position on the field . The quarterback is the leader True Florida quarterback Tim Tebow — and 17 touchdowns 4 [' Florida', ' quarterback', ' Tim', ' Te', 'bow']
+1192 795 In their sport, the position played by x -1 In their sport, the position played by Tim Tebow quarterback Tim Tebow "[' in' ' the' ' NFL' ',' ' the' ' quarterback' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader']" in the NFL , the quarterback is the most important position on the field . The quarterback is the leader True presence of Tim Tebow created a controversy 4 [' presence', ' of', ' Tim', ' Te', 'bow']
+1193 795 In their sport, the position played by x -1 In their sport, the position played by Tim Tebow quarterback Tim Tebow "[' in' ' the' ' NFL' ',' ' the' ' quarterback' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader']" in the NFL , the quarterback is the most important position on the field . The quarterback is the leader True Fellow quarterback Tim Tebow said of McCoy, 4 [' Fellow', ' quarterback', ' Tim', ' Te', 'bow']
+1194 795 In their sport, the position played by x -1 In their sport, the position played by Tim Tebow quarterback Tim Tebow "[' in' ' the' ' NFL' ',' ' the' ' quarterback' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader']" in the NFL , the quarterback is the most important position on the field . The quarterback is the leader True Heisman Trophy winner Tim Tebow were followed 5 [' Heisman', ' Trophy', ' winner', ' Tim', ' Te', 'bow']
+1195 795 In their sport, the position played by x -1 In their sport, the position played by Tim Tebow quarterback Tim Tebow "[' in' ' the' ' NFL' ',' ' the' ' quarterback' ' is' ' the' ' most'
+ ' important' ' position' ' on' ' the' ' field' '.' ' The' ' quarterback'
+ ' is' ' the' ' leader']" in the NFL , the quarterback is the most important position on the field . The quarterback is the leader True " Fellow quarterback Tim Tebow said of McCoy, ""I was" 4 [' Fellow', ' quarterback', ' Tim', ' Te', 'bow']
+1196 799 In their sport, the position played by x -1 In their sport, the position played by Johan Hedberg goaltender Johan Hedberg "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False game three. After Johan Hedberg initially replaced 7 [' game', ' three', '.', ' After', ' Joh', 'an', ' Hed', 'berg']
+1197 799 In their sport, the position played by x -1 In their sport, the position played by Johan Hedberg goaltender Johan Hedberg "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False three. After Johan Hedberg initially replaced 6 [' three', '.', ' After', ' Joh', 'an', ' Hed', 'berg']
+1198 799 In their sport, the position played by x -1 In their sport, the position played by Johan Hedberg goaltender Johan Hedberg "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False game three. After Johan Hedberg initially replaced 7 [' game', ' three', '.', ' After', ' Joh', 'an', ' Hed', 'berg']
+1199 799 In their sport, the position played by x -1 In their sport, the position played by Johan Hedberg goaltender Johan Hedberg "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False three. After Johan Hedberg initially replaced 6 [' three', '.', ' After', ' Joh', 'an', ' Hed', 'berg']
+1200 799 In their sport, the position played by x -1 In their sport, the position played by Johan Hedberg goaltender Johan Hedberg "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False of game three. After Johan Hedberg initially replaced 8 [' of', ' game', ' three', '.', ' After', ' Joh', 'an', ' Hed', 'berg']
+1201 800 In their sport, the position played by x -1 In their sport, the position played by Ben Mauk quarterback Ben Mauk "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference True Bearcats hoped quarterback Ben Mauk would be allowed 6 [' Bear', 'cats', ' hoped', ' quarterback', ' Ben', ' Mau', 'k']
+1202 800 In their sport, the position played by x -1 In their sport, the position played by Ben Mauk quarterback Ben Mauk "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference True quarterback Ben Mauk would be allowed to 3 [' quarterback', ' Ben', ' Mau', 'k']
+1203 800 In their sport, the position played by x -1 In their sport, the position played by Ben Mauk quarterback Ben Mauk "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference True quarterback Ben Mauk would be allowed 3 [' quarterback', ' Ben', ' Mau', 'k']
+1204 801 In their sport, the position played by x -1 In their sport, the position played by Don Majkowski quarterback Don Majkowski "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True replaced quarterback Don Majkowski with Randy Wright. 4 [' replaced', ' quarterback', ' Don', ' Maj', 'kowski']
+1205 801 In their sport, the position played by x -1 In their sport, the position played by Don Majkowski quarterback Don Majkowski "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True quarterback Don Majkowski and Favre played 3 [' quarterback', ' Don', ' Maj', 'kowski']
+1206 801 In their sport, the position played by x -1 In their sport, the position played by Don Majkowski quarterback Don Majkowski "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True replaced quarterback Don Majkowski with Randy Wright. 4 [' replaced', ' quarterback', ' Don', ' Maj', 'kowski']
+1207 801 In their sport, the position played by x -1 In their sport, the position played by Don Majkowski quarterback Don Majkowski "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' quarterback' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' thing']" ", who was a great player , but not a great quarterback .
+
+ I think the best thing" True quarterback Don Majkowski and Favre played 3 [' quarterback', ' Don', ' Maj', 'kowski']
+1208 802 In their sport, the position played by x -1 In their sport, the position played by Jock Stein midfielder Jock Stein "[',' ' the' ' manager' ' of' ' Celtic' ',' ' was' ' a' ' very'
+ ' different' ' one' '.' ' He' ' was' ' a' ' man' ' who' ' had' ' been'
+ ' a']" , the manager of Celtic , was a very different one . He was a man who had been a False 2 ['J', 'ock', ' Stein']
+1209 802 In their sport, the position played by x -1 In their sport, the position played by Jock Stein midfielder Jock Stein "[',' ' the' ' manager' ' of' ' Celtic' ',' ' was' ' a' ' very'
+ ' different' ' one' '.' ' He' ' was' ' a' ' man' ' who' ' had' ' been'
+ ' a']" , the manager of Celtic , was a very different one . He was a man who had been a False Phase 3b, the Jock Stein Stand on the 7 [' Phase', ' 3', 'b', ',', ' the', ' J', 'ock', ' Stein']
+1210 802 In their sport, the position played by x -1 In their sport, the position played by Jock Stein midfielder Jock Stein "[',' ' the' ' manager' ' of' ' Celtic' ',' ' was' ' a' ' very'
+ ' different' ' one' '.' ' He' ' was' ' a' ' man' ' who' ' had' ' been'
+ ' a']" , the manager of Celtic , was a very different one . He was a man who had been a False " match their manager Jock Stein said that, ""Celtic" 5 [' match', ' their', ' manager', ' J', 'ock', ' Stein']
+1211 802 In their sport, the position played by x -1 In their sport, the position played by Jock Stein midfielder Jock Stein "[',' ' the' ' manager' ' of' ' Celtic' ',' ' was' ' a' ' very'
+ ' different' ' one' '.' ' He' ' was' ' a' ' man' ' who' ' had' ' been'
+ ' a']" , the manager of Celtic , was a very different one . He was a man who had been a False by Phase 3b, the Jock Stein Stand on the 8 [' by', ' Phase', ' 3', 'b', ',', ' the', ' J', 'ock', ' Stein']
+1212 802 In their sport, the position played by x -1 In their sport, the position played by Jock Stein midfielder Jock Stein "[',' ' the' ' manager' ' of' ' Celtic' ',' ' was' ' a' ' very'
+ ' different' ' one' '.' ' He' ' was' ' a' ' man' ' who' ' had' ' been'
+ ' a']" , the manager of Celtic , was a very different one . He was a man who had been a False inception, with Beattie and Jock Stein occupying it in 9 [' inception', ',', ' with', ' Be', 'att', 'ie', ' and', ' J', 'ock', ' Stein']
+1213 806 In their sport, the position played by x -1 In their sport, the position played by Keith Rivers linebacker Keith Rivers "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False " Brian Cushing, Keith Rivers and Rey Maualuga.
+" 5 [' Brian', ' C', 'ushing', ',', ' Keith', ' Rivers']
+1214 806 In their sport, the position played by x -1 In their sport, the position played by Keith Rivers linebacker Keith Rivers "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False Trojans. Linebacker Keith Rivers suffered a high ankle 7 [' Tro', 'j', 'ans', '.', ' Line', 'backer', ' Keith', ' Rivers']
+1215 806 In their sport, the position played by x -1 In their sport, the position played by Keith Rivers linebacker Keith Rivers "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False " by Brian Cushing, Keith Rivers and Rey Maualuga.
+" 6 [' by', ' Brian', ' C', 'ushing', ',', ' Keith', ' Rivers']
+1216 806 In their sport, the position played by x -1 In their sport, the position played by Keith Rivers linebacker Keith Rivers "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False " led by Brian Cushing, Keith Rivers and Rey Maualuga.
+" 7 [' led', ' by', ' Brian', ' C', 'ushing', ',', ' Keith', ' Rivers']
+1217 806 In their sport, the position played by x -1 In their sport, the position played by Keith Rivers linebacker Keith Rivers "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False Trojans. Linebacker Keith Rivers suffered a high ankle 7 [' Tro', 'j', 'ans', '.', ' Line', 'backer', ' Keith', ' Rivers']
+1218 814 In their sport, the position played by x -1 In their sport, the position played by Mule Haas outfielder Mule Haas "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False inside-the-park home run by Mule Haas as the Athletics 10 [' inside', '-', 'the', '-', 'park', ' home', ' run', ' by', ' M', 'ule', ' Haas']
+1219 814 In their sport, the position played by x -1 In their sport, the position played by Mule Haas outfielder Mule Haas "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False inside-the-park home run by Mule Haas as the Athletics 10 [' inside', '-', 'the', '-', 'park', ' home', ' run', ' by', ' M', 'ule', ' Haas']
+1220 814 In their sport, the position played by x -1 In their sport, the position played by Mule Haas outfielder Mule Haas "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False inside-the-park home run by Mule Haas as the Athletics 10 [' inside', '-', 'the', '-', 'park', ' home', ' run', ' by', ' M', 'ule', ' Haas']
+1221 814 In their sport, the position played by x -1 In their sport, the position played by Mule Haas outfielder Mule Haas "[',' ' the' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n' 'The' '\n' '\n'
+ 'The' '\n' '\n' 'The' '\n' '\n' 'The']" ", the
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The
+
+ The" False Series game since Mule Haas in the 1929 World 5 [' Series', ' game', ' since', ' M', 'ule', ' Haas']
+1222 815 In their sport, the position played by x -1 In their sport, the position played by David Beckham midfielder David Beckham "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' former'
+ ' England' ' captain' ' was' ' a' ' key' ' figure' ' in' ' the' ' team'
+ ""'s""]" " in the World Cup final .
+
+ The former England captain was a key figure in the team 's" False " changing room that hit David Beckham above the left eye.
+" 5 [' changing', ' room', ' that', ' hit', ' David', ' Beckham']
+1223 815 In their sport, the position played by x -1 In their sport, the position played by David Beckham midfielder David Beckham "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' former'
+ ' England' ' captain' ' was' ' a' ' key' ' figure' ' in' ' the' ' team'
+ ""'s""]" " in the World Cup final .
+
+ The former England captain was a key figure in the team 's" False childhood idol David Beckham while at the 3 [' childhood', ' idol', ' David', ' Beckham']
+1224 815 In their sport, the position played by x -1 In their sport, the position played by David Beckham midfielder David Beckham "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' former'
+ ' England' ' captain' ' was' ' a' ' key' ' figure' ' in' ' the' ' team'
+ ""'s""]" " in the World Cup final .
+
+ The former England captain was a key figure in the team 's" False 1 ['David', ' Beckham']
+1225 815 In their sport, the position played by x -1 In their sport, the position played by David Beckham midfielder David Beckham "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' former'
+ ' England' ' captain' ' was' ' a' ' key' ' figure' ' in' ' the' ' team'
+ ""'s""]" " in the World Cup final .
+
+ The former England captain was a key figure in the team 's" False United player David Beckham was seen wearing 3 [' United', ' player', ' David', ' Beckham']
+1226 815 In their sport, the position played by x -1 In their sport, the position played by David Beckham midfielder David Beckham "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' former'
+ ' England' ' captain' ' was' ' a' ' key' ' figure' ' in' ' the' ' team'
+ ""'s""]" " in the World Cup final .
+
+ The former England captain was a key figure in the team 's" False Carlos, Raúl, David Beckham and Fabio Cannavaro. 7 [' Carlos', ',', ' Ra', 'ú', 'l', ',', ' David', ' Beckham']
+1227 816 In their sport, the position played by x -1 In their sport, the position played by Mike Richter goaltender Mike Richter "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' goalie']" ", who was a great player , but not a great goalie .
+
+ I think the best goalie" False looked up to Mike Richter of the New York 5 [' looked', ' up', ' to', ' Mike', ' Rich', 'ter']
+1228 816 In their sport, the position played by x -1 In their sport, the position played by Mike Richter goaltender Mike Richter "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' goalie']" ", who was a great player , but not a great goalie .
+
+ I think the best goalie" False Schneider looked up to Mike Richter of the New York Rangers 6 [' Schneider', ' looked', ' up', ' to', ' Mike', ' Rich', 'ter']
+1229 816 In their sport, the position played by x -1 In their sport, the position played by Mike Richter goaltender Mike Richter "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' goalie']" ", who was a great player , but not a great goalie .
+
+ I think the best goalie" False Schneider looked up to Mike Richter of the New York 6 [' Schneider', ' looked', ' up', ' to', ' Mike', ' Rich', 'ter']
+1230 816 In their sport, the position played by x -1 In their sport, the position played by Mike Richter goaltender Mike Richter "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' best' ' goalie']" ", who was a great player , but not a great goalie .
+
+ I think the best goalie" False Schneider looked up to Mike Richter of the New York 6 [' Schneider', ' looked', ' up', ' to', ' Mike', ' Rich', 'ter']
+1231 817 In their sport, the position played by x -1 In their sport, the position played by Jeff Rutledge quarterback Jeff Rutledge "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False year, although Jeff Rutledge saw considerable 5 [' year', ',', ' although', ' Jeff', ' Rut', 'ledge']
+1232 817 In their sport, the position played by x -1 In their sport, the position played by Jeff Rutledge quarterback Jeff Rutledge "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False string quarterback Jeff Rutledge then looked 4 [' string', ' quarterback', ' Jeff', ' Rut', 'ledge']
+1233 817 In their sport, the position played by x -1 In their sport, the position played by Jeff Rutledge quarterback Jeff Rutledge "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False string quarterback Jeff Rutledge then looked 4 [' string', ' quarterback', ' Jeff', ' Rut', 'ledge']
+1234 817 In their sport, the position played by x -1 In their sport, the position played by Jeff Rutledge quarterback Jeff Rutledge "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False year, although Jeff Rutledge saw considerable 5 [' year', ',', ' although', ' Jeff', ' Rut', 'ledge']
+1235 817 In their sport, the position played by x -1 In their sport, the position played by Jeff Rutledge quarterback Jeff Rutledge "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False quarterback Jeff Rutledge then looked over to 3 [' quarterback', ' Jeff', ' Rut', 'ledge']
+1236 819 In their sport, the position played by x -1 In their sport, the position played by Warren Sapp tackle Warren Sapp "[',' ' the' ' former' ' NFL' ' defensive' ' lineman' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can'
+ ' be' ' a']" , the former NFL defensive lineman , is a perfect example of the kind of player who can be a False Derek Hough and Warren Sapp partnering 5 [' Derek', ' H', 'ough', ' and', ' Warren', ' Sapp']
+1237 820 In their sport, the position played by x -1 In their sport, the position played by John Huarte quarterback John Huarte "[' in' ' the' ' film' ',' ' the' '\n' '\n' 'The' ' film' ' is' ' a'
+ ' comedy' ',' ' but' ' it' ' is' ' also' ' a' ' drama' ',']" " in the film , the
+
+ The film is a comedy , but it is also a drama ," False quarterback John Huarte and wide receiver 3 [' quarterback', ' John', ' Hu', 'arte']
+1238 822 In their sport, the position played by x -1 In their sport, the position played by Ray Emery goaltender Ray Emery "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' Oilers' ' are']" ", who was a great player , but not a great goalie .
+
+ I think the Oilers are" False birthday, against Ray Emery of the Philadelphia 5 [' birthday', ',', ' against', ' Ray', ' Emer', 'y']
+1239 822 In their sport, the position played by x -1 In their sport, the position played by Ray Emery goaltender Ray Emery "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' Oilers' ' are']" ", who was a great player , but not a great goalie .
+
+ I think the Oilers are" False 29, 2014, against Ray Emery of the Philadelphia 7 [' 29', ',', ' 2014', ',', ' against', ' Ray', ' Emer', 'y']
+1240 822 In their sport, the position played by x -1 In their sport, the position played by Ray Emery goaltender Ray Emery "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' Oilers' ' are']" ", who was a great player , but not a great goalie .
+
+ I think the Oilers are" False with rookie netminder Ray Emery as their starter. Without 7 [' with', ' rookie', ' net', 'm', 'inder', ' Ray', ' Emer', 'y']
+1241 822 In their sport, the position played by x -1 In their sport, the position played by Ray Emery goaltender Ray Emery "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' Oilers' ' are']" ", who was a great player , but not a great goalie .
+
+ I think the Oilers are" False " Ray Emery =
+" 2 [' Ray', ' Emer', 'y']
+1242 822 In their sport, the position played by x -1 In their sport, the position played by Ray Emery goaltender Ray Emery "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' goalie' '.' '\n' '\n' 'I' ' think' ' the' ' Oilers' ' are']" ", who was a great player , but not a great goalie .
+
+ I think the Oilers are" False " Ray Emery =
+" 2 [' Ray', ' Emer', 'y']
+1243 823 In their sport, the position played by x -1 In their sport, the position played by Gareth Barry midfielder Gareth Barry "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' England'
+ ' manager' ',' ' Gareth' ' South' 'gate' ',' ' has' ' been' ' criticised'
+ ' for']" " in the World Cup final .
+
+ The England manager , Gareth South gate , has been criticised for" False record held by Gareth Barry and Jamie Carragher. 4 [' record', ' held', ' by', ' Gareth', ' Barry']
+1244 823 In their sport, the position played by x -1 In their sport, the position played by Gareth Barry midfielder Gareth Barry "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' England'
+ ' manager' ',' ' Gareth' ' South' 'gate' ',' ' has' ' been' ' criticised'
+ ' for']" " in the World Cup final .
+
+ The England manager , Gareth South gate , has been criticised for" False team captain Gareth Barry and 1982 European 3 [' team', ' captain', ' Gareth', ' Barry']
+1245 823 In their sport, the position played by x -1 In their sport, the position played by Gareth Barry midfielder Gareth Barry "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' England'
+ ' manager' ',' ' Gareth' ' South' 'gate' ',' ' has' ' been' ' criticised'
+ ' for']" " in the World Cup final .
+
+ The England manager , Gareth South gate , has been criticised for" False urging captain Gareth Barry to move to 3 [' urging', ' captain', ' Gareth', ' Barry']
+1246 823 In their sport, the position played by x -1 In their sport, the position played by Gareth Barry midfielder Gareth Barry "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' England'
+ ' manager' ',' ' Gareth' ' South' 'gate' ',' ' has' ' been' ' criticised'
+ ' for']" " in the World Cup final .
+
+ The England manager , Gareth South gate , has been criticised for" False team captain Gareth Barry and 1982 European 3 [' team', ' captain', ' Gareth', ' Barry']
+1247 823 In their sport, the position played by x -1 In their sport, the position played by Gareth Barry midfielder Gareth Barry "[' in' ' the' ' World' ' Cup' ' final' '.' '\n' '\n' 'The' ' England'
+ ' manager' ',' ' Gareth' ' South' 'gate' ',' ' has' ' been' ' criticised'
+ ' for']" " in the World Cup final .
+
+ The England manager , Gareth South gate , has been criticised for" False where goals from Gareth Barry and Kevin Mirallas 4 [' where', ' goals', ' from', ' Gareth', ' Barry']
+1248 824 In their sport, the position played by x -1 In their sport, the position played by Hatem Ben Arfa midfielder Hatem Ben Arfa "[',' ' who' ' is' ' a' ' very' ' good' ' player' ',' ' but' ' he' ' is'
+ ' not' ' a' ' good' ' player' ' for' ' the' ' national' ' team' '.']" , who is a very good player , but he is not a good player for the national team . False " Ben Arfa =
+" 8 [' Ben', ' Ar', 'fa', ' =', 'H', 'atem', ' Ben', ' Ar', 'fa']
+1249 824 In their sport, the position played by x -1 In their sport, the position played by Hatem Ben Arfa midfielder Hatem Ben Arfa "[',' ' who' ' is' ' a' ' very' ' good' ' player' ',' ' but' ' he' ' is'
+ ' not' ' a' ' good' ' player' ' for' ' the' ' national' ' team' '.']" , who is a very good player , but he is not a good player for the national team . False 4 ['H', 'atem', ' Ben', ' Ar', 'fa']
+1250 824 In their sport, the position played by x -1 In their sport, the position played by Hatem Ben Arfa midfielder Hatem Ben Arfa "[',' ' who' ' is' ' a' ' very' ' good' ' player' ',' ' but' ' he' ' is'
+ ' not' ' a' ' good' ' player' ' for' ' the' ' national' ' team' '.']" , who is a very good player , but he is not a good player for the national team . False Sidney Govou, Hatem Ben Arfa and Benzema. 8 [' Sidney', ' Gov', 'ou', ',', ' Hat', 'em', ' Ben', ' Ar', 'fa']
+1251 824 In their sport, the position played by x -1 In their sport, the position played by Hatem Ben Arfa midfielder Hatem Ben Arfa "[',' ' who' ' is' ' a' ' very' ' good' ' player' ',' ' but' ' he' ' is'
+ ' not' ' a' ' good' ' player' ' for' ' the' ' national' ' team' '.']" , who is a very good player , but he is not a good player for the national team . False " Arfa =
+" 7 [' Ar', 'fa', ' =', 'H', 'atem', ' Ben', ' Ar', 'fa']
+1252 824 In their sport, the position played by x -1 In their sport, the position played by Hatem Ben Arfa midfielder Hatem Ben Arfa "[',' ' who' ' is' ' a' ' very' ' good' ' player' ',' ' but' ' he' ' is'
+ ' not' ' a' ' good' ' player' ' for' ' the' ' national' ' team' '.']" , who is a very good player , but he is not a good player for the national team . False Ribéry, Sidney Govou, Hatem Ben Arfa and Benzema. Benzema 12 [' Rib', 'é', 'ry', ',', ' Sidney', ' Gov', 'ou', ',', ' Hat', 'em', ' Ben', ' Ar', 'fa']
+1253 826 In their sport, the position played by x -1 In their sport, the position played by Jay Cutler quarterback Jay Cutler "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team']" ", the quarterback , is to be the quarterback .
+
+ The quarterback is the leader of the team" True 23. However, Jay Cutler threw a 17-yard 5 [' 23', '.', ' However', ',', ' Jay', ' Cutler']
+1254 826 In their sport, the position played by x -1 In their sport, the position played by Jay Cutler quarterback Jay Cutler "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team']" ", the quarterback , is to be the quarterback .
+
+ The quarterback is the leader of the team" True for a one-yard loss, Jay Cutler threw a short 8 [' for', ' a', ' one', '-', 'yard', ' loss', ',', ' Jay', ' Cutler']
+1255 826 In their sport, the position played by x -1 In their sport, the position played by Jay Cutler quarterback Jay Cutler "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team']" ", the quarterback , is to be the quarterback .
+
+ The quarterback is the leader of the team" True Malcolm Jenkins forced Jay Cutler to fumble, and Cameron 4 [' Malcolm', ' Jenkins', ' forced', ' Jay', ' Cutler']
+1256 826 In their sport, the position played by x -1 In their sport, the position played by Jay Cutler quarterback Jay Cutler "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team']" ", the quarterback , is to be the quarterback .
+
+ The quarterback is the leader of the team" True the Chargers, Jay Cutler threw his 138th 4 [' the', ' Chargers', ',', ' Jay', ' Cutler']
+1257 826 In their sport, the position played by x -1 In their sport, the position played by Jay Cutler quarterback Jay Cutler "[',' ' the' ' quarterback' ',' ' is' ' to' ' be' ' the' ' quarterback' '.'
+ '\n' '\n' 'The' ' quarterback' ' is' ' the' ' leader' ' of' ' the'
+ ' team']" ", the quarterback , is to be the quarterback .
+
+ The quarterback is the leader of the team" True quarterback Jay Cutler to the Titans for 2 [' quarterback', ' Jay', ' Cutler']
+1258 828 In their sport, the position played by x -1 In their sport, the position played by Craig Breslow pitcher Craig Breslow "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' World' ' Series' ' in' ' 2004' ',' ' is' ' a' ' perfect'
+ ' example']" , who was a member of the team that won the World Series in 2004 , is a perfect example False future major leaguers Craig Breslow and Adam Greenberg. 8 [' future', ' major', ' le', 'agu', 'ers', ' Craig', ' B', 'res', 'low']
+1259 828 In their sport, the position played by x -1 In their sport, the position played by Craig Breslow pitcher Craig Breslow "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' World' ' Series' ' in' ' 2004' ',' ' is' ' a' ' perfect'
+ ' example']" , who was a member of the team that won the World Series in 2004 , is a perfect example False Aceves (10) and Craig Breslow (8). Ramírez reached 9 [' Ace', 'ves', ' (', '10', ')', ' and', ' Craig', ' B', 'res', 'low']
+1260 831 In their sport, the position played by x -1 In their sport, the position played by Gary Danielson quarterback Gary Danielson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' title' ' in' ' 1991' ',' ' was' ' filled' ' by' ' the']" , who was a member of the team that won the NCAA title in 1991 , was filled by the False play-by-play commentary, Gary Danielson provided the color 9 [' play', '-', 'by', '-', 'play', ' commentary', ',', ' Gary', ' Dani', 'elson']
+1261 831 In their sport, the position played by x -1 In their sport, the position played by Gary Danielson quarterback Gary Danielson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' title' ' in' ' 1991' ',' ' was' ' filled' ' by' ' the']" , who was a member of the team that won the NCAA title in 1991 , was filled by the False Lions quarterback Gary Danielson dropped back to pass 4 [' Lions', ' quarterback', ' Gary', ' Dani', 'elson']
+1262 831 In their sport, the position played by x -1 In their sport, the position played by Gary Danielson quarterback Gary Danielson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' title' ' in' ' 1991' ',' ' was' ' filled' ' by' ' the']" , who was a member of the team that won the NCAA title in 1991 , was filled by the False Musburger, Jack Arute, and Gary Danielson were the broadcasters 11 [' Mus', 'bur', 'ger', ',', ' Jack', ' Ar', 'ute', ',', ' and', ' Gary', ' Dani', 'elson']
+1263 832 In their sport, the position played by x -1 In their sport, the position played by Dan Fouts quarterback Dan Fouts "[',' ' the' ' quarterback' ',' ' is' ' the' ' same' ' as' ' the'
+ ' position' ' played' ' by' ' the' ' quarterback' ' in' ' the' ' NFL' '.'
+ '\n' '\n']" ", the quarterback , is the same as the position played by the quarterback in the NFL .
+
+" True Chargers quarterback Dan Fouts threw five second 4 [' Chargers', ' quarterback', ' Dan', ' F', 'outs']
+1264 832 In their sport, the position played by x -1 In their sport, the position played by Dan Fouts quarterback Dan Fouts "[',' ' the' ' quarterback' ',' ' is' ' the' ' same' ' as' ' the'
+ ' position' ' played' ' by' ' the' ' quarterback' ' in' ' the' ' NFL' '.'
+ '\n' '\n']" ", the quarterback , is the same as the position played by the quarterback in the NFL .
+
+" True Chargers quarterback Dan Fouts threw five second 4 [' Chargers', ' quarterback', ' Dan', ' F', 'outs']
+1265 832 In their sport, the position played by x -1 In their sport, the position played by Dan Fouts quarterback Dan Fouts "[',' ' the' ' quarterback' ',' ' is' ' the' ' same' ' as' ' the'
+ ' position' ' played' ' by' ' the' ' quarterback' ' in' ' the' ' NFL' '.'
+ '\n' '\n']" ", the quarterback , is the same as the position played by the quarterback in the NFL .
+
+" True Chargers quarterback Dan Fouts threw five second 4 [' Chargers', ' quarterback', ' Dan', ' F', 'outs']
+1266 834 In their sport, the position played by x -1 In their sport, the position played by Mariano Rivera closer Mariano Rivera "[',' ' who' ' has' ' been' ' a' ' model' ' of' ' consistency' ' and'
+ ' durability' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who has been a model of consistency and durability , is a perfect example of the type of player False end the inning. Mariano Rivera entered in the bottom 6 [' end', ' the', ' inning', '.', ' Mar', 'iano', ' Rivera']
+1267 834 In their sport, the position played by x -1 In their sport, the position played by Mariano Rivera closer Mariano Rivera "[',' ' who' ' has' ' been' ' a' ' model' ' of' ' consistency' ' and'
+ ' durability' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who has been a model of consistency and durability , is a perfect example of the type of player False and was replaced by Mariano Rivera in the eighth. 6 [' and', ' was', ' replaced', ' by', ' Mar', 'iano', ' Rivera']
+1268 834 In their sport, the position played by x -1 In their sport, the position played by Mariano Rivera closer Mariano Rivera "[',' ' who' ' has' ' been' ' a' ' model' ' of' ' consistency' ' and'
+ ' durability' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who has been a model of consistency and durability , is a perfect example of the type of player False bottom of the sixth. Mariano Rivera recorded his 38th 7 [' bottom', ' of', ' the', ' sixth', '.', ' Mar', 'iano', ' Rivera']
+1269 834 In their sport, the position played by x -1 In their sport, the position played by Mariano Rivera closer Mariano Rivera "[',' ' who' ' has' ' been' ' a' ' model' ' of' ' consistency' ' and'
+ ' durability' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who has been a model of consistency and durability , is a perfect example of the type of player False innings, replaced by Mariano Rivera in the eighth. The 6 [' innings', ',', ' replaced', ' by', ' Mar', 'iano', ' Rivera']
+1270 834 In their sport, the position played by x -1 In their sport, the position played by Mariano Rivera closer Mariano Rivera "[',' ' who' ' has' ' been' ' a' ' model' ' of' ' consistency' ' and'
+ ' durability' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' type'
+ ' of' ' player']" , who has been a model of consistency and durability , is a perfect example of the type of player False and Rodriguez. Mariano Rivera entered in the bottom 5 [' and', ' Rodriguez', '.', ' Mar', 'iano', ' Rivera']
+1271 835 In their sport, the position played by x -1 In their sport, the position played by Rory Delap midfielder Rory Delap "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False throw-in specialist Rory Delap delivered the 6 [' throw', '-', 'in', ' specialist', ' Rory', ' Del', 'ap']
+1272 835 In their sport, the position played by x -1 In their sport, the position played by Rory Delap midfielder Rory Delap "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False injury-time goal from a Rory Delap throw-in, giving Stoke 8 [' injury', '-', 'time', ' goal', ' from', ' a', ' Rory', ' Del', 'ap']
+1273 835 In their sport, the position played by x -1 In their sport, the position played by Rory Delap midfielder Rory Delap "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False throw-in specialist Rory Delap delivered the 6 [' throw', '-', 'in', ' specialist', ' Rory', ' Del', 'ap']
+1274 835 In their sport, the position played by x -1 In their sport, the position played by Rory Delap midfielder Rory Delap "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False specialist Rory Delap delivered 3 [' specialist', ' Rory', ' Del', 'ap']
+1275 835 In their sport, the position played by x -1 In their sport, the position played by Rory Delap midfielder Rory Delap "[' in' ' the' ' first' ' half' ' of' ' the' ' season' ',' ' he' ' was'
+ ' a' ' very' ' good' ' player' '.' ' He' ' was' ' a' ' very' ' good']" in the first half of the season , he was a very good player . He was a very good False throw-in specialist Rory Delap delivered the ball 6 [' throw', '-', 'in', ' specialist', ' Rory', ' Del', 'ap']
+1276 838 In their sport, the position played by x -1 In their sport, the position played by Kari Lehtonen goaltender Kari Lehtonen "[',' ' who' ' was' ' the' ' backup' ' for' ' the' ' Dallas' ' Stars' ','
+ ' was' ' a' ' bit' ' of' ' a' ' surprise' '.' ' He' ' was' ' a']" , who was the backup for the Dallas Stars , was a bit of a surprise . He was a False goaltender Kari Lehtonen in a 5 – 3 win. The 5 [' goaltender', ' K', 'ari', ' Leh', 'ton', 'en']
+1277 838 In their sport, the position played by x -1 In their sport, the position played by Kari Lehtonen goaltender Kari Lehtonen "[',' ' who' ' was' ' the' ' backup' ' for' ' the' ' Dallas' ' Stars' ','
+ ' was' ' a' ' bit' ' of' ' a' ' surprise' '.' ' He' ' was' ' a']" , who was the backup for the Dallas Stars , was a bit of a surprise . He was a False Finnish goaltender Kari Lehtonen in a 5 – 3 win. The 6 [' Finnish', ' goaltender', ' K', 'ari', ' Leh', 'ton', 'en']
+1278 840 In their sport, the position played by x -1 In their sport, the position played by Diego Maradona midfielder Diego Maradona "[',' ' the' ' Argentine' 'an' ' footballer' ',' ' is' ' a' ' very'
+ ' important' ' one' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world']" , the Argentine an footballer , is a very important one . He is the best player in the world False 4 ['Die', 'go', ' Mar', 'ad', 'ona']
+1279 840 In their sport, the position played by x -1 In their sport, the position played by Diego Maradona midfielder Diego Maradona "[',' ' the' ' Argentine' 'an' ' footballer' ',' ' is' ' a' ' very'
+ ' important' ' one' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world']" , the Argentine an footballer , is a very important one . He is the best player in the world False Argentine midfielder Diego Maradona and Mexican forward 5 [' Argentine', ' midfielder', ' Diego', ' Mar', 'ad', 'ona']
+1280 840 In their sport, the position played by x -1 In their sport, the position played by Diego Maradona midfielder Diego Maradona "[',' ' the' ' Argentine' 'an' ' footballer' ',' ' is' ' a' ' very'
+ ' important' ' one' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world']" , the Argentine an footballer , is a very important one . He is the best player in the world False " alongside legends like Diego Maradona and Pelé.
+" 6 [' alongside', ' legends', ' like', ' Diego', ' Mar', 'ad', 'ona']
+1281 840 In their sport, the position played by x -1 In their sport, the position played by Diego Maradona midfielder Diego Maradona "[',' ' the' ' Argentine' 'an' ' footballer' ',' ' is' ' a' ' very'
+ ' important' ' one' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world']" , the Argentine an footballer , is a very important one . He is the best player in the world False " alongside legends like Diego Maradona and Pelé.
+" 6 [' alongside', ' legends', ' like', ' Diego', ' Mar', 'ad', 'ona']
+1282 840 In their sport, the position played by x -1 In their sport, the position played by Diego Maradona midfielder Diego Maradona "[',' ' the' ' Argentine' 'an' ' footballer' ',' ' is' ' a' ' very'
+ ' important' ' one' '.' ' He' ' is' ' the' ' best' ' player' ' in' ' the'
+ ' world']" , the Argentine an footballer , is a very important one . He is the best player in the world False Argentine footballer Diego Maradona despite the 5 [' Argentine', ' footballer', ' Diego', ' Mar', 'ad', 'ona']
+1283 841 In their sport, the position played by x -1 In their sport, the position played by Massimiliano Allegri midfielder Massimiliano Allegri "[',' ' who' ' has' ' been' ' in' ' charge' ' of' ' the' ' Italian'
+ ' national' ' team' ' since' ' the' ' summer' ' of' ' 2014' ',' ' is'
+ ' a' ' very']" , who has been in charge of the Italian national team since the summer of 2014 , is a very False In 2014 – 15, Massimiliano Allegri was appointed 9 [' In', ' 2014', ' –', ' 15', ',', ' Mass', 'imil', 'iano', ' Alleg', 'ri']
+1284 841 In their sport, the position played by x -1 In their sport, the position played by Massimiliano Allegri midfielder Massimiliano Allegri "[',' ' who' ' has' ' been' ' in' ' charge' ' of' ' the' ' Italian'
+ ' national' ' team' ' since' ' the' ' summer' ' of' ' 2014' ',' ' is'
+ ' a' ' very']" , who has been in charge of the Italian national team since the summer of 2014 , is a very False Stadium. In 2014 – 15, Massimiliano Allegri was appointed as manager, 11 [' Stadium', '.', ' In', ' 2014', ' –', ' 15', ',', ' Mass', 'imil', 'iano', ' Alleg', 'ri']
+1285 842 In their sport, the position played by x -1 In their sport, the position played by Agostino Di Bartolomei midfielder Agostino Di Bartolomei "[',' ' the' ' Italian' ' goalkeeper' ',' ' is' ' called' ' the' ' ""'
+ 'keeper' '""' ' and' ' the' ' position' ' played' ' by' ' the' ' defender'
+ ' is' ' called']" ", the Italian goalkeeper , is called the "" keeper "" and the position played by the defender is called" False but Roma captain Agostino Di Bartolomei took the ball from 10 [' but', ' Roma', ' captain', ' Ag', 'ost', 'ino', ' Di', ' Bart', 'ol', 'ome', 'i']
+1286 842 In their sport, the position played by x -1 In their sport, the position played by Agostino Di Bartolomei midfielder Agostino Di Bartolomei "[',' ' the' ' Italian' ' goalkeeper' ',' ' is' ' called' ' the' ' ""'
+ 'keeper' '""' ' and' ' the' ' position' ' played' ' by' ' the' ' defender'
+ ' is' ' called']" ", the Italian goalkeeper , is called the "" keeper "" and the position played by the defender is called" False Roma captain Agostino Di Bartolomei took the ball 9 [' Roma', ' captain', ' Ag', 'ost', 'ino', ' Di', ' Bart', 'ol', 'ome', 'i']
+1287 842 In their sport, the position played by x -1 In their sport, the position played by Agostino Di Bartolomei midfielder Agostino Di Bartolomei "[',' ' the' ' Italian' ' goalkeeper' ',' ' is' ' called' ' the' ' ""'
+ 'keeper' '""' ' and' ' the' ' position' ' played' ' by' ' the' ' defender'
+ ' is' ' called']" ", the Italian goalkeeper , is called the "" keeper "" and the position played by the defender is called" False Roma captain Agostino Di Bartolomei took the ball from 9 [' Roma', ' captain', ' Ag', 'ost', 'ino', ' Di', ' Bart', 'ol', 'ome', 'i']
+1288 845 In their sport, the position played by x -1 In their sport, the position played by Frankie Albert quarterback Frankie Albert "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False of quarterback Frankie Albert and end Alyn Beals 3 [' of', ' quarterback', ' Frankie', ' Albert']
+1289 845 In their sport, the position played by x -1 In their sport, the position played by Frankie Albert quarterback Frankie Albert "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False sharing the honor with Frankie Albert of the 49ers. 5 [' sharing', ' the', ' honor', ' with', ' Frankie', ' Albert']
+1290 845 In their sport, the position played by x -1 In their sport, the position played by Frankie Albert quarterback Frankie Albert "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False award with Frankie Albert of the San Francisco 3 [' award', ' with', ' Frankie', ' Albert']
+1291 845 In their sport, the position played by x -1 In their sport, the position played by Frankie Albert quarterback Frankie Albert "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False featured quarterback Frankie Albert and end Alyn Beals. 3 [' featured', ' quarterback', ' Frankie', ' Albert']
+1292 845 In their sport, the position played by x -1 In their sport, the position played by Frankie Albert quarterback Frankie Albert "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL player , is a perfect example of the kind of player that the NFL False production of quarterback Frankie Albert and end Alyn Beals 4 [' production', ' of', ' quarterback', ' Frankie', ' Albert']
+1293 847 In their sport, the position played by x -1 In their sport, the position played by Scott Sinclair midfielder Scott Sinclair "[',' ' who' ' was' ' a' ' key' ' part' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' league' ' title' ' in' ' 2014' '.' '\n' '\n' 'The']" ", who was a key part of the team that won the league title in 2014 .
+
+ The" False the deal taking Scott Sinclair to Chelsea after 4 [' the', ' deal', ' taking', ' Scott', ' Sinclair']
+1294 847 In their sport, the position played by x -1 In their sport, the position played by Scott Sinclair midfielder Scott Sinclair "[',' ' who' ' was' ' a' ' key' ' part' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' league' ' title' ' in' ' 2014' '.' '\n' '\n' 'The']" ", who was a key part of the team that won the league title in 2014 .
+
+ The" False replacing Scott Sinclair for the last 2 [' replacing', ' Scott', ' Sinclair']
+1295 847 In their sport, the position played by x -1 In their sport, the position played by Scott Sinclair midfielder Scott Sinclair "[',' ' who' ' was' ' a' ' key' ' part' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' league' ' title' ' in' ' 2014' '.' '\n' '\n' 'The']" ", who was a key part of the team that won the league title in 2014 .
+
+ The" False of the deal taking Scott Sinclair to Chelsea 5 [' of', ' the', ' deal', ' taking', ' Scott', ' Sinclair']
+1296 847 In their sport, the position played by x -1 In their sport, the position played by Scott Sinclair midfielder Scott Sinclair "[',' ' who' ' was' ' a' ' key' ' part' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' league' ' title' ' in' ' 2014' '.' '\n' '\n' 'The']" ", who was a key part of the team that won the league title in 2014 .
+
+ The" False of the deal taking Scott Sinclair to Chelsea after he 5 [' of', ' the', ' deal', ' taking', ' Scott', ' Sinclair']
+1297 847 In their sport, the position played by x -1 In their sport, the position played by Scott Sinclair midfielder Scott Sinclair "[',' ' who' ' was' ' a' ' key' ' part' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' league' ' title' ' in' ' 2014' '.' '\n' '\n' 'The']" ", who was a key part of the team that won the league title in 2014 .
+
+ The" False Micah Richards and Scott Sinclair in place of Kolarov 5 [' Mic', 'ah', ' Richards', ' and', ' Scott', ' Sinclair']
+1298 848 In their sport, the position played by x -1 In their sport, the position played by Pekka Rinne goaltender Pekka Rinne "[' is' ' the' ' same' ' as' ' the' ' one' ' he' ' has' ' been' ' in'
+ ' for' ' the' ' past' ' few' ' seasons' '.' ' He' ' is' ' the' ' goalie']" is the same as the one he has been in for the past few seasons . He is the goalie False 2009, against Pekka Rinne of the Nashville 7 [' 2009', ',', ' against', ' P', 'ek', 'ka', ' Rin', 'ne']
+1299 848 In their sport, the position played by x -1 In their sport, the position played by Pekka Rinne goaltender Pekka Rinne "[' is' ' the' ' same' ' as' ' the' ' one' ' he' ' has' ' been' ' in'
+ ' for' ' the' ' past' ' few' ' seasons' '.' ' He' ' is' ' the' ' goalie']" is the same as the one he has been in for the past few seasons . He is the goalie False Trophy nominee Pekka Rinne and the Nashville 6 [' Trophy', ' nominee', ' P', 'ek', 'ka', ' Rin', 'ne']
+1300 848 In their sport, the position played by x -1 In their sport, the position played by Pekka Rinne goaltender Pekka Rinne "[' is' ' the' ' same' ' as' ' the' ' one' ' he' ' has' ' been' ' in'
+ ' for' ' the' ' past' ' few' ' seasons' '.' ' He' ' is' ' the' ' goalie']" is the same as the one he has been in for the past few seasons . He is the goalie False 17, 2009, against Pekka Rinne of the Nashville 9 [' 17', ',', ' 2009', ',', ' against', ' P', 'ek', 'ka', ' Rin', 'ne']
+1301 848 In their sport, the position played by x -1 In their sport, the position played by Pekka Rinne goaltender Pekka Rinne "[' is' ' the' ' same' ' as' ' the' ' one' ' he' ' has' ' been' ' in'
+ ' for' ' the' ' past' ' few' ' seasons' '.' ' He' ' is' ' the' ' goalie']" is the same as the one he has been in for the past few seasons . He is the goalie False shot against Pekka Rinne of the Nashville Predators. 6 [' shot', ' against', ' P', 'ek', 'ka', ' Rin', 'ne']
+1302 848 In their sport, the position played by x -1 In their sport, the position played by Pekka Rinne goaltender Pekka Rinne "[' is' ' the' ' same' ' as' ' the' ' one' ' he' ' has' ' been' ' in'
+ ' for' ' the' ' past' ' few' ' seasons' '.' ' He' ' is' ' the' ' goalie']" is the same as the one he has been in for the past few seasons . He is the goalie False 2009, against Pekka Rinne of the Nashville 7 [' 2009', ',', ' against', ' P', 'ek', 'ka', ' Rin', 'ne']
+1303 851 In their sport, the position played by x -1 In their sport, the position played by Ryan Leaf quarterback Ryan Leaf "[',' ' the' ' former' ' Heisman' ' Trophy' ' winner' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' who' ' can'
+ ' be' ' a']" , the former Heisman Trophy winner , is a perfect example of the type of player who can be a False " Ryan Leaf =
+" 1 [' Ryan', ' Leaf']
+1304 851 In their sport, the position played by x -1 In their sport, the position played by Ryan Leaf quarterback Ryan Leaf "[',' ' the' ' former' ' Heisman' ' Trophy' ' winner' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' who' ' can'
+ ' be' ' a']" , the former Heisman Trophy winner , is a perfect example of the type of player who can be a False featured quarterback Ryan Leaf and a five-wideout 3 [' featured', ' quarterback', ' Ryan', ' Leaf']
+1305 851 In their sport, the position played by x -1 In their sport, the position played by Ryan Leaf quarterback Ryan Leaf "[',' ' the' ' former' ' Heisman' ' Trophy' ' winner' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' who' ' can'
+ ' be' ' a']" , the former Heisman Trophy winner , is a perfect example of the type of player who can be a False featured quarterback Ryan Leaf and a five-wideout 3 [' featured', ' quarterback', ' Ryan', ' Leaf']
+1306 851 In their sport, the position played by x -1 In their sport, the position played by Ryan Leaf quarterback Ryan Leaf "[',' ' the' ' former' ' Heisman' ' Trophy' ' winner' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' who' ' can'
+ ' be' ' a']" , the former Heisman Trophy winner , is a perfect example of the type of player who can be a False " Ryan Leaf =
+" 1 [' Ryan', ' Leaf']
+1307 851 In their sport, the position played by x -1 In their sport, the position played by Ryan Leaf quarterback Ryan Leaf "[',' ' the' ' former' ' Heisman' ' Trophy' ' winner' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' type' ' of' ' player' ' who' ' can'
+ ' be' ' a']" , the former Heisman Trophy winner , is a perfect example of the type of player who can be a False many considered Ryan Leaf his rival for the 3 [' many', ' considered', ' Ryan', ' Leaf']
+1308 859 In their sport, the position played by x -1 In their sport, the position played by Josh Gibson catcher Josh Gibson "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", who was a great player , but not a great person .
+
+ I 'm not sure if" False their star catcher, Josh Gibson dropped a pop fly, 5 [' their', ' star', ' catcher', ',', ' Josh', ' Gibson']
+1309 859 In their sport, the position played by x -1 In their sport, the position played by Josh Gibson catcher Josh Gibson "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", who was a great player , but not a great person .
+
+ I 'm not sure if" False talents Satchel Paige and Josh Gibson were upset when 6 [' talents', ' Sat', 'chel', ' Paige', ' and', ' Josh', ' Gibson']
+1310 859 In their sport, the position played by x -1 In their sport, the position played by Josh Gibson catcher Josh Gibson "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", who was a great player , but not a great person .
+
+ I 'm not sure if" False League Museum's Josh Gibson Legacy Award, 4 "[' League', ' Museum', ""'s"", ' Josh', ' Gibson']"
+1311 859 In their sport, the position played by x -1 In their sport, the position played by Josh Gibson catcher Josh Gibson "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", who was a great player , but not a great person .
+
+ I 'm not sure if" False their star catcher, Josh Gibson dropped a pop 5 [' their', ' star', ' catcher', ',', ' Josh', ' Gibson']
+1312 859 In their sport, the position played by x -1 In their sport, the position played by Josh Gibson catcher Josh Gibson "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if']" ", who was a great player , but not a great person .
+
+ I 'm not sure if" False against catcher Josh Gibson and pitcher Ray Brown 3 [' against', ' catcher', ' Josh', ' Gibson']
+1313 864 In their sport, the position played by x -1 In their sport, the position played by Sean Glennon quarterback Sean Glennon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL True quarterbacks Sean Glennon (174 yards) and 3 [' quarterbacks', ' Sean', ' Glenn', 'on']
+1314 864 In their sport, the position played by x -1 In their sport, the position played by Sean Glennon quarterback Sean Glennon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL True as quarterback Sean Glennon completed a 16-yard 4 [' as', ' quarterback', ' Sean', ' Glenn', 'on']
+1315 864 In their sport, the position played by x -1 In their sport, the position played by Sean Glennon quarterback Sean Glennon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL True quarter, quarterback Sean Glennon threw for the 5 [' quarter', ',', ' quarterback', ' Sean', ' Glenn', 'on']
+1316 864 In their sport, the position played by x -1 In their sport, the position played by Sean Glennon quarterback Sean Glennon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL True Tech quarterback Sean Glennon led the Hokies 4 [' Tech', ' quarterback', ' Sean', ' Glenn', 'on']
+1317 864 In their sport, the position played by x -1 In their sport, the position played by Sean Glennon quarterback Sean Glennon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the'
+ ' NFL']" , who is a former NFL quarterback , is a perfect example of the kind of player that the NFL True Furman. Both Taylor and Sean Glennon performed well against 8 [' Fur', 'man', '.', ' Both', ' Taylor', ' and', ' Sean', ' Glenn', 'on']
+1318 866 In their sport, the position played by x -1 In their sport, the position played by Francisco Cervelli catcher Francisco Cervelli "[',' ' who' ' has' ' been' ' a' ' great' ' asset' ' to' ' the' ' team' '.'
+ ' He' ' has' ' been' ' a' ' great' ' leader' ' and' ' a' ' great']" , who has been a great asset to the team . He has been a great leader and a great False He also robbed both Francisco Cervelli and Michael Morse of 7 [' He', ' also', ' robbed', ' both', ' Francisco', ' C', 'erve', 'lli']
+1319 866 In their sport, the position played by x -1 In their sport, the position played by Francisco Cervelli catcher Francisco Cervelli "[',' ' who' ' has' ' been' ' a' ' great' ' asset' ' to' ' the' ' team' '.'
+ ' He' ' has' ' been' ' a' ' great' ' leader' ' and' ' a' ' great']" , who has been a great asset to the team . He has been a great leader and a great False robbed both Francisco Cervelli and Michael 5 [' robbed', ' both', ' Francisco', ' C', 'erve', 'lli']
+1320 866 In their sport, the position played by x -1 In their sport, the position played by Francisco Cervelli catcher Francisco Cervelli "[',' ' who' ' has' ' been' ' a' ' great' ' asset' ' to' ' the' ' team' '.'
+ ' He' ' has' ' been' ' a' ' great' ' leader' ' and' ' a' ' great']" , who has been a great asset to the team . He has been a great leader and a great False He also robbed both Francisco Cervelli and Michael Morse 7 [' He', ' also', ' robbed', ' both', ' Francisco', ' C', 'erve', 'lli']
+1321 866 In their sport, the position played by x -1 In their sport, the position played by Francisco Cervelli catcher Francisco Cervelli "[',' ' who' ' has' ' been' ' a' ' great' ' asset' ' to' ' the' ' team' '.'
+ ' He' ' has' ' been' ' a' ' great' ' leader' ' and' ' a' ' great']" , who has been a great asset to the team . He has been a great leader and a great False robbed both Francisco Cervelli and Michael Morse 5 [' robbed', ' both', ' Francisco', ' C', 'erve', 'lli']
+1322 866 In their sport, the position played by x -1 In their sport, the position played by Francisco Cervelli catcher Francisco Cervelli "[',' ' who' ' has' ' been' ' a' ' great' ' asset' ' to' ' the' ' team' '.'
+ ' He' ' has' ' been' ' a' ' great' ' leader' ' and' ' a' ' great']" , who has been a great asset to the team . He has been a great leader and a great False robbed both Francisco Cervelli and Michael Morse of 5 [' robbed', ' both', ' Francisco', ' C', 'erve', 'lli']
+1323 868 In their sport, the position played by x -1 In their sport, the position played by Ryan Mallett quarterback Ryan Mallett "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Patriots'
+ ' in' ' the' ' Super' ' Bowl' '.' '\n' '\n' 'The' ' Patriots' ' have'
+ ' a']" ", who was a backup quarterback for the Patriots in the Super Bowl .
+
+ The Patriots have a" True friends with Ryan Mallett of the Baltimore 4 [' friends', ' with', ' Ryan', ' Mal', 'lett']
+1324 868 In their sport, the position played by x -1 In their sport, the position played by Ryan Mallett quarterback Ryan Mallett "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Patriots'
+ ' in' ' the' ' Super' ' Bowl' '.' '\n' '\n' 'The' ' Patriots' ' have'
+ ' a']" ", who was a backup quarterback for the Patriots in the Super Bowl .
+
+ The Patriots have a" True opening game. When Ryan Mallett played in place 6 [' opening', ' game', '.', ' When', ' Ryan', ' Mal', 'lett']
+1325 868 In their sport, the position played by x -1 In their sport, the position played by Ryan Mallett quarterback Ryan Mallett "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Patriots'
+ ' in' ' the' ' Super' ' Bowl' '.' '\n' '\n' 'The' ' Patriots' ' have'
+ ' a']" ", who was a backup quarterback for the Patriots in the Super Bowl .
+
+ The Patriots have a" True 14 – 7 after Ryan Mallett hit Greg Childs 6 [' 14', ' –', ' 7', ' after', ' Ryan', ' Mal', 'lett']
+1326 868 In their sport, the position played by x -1 In their sport, the position played by Ryan Mallett quarterback Ryan Mallett "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Patriots'
+ ' in' ' the' ' Super' ' Bowl' '.' '\n' '\n' 'The' ' Patriots' ' have'
+ ' a']" ", who was a backup quarterback for the Patriots in the Super Bowl .
+
+ The Patriots have a" True lead to 14 – 7 after Ryan Mallett hit Greg Childs 8 [' lead', ' to', ' 14', ' –', ' 7', ' after', ' Ryan', ' Mal', 'lett']
+1327 868 In their sport, the position played by x -1 In their sport, the position played by Ryan Mallett quarterback Ryan Mallett "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Patriots'
+ ' in' ' the' ' Super' ' Bowl' '.' '\n' '\n' 'The' ' Patriots' ' have'
+ ' a']" ", who was a backup quarterback for the Patriots in the Super Bowl .
+
+ The Patriots have a" True opening game. When Ryan Mallett played in place 6 [' opening', ' game', '.', ' When', ' Ryan', ' Mal', 'lett']
+1328 869 In their sport, the position played by x -1 In their sport, the position played by Willie McFaul goalkeeper Willie McFaul "[',' ' who' ' was' ' a' ' great' ' player' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of']" ", who was a great player , and a great person .
+
+ I was a huge fan of" False 3 ['Will', 'ie', ' McF', 'aul']
+1329 869 In their sport, the position played by x -1 In their sport, the position played by Willie McFaul goalkeeper Willie McFaul "[',' ' who' ' was' ' a' ' great' ' player' ',' ' and' ' a' ' great'
+ ' person' '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of']" ", who was a great player , and a great person .
+
+ I was a huge fan of" False 3 ['Will', 'ie', ' McF', 'aul']
+1330 871 In their sport, the position played by x -1 In their sport, the position played by Darren Daulton catcher Darren Daulton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False Phillies catcher Darren Daulton followed Gooden and 5 [' Phillies', ' catcher', ' Darren', ' D', 'ault', 'on']
+1331 871 In their sport, the position played by x -1 In their sport, the position played by Darren Daulton catcher Darren Daulton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False Phillies catcher Darren Daulton followed Gooden 5 [' Phillies', ' catcher', ' Darren', ' D', 'ault', 'on']
+1332 871 In their sport, the position played by x -1 In their sport, the position played by Darren Daulton catcher Darren Daulton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False Phillies catcher Darren Daulton followed Gooden and 5 [' Phillies', ' catcher', ' Darren', ' D', 'ault', 'on']
+1333 874 In their sport, the position played by x -1 In their sport, the position played by Nolan Ryan pitcher Nolan Ryan "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Texas' ' Rangers' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' that']" , who was a pitcher for the Texas Rangers , is a perfect example of the type of player that True total at the time. Nolan Ryan struck out 383 batters 6 [' total', ' at', ' the', ' time', '.', ' Nolan', ' Ryan']
+1334 874 In their sport, the position played by x -1 In their sport, the position played by Nolan Ryan pitcher Nolan Ryan "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Texas' ' Rangers' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' that']" , who was a pitcher for the Texas Rangers , is a perfect example of the type of player that True and 1,606 innings. Nolan Ryan broke his Astros 7 [' and', ' 1', ',', '606', ' innings', '.', ' Nolan', ' Ryan']
+1335 874 In their sport, the position played by x -1 In their sport, the position played by Nolan Ryan pitcher Nolan Ryan "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Texas' ' Rangers' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' that']" , who was a pitcher for the Texas Rangers , is a perfect example of the type of player that True and 1,606 innings. Nolan Ryan broke his Astros 7 [' and', ' 1', ',', '606', ' innings', '.', ' Nolan', ' Ryan']
+1336 874 In their sport, the position played by x -1 In their sport, the position played by Nolan Ryan pitcher Nolan Ryan "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Texas' ' Rangers' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' that']" , who was a pitcher for the Texas Rangers , is a perfect example of the type of player that True test involving Nolan Ryan would be evaluated 3 [' test', ' involving', ' Nolan', ' Ryan']
+1337 874 In their sport, the position played by x -1 In their sport, the position played by Nolan Ryan pitcher Nolan Ryan "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Texas' ' Rangers' ','
+ ' is' ' a' ' perfect' ' example' ' of' ' the' ' type' ' of' ' player'
+ ' that']" , who was a pitcher for the Texas Rangers , is a perfect example of the type of player that True Pedro Martínez, and Nolan Ryan are the only 8 [' Pedro', ' Mart', 'í', 'ne', 'z', ',', ' and', ' Nolan', ' Ryan']
+1338 875 In their sport, the position played by x -1 In their sport, the position played by Braden Holtby goaltender Braden Holtby "[',' ' who' ' has' ' been' ' the' ' Capitals' '�' '�' ' best' ' player'
+ ' this' ' season' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Capitals � � best player this season , is a bit of a mystery . False Capitals goaltender Braden Holtby in Game 3 of 5 [' Capitals', ' goaltender', ' Brad', 'en', ' Holt', 'by']
+1339 875 In their sport, the position played by x -1 In their sport, the position played by Braden Holtby goaltender Braden Holtby "[',' ' who' ' has' ' been' ' the' ' Capitals' '�' '�' ' best' ' player'
+ ' this' ' season' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Capitals � � best player this season , is a bit of a mystery . False out of the lineup. Braden Holtby was recalled to 8 [' out', ' of', ' the', ' lineup', '.', ' Brad', 'en', ' Holt', 'by']
+1340 875 In their sport, the position played by x -1 In their sport, the position played by Braden Holtby goaltender Braden Holtby "[',' ' who' ' has' ' been' ' the' ' Capitals' '�' '�' ' best' ' player'
+ ' this' ' season' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Capitals � � best player this season , is a bit of a mystery . False goaltender Braden Holtby and defenseman 4 [' goaltender', ' Brad', 'en', ' Holt', 'by']
+1341 875 In their sport, the position played by x -1 In their sport, the position played by Braden Holtby goaltender Braden Holtby "[',' ' who' ' has' ' been' ' the' ' Capitals' '�' '�' ' best' ' player'
+ ' this' ' season' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Capitals � � best player this season , is a bit of a mystery . False forced to start Braden Holtby in Game 1. Through 6 [' forced', ' to', ' start', ' Brad', 'en', ' Holt', 'by']
+1342 875 In their sport, the position played by x -1 In their sport, the position played by Braden Holtby goaltender Braden Holtby "[',' ' who' ' has' ' been' ' the' ' Capitals' '�' '�' ' best' ' player'
+ ' this' ' season' ',' ' is' ' a' ' bit' ' of' ' a' ' mystery' '.']" , who has been the Capitals � � best player this season , is a bit of a mystery . False forced to start Braden Holtby in Game 1. Through 6 [' forced', ' to', ' start', ' Brad', 'en', ' Holt', 'by']
+1343 877 In their sport, the position played by x -1 In their sport, the position played by Enos Slaughter outfielder Enos Slaughter "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a']" , the former NFL running back , is a perfect example of the kind of player who can be a False bench player, as Enos Slaughter joined the team, 6 [' bench', ' player', ',', ' as', ' En', 'os', ' Slaughter']
+1344 877 In their sport, the position played by x -1 In their sport, the position played by Enos Slaughter outfielder Enos Slaughter "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a']" , the former NFL running back , is a perfect example of the kind of player who can be a False Musial struck out, Enos Slaughter singled and 7 [' Mus', 'ial', ' struck', ' out', ',', ' En', 'os', ' Slaughter']
+1345 877 In their sport, the position played by x -1 In their sport, the position played by Enos Slaughter outfielder Enos Slaughter "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a']" , the former NFL running back , is a perfect example of the kind of player who can be a False a bench player, as Enos Slaughter joined the team, 7 [' a', ' bench', ' player', ',', ' as', ' En', 'os', ' Slaughter']
+1346 877 In their sport, the position played by x -1 In their sport, the position played by Enos Slaughter outfielder Enos Slaughter "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a']" , the former NFL running back , is a perfect example of the kind of player who can be a False player, as Enos Slaughter joined the team, 5 [' player', ',', ' as', ' En', 'os', ' Slaughter']
+1347 877 In their sport, the position played by x -1 In their sport, the position played by Enos Slaughter outfielder Enos Slaughter "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a']" , the former NFL running back , is a perfect example of the kind of player who can be a False Musial struck out, Enos Slaughter singled and Whitey 7 [' Mus', 'ial', ' struck', ' out', ',', ' En', 'os', ' Slaughter']
+1348 878 In their sport, the position played by x -1 In their sport, the position played by Atlee Hammaker pitcher Atlee Hammaker "[',' ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about' ' the' ' new' ' �' '�' 'The' ' Last' ' of' ' Us' '�']" ", the
+
+ The first thing that strikes you about the new � � The Last of Us �" False José Uribe and Atlee Hammaker got on base with 7 [' José', ' U', 'ribe', ' and', ' At', 'lee', ' Ham', 'maker']
+1349 878 In their sport, the position played by x -1 In their sport, the position played by Atlee Hammaker pitcher Atlee Hammaker "[',' ' the' '\n' '\n' 'The' ' first' ' thing' ' that' ' strikes' ' you'
+ ' about' ' the' ' new' ' �' '�' 'The' ' Last' ' of' ' Us' '�']" ", the
+
+ The first thing that strikes you about the new � � The Last of Us �" False José Uribe and Atlee Hammaker got on base with consecutive 7 [' José', ' U', 'ribe', ' and', ' At', 'lee', ' Ham', 'maker']
+1350 879 In their sport, the position played by x -1 In their sport, the position played by Don Larsen pitcher Don Larsen "[',' ' the' ' pitcher' ',' ' was' ' to' ' be' ' the' ' first' ' to'
+ ' throw' ' a' ' no' '-' 'h' 'itter' ' in' ' the' ' World' ' Series']" , the pitcher , was to be the first to throw a no - h itter in the World Series True for the season. Don Larsen entered to relieve 6 [' for', ' the', ' season', '.', ' Don', ' Lars', 'en']
+1351 879 In their sport, the position played by x -1 In their sport, the position played by Don Larsen pitcher Don Larsen "[',' ' the' ' pitcher' ',' ' was' ' to' ' be' ' the' ' first' ' to'
+ ' throw' ' a' ' no' '-' 'h' 'itter' ' in' ' the' ' World' ' Series']" , the pitcher , was to be the first to throw a no - h itter in the World Series True 2, Yankees pitcher Don Larsen threw a perfect 6 [' 2', ',', ' Yankees', ' pitcher', ' Don', ' Lars', 'en']
+1352 879 In their sport, the position played by x -1 In their sport, the position played by Don Larsen pitcher Don Larsen "[',' ' the' ' pitcher' ',' ' was' ' to' ' be' ' the' ' first' ' to'
+ ' throw' ' a' ' no' '-' 'h' 'itter' ' in' ' the' ' World' ' Series']" , the pitcher , was to be the first to throw a no - h itter in the World Series True Yankees pitcher Don Larsen — who threw the 4 [' Yankees', ' pitcher', ' Don', ' Lars', 'en']
+1353 879 In their sport, the position played by x -1 In their sport, the position played by Don Larsen pitcher Don Larsen "[',' ' the' ' pitcher' ',' ' was' ' to' ' be' ' the' ' first' ' to'
+ ' throw' ' a' ' no' '-' 'h' 'itter' ' in' ' the' ' World' ' Series']" , the pitcher , was to be the first to throw a no - h itter in the World Series True Giants along with Don Larsen in exchange for 5 [' Giants', ' along', ' with', ' Don', ' Lars', 'en']
+1354 879 In their sport, the position played by x -1 In their sport, the position played by Don Larsen pitcher Don Larsen "[',' ' the' ' pitcher' ',' ' was' ' to' ' be' ' the' ' first' ' to'
+ ' throw' ' a' ' no' '-' 'h' 'itter' ' in' ' the' ' World' ' Series']" , the pitcher , was to be the first to throw a no - h itter in the World Series True Pierce and Don Larsen for Bob Farley, Eddie 4 [' Pierce', ' and', ' Don', ' Lars', 'en']
+1355 883 In their sport, the position played by x -1 In their sport, the position played by Josh McCown quarterback Josh McCown "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Jets'
+ ' in' ' the' ' NFL' ' for' ' the' ' past' ' two' ' seasons' ',' ' is'
+ ' a']" , who was a backup quarterback for the Jets in the NFL for the past two seasons , is a True Concannon in 1972, Josh McCown was named the 7 [' Conc', 'annon', ' in', ' 1972', ',', ' Josh', ' McC', 'own']
+1356 883 In their sport, the position played by x -1 In their sport, the position played by Josh McCown quarterback Josh McCown "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Jets'
+ ' in' ' the' ' NFL' ' for' ' the' ' past' ' two' ' seasons' ',' ' is'
+ ' a']" , who was a backup quarterback for the Jets in the NFL for the past two seasons , is a True quarterback spot behind Josh McCown and Mike Glennon. 5 [' quarterback', ' spot', ' behind', ' Josh', ' McC', 'own']
+1357 883 In their sport, the position played by x -1 In their sport, the position played by Josh McCown quarterback Josh McCown "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Jets'
+ ' in' ' the' ' NFL' ' for' ' the' ' past' ' two' ' seasons' ',' ' is'
+ ' a']" , who was a backup quarterback for the Jets in the NFL for the past two seasons , is a True that fans favored Josh McCown with 66.87 5 [' that', ' fans', ' favored', ' Josh', ' McC', 'own']
+1358 883 In their sport, the position played by x -1 In their sport, the position played by Josh McCown quarterback Josh McCown "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Jets'
+ ' in' ' the' ' NFL' ' for' ' the' ' past' ' two' ' seasons' ',' ' is'
+ ' a']" , who was a backup quarterback for the Jets in the NFL for the past two seasons , is a True Bears quarterback Josh McCown had signed a 4 [' Bears', ' quarterback', ' Josh', ' McC', 'own']
+1359 883 In their sport, the position played by x -1 In their sport, the position played by Josh McCown quarterback Josh McCown "[',' ' who' ' was' ' a' ' backup' ' quarterback' ' for' ' the' ' Jets'
+ ' in' ' the' ' NFL' ' for' ' the' ' past' ' two' ' seasons' ',' ' is'
+ ' a']" , who was a backup quarterback for the Jets in the NFL for the past two seasons , is a True with 2: 22 remaining, Josh McCown entered the 8 [' with', ' 2', ':', ' 22', ' remaining', ',', ' Josh', ' McC', 'own']
+1360 886 In their sport, the position played by x -1 In their sport, the position played by Drake Britton pitcher Drake Britton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2014' '.' '\n'
+ '\n']" ", who was a member of the team that won the NCAA Division I Championship in 2014 .
+
+" False Che-Hsuan Lin and Drake Britton were added to 9 [' Che', '-', 'H', 'su', 'an', ' Lin', ' and', ' Drake', ' Britt', 'on']
+1361 886 In their sport, the position played by x -1 In their sport, the position played by Drake Britton pitcher Drake Britton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2014' '.' '\n'
+ '\n']" ", who was a member of the team that won the NCAA Division I Championship in 2014 .
+
+" False Che-Hsuan Lin and Drake Britton were added 9 [' Che', '-', 'H', 'su', 'an', ' Lin', ' and', ' Drake', ' Britt', 'on']
+1362 886 In their sport, the position played by x -1 In their sport, the position played by Drake Britton pitcher Drake Britton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2014' '.' '\n'
+ '\n']" ", who was a member of the team that won the NCAA Division I Championship in 2014 .
+
+" False Che-Hsuan Lin and Drake Britton were added to 9 [' Che', '-', 'H', 'su', 'an', ' Lin', ' and', ' Drake', ' Britt', 'on']
+1363 886 In their sport, the position played by x -1 In their sport, the position played by Drake Britton pitcher Drake Britton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' NCAA' ' Division' ' I' ' Championship' ' in' ' 2014' '.' '\n'
+ '\n']" ", who was a member of the team that won the NCAA Division I Championship in 2014 .
+
+" False Middlebrooks, Che-Hsuan Lin and Drake Britton were added to 13 [' Middle', 'brook', 's', ',', ' Che', '-', 'H', 'su', 'an', ' Lin', ' and', ' Drake', ' Britt', 'on']
+1364 888 In their sport, the position played by x -1 In their sport, the position played by Marvin Mitchell linebacker Marvin Mitchell "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' two' ' games' ' of' ' the' ' series' ',' ' and' ' the']" , who was a member of the team that won the first two games of the series , and the False Audie Cole, with Marvin Mitchell recovering. Despite 6 [' Aud', 'ie', ' Cole', ',', ' with', ' Marvin', ' Mitchell']
+1365 888 In their sport, the position played by x -1 In their sport, the position played by Marvin Mitchell linebacker Marvin Mitchell "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' two' ' games' ' of' ' the' ' series' ',' ' and' ' the']" , who was a member of the team that won the first two games of the series , and the False Cole, with Marvin Mitchell recovering. Despite 4 [' Cole', ',', ' with', ' Marvin', ' Mitchell']
+1366 888 In their sport, the position played by x -1 In their sport, the position played by Marvin Mitchell linebacker Marvin Mitchell "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' two' ' games' ' of' ' the' ' series' ',' ' and' ' the']" , who was a member of the team that won the first two games of the series , and the False Audie Cole, with Marvin Mitchell recovering. Despite 6 [' Aud', 'ie', ' Cole', ',', ' with', ' Marvin', ' Mitchell']
+1367 888 In their sport, the position played by x -1 In their sport, the position played by Marvin Mitchell linebacker Marvin Mitchell "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' two' ' games' ' of' ' the' ' series' ',' ' and' ' the']" , who was a member of the team that won the first two games of the series , and the False Cole, with Marvin Mitchell recovering. Despite 4 [' Cole', ',', ' with', ' Marvin', ' Mitchell']
+1368 890 In their sport, the position played by x -1 In their sport, the position played by Dan Conners linebacker Dan Conners "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference False Telltale CEO Dan Conners noted that 5 [' Tell', 'tale', ' CEO', ' Dan', ' Con', 'ners']
+1369 890 In their sport, the position played by x -1 In their sport, the position played by Dan Conners linebacker Dan Conners "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference False later. Telltale CEO Dan Conners noted that the game 7 [' later', '.', ' Tell', 'tale', ' CEO', ' Dan', ' Con', 'ners']
+1370 890 In their sport, the position played by x -1 In their sport, the position played by Dan Conners linebacker Dan Conners "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' perfect'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a perfect example of the kind of player who can make a difference False later. Telltale CEO Dan Conners noted that the 7 [' later', '.', ' Tell', 'tale', ' CEO', ' Dan', ' Con', 'ners']
+1371 891 In their sport, the position played by x -1 In their sport, the position played by Calum Chambers defender Calum Chambers "[',' ' who' ' was' ' a' ' late' ' substitute' ' for' ' the' ' injured'
+ ' Kier' 'an' ' Gibbs' ',' ' was' ' a' ' surprise' '.' '\n' '\n' 'The']" ", who was a late substitute for the injured Kier an Gibbs , was a surprise .
+
+ The" False Mathieu Debuchy, Calum Chambers and Alexis Sánchez. 8 [' Math', 'ieu', ' Deb', 'uch', 'y', ',', ' Cal', 'um', ' Chambers']
+1372 891 In their sport, the position played by x -1 In their sport, the position played by Calum Chambers defender Calum Chambers "[',' ' who' ' was' ' a' ' late' ' substitute' ' for' ' the' ' injured'
+ ' Kier' 'an' ' Gibbs' ',' ' was' ' a' ' surprise' '.' '\n' '\n' 'The']" ", who was a late substitute for the injured Kier an Gibbs , was a surprise .
+
+ The" False Mathieu Debuchy, Calum Chambers and Alexis 8 [' Math', 'ieu', ' Deb', 'uch', 'y', ',', ' Cal', 'um', ' Chambers']
+1373 891 In their sport, the position played by x -1 In their sport, the position played by Calum Chambers defender Calum Chambers "[',' ' who' ' was' ' a' ' late' ' substitute' ' for' ' the' ' injured'
+ ' Kier' 'an' ' Gibbs' ',' ' was' ' a' ' surprise' '.' '\n' '\n' 'The']" ", who was a late substitute for the injured Kier an Gibbs , was a surprise .
+
+ The" False Mathieu Debuchy, Calum Chambers and Alexis 8 [' Math', 'ieu', ' Deb', 'uch', 'y', ',', ' Cal', 'um', ' Chambers']
+1374 891 In their sport, the position played by x -1 In their sport, the position played by Calum Chambers defender Calum Chambers "[',' ' who' ' was' ' a' ' late' ' substitute' ' for' ' the' ' injured'
+ ' Kier' 'an' ' Gibbs' ',' ' was' ' a' ' surprise' '.' '\n' '\n' 'The']" ", who was a late substitute for the injured Kier an Gibbs , was a surprise .
+
+ The" False Mathieu Debuchy, Calum Chambers and Alexis Sánchez. 8 [' Math', 'ieu', ' Deb', 'uch', 'y', ',', ' Cal', 'um', ' Chambers']
+1375 894 In their sport, the position played by x -1 In their sport, the position played by Paulo Sousa midfielder Paulo Sousa "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False keep him, and that Paulo Sousa had been the club's 8 [' keep', ' him', ',', ' and', ' that', ' Paulo', ' S', 'ous', 'a']
+1376 894 In their sport, the position played by x -1 In their sport, the position played by Paulo Sousa midfielder Paulo Sousa "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False keep him, and that Paulo Sousa had been the club's 8 [' keep', ' him', ',', ' and', ' that', ' Paulo', ' S', 'ous', 'a']
+1377 894 In their sport, the position played by x -1 In their sport, the position played by Paulo Sousa midfielder Paulo Sousa "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False keep him, and that Paulo Sousa had been the 8 [' keep', ' him', ',', ' and', ' that', ' Paulo', ' S', 'ous', 'a']
+1378 894 In their sport, the position played by x -1 In their sport, the position played by Paulo Sousa midfielder Paulo Sousa "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False him, and that Paulo Sousa had been the club's 7 [' him', ',', ' and', ' that', ' Paulo', ' S', 'ous', 'a']
+1379 894 In their sport, the position played by x -1 In their sport, the position played by Paulo Sousa midfielder Paulo Sousa "[',' ' who' ' was' ' the' ' first' ' to' ' score' ' a' ' goal' ' in'
+ ' the' ' game' ',' ' was' ' a' ' great' ' example' ' of' ' how' ' the']" , who was the first to score a goal in the game , was a great example of how the False 9 league games, Paulo Sousa was sacked by the 7 [' 9', ' league', ' games', ',', ' Paulo', ' S', 'ous', 'a']
+1380 897 In their sport, the position played by x -1 In their sport, the position played by Stephen McGee quarterback Stephen McGee "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False quarterback Stephen McGee and TCU defensive 2 [' quarterback', ' Stephen', ' McGee']
+1381 897 In their sport, the position played by x -1 In their sport, the position played by Stephen McGee quarterback Stephen McGee "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False 2010, when he sacked QB Stephen McGee for an 11-yard loss 7 [' 2010', ',', ' when', ' he', ' sacked', ' QB', ' Stephen', ' McGee']
+1382 897 In their sport, the position played by x -1 In their sport, the position played by Stephen McGee quarterback Stephen McGee "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False screen pass from Stephen McGee to Mike Goodson. 4 [' screen', ' pass', ' from', ' Stephen', ' McGee']
+1383 897 In their sport, the position played by x -1 In their sport, the position played by Stephen McGee quarterback Stephen McGee "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False plays later, Stephen McGee scampered into the 4 [' plays', ' later', ',', ' Stephen', ' McGee']
+1384 897 In their sport, the position played by x -1 In their sport, the position played by Stephen McGee quarterback Stephen McGee "[' in' ' the' ' film' ',' ' the' ' character' ' of' ' the' ' hero' ' is'
+ ' a' ' bit' ' more' ' complex' '.' ' He' ' is' ' a' ' man' ' who']" in the film , the character of the hero is a bit more complex . He is a man who False moving the ball. Stephen McGee threw a 44-yard 5 [' moving', ' the', ' ball', '.', ' Stephen', ' McGee']
+1385 898 In their sport, the position played by x -1 In their sport, the position played by Matt Holland midfielder Matt Holland "[',' ' who' ' is' ' a' ' former' ' All' '-' 'American' ' at' ' the'
+ ' University' ' of' ' Michigan' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the']" , who is a former All - American at the University of Michigan , is a perfect example of the False players such as Matt Holland accepted the 4 [' players', ' such', ' as', ' Matt', ' Holland']
+1386 898 In their sport, the position played by x -1 In their sport, the position played by Matt Holland midfielder Matt Holland "[',' ' who' ' is' ' a' ' former' ' All' '-' 'American' ' at' ' the'
+ ' University' ' of' ' Michigan' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the']" , who is a former All - American at the University of Michigan , is a perfect example of the False players such as Matt Holland accepted the chant 4 [' players', ' such', ' as', ' Matt', ' Holland']
+1387 898 In their sport, the position played by x -1 In their sport, the position played by Matt Holland midfielder Matt Holland "[',' ' who' ' is' ' a' ' former' ' All' '-' 'American' ' at' ' the'
+ ' University' ' of' ' Michigan' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the']" , who is a former All - American at the University of Michigan , is a perfect example of the False while players such as Matt Holland accepted the 5 [' while', ' players', ' such', ' as', ' Matt', ' Holland']
+1388 898 In their sport, the position played by x -1 In their sport, the position played by Matt Holland midfielder Matt Holland "[',' ' who' ' is' ' a' ' former' ' All' '-' 'American' ' at' ' the'
+ ' University' ' of' ' Michigan' ',' ' is' ' a' ' perfect' ' example'
+ ' of' ' the']" , who is a former All - American at the University of Michigan , is a perfect example of the False players such as Matt Holland accepted the 4 [' players', ' such', ' as', ' Matt', ' Holland']
+1389 901 In their sport, the position played by x -1 In their sport, the position played by Mike Hampton pitcher Mike Hampton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' coach' '.' '\n' '\n' 'I' ' think' ' the' ' reason' ' why']" ", who was a great player , but not a great coach .
+
+ I think the reason why" False against pitcher Mike Hampton in an 8 – 0 loss. 3 [' against', ' pitcher', ' Mike', ' Hampton']
+1390 901 In their sport, the position played by x -1 In their sport, the position played by Mike Hampton pitcher Mike Hampton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' coach' '.' '\n' '\n' 'I' ' think' ' the' ' reason' ' why']" ", who was a great player , but not a great coach .
+
+ I think the reason why" False however, he tied Mike Hampton for the most 5 [' however', ',', ' he', ' tied', ' Mike', ' Hampton']
+1391 901 In their sport, the position played by x -1 In their sport, the position played by Mike Hampton pitcher Mike Hampton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' coach' '.' '\n' '\n' 'I' ' think' ' the' ' reason' ' why']" ", who was a great player , but not a great coach .
+
+ I think the reason why" False however, he tied Mike Hampton for the most runs 5 [' however', ',', ' he', ' tied', ' Mike', ' Hampton']
+1392 901 In their sport, the position played by x -1 In their sport, the position played by Mike Hampton pitcher Mike Hampton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' coach' '.' '\n' '\n' 'I' ' think' ' the' ' reason' ' why']" ", who was a great player , but not a great coach .
+
+ I think the reason why" False however, he tied Mike Hampton for the most runs 5 [' however', ',', ' he', ' tied', ' Mike', ' Hampton']
+1393 901 In their sport, the position played by x -1 In their sport, the position played by Mike Hampton pitcher Mike Hampton "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' coach' '.' '\n' '\n' 'I' ' think' ' the' ' reason' ' why']" ", who was a great player , but not a great coach .
+
+ I think the reason why" False single against pitcher Mike Hampton in an 8 – 0 4 [' single', ' against', ' pitcher', ' Mike', ' Hampton']
+1394 902 In their sport, the position played by x -1 In their sport, the position played by Egil Olsen midfielder Egil Olsen "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False international level, Egil Olsen called up King 5 [' international', ' level', ',', ' Eg', 'il', ' Olsen']
+1395 905 In their sport, the position played by x -1 In their sport, the position played by Jim Zorn quarterback Jim Zorn "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True former teammate Jim Zorn offered his services 4 [' former', ' teammate', ' Jim', ' Z', 'orn']
+1396 905 In their sport, the position played by x -1 In their sport, the position played by Jim Zorn quarterback Jim Zorn "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True two field goals. Jim Zorn had been signed to 6 [' two', ' field', ' goals', '.', ' Jim', ' Z', 'orn']
+1397 905 In their sport, the position played by x -1 In their sport, the position played by Jim Zorn quarterback Jim Zorn "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True When former teammate Jim Zorn offered his services 5 [' When', ' former', ' teammate', ' Jim', ' Z', 'orn']
+1398 905 In their sport, the position played by x -1 In their sport, the position played by Jim Zorn quarterback Jim Zorn "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Bombers quarterback Jim Zorn returned to the 5 [' Bom', 'bers', ' quarterback', ' Jim', ' Z', 'orn']
+1399 905 In their sport, the position played by x -1 In their sport, the position played by Jim Zorn quarterback Jim Zorn "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True to two field goals. Jim Zorn had been signed 7 [' to', ' two', ' field', ' goals', '.', ' Jim', ' Z', 'orn']
+1400 909 In their sport, the position played by x -1 In their sport, the position played by Abner Dalrymple outfielder Abner Dalrymple "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False receive one was Abner Dalrymple in 1881. During 8 [' receive', ' one', ' was', ' Ab', 'ner', ' Dal', 'ry', 'm', 'ple']
+1401 909 In their sport, the position played by x -1 In their sport, the position played by Abner Dalrymple outfielder Abner Dalrymple "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False an 11 – 7 game after Abner Dalrymple in 1881. For 11 [' an', ' 11', ' –', ' 7', ' game', ' after', ' Ab', 'ner', ' Dal', 'ry', 'm', 'ple']
+1402 909 In their sport, the position played by x -1 In their sport, the position played by Abner Dalrymple outfielder Abner Dalrymple "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False to receive one was Abner Dalrymple in 1881. During 9 [' to', ' receive', ' one', ' was', ' Ab', 'ner', ' Dal', 'ry', 'm', 'ple']
+1403 909 In their sport, the position played by x -1 In their sport, the position played by Abner Dalrymple outfielder Abner Dalrymple "[',' ' the' '\n' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
+ ' ' ' ' ' ']" ", the
+ " False 11 – 7 game after Abner Dalrymple in 1881. For the 1901 10 [' 11', ' –', ' 7', ' game', ' after', ' Ab', 'ner', ' Dal', 'ry', 'm', 'ple']
+1404 910 In their sport, the position played by x -1 In their sport, the position played by Cory Schneider goaltender Cory Schneider "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Calder'
+ ' Trophy' ' as' ' the' ' NHL' ""'s"" ' rookie' ' of' ' the' ' year' ',']" , who was the first goalie to win the Calder Trophy as the NHL 's rookie of the year , False acquired goaltender Cory Schneider from Vancouver in 3 [' acquired', ' goaltender', ' Cory', ' Schneider']
+1405 910 In their sport, the position played by x -1 In their sport, the position played by Cory Schneider goaltender Cory Schneider "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Calder'
+ ' Trophy' ' as' ' the' ' NHL' ""'s"" ' rookie' ' of' ' the' ' year' ',']" , who was the first goalie to win the Calder Trophy as the NHL 's rookie of the year , False re-signing of Cory Schneider to a three-year 6 [' re', '-', 'sign', 'ing', ' of', ' Cory', ' Schneider']
+1406 910 In their sport, the position played by x -1 In their sport, the position played by Cory Schneider goaltender Cory Schneider "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Calder'
+ ' Trophy' ' as' ' the' ' NHL' ""'s"" ' rookie' ' of' ' the' ' year' ',']" , who was the first goalie to win the Calder Trophy as the NHL 's rookie of the year , False acquired goaltender Cory Schneider from Vancouver 3 [' acquired', ' goaltender', ' Cory', ' Schneider']
+1407 910 In their sport, the position played by x -1 In their sport, the position played by Cory Schneider goaltender Cory Schneider "[',' ' who' ' was' ' the' ' first' ' goalie' ' to' ' win' ' the' ' Calder'
+ ' Trophy' ' as' ' the' ' NHL' ""'s"" ' rookie' ' of' ' the' ' year' ',']" , who was the first goalie to win the Calder Trophy as the NHL 's rookie of the year , False " Schneider =
+" 4 [' Schneider', ' =', 'C', 'ory', ' Schneider']
+1408 911 In their sport, the position played by x -1 In their sport, the position played by Darcy Kuemper goaltender Darcy Kuemper "[',' ' the' ' goalie' ',' ' is' ' the' ' most' ' important' ' position'
+ ' on' ' the' ' ice' '.' ' He' ' is' ' the' ' last' ' line' ' of'
+ ' defense']" , the goalie , is the most important position on the ice . He is the last line of defense False Sean Couturier and Darcy Kuemper for the latter. 9 [' Sean', ' Cout', 'ur', 'ier', ' and', ' Dar', 'cy', ' Ku', 'em', 'per']
+1409 912 In their sport, the position played by x -1 In their sport, the position played by Ryan Fraser midfielder Ryan Fraser "[',' ' who' ' has' ' been' ' a' ' revelation' ' for' ' the' ' club'
+ ' since' ' his' ' arrival' ' from' ' B' 'ourn' 'emouth' ' in' ' the'
+ ' summer' ',']" , who has been a revelation for the club since his arrival from B ourn emouth in the summer , False with fellow winger Ryan Fraser as Bournemouth 4 [' with', ' fellow', ' winger', ' Ryan', ' Fraser']
+1410 914 In their sport, the position played by x -1 In their sport, the position played by Brian Labone defender Brian Labone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False matches for Leeds; Brian Labone would take his 6 [' matches', ' for', ' Leeds', ';', ' Brian', ' Lab', 'one']
+1411 914 In their sport, the position played by x -1 In their sport, the position played by Brian Labone defender Brian Labone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False matches for Leeds; Brian Labone would take his 6 [' matches', ' for', ' Leeds', ';', ' Brian', ' Lab', 'one']
+1412 914 In their sport, the position played by x -1 In their sport, the position played by Brian Labone defender Brian Labone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False matches for Leeds; Brian Labone would take his place 6 [' matches', ' for', ' Leeds', ';', ' Brian', ' Lab', 'one']
+1413 914 In their sport, the position played by x -1 In their sport, the position played by Brian Labone defender Brian Labone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False matches for Leeds; Brian Labone would take his 6 [' matches', ' for', ' Leeds', ';', ' Brian', ' Lab', 'one']
+1414 915 In their sport, the position played by x -1 In their sport, the position played by Derrick Thomas linebacker Derrick Thomas "[',' ' who' ' was' ' a' ' defensive' ' end' ' for' ' the' ' New' ' York'
+ ' Giants' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous']" , who was a defensive end for the New York Giants , is a position that is not as glamorous False Broncos, and Derrick Thomas was paralyzed from 4 [' Broncos', ',', ' and', ' Derrick', ' Thomas']
+1415 915 In their sport, the position played by x -1 In their sport, the position played by Derrick Thomas linebacker Derrick Thomas "[',' ' who' ' was' ' a' ' defensive' ' end' ' for' ' the' ' New' ' York'
+ ' Giants' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous']" , who was a defensive end for the New York Giants , is a position that is not as glamorous False Broncos, and Derrick Thomas was paralyzed from 4 [' Broncos', ',', ' and', ' Derrick', ' Thomas']
+1416 915 In their sport, the position played by x -1 In their sport, the position played by Derrick Thomas linebacker Derrick Thomas "[',' ' who' ' was' ' a' ' defensive' ' end' ' for' ' the' ' New' ' York'
+ ' Giants' ',' ' is' ' a' ' position' ' that' ' is' ' not' ' as'
+ ' glamorous']" , who was a defensive end for the New York Giants , is a position that is not as glamorous False MVP and won the Derrick Thomas Community Award. His 5 [' MVP', ' and', ' won', ' the', ' Derrick', ' Thomas']
+1417 919 In their sport, the position played by x -1 In their sport, the position played by Jean Makoun midfielder Jean Makoun "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' French' ' national'
+ ' team' ',' ' was' ' a' ' key' ' factor' ' in' ' the' ' team' ""'s""
+ ' success']" , who was a member of the French national team , was a key factor in the team 's success False followed by goals from Jean Makoun and Juninho either 6 [' followed', ' by', ' goals', ' from', ' Jean', ' Mak', 'oun']
+1418 920 In their sport, the position played by x -1 In their sport, the position played by Carey Price goaltender Carey Price "[',' ' who' ' was' ' named' ' the' ' NHL' ""'s"" ' top' ' goaltender' ' in'
+ ' the' ' regular' ' season' ',' ' is' ' a' ' big' ' reason' ' why' ' the']" , who was named the NHL 's top goaltender in the regular season , is a big reason why the True " season: 44, Carey Price (2014 – 15)
+" 5 [' season', ':', ' 44', ',', ' Carey', ' Price']
+1419 920 In their sport, the position played by x -1 In their sport, the position played by Carey Price goaltender Carey Price "[',' ' who' ' was' ' named' ' the' ' NHL' ""'s"" ' top' ' goaltender' ' in'
+ ' the' ' regular' ' season' ',' ' is' ' a' ' big' ' reason' ' why' ' the']" , who was named the NHL 's top goaltender in the regular season , is a big reason why the True Kredier ran into Carey Price injuring his leg. 6 [' K', 'red', 'ier', ' ran', ' into', ' Carey', ' Price']
+1420 920 In their sport, the position played by x -1 In their sport, the position played by Carey Price goaltender Carey Price "[',' ' who' ' was' ' named' ' the' ' NHL' ""'s"" ' top' ' goaltender' ' in'
+ ' the' ' regular' ' season' ',' ' is' ' a' ' big' ' reason' ' why' ' the']" , who was named the NHL 's top goaltender in the regular season , is a big reason why the True " a season: 44, Carey Price (2014 – 15)
+" 6 [' a', ' season', ':', ' 44', ',', ' Carey', ' Price']
+1421 920 In their sport, the position played by x -1 In their sport, the position played by Carey Price goaltender Carey Price "[',' ' who' ' was' ' named' ' the' ' NHL' ""'s"" ' top' ' goaltender' ' in'
+ ' the' ' regular' ' season' ',' ' is' ' a' ' big' ' reason' ' why' ' the']" , who was named the NHL 's top goaltender in the regular season , is a big reason why the True Kredier ran into Carey Price injuring his leg. 6 [' K', 'red', 'ier', ' ran', ' into', ' Carey', ' Price']
+1422 920 In their sport, the position played by x -1 In their sport, the position played by Carey Price goaltender Carey Price "[',' ' who' ' was' ' named' ' the' ' NHL' ""'s"" ' top' ' goaltender' ' in'
+ ' the' ' regular' ' season' ',' ' is' ' a' ' big' ' reason' ' why' ' the']" , who was named the NHL 's top goaltender in the regular season , is a big reason why the True Canadiens' goaltender Carey Price made several 4 "[' Canadiens', ""'"", ' goaltender', ' Carey', ' Price']"
+1423 921 In their sport, the position played by x -1 In their sport, the position played by Rudy Carpenter quarterback Rudy Carpenter "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' 1930' ',' ' and' ' the' ' first']" , who was a member of the team that won the first World Cup in 1930 , and the first False season. ASU ’ s Rudy Carpenter fumbled the ball as 8 [' season', '.', ' AS', 'U', ' �', '�', ' s', ' Rudy', ' Carpenter']
+1424 923 In their sport, the position played by x -1 In their sport, the position played by Kirk Rueter pitcher Kirk Rueter "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False Giants'rotation as Kirk Rueter was demoted to 7 "[' Giants', ""'"", 'rot', 'ation', ' as', ' Kirk', ' Ru', 'eter']"
+1425 923 In their sport, the position played by x -1 In their sport, the position played by Kirk Rueter pitcher Kirk Rueter "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False Giants'rotation as Kirk Rueter was demoted 7 "[' Giants', ""'"", 'rot', 'ation', ' as', ' Kirk', ' Ru', 'eter']"
+1426 923 In their sport, the position played by x -1 In their sport, the position played by Kirk Rueter pitcher Kirk Rueter "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Rifle'
+ ' Association' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind'
+ ' of' ' person' ' who']" , the former head of the National Rifle Association , is a perfect example of the kind of person who False appearance since Kirk Rueter did so on August 4 [' appearance', ' since', ' Kirk', ' Ru', 'eter']
+1427 924 In their sport, the position played by x -1 In their sport, the position played by Manny Ramirez outfielder Manny Ramirez "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' teammate' '.' '\n' '\n' 'I' ' think' ' the' ' Dodgers' ' are']" ", who was a great player , but not a great teammate .
+
+ I think the Dodgers are" False audible, center Manny Ramirez snapped the ball 4 [' audible', ',', ' center', ' Manny', ' Ramirez']
+1428 924 In their sport, the position played by x -1 In their sport, the position played by Manny Ramirez outfielder Manny Ramirez "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' teammate' '.' '\n' '\n' 'I' ' think' ' the' ' Dodgers' ' are']" ", who was a great player , but not a great teammate .
+
+ I think the Dodgers are" False " role"". However, with Manny Ramirez sudden retirement," 6 "[' role', '"".', ' However', ',', ' with', ' Manny', ' Ramirez']"
+1429 924 In their sport, the position played by x -1 In their sport, the position played by Manny Ramirez outfielder Manny Ramirez "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' teammate' '.' '\n' '\n' 'I' ' think' ' the' ' Dodgers' ' are']" ", who was a great player , but not a great teammate .
+
+ I think the Dodgers are" False David Ortiz, and Manny Ramirez received a 50-game 5 [' David', ' Ortiz', ',', ' and', ' Manny', ' Ramirez']
+1430 924 In their sport, the position played by x -1 In their sport, the position played by Manny Ramirez outfielder Manny Ramirez "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' teammate' '.' '\n' '\n' 'I' ' think' ' the' ' Dodgers' ' are']" ", who was a great player , but not a great teammate .
+
+ I think the Dodgers are" False " Ramirez =
+" 4 [' Ramirez', ' =', 'M', 'anny', ' Ramirez']
+1431 924 In their sport, the position played by x -1 In their sport, the position played by Manny Ramirez outfielder Manny Ramirez "[',' ' who' ' was' ' a' ' great' ' player' ',' ' but' ' not' ' a' ' great'
+ ' teammate' '.' '\n' '\n' 'I' ' think' ' the' ' Dodgers' ' are']" ", who was a great player , but not a great teammate .
+
+ I think the Dodgers are" False reportedly showed Manny Ramirez taking a swing 3 [' reportedly', ' showed', ' Manny', ' Ramirez']
+1432 925 In their sport, the position played by x -1 In their sport, the position played by Henk ten Cate midfielder Henk ten Cate "[',' ' the' ' Dutch' 'man' ' who' ' has' ' been' ' the' ' team' '�' '�'
+ 's' ' coach' ' since' ' the' ' start' ' of' ' the' ' season' ',']" , the Dutch man who has been the team � � s coach since the start of the season , False assistant manager Henk ten Cate admitted that 6 [' assistant', ' manager', ' Hen', 'k', ' ten', ' C', 'ate']
+1433 928 In their sport, the position played by x -1 In their sport, the position played by Daunte Culpepper quarterback Daunte Culpepper "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True offense, led by Daunte Culpepper and receivers Randy 9 [' offense', ',', ' led', ' by', ' D', 'aun', 'te', ' Cul', 'pe', 'pper']
+1434 928 In their sport, the position played by x -1 In their sport, the position played by Daunte Culpepper quarterback Daunte Culpepper "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True the mark set by Daunte Culpepper for Central Florida, 9 [' the', ' mark', ' set', ' by', ' D', 'aun', 'te', ' Cul', 'pe', 'pper']
+1435 928 In their sport, the position played by x -1 In their sport, the position played by Daunte Culpepper quarterback Daunte Culpepper "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True offense, led by Daunte Culpepper and receivers 9 [' offense', ',', ' led', ' by', ' D', 'aun', 'te', ' Cul', 'pe', 'pper']
+1436 928 In their sport, the position played by x -1 In their sport, the position played by Daunte Culpepper quarterback Daunte Culpepper "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True All-Pro quarterback Daunte Culpepper was signed to a 9 [' All', '-', 'Pro', ' quarterback', ' D', 'aun', 'te', ' Cul', 'pe', 'pper']
+1437 929 In their sport, the position played by x -1 In their sport, the position played by Lee Sharpe midfielder Lee Sharpe "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True Loan signing Lee Sharpe and Dean Windass 4 [' Loan', ' signing', ' Lee', ' Shar', 'pe']
+1438 929 In their sport, the position played by x -1 In their sport, the position played by Lee Sharpe midfielder Lee Sharpe "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True " Cantona and Lee Sharpe of Manchester United.
+" 5 [' Cant', 'ona', ' and', ' Lee', ' Shar', 'pe']
+1439 929 In their sport, the position played by x -1 In their sport, the position played by Lee Sharpe midfielder Lee Sharpe "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True Loan signing Lee Sharpe and Dean Windass 4 [' Loan', ' signing', ' Lee', ' Shar', 'pe']
+1440 929 In their sport, the position played by x -1 In their sport, the position played by Lee Sharpe midfielder Lee Sharpe "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True " in favour of Lee Sharpe on the left wing.
+" 5 [' in', ' favour', ' of', ' Lee', ' Shar', 'pe']
+1441 929 In their sport, the position played by x -1 In their sport, the position played by Lee Sharpe midfielder Lee Sharpe "[',' ' the' ' former' ' England' ' international' ',' ' is' ' a' ' key'
+ ' one' '.' '\n' '\n' 'The' ' former' ' Manchester' ' United'
+ ' midfielder' ' is' ' a']" ", the former England international , is a key one .
+
+ The former Manchester United midfielder is a" True picked in favour of Lee Sharpe on the left 6 [' picked', ' in', ' favour', ' of', ' Lee', ' Shar', 'pe']
+1442 931 In their sport, the position played by x -1 In their sport, the position played by Pud Galvin pitcher Pud Galvin "[' in' ' the' ' film' ' is' ' a' ' bit' ' of' ' a' ' stretch' ',' ' but'
+ ' it' '�' '�' 's' ' a' ' fun' ' one' '.' '\n']" " in the film is a bit of a stretch , but it � � s a fun one .
+" False pitcher. At the time, Pud Galvin had the second most 9 [' pitcher', '.', ' At', ' the', ' time', ',', ' P', 'ud', ' Gal', 'vin']
+1443 931 In their sport, the position played by x -1 In their sport, the position played by Pud Galvin pitcher Pud Galvin "[' in' ' the' ' film' ' is' ' a' ' bit' ' of' ' a' ' stretch' ',' ' but'
+ ' it' '�' '�' 's' ' a' ' fun' ' one' '.' '\n']" " in the film is a bit of a stretch , but it � � s a fun one .
+" False pitcher. At the time, Pud Galvin had the second 9 [' pitcher', '.', ' At', ' the', ' time', ',', ' P', 'ud', ' Gal', 'vin']
+1444 931 In their sport, the position played by x -1 In their sport, the position played by Pud Galvin pitcher Pud Galvin "[' in' ' the' ' film' ' is' ' a' ' bit' ' of' ' a' ' stretch' ',' ' but'
+ ' it' '�' '�' 's' ' a' ' fun' ' one' '.' '\n']" " in the film is a bit of a stretch , but it � � s a fun one .
+" False pitcher. At the time, Pud Galvin had the second most 9 [' pitcher', '.', ' At', ' the', ' time', ',', ' P', 'ud', ' Gal', 'vin']
+1445 933 In their sport, the position played by x -1 In their sport, the position played by Woodrow Dantzler quarterback Woodrow Dantzler "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' in' ' the' ' audience' ' at' ' the' '\n']" ", the
+
+ The first time I saw the movie , I was in the audience at the
+" False quarterback Woodrow Dantzler was removed 5 [' quarterback', ' Wood', 'row', ' D', 'antz', 'ler']
+1446 933 In their sport, the position played by x -1 In their sport, the position played by Woodrow Dantzler quarterback Woodrow Dantzler "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' in' ' the' ' audience' ' at' ' the' '\n']" ", the
+
+ The first time I saw the movie , I was in the audience at the
+" False starting quarterback Woodrow Dantzler was replaced by backup 6 [' starting', ' quarterback', ' Wood', 'row', ' D', 'antz', 'ler']
+1447 933 In their sport, the position played by x -1 In their sport, the position played by Woodrow Dantzler quarterback Woodrow Dantzler "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' in' ' the' ' audience' ' at' ' the' '\n']" ", the
+
+ The first time I saw the movie , I was in the audience at the
+" False starting quarterback Woodrow Dantzler was replaced 6 [' starting', ' quarterback', ' Wood', 'row', ' D', 'antz', 'ler']
+1448 933 In their sport, the position played by x -1 In their sport, the position played by Woodrow Dantzler quarterback Woodrow Dantzler "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' in' ' the' ' audience' ' at' ' the' '\n']" ", the
+
+ The first time I saw the movie , I was in the audience at the
+" False starting quarterback Woodrow Dantzler was replaced 6 [' starting', ' quarterback', ' Wood', 'row', ' D', 'antz', 'ler']
+1449 934 In their sport, the position played by x -1 In their sport, the position played by Drew Brees quarterback Drew Brees "[' and' ' the' ' Saints' ',' ' the' ' Saints' ' have' ' a' ' lot' ' of'
+ ' talent' ',' ' but' ' they' ' are' ' not' ' a' ' team' ' that' ' can']" and the Saints , the Saints have a lot of talent , but they are not a team that can False as a Buccaneer on a Drew Brees pass that 8 [' as', ' a', ' Bucc', 'ane', 'er', ' on', ' a', ' Drew', ' Brees']
+1450 934 In their sport, the position played by x -1 In their sport, the position played by Drew Brees quarterback Drew Brees "[' and' ' the' ' Saints' ',' ' the' ' Saints' ' have' ' a' ' lot' ' of'
+ ' talent' ',' ' but' ' they' ' are' ' not' ' a' ' team' ' that' ' can']" and the Saints , the Saints have a lot of talent , but they are not a team that can False " (surpassed by Drew Brees in 2011)
+" 6 [' (', 'sur', 'pass', 'ed', ' by', ' Drew', ' Brees']
+1451 934 In their sport, the position played by x -1 In their sport, the position played by Drew Brees quarterback Drew Brees "[' and' ' the' ' Saints' ',' ' the' ' Saints' ' have' ' a' ' lot' ' of'
+ ' talent' ',' ' but' ' they' ' are' ' not' ' a' ' team' ' that' ' can']" and the Saints , the Saints have a lot of talent , but they are not a team that can False the Bears, as Drew Brees was 0 – 4 in 5 [' the', ' Bears', ',', ' as', ' Drew', ' Brees']
+1452 934 In their sport, the position played by x -1 In their sport, the position played by Drew Brees quarterback Drew Brees "[' and' ' the' ' Saints' ',' ' the' ' Saints' ' have' ' a' ' lot' ' of'
+ ' talent' ',' ' but' ' they' ' are' ' not' ' a' ' team' ' that' ' can']" and the Saints , the Saints have a lot of talent , but they are not a team that can False " completions: 4 (surpassed by Drew Brees in 2011)
+" 10 [' comple', 'tions', ':', ' 4', ' (', 'sur', 'pass', 'ed', ' by', ' Drew', ' Brees']
+1453 934 In their sport, the position played by x -1 In their sport, the position played by Drew Brees quarterback Drew Brees "[' and' ' the' ' Saints' ',' ' the' ' Saints' ' have' ' a' ' lot' ' of'
+ ' talent' ',' ' but' ' they' ' are' ' not' ' a' ' team' ' that' ' can']" and the Saints , the Saints have a lot of talent , but they are not a team that can False for the Bears, as Drew Brees was 0 – 4 6 [' for', ' the', ' Bears', ',', ' as', ' Drew', ' Brees']
+1454 935 In their sport, the position played by x -1 In their sport, the position played by Garry Unger centre Garry Unger "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False Mahovlich, Garry Unger and Pete Stemkowski. 6 [' Mah', 'ov', 'lich', ',', ' Garry', ' Ung', 'er']
+1455 935 In their sport, the position played by x -1 In their sport, the position played by Garry Unger centre Garry Unger "[',' ' the' ' former' ' head' ' of' ' the' ' National' ' Security'
+ ' Agency' ',' ' is' ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of'
+ ' person' ' who']" , the former head of the National Security Agency , is a perfect example of the kind of person who False Frank Mahovlich, Garry Unger and Pete Stemkowski. 7 [' Frank', ' Mah', 'ov', 'lich', ',', ' Garry', ' Ung', 'er']
+1456 936 In their sport, the position played by x -1 In their sport, the position played by Bobby Layne quarterback Bobby Layne "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True 17 – 16 on a 33-yard Bobby Layne touchdown pass to 10 [' 17', ' –', ' 16', ' on', ' a', ' 33', '-', 'yard', ' Bobby', ' Lay', 'ne']
+1457 936 In their sport, the position played by x -1 In their sport, the position played by Bobby Layne quarterback Bobby Layne "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Johnny Lujack and Bobby Layne to join Kindt, 6 [' Johnny', ' Lu', 'jack', ' and', ' Bobby', ' Lay', 'ne']
+1458 936 In their sport, the position played by x -1 In their sport, the position played by Bobby Layne quarterback Bobby Layne "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Lujack and Bobby Layne to join Kindt, Luckman 5 [' Lu', 'jack', ' and', ' Bobby', ' Lay', 'ne']
+1459 936 In their sport, the position played by x -1 In their sport, the position played by Bobby Layne quarterback Bobby Layne "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True Johnny Lujack and Bobby Layne to join Kindt, Luckman 6 [' Johnny', ' Lu', 'jack', ' and', ' Bobby', ' Lay', 'ne']
+1460 936 In their sport, the position played by x -1 In their sport, the position played by Bobby Layne quarterback Bobby Layne "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True – 16 on a 33-yard Bobby Layne touchdown pass to 9 [' –', ' 16', ' on', ' a', ' 33', '-', 'yard', ' Bobby', ' Lay', 'ne']
+1461 943 In their sport, the position played by x -1 In their sport, the position played by Joe Theismann quarterback Joe Theismann "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True signing of a young Joe Theismann (and other 7 [' signing', ' of', ' a', ' young', ' Joe', ' The', 'is', 'mann']
+1462 943 In their sport, the position played by x -1 In their sport, the position played by Joe Theismann quarterback Joe Theismann "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True signing of a young Joe Theismann (and other 7 [' signing', ' of', ' a', ' young', ' Joe', ' The', 'is', 'mann']
+1463 943 In their sport, the position played by x -1 In their sport, the position played by Joe Theismann quarterback Joe Theismann "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True signing of a young Joe Theismann (and other American 7 [' signing', ' of', ' a', ' young', ' Joe', ' The', 'is', 'mann']
+1464 943 In their sport, the position played by x -1 In their sport, the position played by Joe Theismann quarterback Joe Theismann "[',' ' the' ' quarterback' ',' ' is' ' the' ' most' ' important'
+ ' position' ' on' ' the' ' field' '.' ' The' ' quarterback' ' is' ' the'
+ ' leader' ' of' ' the']" , the quarterback , is the most important position on the field . The quarterback is the leader of the True signing of a young Joe Theismann (and other American 7 [' signing', ' of', ' a', ' young', ' Joe', ' The', 'is', 'mann']
+1465 944 In their sport, the position played by x -1 In their sport, the position played by Bob Meusel outfielder Bob Meusel "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' a' ' little' ' confused' '.' ' I' ' thought']" ", the
+
+ The first time I saw the movie , I was a little confused . I thought" False featured fellow Yankees Bob Meusel and Bill Piercy (who 6 [' featured', ' fellow', ' Yankees', ' Bob', ' Me', 'u', 'sel']
+1466 944 In their sport, the position played by x -1 In their sport, the position played by Bob Meusel outfielder Bob Meusel "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' a' ' little' ' confused' '.' ' I' ' thought']" ", the
+
+ The first time I saw the movie , I was a little confused . I thought" False " Bob Meusel =
+" 3 [' Bob', ' Me', 'u', 'sel']
+1467 944 In their sport, the position played by x -1 In their sport, the position played by Bob Meusel outfielder Bob Meusel "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' a' ' little' ' confused' '.' ' I' ' thought']" ", the
+
+ The first time I saw the movie , I was a little confused . I thought" False Babe Ruth. Bob Meusel was then walked, 6 [' Babe', ' Ruth', '.', ' Bob', ' Me', 'u', 'sel']
+1468 944 In their sport, the position played by x -1 In their sport, the position played by Bob Meusel outfielder Bob Meusel "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' a' ' little' ' confused' '.' ' I' ' thought']" ", the
+
+ The first time I saw the movie , I was a little confused . I thought" False fellow outfielder Bob Meusel and third baseman 5 [' fellow', ' outfielder', ' Bob', ' Me', 'u', 'sel']
+1469 944 In their sport, the position played by x -1 In their sport, the position played by Bob Meusel outfielder Bob Meusel "[',' ' the' '\n' '\n' 'The' ' first' ' time' ' I' ' saw' ' the' ' movie'
+ ',' ' I' ' was' ' a' ' little' ' confused' '.' ' I' ' thought']" ", the
+
+ The first time I saw the movie , I was a little confused . I thought" False Ruth, Joe Dugan, and Bob Meusel remained in the 10 [' Ruth', ',', ' Joe', ' Dug', 'an', ',', ' and', ' Bob', ' Me', 'u', 'sel']
+1470 945 In their sport, the position played by x -1 In their sport, the position played by Ben Bennett quarterback Ben Bennett "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False prospector named Ben Bennett on his claim of 80 4 [' prospect', 'or', ' named', ' Ben', ' Bennett']
+1471 945 In their sport, the position played by x -1 In their sport, the position played by Ben Bennett quarterback Ben Bennett "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' team' ' that' ' won'
+ ' the' ' first' ' World' ' Cup' ' in' ' the' ' United' ' States' ' in'
+ ' 1991']" , who was a member of the team that won the first World Cup in the United States in 1991 False prospector named Ben Bennett on his claim of 80 4 [' prospect', 'or', ' named', ' Ben', ' Bennett']
+1472 947 In their sport, the position played by x -1 In their sport, the position played by Andrea Pirlo midfielder Andrea Pirlo "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Italian midfielder , is a key one . He is the ful cr um of the team , True Liverpool 1 – 0 up. Andrea Pirlo was next for Milan, 9 [' Liverpool', ' 1', ' –', ' 0', ' up', '.', ' Andrea', ' P', 'irl', 'o']
+1473 947 In their sport, the position played by x -1 In their sport, the position played by Andrea Pirlo midfielder Andrea Pirlo "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Italian midfielder , is a key one . He is the ful cr um of the team , True Liverpool 1 – 0 up. Andrea Pirlo was next for Milan, 9 [' Liverpool', ' 1', ' –', ' 0', ' up', '.', ' Andrea', ' P', 'irl', 'o']
+1474 947 In their sport, the position played by x -1 In their sport, the position played by Andrea Pirlo midfielder Andrea Pirlo "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Italian midfielder , is a key one . He is the ful cr um of the team , True 81st minute for Andrea Pirlo as Milan won 7 [' 81', 'st', ' minute', ' for', ' Andrea', ' P', 'irl', 'o']
+1475 947 In their sport, the position played by x -1 In their sport, the position played by Andrea Pirlo midfielder Andrea Pirlo "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Italian midfielder , is a key one . He is the ful cr um of the team , True Liverpool 1 – 0 up. Andrea Pirlo was next for 9 [' Liverpool', ' 1', ' –', ' 0', ' up', '.', ' Andrea', ' P', 'irl', 'o']
+1476 947 In their sport, the position played by x -1 In their sport, the position played by Andrea Pirlo midfielder Andrea Pirlo "[',' ' the' ' Italian' ' midfielder' ',' ' is' ' a' ' key' ' one' '.'
+ ' He' ' is' ' the' ' ful' 'cr' 'um' ' of' ' the' ' team' ',']" , the Italian midfielder , is a key one . He is the ful cr um of the team , True free-kick taken by Andrea Pirlo deflected off 8 [' free', '-', 'kick', ' taken', ' by', ' Andrea', ' P', 'irl', 'o']
diff --git a/patchscopes/code/preprocessed_data/factual/person_plays_pro_sport.tsv b/patchscopes/code/preprocessed_data/factual/person_plays_pro_sport.tsv
new file mode 100644
index 00000000..d5a23611
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/person_plays_pro_sport.tsv
@@ -0,0 +1,3814 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The professional sport played by x -1 The professional sport played by Andreas Ivanschitz soccer Andreas Ivanschitz "[',' ' a' ' German' '-' 'born' ' American' ' who' ' has' ' been' ' living'
+ ' in' ' the' ' United' ' States' ' since' ' the' ' age' ' of' ' three'
+ '.']" , a German - born American who has been living in the United States since the age of three . False international signings Andreas Ivanschitz and Román Torres 6 [' international', ' signings', ' Andreas', ' Iv', 'ans', 'ch', 'itz']
+1 0 The professional sport played by x -1 The professional sport played by Andreas Ivanschitz soccer Andreas Ivanschitz "[',' ' a' ' German' '-' 'born' ' American' ' who' ' has' ' been' ' living'
+ ' in' ' the' ' United' ' States' ' since' ' the' ' age' ' of' ' three'
+ '.']" , a German - born American who has been living in the United States since the age of three . False international signings Andreas Ivanschitz and Román Torres 6 [' international', ' signings', ' Andreas', ' Iv', 'ans', 'ch', 'itz']
+2 1 The professional sport played by x -1 The professional sport played by Roberto Clemente baseball Roberto Clemente "[',' ' Jr' '.' '\n' '\n' 'The' ' game' ' is' ' a' ' baseball' ' game'
+ ' played' ' by' ' two' ' teams' ' of' ' nine' ' players' ' each' '.']" ", Jr .
+
+ The game is a baseball game played by two teams of nine players each ." True Pirates' right-fielder Roberto Clemente got his 3,000th hit 8 "[' Pirates', ""'"", ' right', '-', 'field', 'er', ' Roberto', ' Clement', 'e']"
+3 1 The professional sport played by x -1 The professional sport played by Roberto Clemente baseball Roberto Clemente "[',' ' Jr' '.' '\n' '\n' 'The' ' game' ' is' ' a' ' baseball' ' game'
+ ' played' ' by' ' two' ' teams' ' of' ' nine' ' players' ' each' '.']" ", Jr .
+
+ The game is a baseball game played by two teams of nine players each ." True Jeter won the Roberto Clemente Award and was named 6 [' J', 'eter', ' won', ' the', ' Roberto', ' Clement', 'e']
+4 1 The professional sport played by x -1 The professional sport played by Roberto Clemente baseball Roberto Clemente "[',' ' Jr' '.' '\n' '\n' 'The' ' game' ' is' ' a' ' baseball' ' game'
+ ' played' ' by' ' two' ' teams' ' of' ' nine' ' players' ' each' '.']" ", Jr .
+
+ The game is a baseball game played by two teams of nine players each ." True Tour in 2001 at the Roberto Clemente Coliseum in San 7 [' Tour', ' in', ' 2001', ' at', ' the', ' Roberto', ' Clement', 'e']
+5 1 The professional sport played by x -1 The professional sport played by Roberto Clemente baseball Roberto Clemente "[',' ' Jr' '.' '\n' '\n' 'The' ' game' ' is' ' a' ' baseball' ' game'
+ ' played' ' by' ' two' ' teams' ' of' ' nine' ' players' ' each' '.']" ", Jr .
+
+ The game is a baseball game played by two teams of nine players each ." True Illustrated, and won the Roberto Clemente Award, Hank Aaron 7 [' Illustrated', ',', ' and', ' won', ' the', ' Roberto', ' Clement', 'e']
+6 1 The professional sport played by x -1 The professional sport played by Roberto Clemente baseball Roberto Clemente "[',' ' Jr' '.' '\n' '\n' 'The' ' game' ' is' ' a' ' baseball' ' game'
+ ' played' ' by' ' two' ' teams' ' of' ' nine' ' players' ' each' '.']" ", Jr .
+
+ The game is a baseball game played by two teams of nine players each ." True 3 ['Rober', 'to', ' Clement', 'e']
+7 3 The professional sport played by x -1 The professional sport played by Hank Aaron baseball Hank Aaron "[',' ' the' ' first' ' black' ' player' ' to' ' hit' ' a' ' home' ' run'
+ ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' first' ' black']" ", the first black player to hit a home run in the World Series .
+
+ The first black" False Award, he won the Hank Aaron Award, given for 6 [' Award', ',', ' he', ' won', ' the', ' Hank', ' Aaron']
+8 3 The professional sport played by x -1 The professional sport played by Hank Aaron baseball Hank Aaron "[',' ' the' ' first' ' black' ' player' ' to' ' hit' ' a' ' home' ' run'
+ ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' first' ' black']" ", the first black player to hit a home run in the World Series .
+
+ The first black" False Koufax walked Hank Aaron on four pitches to 4 [' Kou', 'fax', ' walked', ' Hank', ' Aaron']
+9 3 The professional sport played by x -1 The professional sport played by Hank Aaron baseball Hank Aaron "[',' ' the' ' first' ' black' ' player' ' to' ' hit' ' a' ' home' ' run'
+ ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' first' ' black']" ", the first black player to hit a home run in the World Series .
+
+ The first black" False been Tony Gwynn. Hank Aaron has been the 7 [' been', ' Tony', ' G', 'wyn', 'n', '.', ' Hank', ' Aaron']
+10 3 The professional sport played by x -1 The professional sport played by Hank Aaron baseball Hank Aaron "[',' ' the' ' first' ' black' ' player' ' to' ' hit' ' a' ' home' ' run'
+ ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' first' ' black']" ", the first black player to hit a home run in the World Series .
+
+ The first black" False and steals since Hank Aaron in 1963. He 4 [' and', ' steals', ' since', ' Hank', ' Aaron']
+11 3 The professional sport played by x -1 The professional sport played by Hank Aaron baseball Hank Aaron "[',' ' the' ' first' ' black' ' player' ' to' ' hit' ' a' ' home' ' run'
+ ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' first' ' black']" ", the first black player to hit a home run in the World Series .
+
+ The first black" False named winners of the Hank Aaron Award for their 5 [' named', ' winners', ' of', ' the', ' Hank', ' Aaron']
+12 4 The professional sport played by x -1 The professional sport played by Jari Kurri hockey Jari Kurri "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Finnish' ' national'
+ ' team' ' from' ' the' ' late' ' 1980' 's' ' to' ' the' ' early' ' 1990'
+ 's']" , who was a member of the Finnish national team from the late 1980 s to the early 1990 s False – 86 season. Finns Jari Kurri and Esa Tikkanen 9 [' –', ' 86', ' season', '.', ' Finn', 's', ' J', 'ari', ' Kur', 'ri']
+13 4 The professional sport played by x -1 The professional sport played by Jari Kurri hockey Jari Kurri "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Finnish' ' national'
+ ' team' ' from' ' the' ' late' ' 1980' 's' ' to' ' the' ' early' ' 1990'
+ 's']" , who was a member of the Finnish national team from the late 1980 s to the early 1990 s False the NHL. He idolized Jari Kurri and was more 9 [' the', ' NHL', '.', ' He', ' idol', 'ized', ' J', 'ari', ' Kur', 'ri']
+14 4 The professional sport played by x -1 The professional sport played by Jari Kurri hockey Jari Kurri "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Finnish' ' national'
+ ' team' ' from' ' the' ' late' ' 1980' 's' ' to' ' the' ' early' ' 1990'
+ 's']" , who was a member of the Finnish national team from the late 1980 s to the early 1990 s False 86 season. Finns Jari Kurri and Esa Tikkanen 8 [' 86', ' season', '.', ' Finn', 's', ' J', 'ari', ' Kur', 'ri']
+15 4 The professional sport played by x -1 The professional sport played by Jari Kurri hockey Jari Kurri "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Finnish' ' national'
+ ' team' ' from' ' the' ' late' ' 1980' 's' ' to' ' the' ' early' ' 1990'
+ 's']" , who was a member of the Finnish national team from the late 1980 s to the early 1990 s False medal for Russia and Jari Kurri won the bronze 7 [' medal', ' for', ' Russia', ' and', ' J', 'ari', ' Kur', 'ri']
+16 4 The professional sport played by x -1 The professional sport played by Jari Kurri hockey Jari Kurri "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Finnish' ' national'
+ ' team' ' from' ' the' ' late' ' 1980' 's' ' to' ' the' ' early' ' 1990'
+ 's']" , who was a member of the Finnish national team from the late 1980 s to the early 1990 s False 1998 – 99 and the Jari Kurri trophy as the 8 [' 1998', ' –', ' 99', ' and', ' the', ' J', 'ari', ' Kur', 'ri']
+17 5 The professional sport played by x -1 The professional sport played by Zdeno Chára hockey Zdeno Chára "['i' ',' ' a' ' former' ' professional' ' basketball' ' player' ',' ' is'
+ ' a' ' popular' ' sport' ' in' ' the' ' Czech' ' Republic' '.' ' It'
+ ' is' ' played']" i , a former professional basketball player , is a popular sport in the Czech Republic . It is played False finalists — along with Zdeno Chára and Nicklas Lidström 9 [' finalists', ' —', ' along', ' with', ' Z', 'den', 'o', ' Ch', 'á', 'ra']
+18 5 The professional sport played by x -1 The professional sport played by Zdeno Chára hockey Zdeno Chára "['i' ',' ' a' ' former' ' professional' ' basketball' ' player' ',' ' is'
+ ' a' ' popular' ' sport' ' in' ' the' ' Czech' ' Republic' '.' ' It'
+ ' is' ' played']" i , a former professional basketball player , is a popular sport in the Czech Republic . It is played False finalists — along with Zdeno Chára and Nicklas Lidström 9 [' finalists', ' —', ' along', ' with', ' Z', 'den', 'o', ' Ch', 'á', 'ra']
+19 5 The professional sport played by x -1 The professional sport played by Zdeno Chára hockey Zdeno Chára "['i' ',' ' a' ' former' ' professional' ' basketball' ' player' ',' ' is'
+ ' a' ' popular' ' sport' ' in' ' the' ' Czech' ' Republic' '.' ' It'
+ ' is' ' played']" i , a former professional basketball player , is a popular sport in the Czech Republic . It is played False finalists were Zdeno Chára and Duncan Keith, 7 [' finalists', ' were', ' Z', 'den', 'o', ' Ch', 'á', 'ra']
+20 7 The professional sport played by x -1 The professional sport played by Peter Šťastný hockey Peter Šťastný "[',' ' a' ' former' ' professional' ' ice' ' hockey' ' player' ',' ' and'
+ ' his' ' wife' ',' ' Mart' 'ina' ' �' '�' '�' '�' 'ast' 'n']" , a former professional ice hockey player , and his wife , Mart ina � � � � ast n True former Nordiques mentor Peter Šťastný [1]), while 12 [' former', ' Nord', 'iques', ' mentor', ' Peter', ' �', '�', '�', '�', 'ast', 'n', '�', '�']
+21 7 The professional sport played by x -1 The professional sport played by Peter Šťastný hockey Peter Šťastný "[',' ' a' ' former' ' professional' ' ice' ' hockey' ' player' ',' ' and'
+ ' his' ' wife' ',' ' Mart' 'ina' ' �' '�' '�' '�' 'ast' 'n']" , a former professional ice hockey player , and his wife , Mart ina � � � � ast n True former Nordiques mentor Peter Šťastný [1]), while also being 12 [' former', ' Nord', 'iques', ' mentor', ' Peter', ' �', '�', '�', '�', 'ast', 'n', '�', '�']
+22 7 The professional sport played by x -1 The professional sport played by Peter Šťastný hockey Peter Šťastný "[',' ' a' ' former' ' professional' ' ice' ' hockey' ' player' ',' ' and'
+ ' his' ' wife' ',' ' Mart' 'ina' ' �' '�' '�' '�' 'ast' 'n']" , a former professional ice hockey player , and his wife , Mart ina � � � � ast n True former Nordiques mentor Peter Šťastný [1]), while also 12 [' former', ' Nord', 'iques', ' mentor', ' Peter', ' �', '�', '�', '�', 'ast', 'n', '�', '�']
+23 7 The professional sport played by x -1 The professional sport played by Peter Šťastný hockey Peter Šťastný "[',' ' a' ' former' ' professional' ' ice' ' hockey' ' player' ',' ' and'
+ ' his' ' wife' ',' ' Mart' 'ina' ' �' '�' '�' '�' 'ast' 'n']" , a former professional ice hockey player , and his wife , Mart ina � � � � ast n True franchise, passing Peter Šťastný as the all-time 11 [' franchise', ',', ' passing', ' Peter', ' �', '�', '�', '�', 'ast', 'n', '�', '�']
+24 7 The professional sport played by x -1 The professional sport played by Peter Šťastný hockey Peter Šťastný "[',' ' a' ' former' ' professional' ' ice' ' hockey' ' player' ',' ' and'
+ ' his' ' wife' ',' ' Mart' 'ina' ' �' '�' '�' '�' 'ast' 'n']" , a former professional ice hockey player , and his wife , Mart ina � � � � ast n True Nordiques mentor Peter Šťastný [1]), while also 11 [' Nord', 'iques', ' mentor', ' Peter', ' �', '�', '�', '�', 'ast', 'n', '�', '�']
+25 8 The professional sport played by x -1 The professional sport played by R.A. Dickey baseball R.A. Dickey "[',' ' the' ' former' ' New' ' York' ' Mets' ' pitcher' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' be']" , the former New York Mets pitcher , is a great example of the kind of person who can be False Knuckleball pitcher R.A. Dickey won the Pacific 9 [' Kn', 'uckle', 'ball', ' pitcher', ' R', '.', 'A', '.', ' D', 'ickey']
+26 8 The professional sport played by x -1 The professional sport played by R.A. Dickey baseball R.A. Dickey "[',' ' the' ' former' ' New' ' York' ' Mets' ' pitcher' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' be']" , the former New York Mets pitcher , is a great example of the kind of person who can be False fifth (tied with R.A. Dickey and Adam Wainwright) 10 [' fifth', ' (', 't', 'ied', ' with', ' R', '.', 'A', '.', ' D', 'ickey']
+27 8 The professional sport played by x -1 The professional sport played by R.A. Dickey baseball R.A. Dickey "[',' ' the' ' former' ' New' ' York' ' Mets' ' pitcher' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' be']" , the former New York Mets pitcher , is a great example of the kind of person who can be False Valuable Pitcher Award. R.A. Dickey (2007), Johnny 11 [' Val', 'uable', ' Pit', 'cher', ' Award', '.', ' R', '.', 'A', '.', ' D', 'ickey']
+28 8 The professional sport played by x -1 The professional sport played by R.A. Dickey baseball R.A. Dickey "[',' ' the' ' former' ' New' ' York' ' Mets' ' pitcher' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' be']" , the former New York Mets pitcher , is a great example of the kind of person who can be False Pitcher Award. R.A. Dickey (2007), Johnny 9 [' Pit', 'cher', ' Award', '.', ' R', '.', 'A', '.', ' D', 'ickey']
+29 8 The professional sport played by x -1 The professional sport played by R.A. Dickey baseball R.A. Dickey "[',' ' the' ' former' ' New' ' York' ' Mets' ' pitcher' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' be']" , the former New York Mets pitcher , is a great example of the kind of person who can be False fifth (tied with R.A. Dickey and Adam Wainwright) 10 [' fifth', ' (', 't', 'ied', ' with', ' R', '.', 'A', '.', ' D', 'ickey']
+30 9 The professional sport played by x -1 The professional sport played by Luis Scola basketball Luis Scola "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NBA']" , who is a former NBA player , is a great example of the kind of player that the NBA False veteran forward Luis Scola using their one-time 4 [' veteran', ' forward', ' Luis', ' Sc', 'ola']
+31 9 The professional sport played by x -1 The professional sport played by Luis Scola basketball Luis Scola "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NBA']" , who is a former NBA player , is a great example of the kind of player that the NBA False veteran forward Luis Scola using their one-time 4 [' veteran', ' forward', ' Luis', ' Sc', 'ola']
+32 9 The professional sport played by x -1 The professional sport played by Luis Scola basketball Luis Scola "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NBA']" , who is a former NBA player , is a great example of the kind of player that the NBA False veteran forward Luis Scola using their one-time 4 [' veteran', ' forward', ' Luis', ' Sc', 'ola']
+33 9 The professional sport played by x -1 The professional sport played by Luis Scola basketball Luis Scola "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NBA']" , who is a former NBA player , is a great example of the kind of player that the NBA False " Biyombo, and Luis Scola via free agency.
+" 7 [' B', 'iy', 'ombo', ',', ' and', ' Luis', ' Sc', 'ola']
+34 9 The professional sport played by x -1 The professional sport played by Luis Scola basketball Luis Scola "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NBA']" , who is a former NBA player , is a great example of the kind of player that the NBA False veteran forward Luis Scola using their one-time 4 [' veteran', ' forward', ' Luis', ' Sc', 'ola']
+35 10 The professional sport played by x -1 The professional sport played by Satchel Paige baseball Satchel Paige "[',' ' the' ' Negro' ' Le' 'agues' ',' ' and' ' the' ' Negro' ' National'
+ ' League' '.' '\n' '\n' 'The' ' Negro' ' Le' 'agues' ' were' ' the']" ", the Negro Le agues , and the Negro National League .
+
+ The Negro Le agues were the" False and black talents Satchel Paige and Josh Gibson 5 [' and', ' black', ' talents', ' Sat', 'chel', ' Paige']
+36 10 The professional sport played by x -1 The professional sport played by Satchel Paige baseball Satchel Paige "[',' ' the' ' Negro' ' Le' 'agues' ',' ' and' ' the' ' Negro' ' National'
+ ' League' '.' '\n' '\n' 'The' ' Negro' ' Le' 'agues' ' were' ' the']" ", the Negro Le agues , and the Negro National League .
+
+ The Negro Le agues were the" False league legend Satchel Paige in a road game 4 [' league', ' legend', ' Sat', 'chel', ' Paige']
+37 10 The professional sport played by x -1 The professional sport played by Satchel Paige baseball Satchel Paige "[',' ' the' ' Negro' ' Le' 'agues' ',' ' and' ' the' ' Negro' ' National'
+ ' League' '.' '\n' '\n' 'The' ' Negro' ' Le' 'agues' ' were' ' the']" ", the Negro Le agues , and the Negro National League .
+
+ The Negro Le agues were the" False Doby and teammate Satchel Paige were the first African-American 6 [' D', 'oby', ' and', ' teammate', ' Sat', 'chel', ' Paige']
+38 10 The professional sport played by x -1 The professional sport played by Satchel Paige baseball Satchel Paige "[',' ' the' ' Negro' ' Le' 'agues' ',' ' and' ' the' ' Negro' ' National'
+ ' League' '.' '\n' '\n' 'The' ' Negro' ' Le' 'agues' ' were' ' the']" ", the Negro Le agues , and the Negro National League .
+
+ The Negro Le agues were the" False and black talents Satchel Paige and Josh Gibson were 5 [' and', ' black', ' talents', ' Sat', 'chel', ' Paige']
+39 10 The professional sport played by x -1 The professional sport played by Satchel Paige baseball Satchel Paige "[',' ' the' ' Negro' ' Le' 'agues' ',' ' and' ' the' ' Negro' ' National'
+ ' League' '.' '\n' '\n' 'The' ' Negro' ' Le' 'agues' ' were' ' the']" ", the Negro Le agues , and the Negro National League .
+
+ The Negro Le agues were the" False and teammate Satchel Paige were the first African-American 4 [' and', ' teammate', ' Sat', 'chel', ' Paige']
+40 11 The professional sport played by x -1 The professional sport played by Elton Brand basketball Elton Brand "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' '.' '\n' '\n' 'The' ' NBA' ' is' ' a']" ", who is a former NBA player , is a former NBA player .
+
+ The NBA is a" False Mayo, Chris Kaman, Elton Brand and Darren Collison 8 [' Mayo', ',', ' Chris', ' K', 'aman', ',', ' El', 'ton', ' Brand']
+41 11 The professional sport played by x -1 The professional sport played by Elton Brand basketball Elton Brand "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' '.' '\n' '\n' 'The' ' NBA' ' is' ' a']" ", who is a former NBA player , is a former NBA player .
+
+ The NBA is a" False Mayo, Chris Kaman, Elton Brand and Darren 8 [' Mayo', ',', ' Chris', ' K', 'aman', ',', ' El', 'ton', ' Brand']
+42 11 The professional sport played by x -1 The professional sport played by Elton Brand basketball Elton Brand "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' '.' '\n' '\n' 'The' ' NBA' ' is' ' a']" ", who is a former NBA player , is a former NBA player .
+
+ The NBA is a" False Mayo, Chris Kaman, Elton Brand and Darren Collison 8 [' Mayo', ',', ' Chris', ' K', 'aman', ',', ' El', 'ton', ' Brand']
+43 11 The professional sport played by x -1 The professional sport played by Elton Brand basketball Elton Brand "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' '.' '\n' '\n' 'The' ' NBA' ' is' ' a']" ", who is a former NBA player , is a former NBA player .
+
+ The NBA is a" False Chris Kaman, Elton Brand and Darren 6 [' Chris', ' K', 'aman', ',', ' El', 'ton', ' Brand']
+44 12 The professional sport played by x -1 The professional sport played by Terrell Owens football Terrell Owens "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys'
+ ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'The' ' Cowboys' ' were'
+ ' the']" ", who was a member of the Dallas Cowboys from 1995 to 2004 .
+
+ The Cowboys were the" False of wide receiver Terrell Owens and the departure 5 [' of', ' wide', ' receiver', ' Ter', 'rell', ' Owens']
+45 12 The professional sport played by x -1 The professional sport played by Terrell Owens football Terrell Owens "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys'
+ ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'The' ' Cowboys' ' were'
+ ' the']" ", who was a member of the Dallas Cowboys from 1995 to 2004 .
+
+ The Cowboys were the" False presence of Terrell Owens in 2004 led 4 [' presence', ' of', ' Ter', 'rell', ' Owens']
+46 12 The professional sport played by x -1 The professional sport played by Terrell Owens football Terrell Owens "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys'
+ ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'The' ' Cowboys' ' were'
+ ' the']" ", who was a member of the Dallas Cowboys from 1995 to 2004 .
+
+ The Cowboys were the" False The presence of Terrell Owens in 2004 led to limited 5 [' The', ' presence', ' of', ' Ter', 'rell', ' Owens']
+47 12 The professional sport played by x -1 The professional sport played by Terrell Owens football Terrell Owens "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys'
+ ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'The' ' Cowboys' ' were'
+ ' the']" ", who was a member of the Dallas Cowboys from 1995 to 2004 .
+
+ The Cowboys were the" False The presence of Terrell Owens in 2004 led to limited 5 [' The', ' presence', ' of', ' Ter', 'rell', ' Owens']
+48 12 The professional sport played by x -1 The professional sport played by Terrell Owens football Terrell Owens "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys'
+ ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'The' ' Cowboys' ' were'
+ ' the']" ", who was a member of the Dallas Cowboys from 1995 to 2004 .
+
+ The Cowboys were the" False Eagles'acquisition of wide receiver Terrell Owens and the departure 9 "[' Eagles', ""'"", 'acqu', 'isition', ' of', ' wide', ' receiver', ' Ter', 'rell', ' Owens']"
+49 14 The professional sport played by x -1 The professional sport played by Tom Harmon football Tom Harmon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' former' ' NFL' ' quarterback' '.' '\n' '\n' 'The' ' NFL' ' is' ' a']" ", who is a former NFL quarterback , is a former NFL quarterback .
+
+ The NFL is a" False Forest Evashevski and Tom Harmon joining the varsity 8 [' Forest', ' Ev', 'as', 'he', 'vs', 'ki', ' and', ' Tom', ' Harmon']
+50 14 The professional sport played by x -1 The professional sport played by Tom Harmon football Tom Harmon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' former' ' NFL' ' quarterback' '.' '\n' '\n' 'The' ' NFL' ' is' ' a']" ", who is a former NFL quarterback , is a former NFL quarterback .
+
+ The NFL is a" False Trophy winner Tom Harmon and sportscasters 3 [' Trophy', ' winner', ' Tom', ' Harmon']
+51 14 The professional sport played by x -1 The professional sport played by Tom Harmon football Tom Harmon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' former' ' NFL' ' quarterback' '.' '\n' '\n' 'The' ' NFL' ' is' ' a']" ", who is a former NFL quarterback , is a former NFL quarterback .
+
+ The NFL is a" False the first since Tom Harmon to have two return 4 [' the', ' first', ' since', ' Tom', ' Harmon']
+52 14 The professional sport played by x -1 The professional sport played by Tom Harmon football Tom Harmon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' former' ' NFL' ' quarterback' '.' '\n' '\n' 'The' ' NFL' ' is' ' a']" ", who is a former NFL quarterback , is a former NFL quarterback .
+
+ The NFL is a" False and the first since Tom Harmon to have two return 5 [' and', ' the', ' first', ' since', ' Tom', ' Harmon']
+53 14 The professional sport played by x -1 The professional sport played by Tom Harmon football Tom Harmon "[',' ' who' ' is' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a'
+ ' former' ' NFL' ' quarterback' '.' '\n' '\n' 'The' ' NFL' ' is' ' a']" ", who is a former NFL quarterback , is a former NFL quarterback .
+
+ The NFL is a" False and team star Tom Harmon also stopped to visit 4 [' and', ' team', ' star', ' Tom', ' Harmon']
+54 15 The professional sport played by x -1 The professional sport played by Mariano Rivera baseball Mariano Rivera "[',' ' the' ' greatest' ' closer' ' in' ' baseball' ' history' '.' '\n'
+ '\n' 'The' ' Yankees' ' have' ' been' ' the' ' best' ' team' ' in'
+ ' baseball' ' for']" ", the greatest closer in baseball history .
+
+ The Yankees have been the best team in baseball for" True " Yankees closer Mariano Rivera remarked, ""If" 4 [' Yankees', ' closer', ' Mar', 'iano', ' Rivera']
+55 15 The professional sport played by x -1 The professional sport played by Mariano Rivera baseball Mariano Rivera "[',' ' the' ' greatest' ' closer' ' in' ' baseball' ' history' '.' '\n'
+ '\n' 'The' ' Yankees' ' have' ' been' ' the' ' best' ' team' ' in'
+ ' baseball' ' for']" ", the greatest closer in baseball history .
+
+ The Yankees have been the best team in baseball for" True Posada, and Mariano Rivera — that became 6 [' Pos', 'ada', ',', ' and', ' Mar', 'iano', ' Rivera']
+56 15 The professional sport played by x -1 The professional sport played by Mariano Rivera baseball Mariano Rivera "[',' ' the' ' greatest' ' closer' ' in' ' baseball' ' history' '.' '\n'
+ '\n' 'The' ' Yankees' ' have' ' been' ' the' ' best' ' team' ' in'
+ ' baseball' ' for']" ", the greatest closer in baseball history .
+
+ The Yankees have been the best team in baseball for" True end the inning. Mariano Rivera entered in the 6 [' end', ' the', ' inning', '.', ' Mar', 'iano', ' Rivera']
+57 15 The professional sport played by x -1 The professional sport played by Mariano Rivera baseball Mariano Rivera "[',' ' the' ' greatest' ' closer' ' in' ' baseball' ' history' '.' '\n'
+ '\n' 'The' ' Yankees' ' have' ' been' ' the' ' best' ' team' ' in'
+ ' baseball' ' for']" ", the greatest closer in baseball history .
+
+ The Yankees have been the best team in baseball for" True 3 ['M', 'arian', 'o', ' Rivera']
+58 15 The professional sport played by x -1 The professional sport played by Mariano Rivera baseball Mariano Rivera "[',' ' the' ' greatest' ' closer' ' in' ' baseball' ' history' '.' '\n'
+ '\n' 'The' ' Yankees' ' have' ' been' ' the' ' best' ' team' ' in'
+ ' baseball' ' for']" ", the greatest closer in baseball history .
+
+ The Yankees have been the best team in baseball for" True 2014, former Yankee Mariano Rivera debuted his new 6 [' 2014', ',', ' former', ' Yankee', ' Mar', 'iano', ' Rivera']
+59 17 The professional sport played by x -1 The professional sport played by Gary Carter baseball Gary Carter "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Ex' 'pos' ','
+ ' and' ' the' ' first' ' player' ' to' ' hit' ' a' ' home' ' run']" , who was a member of the Montreal Ex pos , and the first player to hit a home run False ahead of the Mets' Gary Carter and Keith Hernandez 6 "[' ahead', ' of', ' the', ' Mets', ""'"", ' Gary', ' Carter']"
+60 17 The professional sport played by x -1 The professional sport played by Gary Carter baseball Gary Carter "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Ex' 'pos' ','
+ ' and' ' the' ' first' ' player' ' to' ' hit' ' a' ' home' ' run']" , who was a member of the Montreal Ex pos , and the first player to hit a home run False ahead of the Mets' Gary Carter and Keith Hernandez 6 "[' ahead', ' of', ' the', ' Mets', ""'"", ' Gary', ' Carter']"
+61 17 The professional sport played by x -1 The professional sport played by Gary Carter baseball Gary Carter "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Ex' 'pos' ','
+ ' and' ' the' ' first' ' player' ' to' ' hit' ' a' ' home' ' run']" , who was a member of the Montreal Ex pos , and the first player to hit a home run False of the Mets' Gary Carter and Keith Hernandez 5 "[' of', ' the', ' Mets', ""'"", ' Gary', ' Carter']"
+62 18 The professional sport played by x -1 The professional sport played by Igor Larionov hockey Igor Larionov "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Soviet' ' Union'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' at' ' the'
+ ' 1980' ' Winter']" , who was a member of the Soviet Union national team that won the gold medal at the 1980 Winter False Krutov and Igor Larionov played for the 7 [' Kr', 'ut', 'ov', ' and', ' Igor', ' Lar', 'ion', 'ov']
+63 18 The professional sport played by x -1 The professional sport played by Igor Larionov hockey Igor Larionov "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Soviet' ' Union'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' at' ' the'
+ ' 1980' ' Winter']" , who was a member of the Soviet Union national team that won the gold medal at the 1980 Winter False Vladimir Krutov and Igor Larionov played for the 8 [' Vladimir', ' Kr', 'ut', 'ov', ' and', ' Igor', ' Lar', 'ion', 'ov']
+64 18 The professional sport played by x -1 The professional sport played by Igor Larionov hockey Igor Larionov "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Soviet' ' Union'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' at' ' the'
+ ' 1980' ' Winter']" , who was a member of the Soviet Union national team that won the gold medal at the 1980 Winter False by fellow Soviets Igor Larionov and Anatoli Semenov 6 [' by', ' fellow', ' Soviets', ' Igor', ' Lar', 'ion', 'ov']
+65 18 The professional sport played by x -1 The professional sport played by Igor Larionov hockey Igor Larionov "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Soviet' ' Union'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' at' ' the'
+ ' 1980' ' Winter']" , who was a member of the Soviet Union national team that won the gold medal at the 1980 Winter False Vladimir Krutov, Igor Larionov and Sergei Makarov 8 [' Vladimir', ' Kr', 'ut', 'ov', ',', ' Igor', ' Lar', 'ion', 'ov']
+66 18 The professional sport played by x -1 The professional sport played by Igor Larionov hockey Igor Larionov "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Soviet' ' Union'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' at' ' the'
+ ' 1980' ' Winter']" , who was a member of the Soviet Union national team that won the gold medal at the 1980 Winter False players, including Igor Larionov and Viacheslav 6 [' players', ',', ' including', ' Igor', ' Lar', 'ion', 'ov']
+67 19 The professional sport played by x -1 The professional sport played by Ashley Cole soccer Ashley Cole "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False and teammate Ashley Cole received abuse from 3 [' and', ' teammate', ' Ashley', ' Cole']
+68 19 The professional sport played by x -1 The professional sport played by Ashley Cole soccer Ashley Cole "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False and a minute later Ashley Cole received a booking 5 [' and', ' a', ' minute', ' later', ' Ashley', ' Cole']
+69 19 The professional sport played by x -1 The professional sport played by Ashley Cole soccer Ashley Cole "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False after defender Ashley Cole scored via a 3 [' after', ' defender', ' Ashley', ' Cole']
+70 19 The professional sport played by x -1 The professional sport played by Ashley Cole soccer Ashley Cole "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False was later dropped. Ashley Cole was not suspended 5 [' was', ' later', ' dropped', '.', ' Ashley', ' Cole']
+71 19 The professional sport played by x -1 The professional sport played by Ashley Cole soccer Ashley Cole "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False " ""tapping-up"" Arsenal defender Ashley Cole in January, and" 9 "[' ""', 't', 'apping', '-', 'up', '""', ' Arsenal', ' defender', ' Ashley', ' Cole']"
+72 20 The professional sport played by x -1 The professional sport played by Dennis Rodman basketball Dennis Rodman "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' big' ' fan' ' of'
+ ' the' ' sport' '.' ' He' ' has' ' been' ' a' ' fan' ' of']" , the former NBA star , is a big fan of the sport . He has been a fan of False Rebound: The Dennis Rodman Story. ISBN 6 [' Re', 'bound', ':', ' The', ' Dennis', ' Rod', 'man']
+73 20 The professional sport played by x -1 The professional sport played by Dennis Rodman basketball Dennis Rodman "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' big' ' fan' ' of'
+ ' the' ' sport' '.' ' He' ' has' ' been' ' a' ' fan' ' of']" , the former NBA star , is a big fan of the sport . He has been a fan of False Scottie Pippen and Dennis Rodman looming, and being 8 [' Scott', 'ie', ' P', 'ipp', 'en', ' and', ' Dennis', ' Rod', 'man']
+74 20 The professional sport played by x -1 The professional sport played by Dennis Rodman basketball Dennis Rodman "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' big' ' fan' ' of'
+ ' the' ' sport' '.' ' He' ' has' ' been' ' a' ' fan' ' of']" , the former NBA star , is a big fan of the sport . He has been a fan of False that was leading Dennis Rodman to a place he didn 5 [' that', ' was', ' leading', ' Dennis', ' Rod', 'man']
+75 20 The professional sport played by x -1 The professional sport played by Dennis Rodman basketball Dennis Rodman "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' big' ' fan' ' of'
+ ' the' ' sport' '.' ' He' ' has' ' been' ' a' ' fan' ' of']" , the former NBA star , is a big fan of the sport . He has been a fan of False former NBA star Dennis Rodman visited North Korea, 5 [' former', ' NBA', ' star', ' Dennis', ' Rod', 'man']
+76 20 The professional sport played by x -1 The professional sport played by Dennis Rodman basketball Dennis Rodman "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' big' ' fan' ' of'
+ ' the' ' sport' '.' ' He' ' has' ' been' ' a' ' fan' ' of']" , the former NBA star , is a big fan of the sport . He has been a fan of False former NBA star Dennis Rodman visited North Korea, 5 [' former', ' NBA', ' star', ' Dennis', ' Rod', 'man']
+77 22 The professional sport played by x -1 The professional sport played by Marián Hossa hockey Marián Hossa "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Czech'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 15' ',' ' has' ' been'
+ ' a']" , who has been a member of the Czech national team since the age of 15 , has been a False acquired star winger Marián Hossa from the Atlanta 7 [' acquired', ' star', ' winger', ' Mari', 'án', ' H', 'oss', 'a']
+78 23 The professional sport played by x -1 The professional sport played by Moe Berg baseball Moe Berg "[',' ' the' ' former' ' Major' ' League' ' Baseball' ' pitcher' ',' ' was'
+ ' a' ' great' ' American' ' hero' '.' ' He' ' was' ' a' ' great'
+ ' American' ' hero']" , the former Major League Baseball pitcher , was a great American hero . He was a great American hero False " Berg =
+" 4 [' Berg', ' =', 'M', 'oe', ' Berg']
+79 23 The professional sport played by x -1 The professional sport played by Moe Berg baseball Moe Berg "[',' ' the' ' former' ' Major' ' League' ' Baseball' ' pitcher' ',' ' was'
+ ' a' ' great' ' American' ' hero' '.' ' He' ' was' ' a' ' great'
+ ' American' ' hero']" , the former Major League Baseball pitcher , was a great American hero . He was a great American hero False 2 ['M', 'oe', ' Berg']
+80 23 The professional sport played by x -1 The professional sport played by Moe Berg baseball Moe Berg "[',' ' the' ' former' ' Major' ' League' ' Baseball' ' pitcher' ',' ' was'
+ ' a' ' great' ' American' ' hero' '.' ' He' ' was' ' a' ' great'
+ ' American' ' hero']" , the former Major League Baseball pitcher , was a great American hero . He was a great American hero False " = Moe Berg =
+" 2 [' =', ' Moe', ' Berg']
+81 23 The professional sport played by x -1 The professional sport played by Moe Berg baseball Moe Berg "[',' ' the' ' former' ' Major' ' League' ' Baseball' ' pitcher' ',' ' was'
+ ' a' ' great' ' American' ' hero' '.' ' He' ' was' ' a' ' great'
+ ' American' ' hero']" , the former Major League Baseball pitcher , was a great American hero . He was a great American hero False 2 ['M', 'oe', ' Berg']
+82 23 The professional sport played by x -1 The professional sport played by Moe Berg baseball Moe Berg "[',' ' the' ' former' ' Major' ' League' ' Baseball' ' pitcher' ',' ' was'
+ ' a' ' great' ' American' ' hero' '.' ' He' ' was' ' a' ' great'
+ ' American' ' hero']" , the former Major League Baseball pitcher , was a great American hero . He was a great American hero False sent the spy Moe Berg to Switzerland 4 [' sent', ' the', ' spy', ' Moe', ' Berg']
+83 24 The professional sport played by x -1 The professional sport played by Arvydas Sabonis basketball Arvydas Sabonis "[',' ' a' ' Lithuan' 'ian' ' basketball' ' player' ',' ' is' ' a' ' game'
+ ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' five' ' players'
+ ' each']" , a Lithuan ian basketball player , is a game that is played by two teams of five players each True " Kersey, Terry Porter and Arvydas Sabonis in the 1980s.
+" 11 [' Ker', 'sey', ',', ' Terry', ' Porter', ' and', ' Ar', 'v', 'yd', 'as', ' Sab', 'onis']
+84 24 The professional sport played by x -1 The professional sport played by Arvydas Sabonis basketball Arvydas Sabonis "[',' ' a' ' Lithuan' 'ian' ' basketball' ' player' ',' ' is' ' a' ' game'
+ ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' five' ' players'
+ ' each']" , a Lithuan ian basketball player , is a game that is played by two teams of five players each True " Terry Porter and Arvydas Sabonis in the 1980s.
+" 8 [' Terry', ' Porter', ' and', ' Ar', 'v', 'yd', 'as', ' Sab', 'onis']
+85 24 The professional sport played by x -1 The professional sport played by Arvydas Sabonis basketball Arvydas Sabonis "[',' ' a' ' Lithuan' 'ian' ' basketball' ' player' ',' ' is' ' a' ' game'
+ ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' five' ' players'
+ ' each']" , a Lithuan ian basketball player , is a game that is played by two teams of five players each True " Terry Porter and Arvydas Sabonis in the 1980s.
+" 8 [' Terry', ' Porter', ' and', ' Ar', 'v', 'yd', 'as', ' Sab', 'onis']
+86 24 The professional sport played by x -1 The professional sport played by Arvydas Sabonis basketball Arvydas Sabonis "[',' ' a' ' Lithuan' 'ian' ' basketball' ' player' ',' ' is' ' a' ' game'
+ ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' five' ' players'
+ ' each']" , a Lithuan ian basketball player , is a game that is played by two teams of five players each True " Terry Porter and Arvydas Sabonis in the 1980s.
+" 8 [' Terry', ' Porter', ' and', ' Ar', 'v', 'yd', 'as', ' Sab', 'onis']
+87 25 The professional sport played by x -1 The professional sport played by Willie Mays baseball Willie Mays "[',' ' the' ' greatest' ' baseball' ' player' ' of' ' all' ' time' '.'
+ '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only']" ", the greatest baseball player of all time .
+
+ I 'm not sure if I 'm the only" True After the third out, Willie Mays walked to 7 [' After', ' the', ' third', ' out', ',', ' Willie', ' M', 'ays']
+88 25 The professional sport played by x -1 The professional sport played by Willie Mays baseball Willie Mays "[',' ' the' ' greatest' ' baseball' ' player' ' of' ' all' ' time' '.'
+ '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only']" ", the greatest baseball player of all time .
+
+ I 'm not sure if I 'm the only" True performance. Willie Mays walked to open 4 [' performance', '.', ' Willie', ' M', 'ays']
+89 25 The professional sport played by x -1 The professional sport played by Willie Mays baseball Willie Mays "[',' ' the' ' greatest' ' baseball' ' player' ' of' ' all' ' time' '.'
+ '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only']" ", the greatest baseball player of all time .
+
+ I 'm not sure if I 'm the only" True Hall of Famer Willie Mays and Jim Davenport. 6 [' Hall', ' of', ' F', 'amer', ' Willie', ' M', 'ays']
+90 25 The professional sport played by x -1 The professional sport played by Willie Mays baseball Willie Mays "[',' ' the' ' greatest' ' baseball' ' player' ' of' ' all' ' time' '.'
+ '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only']" ", the greatest baseball player of all time .
+
+ I 'm not sure if I 'm the only" True list, surpassing Willie Mays by beating out 6 [' list', ',', ' surpass', 'ing', ' Willie', ' M', 'ays']
+91 25 The professional sport played by x -1 The professional sport played by Willie Mays baseball Willie Mays "[',' ' the' ' greatest' ' baseball' ' player' ' of' ' all' ' time' '.'
+ '\n' '\n' 'I' ""'m"" ' not' ' sure' ' if' ' I' ""'m"" ' the' ' only']" ", the greatest baseball player of all time .
+
+ I 'm not sure if I 'm the only" True third out, Willie Mays walked to 5 [' third', ' out', ',', ' Willie', ' M', 'ays']
+92 26 The professional sport played by x -1 The professional sport played by Landon Donovan soccer Landon Donovan "[',' ' the' ' former' ' U' '.' 'S' '.' ' national' ' team' ' star' ','
+ ' is' ' a' ' big' ' fan' ' of' ' the' ' sport' '.' '\n']" ", the former U . S . national team star , is a big fan of the sport .
+" False goal scored by Landon Donovan for the US 5 [' goal', ' scored', ' by', ' L', 'andon', ' Donovan']
+93 26 The professional sport played by x -1 The professional sport played by Landon Donovan soccer Landon Donovan "[',' ' the' ' former' ' U' '.' 'S' '.' ' national' ' team' ' star' ','
+ ' is' ' a' ' big' ' fan' ' of' ' the' ' sport' '.' '\n']" ", the former U . S . national team star , is a big fan of the sport .
+" False goal scored by Landon Donovan for the US National 5 [' goal', ' scored', ' by', ' L', 'andon', ' Donovan']
+94 27 The professional sport played by x -1 The professional sport played by Tim Tebow football Tim Tebow "[',' ' the' ' quarterback' ' for' ' the' ' Denver' ' Broncos' ',' ' is'
+ ' a' ' Christian' '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' multi' '-']" ", the quarterback for the Denver Broncos , is a Christian .
+
+ The NFL is a multi -" False NFL quarterback Tim Tebow and professional 4 [' NFL', ' quarterback', ' Tim', ' Te', 'bow']
+95 27 The professional sport played by x -1 The professional sport played by Tim Tebow football Tim Tebow "[',' ' the' ' quarterback' ' for' ' the' ' Denver' ' Broncos' ',' ' is'
+ ' a' ' Christian' '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' multi' '-']" ", the quarterback for the Denver Broncos , is a Christian .
+
+ The NFL is a multi -" False (1st overall), Tim Tebow (25th overall), 7 [' (', '1', 'st', ' overall', '),', ' Tim', ' Te', 'bow']
+96 27 The professional sport played by x -1 The professional sport played by Tim Tebow football Tim Tebow "[',' ' the' ' quarterback' ' for' ' the' ' Denver' ' Broncos' ',' ' is'
+ ' a' ' Christian' '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' multi' '-']" ", the quarterback for the Denver Broncos , is a Christian .
+
+ The NFL is a multi -" False Texas'offense. Much like Tim Tebow last season, 9 "[' Texas', ""'"", 'off', 'ense', '.', ' Much', ' like', ' Tim', ' Te', 'bow']"
+97 27 The professional sport played by x -1 The professional sport played by Tim Tebow football Tim Tebow "[',' ' the' ' quarterback' ' for' ' the' ' Denver' ' Broncos' ',' ' is'
+ ' a' ' Christian' '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' multi' '-']" ", the quarterback for the Denver Broncos , is a Christian .
+
+ The NFL is a multi -" False Florida quarterback Tim Tebow — and 17 touchdowns 4 [' Florida', ' quarterback', ' Tim', ' Te', 'bow']
+98 27 The professional sport played by x -1 The professional sport played by Tim Tebow football Tim Tebow "[',' ' the' ' quarterback' ' for' ' the' ' Denver' ' Broncos' ',' ' is'
+ ' a' ' Christian' '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' multi' '-']" ", the quarterback for the Denver Broncos , is a Christian .
+
+ The NFL is a multi -" False The presence of Tim Tebow created a controversy 5 [' The', ' presence', ' of', ' Tim', ' Te', 'bow']
+99 28 The professional sport played by x -1 The professional sport played by David Beckham soccer David Beckham "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' he' ' is' '.' ' He'
+ ' is']" , the former England captain , is a great example of the kind of player he is . He is False the reasons David Beckham and Robbie Keane 3 [' the', ' reasons', ' David', ' Beckham']
+100 28 The professional sport played by x -1 The professional sport played by David Beckham soccer David Beckham "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' he' ' is' '.' ' He'
+ ' is']" , the former England captain , is a great example of the kind of player he is . He is False first-half lead. David Beckham was runner 6 [' first', '-', 'half', ' lead', '.', ' David', ' Beckham']
+101 28 The professional sport played by x -1 The professional sport played by David Beckham soccer David Beckham "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' he' ' is' '.' ' He'
+ ' is']" , the former England captain , is a great example of the kind of player he is . He is False changing room that hit David Beckham above the left 5 [' changing', ' room', ' that', ' hit', ' David', ' Beckham']
+102 28 The professional sport played by x -1 The professional sport played by David Beckham soccer David Beckham "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' he' ' is' '.' ' He'
+ ' is']" , the former England captain , is a great example of the kind of player he is . He is False when midfielder David Beckham scored from a 3 [' when', ' midfielder', ' David', ' Beckham']
+103 28 The professional sport played by x -1 The professional sport played by David Beckham soccer David Beckham "[',' ' the' ' former' ' England' ' captain' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' player' ' he' ' is' '.' ' He'
+ ' is']" , the former England captain , is a great example of the kind of player he is . He is False 1 ['David', ' Beckham']
+104 29 The professional sport played by x -1 The professional sport played by Magic Johnson basketball Magic Johnson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' he']" , the NBA 's all - time leading scorer , is a great example of the kind of person he False able to humble rookie Magic Johnson in practice, 5 [' able', ' to', ' humble', ' rookie', ' Magic', ' Johnson']
+105 29 The professional sport played by x -1 The professional sport played by Magic Johnson basketball Magic Johnson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' he']" , the NBA 's all - time leading scorer , is a great example of the kind of person he False Johnson runs Magic Johnson Enterprises, 3 [' Johnson', ' runs', ' Magic', ' Johnson']
+106 29 The professional sport played by x -1 The professional sport played by Magic Johnson basketball Magic Johnson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' he']" , the NBA 's all - time leading scorer , is a great example of the kind of person he False sports team. While Magic Johnson is considered 5 [' sports', ' team', '.', ' While', ' Magic', ' Johnson']
+107 29 The professional sport played by x -1 The professional sport played by Magic Johnson basketball Magic Johnson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' he']" , the NBA 's all - time leading scorer , is a great example of the kind of person he False franchise history to face Magic Johnson and James 5 [' franchise', ' history', ' to', ' face', ' Magic', ' Johnson']
+108 29 The professional sport played by x -1 The professional sport played by Magic Johnson basketball Magic Johnson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' he']" , the NBA 's all - time leading scorer , is a great example of the kind of person he False " Johnson Award: 2014
+" 5 [' Johnson', ' Award', ':', ' 2014', 'Magic', ' Johnson']
+109 31 The professional sport played by x -1 The professional sport played by Andrea Pirlo soccer Andrea Pirlo "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Italian'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been'
+ ' a']" , who has been a regular in the Italian national team since the age of 16 , has been a False free-kick taken by Andrea Pirlo deflected off the 8 [' free', '-', 'kick', ' taken', ' by', ' Andrea', ' P', 'irl', 'o']
+110 31 The professional sport played by x -1 The professional sport played by Andrea Pirlo soccer Andrea Pirlo "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Italian'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been'
+ ' a']" , who has been a regular in the Italian national team since the age of 16 , has been a False minute for Andrea Pirlo as Milan won the match 5 [' minute', ' for', ' Andrea', ' P', 'irl', 'o']
+111 31 The professional sport played by x -1 The professional sport played by Andrea Pirlo soccer Andrea Pirlo "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Italian'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been'
+ ' a']" , who has been a regular in the Italian national team since the age of 16 , has been a False free-kick taken by Andrea Pirlo deflected off 8 [' free', '-', 'kick', ' taken', ' by', ' Andrea', ' P', 'irl', 'o']
+112 31 The professional sport played by x -1 The professional sport played by Andrea Pirlo soccer Andrea Pirlo "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Italian'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been'
+ ' a']" , who has been a regular in the Italian national team since the age of 16 , has been a False Maldini volleyed in an Andrea Pirlo free kick that had 9 [' Mald', 'ini', ' volley', 'ed', ' in', ' an', ' Andrea', ' P', 'irl', 'o']
+113 31 The professional sport played by x -1 The professional sport played by Andrea Pirlo soccer Andrea Pirlo "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Italian'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been'
+ ' a']" , who has been a regular in the Italian national team since the age of 16 , has been a False Liverpool 1 – 0 up. Andrea Pirlo was next for Milan, 9 [' Liverpool', ' 1', ' –', ' 0', ' up', '.', ' Andrea', ' P', 'irl', 'o']
+114 32 The professional sport played by x -1 The professional sport played by Cy Young baseball Cy Young "[' Award' ' winner' ',' ' and' ' the' ' first' ' woman' ' to' ' win'
+ ' the' ' Cy' ' Young' ' Award' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" " Award winner , and the first woman to win the Cy Young Award .
+
+ The first woman to" False Drysdale won both the Cy Young and Sporting News 7 [' D', 'rys', 'dale', ' won', ' both', ' the', ' Cy', ' Young']
+115 32 The professional sport played by x -1 The professional sport played by Cy Young baseball Cy Young "[' Award' ' winner' ',' ' and' ' the' ' first' ' woman' ' to' ' win'
+ ' the' ' Cy' ' Young' ' Award' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" " Award winner , and the first woman to win the Cy Young Award .
+
+ The first woman to" False fourth in the Cy Young Award balloting. 4 [' fourth', ' in', ' the', ' Cy', ' Young']
+116 32 The professional sport played by x -1 The professional sport played by Cy Young baseball Cy Young "[' Award' ' winner' ',' ' and' ' the' ' first' ' woman' ' to' ' win'
+ ' the' ' Cy' ' Young' ' Award' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" " Award winner , and the first woman to win the Cy Young Award .
+
+ The first woman to" False off eventual Cy Young winner Don 3 [' off', ' eventual', ' Cy', ' Young']
+117 32 The professional sport played by x -1 The professional sport played by Cy Young baseball Cy Young "[' Award' ' winner' ',' ' and' ' the' ' first' ' woman' ' to' ' win'
+ ' the' ' Cy' ' Young' ' Award' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" " Award winner , and the first woman to win the Cy Young Award .
+
+ The first woman to" False won the 1988 NL Cy Young Award. Hershiser's 5 [' won', ' the', ' 1988', ' NL', ' Cy', ' Young']
+118 32 The professional sport played by x -1 The professional sport played by Cy Young baseball Cy Young "[' Award' ' winner' ',' ' and' ' the' ' first' ' woman' ' to' ' win'
+ ' the' ' Cy' ' Young' ' Award' '.' '\n' '\n' 'The' ' first' ' woman'
+ ' to']" " Award winner , and the first woman to win the Cy Young Award .
+
+ The first woman to" False Marshall won the Cy Young Award. Alston received 4 [' Marshall', ' won', ' the', ' Cy', ' Young']
+119 33 The professional sport played by x -1 The professional sport played by Bud Grant football Bud Grant "[',' ' the' ' former' ' head' ' coach' ' of' ' the' ' Minnesota'
+ ' Vikings' ',' ' was' ' a' ' member' ' of' ' the' ' Vikings' ""'""
+ ' coaching' ' staff' ' in']" , the former head coach of the Minnesota Vikings , was a member of the Vikings ' coaching staff in False " = Bud Grant =
+" 2 [' =', ' Bud', ' Grant']
+120 33 The professional sport played by x -1 The professional sport played by Bud Grant football Bud Grant "[',' ' the' ' former' ' head' ' coach' ' of' ' the' ' Minnesota'
+ ' Vikings' ',' ' was' ' a' ' member' ' of' ' the' ' Vikings' ""'""
+ ' coaching' ' staff' ' in']" , the former head coach of the Minnesota Vikings , was a member of the Vikings ' coaching staff in False " Bud Grant =
+" 1 [' Bud', ' Grant']
+121 33 The professional sport played by x -1 The professional sport played by Bud Grant football Bud Grant "[',' ' the' ' former' ' head' ' coach' ' of' ' the' ' Minnesota'
+ ' Vikings' ',' ' was' ' a' ' member' ' of' ' the' ' Vikings' ""'""
+ ' coaching' ' staff' ' in']" , the former head coach of the Minnesota Vikings , was a member of the Vikings ' coaching staff in False " Bud Grant =
+" 1 [' Bud', ' Grant']
+122 33 The professional sport played by x -1 The professional sport played by Bud Grant football Bud Grant "[',' ' the' ' former' ' head' ' coach' ' of' ' the' ' Minnesota'
+ ' Vikings' ',' ' was' ' a' ' member' ' of' ' the' ' Vikings' ""'""
+ ' coaching' ' staff' ' in']" , the former head coach of the Minnesota Vikings , was a member of the Vikings ' coaching staff in False " Bud Grant =
+" 1 [' Bud', ' Grant']
+123 34 The professional sport played by x -1 The professional sport played by Jesse Hibbs football Jesse Hibbs "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False directed by Jesse Hibbs with an on-screen 4 [' directed', ' by', ' Jesse', ' Hib', 'bs']
+124 34 The professional sport played by x -1 The professional sport played by Jesse Hibbs football Jesse Hibbs "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False was directed by Jesse Hibbs with an on-screen 5 [' was', ' directed', ' by', ' Jesse', ' Hib', 'bs']
+125 34 The professional sport played by x -1 The professional sport played by Jesse Hibbs football Jesse Hibbs "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False film was directed by Jesse Hibbs with an on-screen 6 [' film', ' was', ' directed', ' by', ' Jesse', ' Hib', 'bs']
+126 34 The professional sport played by x -1 The professional sport played by Jesse Hibbs football Jesse Hibbs "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False was directed by Jesse Hibbs with an on-screen 5 [' was', ' directed', ' by', ' Jesse', ' Hib', 'bs']
+127 34 The professional sport played by x -1 The professional sport played by Jesse Hibbs football Jesse Hibbs "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False more than once. Jesse Hibbs who directed 6 [' more', ' than', ' once', '.', ' Jesse', ' Hib', 'bs']
+128 35 The professional sport played by x -1 The professional sport played by Jim Bunning baseball Jim Bunning "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Philadelphia'
+ ' Phillies' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the'
+ ' first' ' week']" ", who was a pitcher for the Philadelphia Phillies .
+
+ The game was played in the first week" False doubleheader, Koufax faced Jim Bunning for the second 8 [' double', 'header', ',', ' Kou', 'fax', ' faced', ' Jim', ' B', 'unning']
+129 35 The professional sport played by x -1 The professional sport played by Jim Bunning baseball Jim Bunning "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Philadelphia'
+ ' Phillies' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the'
+ ' first' ' week']" ", who was a pitcher for the Philadelphia Phillies .
+
+ The game was played in the first week" False Ford (236 – 106), Jim Bunning (224 – 184) 8 [' Ford', ' (', '236', ' –', ' 106', '),', ' Jim', ' B', 'unning']
+130 35 The professional sport played by x -1 The professional sport played by Jim Bunning baseball Jim Bunning "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Philadelphia'
+ ' Phillies' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the'
+ ' first' ' week']" ", who was a pitcher for the Philadelphia Phillies .
+
+ The game was played in the first week" False seasons; he tied Jim Bunning for the league 6 [' seasons', ';', ' he', ' tied', ' Jim', ' B', 'unning']
+131 35 The professional sport played by x -1 The professional sport played by Jim Bunning baseball Jim Bunning "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Philadelphia'
+ ' Phillies' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the'
+ ' first' ' week']" ", who was a pitcher for the Philadelphia Phillies .
+
+ The game was played in the first week" False incumbent Senator Jim Bunning in the 2010 4 [' incumbent', ' Senator', ' Jim', ' B', 'unning']
+132 35 The professional sport played by x -1 The professional sport played by Jim Bunning baseball Jim Bunning "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Philadelphia'
+ ' Phillies' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the'
+ ' first' ' week']" ", who was a pitcher for the Philadelphia Phillies .
+
+ The game was played in the first week" False Republican nominee Jim Bunning tried to enlist 4 [' Republican', ' nominee', ' Jim', ' B', 'unning']
+133 36 The professional sport played by x -1 The professional sport played by Carlos Arroyo basketball Carlos Arroyo "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' Olympic'
+ ' team' ' in' ' the' ' 2004' ' Summer' ' Olympics' ' in' ' Athens']" , who was a member of the U . S . Olympic team in the 2004 Summer Olympics in Athens False Barea and Carlos Arroyo to play for Puerto 6 [' B', 'area', ' and', ' Carlos', ' Ar', 'roy', 'o']
+134 36 The professional sport played by x -1 The professional sport played by Carlos Arroyo basketball Carlos Arroyo "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.' ' Olympic'
+ ' team' ' in' ' the' ' 2004' ' Summer' ' Olympics' ' in' ' Athens']" , who was a member of the U . S . Olympic team in the 2004 Summer Olympics in Athens False Barea and Carlos Arroyo to play for Puerto 6 [' B', 'area', ' and', ' Carlos', ' Ar', 'roy', 'o']
+135 37 The professional sport played by x -1 The professional sport played by Alessandro Nesta soccer Alessandro Nesta "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' '.'
+ '\n' '\n' 'The' ' game' ' is' ' played' ' on' ' a' ' rectangular'
+ ' field' ' with']" ", who is a former professional soccer player .
+
+ The game is played on a rectangular field with" True turned down after Alessandro Nesta allegedly handballed. 6 [' turned', ' down', ' after', ' Aless', 'andro', ' N', 'esta']
+136 37 The professional sport played by x -1 The professional sport played by Alessandro Nesta soccer Alessandro Nesta "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' '.'
+ '\n' '\n' 'The' ' game' ' is' ' played' ' on' ' a' ' rectangular'
+ ' field' ' with']" ", who is a former professional soccer player .
+
+ The game is played on a rectangular field with" True turned down after Alessandro Nesta allegedly handballed. 6 [' turned', ' down', ' after', ' Aless', 'andro', ' N', 'esta']
+137 37 The professional sport played by x -1 The professional sport played by Alessandro Nesta soccer Alessandro Nesta "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' '.'
+ '\n' '\n' 'The' ' game' ' is' ' played' ' on' ' a' ' rectangular'
+ ' field' ' with']" ", who is a former professional soccer player .
+
+ The game is played on a rectangular field with" True turned down after Alessandro Nesta allegedly 6 [' turned', ' down', ' after', ' Aless', 'andro', ' N', 'esta']
+138 37 The professional sport played by x -1 The professional sport played by Alessandro Nesta soccer Alessandro Nesta "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' '.'
+ '\n' '\n' 'The' ' game' ' is' ' played' ' on' ' a' ' rectangular'
+ ' field' ' with']" ", who is a former professional soccer player .
+
+ The game is played on a rectangular field with" True down after Alessandro Nesta allegedly handballed. 5 [' down', ' after', ' Aless', 'andro', ' N', 'esta']
+139 38 The professional sport played by x -1 The professional sport played by Ryan Smyth hockey Ryan Smyth "[',' ' who' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the' ' other']" , who is a former NHL player , and the other is a former NHL player , and the other False players Joe Sakic and Ryan Smyth to injuries, Ryan 7 [' players', ' Joe', ' Sak', 'ic', ' and', ' Ryan', ' Smy', 'th']
+140 38 The professional sport played by x -1 The professional sport played by Ryan Smyth hockey Ryan Smyth "[',' ' who' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the' ' other']" , who is a former NHL player , and the other is a former NHL player , and the other False After injuries to Ryan Smyth and Steve Stamkos 5 [' After', ' injuries', ' to', ' Ryan', ' Smy', 'th']
+141 38 The professional sport played by x -1 The professional sport played by Ryan Smyth hockey Ryan Smyth "[',' ' who' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the' ' other']" , who is a former NHL player , and the other is a former NHL player , and the other False month, top scorer Ryan Smyth was traded to 6 [' month', ',', ' top', ' scorer', ' Ryan', ' Smy', 'th']
+142 38 The professional sport played by x -1 The professional sport played by Ryan Smyth hockey Ryan Smyth "[',' ' who' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the' ' other']" , who is a former NHL player , and the other is a former NHL player , and the other False players Joe Sakic and Ryan Smyth to injuries, Ryan 7 [' players', ' Joe', ' Sak', 'ic', ' and', ' Ryan', ' Smy', 'th']
+143 38 The professional sport played by x -1 The professional sport played by Ryan Smyth hockey Ryan Smyth "[',' ' who' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NHL' ' player' ',' ' and' ' the' ' other']" , who is a former NHL player , and the other is a former NHL player , and the other False Joe Sakic and Ryan Smyth to injuries, Ryan 6 [' Joe', ' Sak', 'ic', ' and', ' Ryan', ' Smy', 'th']
+144 39 The professional sport played by x -1 The professional sport played by Troy Aikman football Troy Aikman "[',' ' the' ' former' ' Dallas' ' Cowboys' ' quarterback' ',' ' is' ' a'
+ ' big' ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a' ' fan']" , the former Dallas Cowboys quarterback , is a big fan of the game . He has been a fan False Rosey Grier, Troy Aikman and Dan Marino 8 [' Rose', 'y', ' G', 'rier', ',', ' Troy', ' A', 'ik', 'man']
+145 39 The professional sport played by x -1 The professional sport played by Troy Aikman football Troy Aikman "[',' ' the' ' former' ' Dallas' ' Cowboys' ' quarterback' ',' ' is' ' a'
+ ' big' ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a' ' fan']" , the former Dallas Cowboys quarterback , is a big fan of the game . He has been a fan False Taylor batted a Troy Aikman pass high into 6 [' Taylor', ' batted', ' a', ' Troy', ' A', 'ik', 'man']
+146 39 The professional sport played by x -1 The professional sport played by Troy Aikman football Troy Aikman "[',' ' the' ' former' ' Dallas' ' Cowboys' ' quarterback' ',' ' is' ' a'
+ ' big' ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a' ' fan']" , the former Dallas Cowboys quarterback , is a big fan of the game . He has been a fan False Taylor batted a Troy Aikman pass high into 6 [' Taylor', ' batted', ' a', ' Troy', ' A', 'ik', 'man']
+147 39 The professional sport played by x -1 The professional sport played by Troy Aikman football Troy Aikman "[',' ' the' ' former' ' Dallas' ' Cowboys' ' quarterback' ',' ' is' ' a'
+ ' big' ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a' ' fan']" , the former Dallas Cowboys quarterback , is a big fan of the game . He has been a fan False quarterback after Troy Aikman was injured 5 [' quarterback', ' after', ' Troy', ' A', 'ik', 'man']
+148 39 The professional sport played by x -1 The professional sport played by Troy Aikman football Troy Aikman "[',' ' the' ' former' ' Dallas' ' Cowboys' ' quarterback' ',' ' is' ' a'
+ ' big' ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a' ' fan']" , the former Dallas Cowboys quarterback , is a big fan of the game . He has been a fan False Taylor batted a Troy Aikman pass high into 6 [' Taylor', ' batted', ' a', ' Troy', ' A', 'ik', 'man']
+149 40 The professional sport played by x -1 The professional sport played by Joakim Noah basketball Joakim Noah "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' big'
+ ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a']" , who is a former NBA player , is a big fan of the game . He has been a False Derrick Rose and Joakim Noah to injuries, and the 6 [' Derrick', ' Rose', ' and', ' Jo', 'ak', 'im', ' Noah']
+150 40 The professional sport played by x -1 The professional sport played by Joakim Noah basketball Joakim Noah "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' big'
+ ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a']" , who is a former NBA player , is a big fan of the game . He has been a False established by Joakim Noah in the 2006 NCAA 5 [' established', ' by', ' Jo', 'ak', 'im', ' Noah']
+151 40 The professional sport played by x -1 The professional sport played by Joakim Noah basketball Joakim Noah "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' big'
+ ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a']" , who is a former NBA player , is a big fan of the game . He has been a False Derrick Rose and Joakim Noah to injuries, and 6 [' Derrick', ' Rose', ' and', ' Jo', 'ak', 'im', ' Noah']
+152 40 The professional sport played by x -1 The professional sport played by Joakim Noah basketball Joakim Noah "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' big'
+ ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a']" , who is a former NBA player , is a big fan of the game . He has been a False established by Joakim Noah in the 2006 5 [' established', ' by', ' Jo', 'ak', 'im', ' Noah']
+153 40 The professional sport played by x -1 The professional sport played by Joakim Noah basketball Joakim Noah "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' big'
+ ' fan' ' of' ' the' ' game' '.' ' He' ' has' ' been' ' a']" , who is a former NBA player , is a big fan of the game . He has been a False established by Joakim Noah in the 2006 NCAA 5 [' established', ' by', ' Jo', 'ak', 'im', ' Noah']
+154 42 The professional sport played by x -1 The professional sport played by Sergei Fedorov hockey Sergei Fedorov "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Soviet' ' Union'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' at' ' the'
+ ' 1980' ' Winter']" , who was a member of the Soviet Union national team that won the gold medal at the 1980 Winter False including linemates Sergei Fedorov and Alexander 6 [' including', ' linem', 'ates', ' Sergei', ' Fed', 'or', 'ov']
+155 42 The professional sport played by x -1 The professional sport played by Sergei Fedorov hockey Sergei Fedorov "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Soviet' ' Union'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' at' ' the'
+ ' 1980' ' Winter']" , who was a member of the Soviet Union national team that won the gold medal at the 1980 Winter False including linemates Sergei Fedorov and Alexander 6 [' including', ' linem', 'ates', ' Sergei', ' Fed', 'or', 'ov']
+156 42 The professional sport played by x -1 The professional sport played by Sergei Fedorov hockey Sergei Fedorov "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Soviet' ' Union'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' at' ' the'
+ ' 1980' ' Winter']" , who was a member of the Soviet Union national team that won the gold medal at the 1980 Winter False Alexander Mogilny and Sergei Fedorov led the Soviet Union 8 [' Alexander', ' Mog', 'il', 'ny', ' and', ' Sergei', ' Fed', 'or', 'ov']
+157 43 The professional sport played by x -1 The professional sport played by Chris Paul basketball Chris Paul "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False Forest team led by Chris Paul in double overtime 5 [' Forest', ' team', ' led', ' by', ' Chris', ' Paul']
+158 43 The professional sport played by x -1 The professional sport played by Chris Paul basketball Chris Paul "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False points while Chris Paul added a game high 3 [' points', ' while', ' Chris', ' Paul']
+159 43 The professional sport played by x -1 The professional sport played by Chris Paul basketball Chris Paul "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False " = Chris Paul =
+" 2 [' =', ' Chris', ' Paul']
+160 43 The professional sport played by x -1 The professional sport played by Chris Paul basketball Chris Paul "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False " Chris Paul =
+" 1 [' Chris', ' Paul']
+161 43 The professional sport played by x -1 The professional sport played by Chris Paul basketball Chris Paul "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False Clippers after Chris Paul suffered a separated 3 [' Clippers', ' after', ' Chris', ' Paul']
+162 44 The professional sport played by x -1 The professional sport played by David Robinson basketball David Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' first' ' player' ' to']" , who was a member of the NBA 's all - time great s , and the first player to False " film historian David Robinson claims that ""the cylinder" 3 [' film', ' historian', ' David', ' Robinson']
+163 44 The professional sport played by x -1 The professional sport played by David Robinson basketball David Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' first' ' player' ' to']" , who was a member of the NBA 's all - time great s , and the first player to False Bird, and David Robinson and was dubbed the 4 [' Bird', ',', ' and', ' David', ' Robinson']
+164 44 The professional sport played by x -1 The professional sport played by David Robinson basketball David Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' first' ' player' ' to']" , who was a member of the NBA 's all - time great s , and the first player to False Spurs. With David Robinson guarding him, Olajuwon 4 [' Spurs', '.', ' With', ' David', ' Robinson']
+165 44 The professional sport played by x -1 The professional sport played by David Robinson basketball David Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' first' ' player' ' to']" , who was a member of the NBA 's all - time great s , and the first player to False theatre. However, David Robinson notes even the performances 5 [' theatre', '.', ' However', ',', ' David', ' Robinson']
+166 44 The professional sport played by x -1 The professional sport played by David Robinson basketball David Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' first' ' player' ' to']" , who was a member of the NBA 's all - time great s , and the first player to False such as Patrick Ewing, David Robinson, Shaquille O 'Neal, 7 [' such', ' as', ' Patrick', ' E', 'wing', ',', ' David', ' Robinson']
+167 45 The professional sport played by x -1 The professional sport played by Gale Sayers football Gale Sayers "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False retirement of the iconic Gale Sayers in 1972. Payton's 6 [' retirement', ' of', ' the', ' iconic', ' Gale', ' S', 'ayers']
+168 45 The professional sport played by x -1 The professional sport played by Gale Sayers football Gale Sayers "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False 3 ['G', 'ale', ' S', 'ayers']
+169 45 The professional sport played by x -1 The professional sport played by Gale Sayers football Gale Sayers "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False Dungy, Jim Brown, Gale Sayers and the Manning family: 8 [' Dun', 'gy', ',', ' Jim', ' Brown', ',', ' Gale', ' S', 'ayers']
+170 45 The professional sport played by x -1 The professional sport played by Gale Sayers football Gale Sayers "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False Nevers (1929) and Gale Sayers (1965). The Browns 9 [' Ne', 'vers', ' (', '19', '29', ')', ' and', ' Gale', ' S', 'ayers']
+171 45 The professional sport played by x -1 The professional sport played by Gale Sayers football Gale Sayers "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False Dungy, Jim Brown, Gale Sayers and the Manning 8 [' Dun', 'gy', ',', ' Jim', ' Brown', ',', ' Gale', ' S', 'ayers']
+172 46 The professional sport played by x -1 The professional sport played by Bobby Orr hockey Bobby Orr "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' career'
+ ' goals' ',' ' assists' ',' ' and' ' points' '.' '\n' '\n' 'The']" ", the NHL 's all - time leader in career goals , assists , and points .
+
+ The" False (1979 – 92), Bobby Orr (1969 – 72) and Stan 7 [' (', '1979', ' –', ' 92', '),', ' Bobby', ' Or', 'r']
+173 46 The professional sport played by x -1 The professional sport played by Bobby Orr hockey Bobby Orr "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' career'
+ ' goals' ',' ' assists' ',' ' and' ' points' '.' '\n' '\n' 'The']" ", the NHL 's all - time leader in career goals , assists , and points .
+
+ The" False would have voted Bobby Orr or Gordie Howe as 5 [' would', ' have', ' voted', ' Bobby', ' Or', 'r']
+174 46 The professional sport played by x -1 The professional sport played by Bobby Orr hockey Bobby Orr "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' career'
+ ' goals' ',' ' assists' ',' ' and' ' points' '.' '\n' '\n' 'The']" ", the NHL 's all - time leader in career goals , assists , and points .
+
+ The" False on behalf of Bobby Orr with the Bruins 5 [' on', ' behalf', ' of', ' Bobby', ' Or', 'r']
+175 46 The professional sport played by x -1 The professional sport played by Bobby Orr hockey Bobby Orr "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' career'
+ ' goals' ',' ' assists' ',' ' and' ' points' '.' '\n' '\n' 'The']" ", the NHL 's all - time leader in career goals , assists , and points .
+
+ The" False museum, called the Bobby Orr Hall of Fame, where 6 [' museum', ',', ' called', ' the', ' Bobby', ' Or', 'r']
+176 46 The professional sport played by x -1 The professional sport played by Bobby Orr hockey Bobby Orr "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' career'
+ ' goals' ',' ' assists' ',' ' and' ' points' '.' '\n' '\n' 'The']" ", the NHL 's all - time leader in career goals , assists , and points .
+
+ The" False would have voted Bobby Orr or Gordie Howe as 5 [' would', ' have', ' voted', ' Bobby', ' Or', 'r']
+177 47 The professional sport played by x -1 The professional sport played by Herschel Walker football Herschel Walker "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a']" , the former NFL running back , is a great example of the kind of person who can make a False 1995, running back Herschel Walker received more passing 6 [' 1995', ',', ' running', ' back', ' Hers', 'chel', ' Walker']
+178 47 The professional sport played by x -1 The professional sport played by Herschel Walker football Herschel Walker "[',' ' the' ' former' ' NFL' ' running' ' back' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a']" , the former NFL running back , is a great example of the kind of person who can make a False " point, first set by Herschel Walker in 1980.
+" 7 [' point', ',', ' first', ' set', ' by', ' Hers', 'chel', ' Walker']
+179 48 The professional sport played by x -1 The professional sport played by Bobby Hull hockey Bobby Hull "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Blackhawks' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Chicago Blackhawks .
+
+" False until broken by Bobby Hull in 1965. The 4 [' until', ' broken', ' by', ' Bobby', ' Hull']
+180 48 The professional sport played by x -1 The professional sport played by Bobby Hull hockey Bobby Hull "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Blackhawks' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Chicago Blackhawks .
+
+" False and players. Bobby Hull was the most famous 4 [' and', ' players', '.', ' Bobby', ' Hull']
+181 48 The professional sport played by x -1 The professional sport played by Bobby Hull hockey Bobby Hull "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Blackhawks' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Chicago Blackhawks .
+
+" False Both Mikita and Bobby Hull experimented with 5 [' Both', ' Mik', 'ita', ' and', ' Bobby', ' Hull']
+182 48 The professional sport played by x -1 The professional sport played by Bobby Hull hockey Bobby Hull "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Blackhawks' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Chicago Blackhawks .
+
+" False coup was to lure Bobby Hull from the Black 5 [' coup', ' was', ' to', ' lure', ' Bobby', ' Hull']
+183 48 The professional sport played by x -1 The professional sport played by Bobby Hull hockey Bobby Hull "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Blackhawks' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Chicago Blackhawks .
+
+" False fought Hall of Famer Bobby Hull and in the process 6 [' fought', ' Hall', ' of', ' F', 'amer', ' Bobby', ' Hull']
+184 49 The professional sport played by x -1 The professional sport played by Larry Doby baseball Larry Doby "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' first' ' black' ' player' ' in' ' the' ' major'
+ ' leagues']" ", the first black player in the major leagues .
+
+ The first black player in the major leagues" False tribute to Doby on Larry Doby Day by collectively 7 [' tribute', ' to', ' D', 'oby', ' on', ' Larry', ' D', 'oby']
+185 49 The professional sport played by x -1 The professional sport played by Larry Doby baseball Larry Doby "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' first' ' black' ' player' ' in' ' the' ' major'
+ ' leagues']" ", the first black player in the major leagues .
+
+ The first black player in the major leagues" False " baseball field ""Larry Doby Field"" on June" 5 "[' baseball', ' field', ' ""', 'Larry', ' D', 'oby']"
+186 49 The professional sport played by x -1 The professional sport played by Larry Doby baseball Larry Doby "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' first' ' black' ' player' ' in' ' the' ' major'
+ ' leagues']" ", the first black player in the major leagues .
+
+ The first black player in the major leagues" False " Livingston wrote, ""The Larry Doby way of pioneering" 7 "[' Livingston', ' wrote', ',', ' ""', 'The', ' Larry', ' D', 'oby']"
+187 49 The professional sport played by x -1 The professional sport played by Larry Doby baseball Larry Doby "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' first' ' black' ' player' ' in' ' the' ' major'
+ ' leagues']" ", the first black player in the major leagues .
+
+ The first black player in the major leagues" False often forgotten ... Larry Doby integrated all those 5 [' often', ' forgotten', '...', ' Larry', ' D', 'oby']
+188 49 The professional sport played by x -1 The professional sport played by Larry Doby baseball Larry Doby "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' first' ' black' ' player' ' in' ' the' ' major'
+ ' leagues']" ", the first black player in the major leagues .
+
+ The first black player in the major leagues" False outfield alongside Larry Doby and Dale Mitchell. 4 [' outfield', ' alongside', ' Larry', ' D', 'oby']
+189 50 The professional sport played by x -1 The professional sport played by Ed O'Neill football Ed O'Neill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' Broadway' ' musical' ' ""' 'The' ' Music' ' Man' '""' ' and'
+ ' who']" ", who was a member of the original cast of the Broadway musical "" The Music Man "" and who" False the Boston Red Sox. Ed O'Neill also auditioned 8 "[' the', ' Boston', ' Red', ' Sox', '.', ' Ed', ' O', ""'"", 'Neill']"
+190 50 The professional sport played by x -1 The professional sport played by Ed O'Neill football Ed O'Neill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' Broadway' ' musical' ' ""' 'The' ' Music' ' Man' '""' ' and'
+ ' who']" ", who was a member of the original cast of the Broadway musical "" The Music Man "" and who" False Boston Red Sox. Ed O'Neill also auditioned for 7 "[' Boston', ' Red', ' Sox', '.', ' Ed', ' O', ""'"", 'Neill']"
+191 50 The professional sport played by x -1 The professional sport played by Ed O'Neill football Ed O'Neill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' Broadway' ' musical' ' ""' 'The' ' Music' ' Man' '""' ' and'
+ ' who']" ", who was a member of the original cast of the Broadway musical "" The Music Man "" and who" False Boston Red Sox. Ed O'Neill also auditioned for 7 "[' Boston', ' Red', ' Sox', '.', ' Ed', ' O', ""'"", 'Neill']"
+192 50 The professional sport played by x -1 The professional sport played by Ed O'Neill football Ed O'Neill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' Broadway' ' musical' ' ""' 'The' ' Music' ' Man' '""' ' and'
+ ' who']" ", who was a member of the original cast of the Broadway musical "" The Music Man "" and who" False Boston Red Sox. Ed O'Neill also auditioned for 7 "[' Boston', ' Red', ' Sox', '.', ' Ed', ' O', ""'"", 'Neill']"
+193 50 The professional sport played by x -1 The professional sport played by Ed O'Neill football Ed O'Neill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' Broadway' ' musical' ' ""' 'The' ' Music' ' Man' '""' ' and'
+ ' who']" ", who was a member of the original cast of the Broadway musical "" The Music Man "" and who" False Boston Red Sox. Ed O'Neill also auditioned for 7 "[' Boston', ' Red', ' Sox', '.', ' Ed', ' O', ""'"", 'Neill']"
+194 51 The professional sport played by x -1 The professional sport played by Kareem Abdul-Jabbar basketball Kareem Abdul-Jabbar "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False Magic Johnson, Kareem Abdul-Jabbar and Michael Cooper 9 [' Magic', ' Johnson', ',', ' Kare', 'em', ' Abdul', '-', 'J', 'ab', 'bar']
+195 51 The professional sport played by x -1 The professional sport played by Kareem Abdul-Jabbar basketball Kareem Abdul-Jabbar "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False Wilt Chamberlain and Kareem Abdul-Jabbar to a one-on-one 10 [' W', 'ilt', ' Chamberlain', ' and', ' Kare', 'em', ' Abdul', '-', 'J', 'ab', 'bar']
+196 51 The professional sport played by x -1 The professional sport played by Kareem Abdul-Jabbar basketball Kareem Abdul-Jabbar "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False Magic Johnson, Kareem Abdul-Jabbar and Michael Cooper 9 [' Magic', ' Johnson', ',', ' Kare', 'em', ' Abdul', '-', 'J', 'ab', 'bar']
+197 51 The professional sport played by x -1 The professional sport played by Kareem Abdul-Jabbar basketball Kareem Abdul-Jabbar "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False ten finalists for the Kareem Abdul-Jabbar Award, also 10 [' ten', ' finalists', ' for', ' the', ' Kare', 'em', ' Abdul', '-', 'J', 'ab', 'bar']
+198 51 The professional sport played by x -1 The professional sport played by Kareem Abdul-Jabbar basketball Kareem Abdul-Jabbar "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False basketball player Kareem Abdul-Jabbar all befriended 8 [' basketball', ' player', ' Kare', 'em', ' Abdul', '-', 'J', 'ab', 'bar']
+199 52 The professional sport played by x -1 The professional sport played by Ty Cobb baseball Ty Cobb "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' hit'
+ ' a' ' home' ' run' ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The'
+ ' first']" ", the first professional baseball player to hit a home run in the World Series .
+
+ The first" True revised. Previously, Ty Cobb and George Sisler 5 [' revised', '.', ' Previously', ',', ' Ty', ' Cobb']
+200 52 The professional sport played by x -1 The professional sport played by Ty Cobb baseball Ty Cobb "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' hit'
+ ' a' ' home' ' run' ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The'
+ ' first']" ", the first professional baseball player to hit a home run in the World Series .
+
+ The first" True induction only Ty Cobb (98.2 %), 3 [' induction', ' only', ' Ty', ' Cobb']
+201 52 The professional sport played by x -1 The professional sport played by Ty Cobb baseball Ty Cobb "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' hit'
+ ' a' ' home' ' run' ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The'
+ ' first']" ", the first professional baseball player to hit a home run in the World Series .
+
+ The first" True and often violent Ty Cobb epitomized this 4 [' and', ' often', ' violent', ' Ty', ' Cobb']
+202 52 The professional sport played by x -1 The professional sport played by Ty Cobb baseball Ty Cobb "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' hit'
+ ' a' ' home' ' run' ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The'
+ ' first']" ", the first professional baseball player to hit a home run in the World Series .
+
+ The first" True " == Rivalry with Ty Cobb ==
+" 5 [' ==', ' R', 'ivalry', ' with', ' Ty', ' Cobb']
+203 52 The professional sport played by x -1 The professional sport played by Ty Cobb baseball Ty Cobb "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' hit'
+ ' a' ' home' ' run' ' in' ' the' ' World' ' Series' '.' '\n' '\n' 'The'
+ ' first']" ", the first professional baseball player to hit a home run in the World Series .
+
+ The first" True people suspect Ty Cobb as well. I think that 3 [' people', ' suspect', ' Ty', ' Cobb']
+204 55 The professional sport played by x -1 The professional sport played by Joe DiMaggio baseball Joe DiMaggio "[',' ' the' ' greatest' ' hitter' ' in' ' baseball' ' history' ',' ' was'
+ ' born' ' in' ' San' ' Francisco' ',' ' California' ',' ' on' ' this'
+ ' day' ' in']" , the greatest hitter in baseball history , was born in San Francisco , California , on this day in True Yankees star Joe DiMaggio was among those 6 [' Yankees', ' star', ' Joe', ' Di', 'M', 'agg', 'io']
+205 55 The professional sport played by x -1 The professional sport played by Joe DiMaggio baseball Joe DiMaggio "[',' ' the' ' greatest' ' hitter' ' in' ' baseball' ' history' ',' ' was'
+ ' born' ' in' ' San' ' Francisco' ',' ' California' ',' ' on' ' this'
+ ' day' ' in']" , the greatest hitter in baseball history , was born in San Francisco , California , on this day in True from 1935 to 1941. Joe DiMaggio played in the first-ever 9 [' from', ' 1935', ' to', ' 1941', '.', ' Joe', ' Di', 'M', 'agg', 'io']
+206 55 The professional sport played by x -1 The professional sport played by Joe DiMaggio baseball Joe DiMaggio "[',' ' the' ' greatest' ' hitter' ' in' ' baseball' ' history' ',' ' was'
+ ' born' ' in' ' San' ' Francisco' ',' ' California' ',' ' on' ' this'
+ ' day' ' in']" , the greatest hitter in baseball history , was born in San Francisco , California , on this day in True Charlie Keller and Joe DiMaggio in the ninth 7 [' Charlie', ' Keller', ' and', ' Joe', ' Di', 'M', 'agg', 'io']
+207 55 The professional sport played by x -1 The professional sport played by Joe DiMaggio baseball Joe DiMaggio "[',' ' the' ' greatest' ' hitter' ' in' ' baseball' ' history' ',' ' was'
+ ' born' ' in' ' San' ' Francisco' ',' ' California' ',' ' on' ' this'
+ ' day' ' in']" , the greatest hitter in baseball history , was born in San Francisco , California , on this day in True the mid-1930s; Joe DiMaggio broke his marks 11 [' the', ' mid', '-', '19', '30', 's', ';', ' Joe', ' Di', 'M', 'agg', 'io']
+208 55 The professional sport played by x -1 The professional sport played by Joe DiMaggio baseball Joe DiMaggio "[',' ' the' ' greatest' ' hitter' ' in' ' baseball' ' history' ',' ' was'
+ ' born' ' in' ' San' ' Francisco' ',' ' California' ',' ' on' ' this'
+ ' day' ' in']" , the greatest hitter in baseball history , was born in San Francisco , California , on this day in True ahead of third-place Joe DiMaggio and Mark McGwire (77). 9 [' ahead', ' of', ' third', '-', 'place', ' Joe', ' Di', 'M', 'agg', 'io']
+209 57 The professional sport played by x -1 The professional sport played by Frédéric Piquionne soccer Frédéric Piquionne "[',' ' a' ' French' ' professional' ' road' ' bicycle' ' racer' '.' ' He'
+ ' is' ' a' ' former' ' professional' ' rider' ' for' ' the' ' French'
+ ' team' ' FD' 'J']" , a French professional road bicycle racer . He is a former professional rider for the French team FD J False both teams. When Frédéric Piquionne equalised for West 12 [' both', ' teams', '.', ' When', ' Fr', 'é', 'd', 'é', 'ric', ' P', 'iqu', 'ion', 'ne']
+210 57 The professional sport played by x -1 The professional sport played by Frédéric Piquionne soccer Frédéric Piquionne "[',' ' a' ' French' ' professional' ' road' ' bicycle' ' racer' '.' ' He'
+ ' is' ' a' ' former' ' professional' ' rider' ' for' ' the' ' French'
+ ' team' ' FD' 'J']" , a French professional road bicycle racer . He is a former professional rider for the French team FD J False both teams. When Frédéric Piquionne equalised 12 [' both', ' teams', '.', ' When', ' Fr', 'é', 'd', 'é', 'ric', ' P', 'iqu', 'ion', 'ne']
+211 58 The professional sport played by x -1 The professional sport played by Tony Gwynn baseball Tony Gwynn "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of' ' his']" ", who was a great player and a great person .
+
+ I was a huge fan of his" False he worked with Tony Gwynn on skills at the 6 [' he', ' worked', ' with', ' Tony', ' G', 'wyn', 'n']
+212 58 The professional sport played by x -1 The professional sport played by Tony Gwynn baseball Tony Gwynn "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of' ' his']" ", who was a great player and a great person .
+
+ I was a huge fan of his" False short right field that Tony Gwynn was unable to see. 7 [' short', ' right', ' field', ' that', ' Tony', ' G', 'wyn', 'n']
+213 58 The professional sport played by x -1 The professional sport played by Tony Gwynn baseball Tony Gwynn "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of' ' his']" ", who was a great player and a great person .
+
+ I was a huge fan of his" False right field that Tony Gwynn was unable to 6 [' right', ' field', ' that', ' Tony', ' G', 'wyn', 'n']
+214 58 The professional sport played by x -1 The professional sport played by Tony Gwynn baseball Tony Gwynn "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of' ' his']" ", who was a great player and a great person .
+
+ I was a huge fan of his" False those guys like Tony Gwynn — they never 6 [' those', ' guys', ' like', ' Tony', ' G', 'wyn', 'n']
+215 58 The professional sport played by x -1 The professional sport played by Tony Gwynn baseball Tony Gwynn "[',' ' who' ' was' ' a' ' great' ' player' ' and' ' a' ' great' ' person'
+ '.' '\n' '\n' 'I' ' was' ' a' ' huge' ' fan' ' of' ' his']" ", who was a great player and a great person .
+
+ I was a huge fan of his" False " (.469, 1980) and Tony Gwynn (.459, 1997).
+" 9 [' (.', '469', ',', ' 1980', ')', ' and', ' Tony', ' G', 'wyn', 'n']
+216 59 The professional sport played by x -1 The professional sport played by Bart Starr football Bart Starr "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' was' ' a' ' great' ' quarterback' '.' ' He' ' was' ' a' ' great'
+ ' quarterback' '.']" , the quarterback for the Green Bay Packers , was a great quarterback . He was a great quarterback . False Athletes in Action / Bart Starr Award winner 6 [' Athlet', 'es', ' in', ' Action', ' /', ' Bart', ' Starr']
+217 59 The professional sport played by x -1 The professional sport played by Bart Starr football Bart Starr "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' was' ' a' ' great' ' quarterback' '.' ' He' ' was' ' a' ' great'
+ ' quarterback' '.']" , the quarterback for the Green Bay Packers , was a great quarterback . He was a great quarterback . False Athletes in Action / Bart Starr Award winner 6 [' Athlet', 'es', ' in', ' Action', ' /', ' Bart', ' Starr']
+218 59 The professional sport played by x -1 The professional sport played by Bart Starr football Bart Starr "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' was' ' a' ' great' ' quarterback' '.' ' He' ' was' ' a' ' great'
+ ' quarterback' '.']" , the quarterback for the Green Bay Packers , was a great quarterback . He was a great quarterback . False Athletes in Action / Bart Starr Award winner after 6 [' Athlet', 'es', ' in', ' Action', ' /', ' Bart', ' Starr']
+219 59 The professional sport played by x -1 The professional sport played by Bart Starr football Bart Starr "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' was' ' a' ' great' ' quarterback' '.' ' He' ' was' ' a' ' great'
+ ' quarterback' '.']" , the quarterback for the Green Bay Packers , was a great quarterback . He was a great quarterback . False Athletes in Action / Bart Starr Award winner after 6 [' Athlet', 'es', ' in', ' Action', ' /', ' Bart', ' Starr']
+220 59 The professional sport played by x -1 The professional sport played by Bart Starr football Bart Starr "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' was' ' a' ' great' ' quarterback' '.' ' He' ' was' ' a' ' great'
+ ' quarterback' '.']" , the quarterback for the Green Bay Packers , was a great quarterback . He was a great quarterback . False Athletes in Action / Bart Starr Award winner 6 [' Athlet', 'es', ' in', ' Action', ' /', ' Bart', ' Starr']
+221 60 The professional sport played by x -1 The professional sport played by Lou Gehrig baseball Lou Gehrig "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' die'
+ ' of' ' ALS' '.' '\n' '\n' 'The' ' first' ' professional' ' baseball'
+ ' player' ' to' ' die']" ", the first professional baseball player to die of ALS .
+
+ The first professional baseball player to die" True Zosky's debut to Lou Gehrig in 1925, asking 8 "[' Z', 'os', 'ky', ""'s"", ' debut', ' to', ' Lou', ' Geh', 'rig']"
+222 60 The professional sport played by x -1 The professional sport played by Lou Gehrig baseball Lou Gehrig "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' die'
+ ' of' ' ALS' '.' '\n' '\n' 'The' ' first' ' professional' ' baseball'
+ ' player' ' to' ' die']" ", the first professional baseball player to die of ALS .
+
+ The first professional baseball player to die" True notably, Pipp with Lou Gehrig at first base, 7 [' notably', ',', ' P', 'ipp', ' with', ' Lou', ' Geh', 'rig']
+223 60 The professional sport played by x -1 The professional sport played by Lou Gehrig baseball Lou Gehrig "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' die'
+ ' of' ' ALS' '.' '\n' '\n' 'The' ' first' ' professional' ' baseball'
+ ' player' ' to' ' die']" ", the first professional baseball player to die of ALS .
+
+ The first professional baseball player to die" True center field, and Lou Gehrig followed by 6 [' center', ' field', ',', ' and', ' Lou', ' Geh', 'rig']
+224 60 The professional sport played by x -1 The professional sport played by Lou Gehrig baseball Lou Gehrig "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' die'
+ ' of' ' ALS' '.' '\n' '\n' 'The' ' first' ' professional' ' baseball'
+ ' player' ' to' ' die']" ", the first professional baseball player to die of ALS .
+
+ The first professional baseball player to die" True honored with the Lou Gehrig Memorial Award, 5 [' honored', ' with', ' the', ' Lou', ' Geh', 'rig']
+225 60 The professional sport played by x -1 The professional sport played by Lou Gehrig baseball Lou Gehrig "[',' ' the' ' first' ' professional' ' baseball' ' player' ' to' ' die'
+ ' of' ' ALS' '.' '\n' '\n' 'The' ' first' ' professional' ' baseball'
+ ' player' ' to' ' die']" ", the first professional baseball player to die of ALS .
+
+ The first professional baseball player to die" True to third base. Lou Gehrig drew a walk 6 [' to', ' third', ' base', '.', ' Lou', ' Geh', 'rig']
+226 61 The professional sport played by x -1 The professional sport played by Grant Hill basketball Grant Hill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' NBA' ""'s"" ' all']" , who was a member of the NBA 's all - time great s , and the NBA 's all False ball, freshman Grant Hill split time 4 [' ball', ',', ' freshman', ' Grant', ' Hill']
+227 61 The professional sport played by x -1 The professional sport played by Grant Hill basketball Grant Hill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' NBA' ""'s"" ' all']" , who was a member of the NBA 's all - time great s , and the NBA 's all False the ball, freshman Grant Hill split time with starter 5 [' the', ' ball', ',', ' freshman', ' Grant', ' Hill']
+228 61 The professional sport played by x -1 The professional sport played by Grant Hill basketball Grant Hill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' NBA' ""'s"" ' all']" , who was a member of the NBA 's all - time great s , and the NBA 's all False ball, freshman Grant Hill split time with 4 [' ball', ',', ' freshman', ' Grant', ' Hill']
+229 61 The professional sport played by x -1 The professional sport played by Grant Hill basketball Grant Hill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' NBA' ""'s"" ' all']" , who was a member of the NBA 's all - time great s , and the NBA 's all False Holy Cross and Grant Hill had eight for 4 [' Holy', ' Cross', ' and', ' Grant', ' Hill']
+230 61 The professional sport played by x -1 The professional sport played by Grant Hill basketball Grant Hill "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' great' 's' ',' ' and' ' the' ' NBA' ""'s"" ' all']" , who was a member of the NBA 's all - time great s , and the NBA 's all False the ball, freshman Grant Hill split time 5 [' the', ' ball', ',', ' freshman', ' Grant', ' Hill']
+231 62 The professional sport played by x -1 The professional sport played by O. J. Simpson football O. J. Simpson "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' found' ' not' ' guilty'
+ ' of' ' murder' ' in' ' the' ' 1994' ' murders' ' of' ' his' ' ex' '-']" , the former NFL star , was found not guilty of murder in the 1994 murders of his ex - False spotlight when he defended O. J. Simpson for the murders of 8 [' spotlight', ' when', ' he', ' defended', ' O', '.', ' J', '.', ' Simpson']
+232 62 The professional sport played by x -1 The professional sport played by O. J. Simpson football O. J. Simpson "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' found' ' not' ' guilty'
+ ' of' ' murder' ' in' ' the' ' 1994' ' murders' ' of' ' his' ' ex' '-']" , the former NFL star , was found not guilty of murder in the 1994 murders of his ex - False spotlight when he defended O. J. Simpson for the murders of 8 [' spotlight', ' when', ' he', ' defended', ' O', '.', ' J', '.', ' Simpson']
+233 62 The professional sport played by x -1 The professional sport played by O. J. Simpson football O. J. Simpson "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' found' ' not' ' guilty'
+ ' of' ' murder' ' in' ' the' ' 1994' ' murders' ' of' ' his' ' ex' '-']" , the former NFL star , was found not guilty of murder in the 1994 murders of his ex - False referencing his anger after O. J. Simpson was acquitted 8 [' referencing', ' his', ' anger', ' after', ' O', '.', ' J', '.', ' Simpson']
+234 62 The professional sport played by x -1 The professional sport played by O. J. Simpson football O. J. Simpson "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' found' ' not' ' guilty'
+ ' of' ' murder' ' in' ' the' ' 1994' ' murders' ' of' ' his' ' ex' '-']" , the former NFL star , was found not guilty of murder in the 1994 murders of his ex - False he defended O. J. Simpson for the murders of 6 [' he', ' defended', ' O', '.', ' J', '.', ' Simpson']
+235 62 The professional sport played by x -1 The professional sport played by O. J. Simpson football O. J. Simpson "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' found' ' not' ' guilty'
+ ' of' ' murder' ' in' ' the' ' 1994' ' murders' ' of' ' his' ' ex' '-']" , the former NFL star , was found not guilty of murder in the 1994 murders of his ex - False attorneys at the O. J. Simpson murder case. 7 [' attorneys', ' at', ' the', ' O', '.', ' J', '.', ' Simpson']
+236 63 The professional sport played by x -1 The professional sport played by Bill Bradley basketball Bill Bradley "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NBA star , is a great example of the kind of person who can make a difference False (write-in) - 1,025 10 [' (', 'write', '-', 'in', ')', ' -', ' 1', ',', '02', 'Bill', ' Bradley']
+237 63 The professional sport played by x -1 The professional sport played by Bill Bradley basketball Bill Bradley "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NBA star , is a great example of the kind of person who can make a difference False well-known Senator Bill Bradley due to Democratic 5 [' well', '-', 'known', ' Senator', ' Bill', ' Bradley']
+238 63 The professional sport played by x -1 The professional sport played by Bill Bradley basketball Bill Bradley "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NBA star , is a great example of the kind of person who can make a difference False Jersey Senator Bill Bradley published a volume 3 [' Jersey', ' Senator', ' Bill', ' Bradley']
+239 63 The professional sport played by x -1 The professional sport played by Bill Bradley basketball Bill Bradley "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NBA star , is a great example of the kind of person who can make a difference False Princeton athletes, Bill Bradley and Dick Kazmaier. 4 [' Princeton', ' athletes', ',', ' Bill', ' Bradley']
+240 63 The professional sport played by x -1 The professional sport played by Bill Bradley basketball Bill Bradley "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NBA star , is a great example of the kind of person who can make a difference False third baseman Bill Bradley which Bradley bobbled 3 [' third', ' baseman', ' Bill', ' Bradley']
+241 64 The professional sport played by x -1 The professional sport played by Michael Jordan basketball Michael Jordan "[',' ' the' ' NBA' ',' ' and' ' the' ' NBA' '.' '\n' '\n' 'The' ' NBA'
+ ' is' ' a' ' league' ' that' ' has' ' been' ' around' ' for']" ", the NBA , and the NBA .
+
+ The NBA is a league that has been around for" False the exposure of Michael Jordan; the marketing of 4 [' the', ' exposure', ' of', ' Michael', ' Jordan']
+242 64 The professional sport played by x -1 The professional sport played by Michael Jordan basketball Michael Jordan "[',' ' the' ' NBA' ',' ' and' ' the' ' NBA' '.' '\n' '\n' 'The' ' NBA'
+ ' is' ' a' ' league' ' that' ' has' ' been' ' around' ' for']" ", the NBA , and the NBA .
+
+ The NBA is a league that has been around for" False Bulls during the Michael Jordan and Phil Jackson 4 [' Bulls', ' during', ' the', ' Michael', ' Jordan']
+243 64 The professional sport played by x -1 The professional sport played by Michael Jordan basketball Michael Jordan "[',' ' the' ' NBA' ',' ' and' ' the' ' NBA' '.' '\n' '\n' 'The' ' NBA'
+ ' is' ' a' ' league' ' that' ' has' ' been' ' around' ' for']" ", the NBA , and the NBA .
+
+ The NBA is a league that has been around for" False for Keeps: Michael Jordan and the World 5 [' for', ' Keep', 's', ':', ' Michael', ' Jordan']
+244 64 The professional sport played by x -1 The professional sport played by Michael Jordan basketball Michael Jordan "[',' ' the' ' NBA' ',' ' and' ' the' ' NBA' '.' '\n' '\n' 'The' ' NBA'
+ ' is' ' a' ' league' ' that' ' has' ' been' ' around' ' for']" ", the NBA , and the NBA .
+
+ The NBA is a league that has been around for" False previously held by Michael Jordan for most All-Star 4 [' previously', ' held', ' by', ' Michael', ' Jordan']
+245 64 The professional sport played by x -1 The professional sport played by Michael Jordan basketball Michael Jordan "[',' ' the' ' NBA' ',' ' and' ' the' ' NBA' '.' '\n' '\n' 'The' ' NBA'
+ ' is' ' a' ' league' ' that' ' has' ' been' ' around' ' for']" ", the NBA , and the NBA .
+
+ The NBA is a league that has been around for" False Olympics. Michael Jordan was an unofficial 3 [' Olympics', '.', ' Michael', ' Jordan']
+246 66 The professional sport played by x -1 The professional sport played by Torsten Frings soccer Torsten Frings "[',' ' who' ' is' ' a' ' former' ' professional' ' footballer' ' and' ' a'
+ ' former' ' German' ' international' '.' '\n' '\n' 'The' ' game' ' is'
+ ' played' ' on']" ", who is a former professional footballer and a former German international .
+
+ The game is played on" False players, signing Torsten Frings and Danny Koevermans 6 [' players', ',', ' signing', ' Tor', 'sten', ' Fr', 'ings']
+247 67 The professional sport played by x -1 The professional sport played by Gonzalo Higuaín soccer Gonzalo Higuaín "[',' ' who' ' has' ' been' ' on' ' the' ' move' ' since' ' he' ' was' ' a'
+ ' teenager' ',' ' has' ' been' ' a' ' revelation' ' for' ' the' ' club']" , who has been on the move since he was a teenager , has been a revelation for the club False Pellegrini preferred Gonzalo Higuaín in the lead 9 [' Pel', 'leg', 'r', 'ini', ' preferred', ' Gonz', 'alo', ' Hig', 'ua', 'ín']
+248 67 The professional sport played by x -1 The professional sport played by Gonzalo Higuaín soccer Gonzalo Higuaín "[',' ' who' ' has' ' been' ' on' ' the' ' move' ' since' ' he' ' was' ' a'
+ ' teenager' ',' ' has' ' been' ' a' ' revelation' ' for' ' the' ' club']" , who has been on the move since he was a teenager , has been a revelation for the club False Karim Benzema and Gonzalo Higuaín who scored 89 goals 9 [' Kar', 'im', ' Benz', 'ema', ' and', ' Gonz', 'alo', ' Hig', 'ua', 'ín']
+249 67 The professional sport played by x -1 The professional sport played by Gonzalo Higuaín soccer Gonzalo Higuaín "[',' ' who' ' has' ' been' ' on' ' the' ' move' ' since' ' he' ' was' ' a'
+ ' teenager' ',' ' has' ' been' ' a' ' revelation' ' for' ' the' ' club']" , who has been on the move since he was a teenager , has been a revelation for the club False 2016, Juventus signing Gonzalo Higuaín became the third 8 [' 2016', ',', ' Juventus', ' signing', ' Gonz', 'alo', ' Hig', 'ua', 'ín']
+250 67 The professional sport played by x -1 The professional sport played by Gonzalo Higuaín soccer Gonzalo Higuaín "[',' ' who' ' has' ' been' ' on' ' the' ' move' ' since' ' he' ' was' ' a'
+ ' teenager' ',' ' has' ' been' ' a' ' revelation' ' for' ' the' ' club']" , who has been on the move since he was a teenager , has been a revelation for the club False Pellegrini preferred Gonzalo Higuaín in the lead 9 [' Pel', 'leg', 'r', 'ini', ' preferred', ' Gonz', 'alo', ' Hig', 'ua', 'ín']
+251 67 The professional sport played by x -1 The professional sport played by Gonzalo Higuaín soccer Gonzalo Higuaín "[',' ' who' ' has' ' been' ' on' ' the' ' move' ' since' ' he' ' was' ' a'
+ ' teenager' ',' ' has' ' been' ' a' ' revelation' ' for' ' the' ' club']" , who has been on the move since he was a teenager , has been a revelation for the club False preferred Gonzalo Higuaín in the lead striker 5 [' preferred', ' Gonz', 'alo', ' Hig', 'ua', 'ín']
+252 68 The professional sport played by x -1 The professional sport played by Patrick Roy hockey Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' and' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins']" , who was a member of the Montreal Canadiens , and the NHL 's all - time leader in wins False Bell Centre to tie Patrick Roy for the most in 5 [' Bell', ' Centre', ' to', ' tie', ' Patrick', ' Roy']
+253 68 The professional sport played by x -1 The professional sport played by Patrick Roy hockey Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' and' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins']" , who was a member of the Montreal Canadiens , and the NHL 's all - time leader in wins False rookie goaltender Patrick Roy led Montreal 3 [' rookie', ' goaltender', ' Patrick', ' Roy']
+254 68 The professional sport played by x -1 The professional sport played by Patrick Roy hockey Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' and' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins']" , who was a member of the Montreal Canadiens , and the NHL 's all - time leader in wins False " rookie goaltender Patrick Roy in five games.
+" 3 [' rookie', ' goaltender', ' Patrick', ' Roy']
+255 68 The professional sport played by x -1 The professional sport played by Patrick Roy hockey Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' and' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins']" , who was a member of the Montreal Canadiens , and the NHL 's all - time leader in wins False Montreal goaltender Patrick Roy to give the Flames 3 [' Montreal', ' goaltender', ' Patrick', ' Roy']
+256 68 The professional sport played by x -1 The professional sport played by Patrick Roy hockey Patrick Roy "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens'
+ ',' ' and' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins']" , who was a member of the Montreal Canadiens , and the NHL 's all - time leader in wins False against goaltender Patrick Roy on the powerplay. 3 [' against', ' goaltender', ' Patrick', ' Roy']
+257 69 The professional sport played by x -1 The professional sport played by Alex Rodriguez baseball Alex Rodriguez "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' MVP' ' of'
+ ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' Yankees' ' are']" ", who is a former All - Star and MVP of the World Series .
+
+ The Yankees are" False million fans. Alex Rodriguez was the leading vote-getter 4 [' million', ' fans', '.', ' Alex', ' Rodriguez']
+258 69 The professional sport played by x -1 The professional sport played by Alex Rodriguez baseball Alex Rodriguez "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' MVP' ' of'
+ ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' Yankees' ' are']" ", who is a former All - Star and MVP of the World Series .
+
+ The Yankees are" False million fans. Alex Rodriguez was the leading 4 [' million', ' fans', '.', ' Alex', ' Rodriguez']
+259 69 The professional sport played by x -1 The professional sport played by Alex Rodriguez baseball Alex Rodriguez "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' MVP' ' of'
+ ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' Yankees' ' are']" ", who is a former All - Star and MVP of the World Series .
+
+ The Yankees are" False shortstops, along with Alex Rodriguez and Nomar Garciaparra, 7 [' short', 'st', 'ops', ',', ' along', ' with', ' Alex', ' Rodriguez']
+260 69 The professional sport played by x -1 The professional sport played by Alex Rodriguez baseball Alex Rodriguez "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' MVP' ' of'
+ ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' Yankees' ' are']" ", who is a former All - Star and MVP of the World Series .
+
+ The Yankees are" False (the Yankees' Alex Rodriguez earned the highest, 5 "[' (', 'the', ' Yankees', ""'"", ' Alex', ' Rodriguez']"
+261 69 The professional sport played by x -1 The professional sport played by Alex Rodriguez baseball Alex Rodriguez "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' MVP' ' of'
+ ' the' ' World' ' Series' '.' '\n' '\n' 'The' ' Yankees' ' are']" ", who is a former All - Star and MVP of the World Series .
+
+ The Yankees are" False and tied with Alex Rodriguez for the MLB lead 4 [' and', ' tied', ' with', ' Alex', ' Rodriguez']
+262 71 The professional sport played by x -1 The professional sport played by Bernie Casey football Bernie Casey "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True 1 ['Bernie', ' Casey']
+263 71 The professional sport played by x -1 The professional sport played by Bernie Casey football Bernie Casey "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True 1 ['Bernie', ' Casey']
+264 71 The professional sport played by x -1 The professional sport played by Bernie Casey football Bernie Casey "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True Afro-American actor Bernie Casey in one of the non-Eon 6 [' Af', 'ro', '-', 'American', ' actor', ' Bernie', ' Casey']
+265 71 The professional sport played by x -1 The professional sport played by Bernie Casey football Bernie Casey "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True portray Leiter was Bernie Casey in Never Say Never 5 [' portray', ' Le', 'iter', ' was', ' Bernie', ' Casey']
+266 71 The professional sport played by x -1 The professional sport played by Bernie Casey football Bernie Casey "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True portray Leiter was Bernie Casey in Never Say 5 [' portray', ' Le', 'iter', ' was', ' Bernie', ' Casey']
+267 72 The professional sport played by x -1 The professional sport played by Lothar Matthäus soccer Lothar Matthäus "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ' from' ' 1974' ' to' ' 1986' '.' '\n' '\n' 'The' ' German'
+ ' national']" ", who was a member of the German national team from 1974 to 1986 .
+
+ The German national" False matches, with only Lothar Matthäus remaining with 9 [' matches', ',', ' with', ' only', ' L', 'oth', 'ar', ' Matth', 'ä', 'us']
+268 72 The professional sport played by x -1 The professional sport played by Lothar Matthäus soccer Lothar Matthäus "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ' from' ' 1974' ' to' ' 1986' '.' '\n' '\n' 'The' ' German'
+ ' national']" ", who was a member of the German national team from 1974 to 1986 .
+
+ The German national" False Substituted Bayern legend Lothar Matthäus removed his runner-up 10 [' Subst', 'it', 'uted', ' Bayern', ' legend', ' L', 'oth', 'ar', ' Matth', 'ä', 'us']
+269 72 The professional sport played by x -1 The professional sport played by Lothar Matthäus soccer Lothar Matthäus "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ' from' ' 1974' ' to' ' 1986' '.' '\n' '\n' 'The' ' German'
+ ' national']" ", who was a member of the German national team from 1974 to 1986 .
+
+ The German national" False 1966) and Germany's Lothar Matthäus (1982 – 1998) 10 "[' 1966', ')', ' and', ' Germany', ""'s"", ' L', 'oth', 'ar', ' Matth', 'ä', 'us']"
+270 72 The professional sport played by x -1 The professional sport played by Lothar Matthäus soccer Lothar Matthäus "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ' from' ' 1974' ' to' ' 1986' '.' '\n' '\n' 'The' ' German'
+ ' national']" ", who was a member of the German national team from 1974 to 1986 .
+
+ The German national" False matches, with only Lothar Matthäus remaining with more 9 [' matches', ',', ' with', ' only', ' L', 'oth', 'ar', ' Matth', 'ä', 'us']
+271 72 The professional sport played by x -1 The professional sport played by Lothar Matthäus soccer Lothar Matthäus "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' German' ' national'
+ ' team' ' from' ' 1974' ' to' ' 1986' '.' '\n' '\n' 'The' ' German'
+ ' national']" ", who was a member of the German national team from 1974 to 1986 .
+
+ The German national" False matches, with only Lothar Matthäus remaining with more 9 [' matches', ',', ' with', ' only', ' L', 'oth', 'ar', ' Matth', 'ä', 'us']
+272 73 The professional sport played by x -1 The professional sport played by Drew Brees football Drew Brees "[',' ' the' ' Saints' ""'"" ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the' ' Saints'
+ ' should' ' be']" , the Saints ' quarterback , is a great example of the type of player that the Saints should be False for the Bears, as Drew Brees was 0 – 4 in 6 [' for', ' the', ' Bears', ',', ' as', ' Drew', ' Brees']
+273 73 The professional sport played by x -1 The professional sport played by Drew Brees football Drew Brees "[',' ' the' ' Saints' ""'"" ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the' ' Saints'
+ ' should' ' be']" , the Saints ' quarterback , is a great example of the type of player that the Saints should be False " completions: 4 (surpassed by Drew Brees in 2011)
+" 10 [' comple', 'tions', ':', ' 4', ' (', 'sur', 'pass', 'ed', ' by', ' Drew', ' Brees']
+274 73 The professional sport played by x -1 The professional sport played by Drew Brees football Drew Brees "[',' ' the' ' Saints' ""'"" ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the' ' Saints'
+ ' should' ' be']" , the Saints ' quarterback , is a great example of the type of player that the Saints should be False the Bears, as Drew Brees was 0 – 4 in Chicago, 5 [' the', ' Bears', ',', ' as', ' Drew', ' Brees']
+275 73 The professional sport played by x -1 The professional sport played by Drew Brees football Drew Brees "[',' ' the' ' Saints' ""'"" ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the' ' Saints'
+ ' should' ' be']" , the Saints ' quarterback , is a great example of the type of player that the Saints should be False for the Bears, as Drew Brees was 0 – 4 in 6 [' for', ' the', ' Bears', ',', ' as', ' Drew', ' Brees']
+276 73 The professional sport played by x -1 The professional sport played by Drew Brees football Drew Brees "[',' ' the' ' Saints' ""'"" ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' type' ' of' ' player' ' that' ' the' ' Saints'
+ ' should' ' be']" , the Saints ' quarterback , is a great example of the type of player that the Saints should be False for the Bears, as Drew Brees was 0 – 4 6 [' for', ' the', ' Bears', ',', ' as', ' Drew', ' Brees']
+277 74 The professional sport played by x -1 The professional sport played by Howie Morenz hockey Howie Morenz "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NHL' ' should'
+ ' be']" , a former NHL player , is a great example of the kind of player that the NHL should be False " Morenz ===
+" 6 [' More', 'nz', ' ===', 'How', 'ie', ' More', 'nz']
+278 74 The professional sport played by x -1 The professional sport played by Howie Morenz hockey Howie Morenz "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NHL' ' should'
+ ' be']" , a former NHL player , is a great example of the kind of player that the NHL should be False vote behind Howie Morenz for the Hart Trophy 5 [' vote', ' behind', ' How', 'ie', ' More', 'nz']
+279 74 The professional sport played by x -1 The professional sport played by Howie Morenz hockey Howie Morenz "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NHL' ' should'
+ ' be']" , a former NHL player , is a great example of the kind of player that the NHL should be False including two Canadiens: Howie Morenz and Georges Vezina. 7 [' including', ' two', ' Canadiens', ':', ' How', 'ie', ' More', 'nz']
+280 74 The professional sport played by x -1 The professional sport played by Howie Morenz hockey Howie Morenz "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NHL' ' should'
+ ' be']" , a former NHL player , is a great example of the kind of player that the NHL should be False All-Stars in the Howie Morenz Memorial Game, 8 [' All', '-', 'Stars', ' in', ' the', ' How', 'ie', ' More', 'nz']
+281 74 The professional sport played by x -1 The professional sport played by Howie Morenz hockey Howie Morenz "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' NHL' ' should'
+ ' be']" , a former NHL player , is a great example of the kind of player that the NHL should be False " Morenz =
+" 6 [' More', 'nz', ' =', 'How', 'ie', ' More', 'nz']
+282 75 The professional sport played by x -1 The professional sport played by Michael Bradley soccer Michael Bradley "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' ','
+ ' and' ' his' ' wife' ',' ' who' ' is' ' a' ' former' ' professional'
+ ' soccer' ' player']" , who is a former professional soccer player , and his wife , who is a former professional soccer player True New England and Michael Bradley who returned from Italy 4 [' New', ' England', ' and', ' Michael', ' Bradley']
+283 75 The professional sport played by x -1 The professional sport played by Michael Bradley soccer Michael Bradley "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' ','
+ ' and' ' his' ' wife' ',' ' who' ' is' ' a' ' former' ' professional'
+ ' soccer' ' player']" , who is a former professional soccer player , and his wife , who is a former professional soccer player True States international Michael Bradley of A.S. Roma, 3 [' States', ' international', ' Michael', ' Bradley']
+284 75 The professional sport played by x -1 The professional sport played by Michael Bradley soccer Michael Bradley "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' ','
+ ' and' ' his' ' wife' ',' ' who' ' is' ' a' ' former' ' professional'
+ ' soccer' ' player']" , who is a former professional soccer player , and his wife , who is a former professional soccer player True 1 ['Michael', ' Bradley']
+285 75 The professional sport played by x -1 The professional sport played by Michael Bradley soccer Michael Bradley "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' ','
+ ' and' ' his' ' wife' ',' ' who' ' is' ' a' ' former' ' professional'
+ ' soccer' ' player']" , who is a former professional soccer player , and his wife , who is a former professional soccer player True 1 ['Michael', ' Bradley']
+286 75 The professional sport played by x -1 The professional sport played by Michael Bradley soccer Michael Bradley "[',' ' who' ' is' ' a' ' former' ' professional' ' soccer' ' player' ','
+ ' and' ' his' ' wife' ',' ' who' ' is' ' a' ' former' ' professional'
+ ' soccer' ' player']" , who is a former professional soccer player , and his wife , who is a former professional soccer player True 1 ['Michael', ' Bradley']
+287 76 The professional sport played by x -1 The professional sport played by Warren Spahn baseball Warren Spahn "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Boston' ' Braves' '.'
+ '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the' ' old' ' Yankee']" ", who was a pitcher for the Boston Braves .
+
+ The game was played in the old Yankee" False with 266 wins until Warren Spahn surpassed his total 6 [' with', ' 266', ' wins', ' until', ' Warren', ' Sp', 'ahn']
+288 76 The professional sport played by x -1 The professional sport played by Warren Spahn baseball Warren Spahn "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Boston' ' Braves' '.'
+ '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the' ' old' ' Yankee']" ", who was a pitcher for the Boston Braves .
+
+ The game was played in the old Yankee" False Braves pitchers Warren Spahn and Lew Burdette 4 [' Braves', ' pitchers', ' Warren', ' Sp', 'ahn']
+289 76 The professional sport played by x -1 The professional sport played by Warren Spahn baseball Warren Spahn "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Boston' ' Braves' '.'
+ '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the' ' old' ' Yankee']" ", who was a pitcher for the Boston Braves .
+
+ The game was played in the old Yankee" False retirement, trailing only Warren Spahn (2,583) among 6 [' retirement', ',', ' trailing', ' only', ' Warren', ' Sp', 'ahn']
+290 76 The professional sport played by x -1 The professional sport played by Warren Spahn baseball Warren Spahn "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Boston' ' Braves' '.'
+ '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the' ' old' ' Yankee']" ", who was a pitcher for the Boston Braves .
+
+ The game was played in the old Yankee" False for the NL lead with Warren Spahn and the MLB lead 7 [' for', ' the', ' NL', ' lead', ' with', ' Warren', ' Sp', 'ahn']
+291 76 The professional sport played by x -1 The professional sport played by Warren Spahn baseball Warren Spahn "[',' ' who' ' was' ' a' ' pitcher' ' for' ' the' ' Boston' ' Braves' '.'
+ '\n' '\n' 'The' ' game' ' was' ' played' ' in' ' the' ' old' ' Yankee']" ", who was a pitcher for the Boston Braves .
+
+ The game was played in the old Yankee" False wins until Warren Spahn surpassed his total 4 [' wins', ' until', ' Warren', ' Sp', 'ahn']
+292 77 The professional sport played by x -1 The professional sport played by Oscar Robertson basketball Oscar Robertson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' 50' 'th' ' Anniversary']" , the NBA 's all - time leading scorer , was a member of the NBA 's 50 th Anniversary False were named to the Oscar Robertson Trophy (USBWA National 5 [' were', ' named', ' to', ' the', ' Oscar', ' Robertson']
+293 77 The professional sport played by x -1 The professional sport played by Oscar Robertson basketball Oscar Robertson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' 50' 'th' ' Anniversary']" , the NBA 's all - time leading scorer , was a member of the NBA 's 50 th Anniversary False named the winner of the Oscar Robertson Trophy by the 6 [' named', ' the', ' winner', ' of', ' the', ' Oscar', ' Robertson']
+294 77 The professional sport played by x -1 The professional sport played by Oscar Robertson basketball Oscar Robertson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' 50' 'th' ' Anniversary']" , the NBA 's all - time leading scorer , was a member of the NBA 's 50 th Anniversary False January 27 and Oscar Robertson National Player 4 [' January', ' 27', ' and', ' Oscar', ' Robertson']
+295 77 The professional sport played by x -1 The professional sport played by Oscar Robertson basketball Oscar Robertson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' 50' 'th' ' Anniversary']" , the NBA 's all - time leading scorer , was a member of the NBA 's 50 th Anniversary False listed on the Oscar Robertson Award preseason 4 [' listed', ' on', ' the', ' Oscar', ' Robertson']
+296 77 The professional sport played by x -1 The professional sport played by Oscar Robertson basketball Oscar Robertson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' 50' 'th' ' Anniversary']" , the NBA 's all - time leading scorer , was a member of the NBA 's 50 th Anniversary False and the Top 15 Oscar Robertson Trophy candidates. 5 [' and', ' the', ' Top', ' 15', ' Oscar', ' Robertson']
+297 79 The professional sport played by x -1 The professional sport played by Jermain Defoe soccer Jermain Defoe "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False international Jermain Defoe for a reported fee 4 [' international', ' Jer', 'main', ' Def', 'oe']
+298 79 The professional sport played by x -1 The professional sport played by Jermain Defoe soccer Jermain Defoe "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False " Drogba and Jermain Defoe ""because of their" 7 [' D', 'rog', 'ba', ' and', ' Jer', 'main', ' Def', 'oe']
+299 79 The professional sport played by x -1 The professional sport played by Jermain Defoe soccer Jermain Defoe "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False lead as a Jermain Defoe free kick 6 [' lead', ' as', ' a', ' Jer', 'main', ' Def', 'oe']
+300 79 The professional sport played by x -1 The professional sport played by Jermain Defoe soccer Jermain Defoe "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False " Didier Drogba and Jermain Defoe ""because of their" 9 [' Did', 'ier', ' D', 'rog', 'ba', ' and', ' Jer', 'main', ' Def', 'oe']
+301 79 The professional sport played by x -1 The professional sport played by Jermain Defoe soccer Jermain Defoe "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False " Didier Drogba and Jermain Defoe ""because of" 9 [' Did', 'ier', ' D', 'rog', 'ba', ' and', ' Jer', 'main', ' Def', 'oe']
+302 80 The professional sport played by x -1 The professional sport played by Ken Dryden hockey Ken Dryden "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins' ','
+ ' shut' 'outs' ',' ' and' ' games' ' played' '.' '\n' '\n']" ", the NHL 's all - time leader in wins , shut outs , and games played .
+
+" False " breathtaking ""by Ken Dryden of Allmusic. The first" 5 "[' breathtaking', ' ""', 'by', ' Ken', ' Dry', 'den']"
+303 80 The professional sport played by x -1 The professional sport played by Ken Dryden hockey Ken Dryden "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins' ','
+ ' shut' 'outs' ',' ' and' ' games' ' played' '.' '\n' '\n']" ", the NHL 's all - time leader in wins , shut outs , and games played .
+
+" False " as"" breathtaking ""by Ken Dryden of Allmusic. The" 7 "[' as', '""', ' breathtaking', ' ""', 'by', ' Ken', ' Dry', 'den']"
+304 80 The professional sport played by x -1 The professional sport played by Ken Dryden hockey Ken Dryden "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins' ','
+ ' shut' 'outs' ',' ' and' ' games' ' played' '.' '\n' '\n']" ", the NHL 's all - time leader in wins , shut outs , and games played .
+
+" False " breathtaking ""by Ken Dryden of Allmusic. The" 5 "[' breathtaking', ' ""', 'by', ' Ken', ' Dry', 'den']"
+305 80 The professional sport played by x -1 The professional sport played by Ken Dryden hockey Ken Dryden "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins' ','
+ ' shut' 'outs' ',' ' and' ' games' ' played' '.' '\n' '\n']" ", the NHL 's all - time leader in wins , shut outs , and games played .
+
+" False Canadian players Ken Dryden and Brad Park turned 4 [' Canadian', ' players', ' Ken', ' Dry', 'den']
+306 80 The professional sport played by x -1 The professional sport played by Ken Dryden hockey Ken Dryden "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' wins' ','
+ ' shut' 'outs' ',' ' and' ' games' ' played' '.' '\n' '\n']" ", the NHL 's all - time leader in wins , shut outs , and games played .
+
+" False been largely positive. Ken Dryden of Allmusic called 6 [' been', ' largely', ' positive', '.', ' Ken', ' Dry', 'den']
+307 81 The professional sport played by x -1 The professional sport played by Woody Strode football Woody Strode "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Kenny Washington and Woody Strode joined the Los Angeles 5 [' Kenny', ' Washington', ' and', ' Woody', ' Stro', 'de']
+308 81 The professional sport played by x -1 The professional sport played by Woody Strode football Woody Strode "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Charlie 5 [' Char', 'li', 'Wood', 'y', ' Stro', 'de']
+309 81 The professional sport played by x -1 The professional sport played by Woody Strode football Woody Strode "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Washington and Woody Strode joined the Los Angeles 4 [' Washington', ' and', ' Woody', ' Stro', 'de']
+310 81 The professional sport played by x -1 The professional sport played by Woody Strode football Woody Strode "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Washington and Woody Strode joined the Los 4 [' Washington', ' and', ' Woody', ' Stro', 'de']
+311 81 The professional sport played by x -1 The professional sport played by Woody Strode football Woody Strode "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Washington and Woody Strode joined the Los Angeles 4 [' Washington', ' and', ' Woody', ' Stro', 'de']
+312 82 The professional sport played by x -1 The professional sport played by Russell Wilson football Russell Wilson "[',' ' the' ' Seattle' ' Seahawks' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to'
+ ' make' ' the']" , the Seattle Seahawks quarterback , is a great example of a player who has been able to make the False Robinson. Russell Wilson who currently 3 [' Robinson', '.', ' Russell', ' Wilson']
+313 82 The professional sport played by x -1 The professional sport played by Russell Wilson football Russell Wilson "[',' ' the' ' Seattle' ' Seahawks' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to'
+ ' make' ' the']" , the Seattle Seahawks quarterback , is a great example of a player who has been able to make the False " passes: 26 (tied by Russell Wilson in 2012)
+" 8 [' passes', ':', ' 26', ' (', 't', 'ied', ' by', ' Russell', ' Wilson']
+314 82 The professional sport played by x -1 The professional sport played by Russell Wilson football Russell Wilson "[',' ' the' ' Seattle' ' Seahawks' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to'
+ ' make' ' the']" , the Seattle Seahawks quarterback , is a great example of a player who has been able to make the False " Wilson ===
+" 3 [' Wilson', ' ===', 'Russell', ' Wilson']
+315 82 The professional sport played by x -1 The professional sport played by Russell Wilson football Russell Wilson "[',' ' the' ' Seattle' ' Seahawks' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to'
+ ' make' ' the']" , the Seattle Seahawks quarterback , is a great example of a player who has been able to make the False " Wilson ===
+" 3 [' Wilson', ' ===', 'Russell', ' Wilson']
+316 82 The professional sport played by x -1 The professional sport played by Russell Wilson football Russell Wilson "[',' ' the' ' Seattle' ' Seahawks' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to'
+ ' make' ' the']" , the Seattle Seahawks quarterback , is a great example of a player who has been able to make the False equalized once more with a Russell Wilson pass, which 7 [' equal', 'ized', ' once', ' more', ' with', ' a', ' Russell', ' Wilson']
+317 83 The professional sport played by x -1 The professional sport played by Eric Lindros hockey Eric Lindros "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Philadelphia' ' Flyers'
+ ',' ' and' ' the' ' NHL' ""'s"" ' first' ' superstar' '.' '\n' '\n' 'The']" ", who was a member of the Philadelphia Flyers , and the NHL 's first superstar .
+
+ The" False " NHL, including on the Eric Lindros trade.
+" 7 [' NHL', ',', ' including', ' on', ' the', ' Eric', ' Lind', 'ros']
+318 83 The professional sport played by x -1 The professional sport played by Eric Lindros hockey Eric Lindros "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Philadelphia' ' Flyers'
+ ',' ' and' ' the' ' NHL' ""'s"" ' first' ' superstar' '.' '\n' '\n' 'The']" ", who was a member of the Philadelphia Flyers , and the NHL 's first superstar .
+
+ The" False tournament since Eric Lindros participated in 4 [' tournament', ' since', ' Eric', ' Lind', 'ros']
+319 83 The professional sport played by x -1 The professional sport played by Eric Lindros hockey Eric Lindros "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Philadelphia' ' Flyers'
+ ',' ' and' ' the' ' NHL' ""'s"" ' first' ' superstar' '.' '\n' '\n' 'The']" ", who was a member of the Philadelphia Flyers , and the NHL 's first superstar .
+
+ The" False trade which brought Eric Lindros to the Philadelphia 5 [' trade', ' which', ' brought', ' Eric', ' Lind', 'ros']
+320 83 The professional sport played by x -1 The professional sport played by Eric Lindros hockey Eric Lindros "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Philadelphia' ' Flyers'
+ ',' ' and' ' the' ' NHL' ""'s"" ' first' ' superstar' '.' '\n' '\n' 'The']" ", who was a member of the Philadelphia Flyers , and the NHL 's first superstar .
+
+ The" False star centre Eric Lindros was skating alone 4 [' star', ' centre', ' Eric', ' Lind', 'ros']
+321 83 The professional sport played by x -1 The professional sport played by Eric Lindros hockey Eric Lindros "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Philadelphia' ' Flyers'
+ ',' ' and' ' the' ' NHL' ""'s"" ' first' ' superstar' '.' '\n' '\n' 'The']" ", who was a member of the Philadelphia Flyers , and the NHL 's first superstar .
+
+ The" False firm in the Eric Lindros holdout issue. 5 [' firm', ' in', ' the', ' Eric', ' Lind', 'ros']
+322 84 The professional sport played by x -1 The professional sport played by Otto Graham basketball Otto Graham "[',' ' the' ' first' ' African' '-' 'American' ' to' ' win' ' the'
+ ' Heisman' ' Trophy' '.' '\n' '\n' '19' '40' 's' '\n' '\n' '19']" ", the first African - American to win the Heisman Trophy .
+
+ 19 40 s
+
+ 19" False quarterback Otto Graham and fullback Marion 2 [' quarterback', ' Otto', ' Graham']
+323 84 The professional sport played by x -1 The professional sport played by Otto Graham basketball Otto Graham "[',' ' the' ' first' ' African' '-' 'American' ' to' ' win' ' the'
+ ' Heisman' ' Trophy' '.' '\n' '\n' '19' '40' 's' '\n' '\n' '19']" ", the first African - American to win the Heisman Trophy .
+
+ 19 40 s
+
+ 19" False accidentally invented when Otto Graham tripped while dropping 4 [' accidentally', ' invented', ' when', ' Otto', ' Graham']
+324 84 The professional sport played by x -1 The professional sport played by Otto Graham basketball Otto Graham "[',' ' the' ' first' ' African' '-' 'American' ' to' ' win' ' the'
+ ' Heisman' ' Trophy' '.' '\n' '\n' '19' '40' 's' '\n' '\n' '19']" ", the first African - American to win the Heisman Trophy .
+
+ 19 40 s
+
+ 19" False from quarterback Otto Graham to Dean Sensanbaugher 3 [' from', ' quarterback', ' Otto', ' Graham']
+325 84 The professional sport played by x -1 The professional sport played by Otto Graham basketball Otto Graham "[',' ' the' ' first' ' African' '-' 'American' ' to' ' win' ' the'
+ ' Heisman' ' Trophy' '.' '\n' '\n' '19' '40' 's' '\n' '\n' '19']" ", the first African - American to win the Heisman Trophy .
+
+ 19 40 s
+
+ 19" False quarterback Otto Graham had his best 2 [' quarterback', ' Otto', ' Graham']
+326 84 The professional sport played by x -1 The professional sport played by Otto Graham basketball Otto Graham "[',' ' the' ' first' ' African' '-' 'American' ' to' ' win' ' the'
+ ' Heisman' ' Trophy' '.' '\n' '\n' '19' '40' 's' '\n' '\n' '19']" ", the first African - American to win the Heisman Trophy .
+
+ 19 40 s
+
+ 19" False Browns quarterback Otto Graham during the second 3 [' Browns', ' quarterback', ' Otto', ' Graham']
+327 85 The professional sport played by x -1 The professional sport played by John Olerud baseball John Olerud "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . True inning before John Olerud singled to leadoff 5 [' inning', ' before', ' John', ' O', 'ler', 'ud']
+328 85 The professional sport played by x -1 The professional sport played by John Olerud baseball John Olerud "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . True against Welch, scoring John Olerud and Candy Maldonado 7 [' against', ' Welch', ',', ' scoring', ' John', ' O', 'ler', 'ud']
+329 85 The professional sport played by x -1 The professional sport played by John Olerud baseball John Olerud "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . True record set by John Olerud in 1998, by reaching 6 [' record', ' set', ' by', ' John', ' O', 'ler', 'ud']
+330 85 The professional sport played by x -1 The professional sport played by John Olerud baseball John Olerud "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . True Welch, scoring John Olerud and Candy Maldonado 6 [' Welch', ',', ' scoring', ' John', ' O', 'ler', 'ud']
+331 85 The professional sport played by x -1 The professional sport played by John Olerud baseball John Olerud "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . True record set by John Olerud in 1998, by reaching 6 [' record', ' set', ' by', ' John', ' O', 'ler', 'ud']
+332 86 The professional sport played by x -1 The professional sport played by Terry Bradshaw football Terry Bradshaw "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False Peyton Manning, and Terry Bradshaw during their first 6 [' Peyton', ' Manning', ',', ' and', ' Terry', ' Brad', 'shaw']
+333 86 The professional sport played by x -1 The professional sport played by Terry Bradshaw football Terry Bradshaw "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False " Steelers Quarterback Terry Bradshaw recalled, ""[h]" 5 [' Steelers', ' Quarter', 'back', ' Terry', ' Brad', 'shaw']
+334 86 The professional sport played by x -1 The professional sport played by Terry Bradshaw football Terry Bradshaw "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False sack of quarterback Terry Bradshaw added fuel to the heated 5 [' sack', ' of', ' quarterback', ' Terry', ' Brad', 'shaw']
+335 86 The professional sport played by x -1 The professional sport played by Terry Bradshaw football Terry Bradshaw "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False Quarterback Terry Bradshaw recalled, 4 [' Quarter', 'back', ' Terry', ' Brad', 'shaw']
+336 87 The professional sport played by x -1 The professional sport played by Pavel Datsyuk hockey Pavel Datsyuk "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Detroit' ' Red'
+ ' Wings' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1997' '.'
+ ' He']" , who has been a member of the Detroit Red Wings since the team 's inception in 1997 . He False finalist along with Pavel Datsyuk of the Detroit Red 8 [' final', 'ist', ' along', ' with', ' Pavel', ' D', 'ats', 'y', 'uk']
+337 87 The professional sport played by x -1 The professional sport played by Pavel Datsyuk hockey Pavel Datsyuk "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Detroit' ' Red'
+ ' Wings' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1997' '.'
+ ' He']" , who has been a member of the Detroit Red Wings since the team 's inception in 1997 . He False season, opposite Pavel Datsyuk and Jordan Staal 7 [' season', ',', ' opposite', ' Pavel', ' D', 'ats', 'y', 'uk']
+338 87 The professional sport played by x -1 The professional sport played by Pavel Datsyuk hockey Pavel Datsyuk "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Detroit' ' Red'
+ ' Wings' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1997' '.'
+ ' He']" , who has been a member of the Detroit Red Wings since the team 's inception in 1997 . He False eventually awarded to Pavel Datsyuk of the Detroit 7 [' eventually', ' awarded', ' to', ' Pavel', ' D', 'ats', 'y', 'uk']
+339 87 The professional sport played by x -1 The professional sport played by Pavel Datsyuk hockey Pavel Datsyuk "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Detroit' ' Red'
+ ' Wings' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1997' '.'
+ ' He']" , who has been a member of the Detroit Red Wings since the team 's inception in 1997 . He False along with Pavel Datsyuk of the Detroit Red 6 [' along', ' with', ' Pavel', ' D', 'ats', 'y', 'uk']
+340 87 The professional sport played by x -1 The professional sport played by Pavel Datsyuk hockey Pavel Datsyuk "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Detroit' ' Red'
+ ' Wings' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1997' '.'
+ ' He']" , who has been a member of the Detroit Red Wings since the team 's inception in 1997 . He False along with Pavel Datsyuk of the Detroit Red 6 [' along', ' with', ' Pavel', ' D', 'ats', 'y', 'uk']
+341 88 The professional sport played by x -1 The professional sport played by Christy Mathewson baseball Christy Mathewson "[',' ' a' ' former' ' All' '-' 'Star' ' pitcher' ' for' ' the' ' New'
+ ' York' ' Giants' ',' ' was' ' a' ' member' ' of' ' the' ' New' ' York']" , a former All - Star pitcher for the New York Giants , was a member of the New York False batter – 17 in 2014. Christy Mathewson of the New York Giants 9 [' batter', ' –', ' 17', ' in', ' 2014', '.', ' Christy', ' Mat', 'hew', 'son']
+342 88 The professional sport played by x -1 The professional sport played by Christy Mathewson baseball Christy Mathewson "[',' ' a' ' former' ' All' '-' 'Star' ' pitcher' ' for' ' the' ' New'
+ ' York' ' Giants' ',' ' was' ' a' ' member' ' of' ' the' ' New' ' York']" , a former All - Star pitcher for the New York Giants , was a member of the New York False Cubs, Giants pitcher Christy Mathewson reportedly 7 [' Cubs', ',', ' Giants', ' pitcher', ' Christy', ' Mat', 'hew', 'son']
+343 88 The professional sport played by x -1 The professional sport played by Christy Mathewson baseball Christy Mathewson "[',' ' a' ' former' ' All' '-' 'Star' ' pitcher' ' for' ' the' ' New'
+ ' York' ' Giants' ',' ' was' ' a' ' member' ' of' ' the' ' New' ' York']" , a former All - Star pitcher for the New York Giants , was a member of the New York False fellow pitcher Christy Mathewson accounted for 73 5 [' fellow', ' pitcher', ' Christy', ' Mat', 'hew', 'son']
+344 88 The professional sport played by x -1 The professional sport played by Christy Mathewson baseball Christy Mathewson "[',' ' a' ' former' ' All' '-' 'Star' ' pitcher' ' for' ' the' ' New'
+ ' York' ' Giants' ',' ' was' ' a' ' member' ' of' ' the' ' New' ' York']" , a former All - Star pitcher for the New York Giants , was a member of the New York False run against Christy Mathewson of the New 5 [' run', ' against', ' Christy', ' Mat', 'hew', 'son']
+345 88 The professional sport played by x -1 The professional sport played by Christy Mathewson baseball Christy Mathewson "[',' ' a' ' former' ' All' '-' 'Star' ' pitcher' ' for' ' the' ' New'
+ ' York' ' Giants' ',' ' was' ' a' ' member' ' of' ' the' ' New' ' York']" , a former All - Star pitcher for the New York Giants , was a member of the New York False office: one of Christy Mathewson and one of Youngs. 7 [' office', ':', ' one', ' of', ' Christy', ' Mat', 'hew', 'son']
+346 89 The professional sport played by x -1 The professional sport played by Timothy Brown football Timothy Brown "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Charbonneau including Timothy Brown and others, and the 5 [' Char', 'bon', 'neau', ' including', ' Timothy', ' Brown']
+347 89 The professional sport played by x -1 The professional sport played by Timothy Brown football Timothy Brown "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False including Timothy Brown and others, and 2 [' including', ' Timothy', ' Brown']
+348 89 The professional sport played by x -1 The professional sport played by Timothy Brown football Timothy Brown "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Charbonneau including Timothy Brown and others, and 5 [' Char', 'bon', 'neau', ' including', ' Timothy', ' Brown']
+349 89 The professional sport played by x -1 The professional sport played by Timothy Brown football Timothy Brown "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Charbonneau including Timothy Brown and others, 5 [' Char', 'bon', 'neau', ' including', ' Timothy', ' Brown']
+350 89 The professional sport played by x -1 The professional sport played by Timothy Brown football Timothy Brown "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Charbonneau including Timothy Brown and others, and 5 [' Char', 'bon', 'neau', ' including', ' Timothy', ' Brown']
+351 91 The professional sport played by x -1 The professional sport played by Shaquille O'Neal basketball Shaquille O'Neal "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' since' ' the' ' early']" , the NBA 's all - time leading scorer , is a game that has been played since the early False superstars Shaquille O'Neal and Kobe Bryant, 7 "[' superst', 'ars', ' Sha', 'qu', 'ille', ' O', ""'"", 'Neal']"
+352 91 The professional sport played by x -1 The professional sport played by Shaquille O'Neal basketball Shaquille O'Neal "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' since' ' the' ' early']" , the NBA 's all - time leading scorer , is a game that has been played since the early False superstars Shaquille O'Neal and Kobe Bryant, 7 "[' superst', 'ars', ' Sha', 'qu', 'ille', ' O', ""'"", 'Neal']"
+353 91 The professional sport played by x -1 The professional sport played by Shaquille O'Neal basketball Shaquille O'Neal "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' since' ' the' ' early']" , the NBA 's all - time leading scorer , is a game that has been played since the early False Kenny Smith and Shaquille O'Neal recap and comment 8 "[' Kenny', ' Smith', ' and', ' Sha', 'qu', 'ille', ' O', ""'"", 'Neal']"
+354 91 The professional sport played by x -1 The professional sport played by Shaquille O'Neal basketball Shaquille O'Neal "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' since' ' the' ' early']" , the NBA 's all - time leading scorer , is a game that has been played since the early False Bryant and Shaquille O'Neal were named joint-winners 7 "[' Bryant', ' and', ' Sha', 'qu', 'ille', ' O', ""'"", 'Neal']"
+355 91 The professional sport played by x -1 The professional sport played by Shaquille O'Neal basketball Shaquille O'Neal "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' since' ' the' ' early']" , the NBA 's all - time leading scorer , is a game that has been played since the early False West's Kobe Bryant and Shaquille O'Neal were named joint-winners 10 "[' West', ""'s"", ' Kobe', ' Bryant', ' and', ' Sha', 'qu', 'ille', ' O', ""'"", 'Neal']"
+356 94 The professional sport played by x -1 The professional sport played by Dirk Nowitzki basketball Dirk Nowitzki "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False Gasol (Spain), Dirk Nowitzki (Germany) and Tony 8 [' Gas', 'ol', ' (', 'Spain', '),', ' Dirk', ' Now', 'itz', 'ki']
+357 94 The professional sport played by x -1 The professional sport played by Dirk Nowitzki basketball Dirk Nowitzki "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False Frühwirth: Einfach Er – Dirk Nowitzki – Aus Würzburg an die 13 [' Fr', 'ü', 'hw', 'irth', ':', ' E', 'inf', 'ach', ' Er', ' –', ' Dirk', ' Now', 'itz', 'ki']
+358 94 The professional sport played by x -1 The professional sport played by Dirk Nowitzki basketball Dirk Nowitzki "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False the roster by Dirk Nowitzki and in the 6 [' the', ' roster', ' by', ' Dirk', ' Now', 'itz', 'ki']
+359 94 The professional sport played by x -1 The professional sport played by Dirk Nowitzki basketball Dirk Nowitzki "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False established NBA stars Dirk Nowitzki and Peja Stojaković, 6 [' established', ' NBA', ' stars', ' Dirk', ' Now', 'itz', 'ki']
+360 94 The professional sport played by x -1 The professional sport played by Dirk Nowitzki basketball Dirk Nowitzki "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False established NBA stars Dirk Nowitzki and Peja Stojaković, 6 [' established', ' NBA', ' stars', ' Dirk', ' Now', 'itz', 'ki']
+361 96 The professional sport played by x -1 The professional sport played by Tim Horton hockey Tim Horton "[""'s"" ',' ' the' ' Canadian' ' fast' ' food' ' chain' ',' ' is' ' a'
+ ' Canadian' ' institution' '.' ' It' ""'s"" ' a' ' Canadian' ' institution'
+ '.' ' It']" 's , the Canadian fast food chain , is a Canadian institution . It 's a Canadian institution . It False their captain Tim Horton died in an automobile 3 [' their', ' captain', ' Tim', ' Horton']
+362 96 The professional sport played by x -1 The professional sport played by Tim Horton hockey Tim Horton "[""'s"" ',' ' the' ' Canadian' ' fast' ' food' ' chain' ',' ' is' ' a'
+ ' Canadian' ' institution' '.' ' It' ""'s"" ' a' ' Canadian' ' institution'
+ '.' ' It']" 's , the Canadian fast food chain , is a Canadian institution . It 's a Canadian institution . It False their captain Tim Horton died in an automobile 3 [' their', ' captain', ' Tim', ' Horton']
+363 96 The professional sport played by x -1 The professional sport played by Tim Horton hockey Tim Horton "[""'s"" ',' ' the' ' Canadian' ' fast' ' food' ' chain' ',' ' is' ' a'
+ ' Canadian' ' institution' '.' ' It' ""'s"" ' a' ' Canadian' ' institution'
+ '.' ' It']" 's , the Canadian fast food chain , is a Canadian institution . It 's a Canadian institution . It False which their captain Tim Horton died in an automobile 4 [' which', ' their', ' captain', ' Tim', ' Horton']
+364 96 The professional sport played by x -1 The professional sport played by Tim Horton hockey Tim Horton "[""'s"" ',' ' the' ' Canadian' ' fast' ' food' ' chain' ',' ' is' ' a'
+ ' Canadian' ' institution' '.' ' It' ""'s"" ' a' ' Canadian' ' institution'
+ '.' ' It']" 's , the Canadian fast food chain , is a Canadian institution . It 's a Canadian institution . It False 1 ['Tim', ' Horton']
+365 96 The professional sport played by x -1 The professional sport played by Tim Horton hockey Tim Horton "[""'s"" ',' ' the' ' Canadian' ' fast' ' food' ' chain' ',' ' is' ' a'
+ ' Canadian' ' institution' '.' ' It' ""'s"" ' a' ' Canadian' ' institution'
+ '.' ' It']" 's , the Canadian fast food chain , is a Canadian institution . It 's a Canadian institution . It False their captain Tim Horton died in an 3 [' their', ' captain', ' Tim', ' Horton']
+366 97 The professional sport played by x -1 The professional sport played by Phil Esposito hockey Phil Esposito "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Boston' ' Red' ' Sox' ','
+ ' and' ' the' ' New' ' York' ' Yankees' '.' '\n' '\n' 'The']" ", who was a member of the Boston Red Sox , and the New York Yankees .
+
+ The" False redirected by Phil Esposito hit Plante in the 5 [' redirected', ' by', ' Phil', ' Es', 'pos', 'ito']
+367 97 The professional sport played by x -1 The professional sport played by Phil Esposito hockey Phil Esposito "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Boston' ' Red' ' Sox' ','
+ ' and' ' the' ' New' ' York' ' Yankees' '.' '\n' '\n' 'The']" ", who was a member of the Boston Red Sox , and the New York Yankees .
+
+ The" False the first period Phil Esposito scored two goals 6 [' the', ' first', ' period', ' Phil', ' Es', 'pos', 'ito']
+368 97 The professional sport played by x -1 The professional sport played by Phil Esposito hockey Phil Esposito "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Boston' ' Red' ' Sox' ','
+ ' and' ' the' ' New' ' York' ' Yankees' '.' '\n' '\n' 'The']" ", who was a member of the Boston Red Sox , and the New York Yankees .
+
+ The" False Mike Bossy and Phil Esposito had ever scored more 7 [' Mike', ' Boss', 'y', ' and', ' Phil', ' Es', 'pos', 'ito']
+369 97 The professional sport played by x -1 The professional sport played by Phil Esposito hockey Phil Esposito "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Boston' ' Red' ' Sox' ','
+ ' and' ' the' ' New' ' York' ' Yankees' '.' '\n' '\n' 'The']" ", who was a member of the Boston Red Sox , and the New York Yankees .
+
+ The" False " conceded that Phil Esposito would ""be difficult" 5 [' conceded', ' that', ' Phil', ' Es', 'pos', 'ito']
+370 97 The professional sport played by x -1 The professional sport played by Phil Esposito hockey Phil Esposito "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Boston' ' Red' ' Sox' ','
+ ' and' ' the' ' New' ' York' ' Yankees' '.' '\n' '\n' 'The']" ", who was a member of the Boston Red Sox , and the New York Yankees .
+
+ The" False four minutes. Phil Esposito complained that one 6 [' four', ' minutes', '.', ' Phil', ' Es', 'pos', 'ito']
+371 98 The professional sport played by x -1 The professional sport played by Bob Cousy basketball Bob Cousy "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' ' for' ' the' ' first' ' half']" , the NBA 's all - time leading scorer , was a member of the Celtics for the first half False England point guard Bob Cousy in the 1950 5 [' England', ' point', ' guard', ' Bob', ' Cous', 'y']
+372 98 The professional sport played by x -1 The professional sport played by Bob Cousy basketball Bob Cousy "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' ' for' ' the' ' first' ' half']" , the NBA 's all - time leading scorer , was a member of the Celtics for the first half False for the 2010 Bob Cousy Award and the 5 [' for', ' the', ' 2010', ' Bob', ' Cous', 'y']
+373 98 The professional sport played by x -1 The professional sport played by Bob Cousy basketball Bob Cousy "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' ' for' ' the' ' first' ' half']" , the NBA 's all - time leading scorer , was a member of the Celtics for the first half False " = Bob Cousy =
+" 3 [' =', ' Bob', ' Cous', 'y']
+374 98 The professional sport played by x -1 The professional sport played by Bob Cousy basketball Bob Cousy "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' ' for' ' the' ' first' ' half']" , the NBA 's all - time leading scorer , was a member of the Celtics for the first half False finalists for the Bob Cousy Award. He was 5 [' finalists', ' for', ' the', ' Bob', ' Cous', 'y']
+375 98 The professional sport played by x -1 The professional sport played by Bob Cousy basketball Bob Cousy "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' ' for' ' the' ' first' ' half']" , the NBA 's all - time leading scorer , was a member of the Celtics for the first half False Tisdale Award, Bob Cousy Award, UPI 7 [' T', 'isd', 'ale', ' Award', ',', ' Bob', ' Cous', 'y']
+376 100 The professional sport played by x -1 The professional sport played by Guus Hiddink soccer Guus Hiddink "[',' ' the' ' Dutch' 'man' ' who' ' has' ' been' ' in' ' charge' ' of'
+ ' the' ' national' ' team' ' since' ' the' ' summer' ' of' ' 2014' ','
+ ' has']" , the Dutch man who has been in charge of the national team since the summer of 2014 , has False 4 ['Gu', 'us', ' H', 'idd', 'ink']
+377 100 The professional sport played by x -1 The professional sport played by Guus Hiddink soccer Guus Hiddink "[',' ' the' ' Dutch' 'man' ' who' ' has' ' been' ' in' ' charge' ' of'
+ ' the' ' national' ' team' ' since' ' the' ' summer' ' of' ' 2014' ','
+ ' has']" , the Dutch man who has been in charge of the national team since the summer of 2014 , has False interim manager Guus Hiddink insisted that Hazard 6 [' interim', ' manager', ' Gu', 'us', ' H', 'idd', 'ink']
+378 100 The professional sport played by x -1 The professional sport played by Guus Hiddink soccer Guus Hiddink "[',' ' the' ' Dutch' 'man' ' who' ' has' ' been' ' in' ' charge' ' of'
+ ' the' ' national' ' team' ' since' ' the' ' summer' ' of' ' 2014' ','
+ ' has']" , the Dutch man who has been in charge of the national team since the summer of 2014 , has False " organised by the Guus Hiddink Foundation.
+" 7 [' organised', ' by', ' the', ' Gu', 'us', ' H', 'idd', 'ink']
+379 100 The professional sport played by x -1 The professional sport played by Guus Hiddink soccer Guus Hiddink "[',' ' the' ' Dutch' 'man' ' who' ' has' ' been' ' in' ' charge' ' of'
+ ' the' ' national' ' team' ' since' ' the' ' summer' ' of' ' 2014' ','
+ ' has']" , the Dutch man who has been in charge of the national team since the summer of 2014 , has False result interim manager Guus Hiddink insisted that Hazard 7 [' result', ' interim', ' manager', ' Gu', 'us', ' H', 'idd', 'ink']
+380 100 The professional sport played by x -1 The professional sport played by Guus Hiddink soccer Guus Hiddink "[',' ' the' ' Dutch' 'man' ' who' ' has' ' been' ' in' ' charge' ' of'
+ ' the' ' national' ' team' ' since' ' the' ' summer' ' of' ' 2014' ','
+ ' has']" , the Dutch man who has been in charge of the national team since the summer of 2014 , has False succeeding manager Guus Hiddink who had left 6 [' succeeding', ' manager', ' Gu', 'us', ' H', 'idd', 'ink']
+381 101 The professional sport played by x -1 The professional sport played by Connie Mack baseball Connie Mack "[',' ' the' ' former' ' manager' ' of' ' the' ' Philadelphia' ' Athletics'
+ ',' ' was' ' a' ' great' ' baseball' ' player' '.' ' He' ' was' ' a'
+ ' great' ' baseball']" , the former manager of the Philadelphia Athletics , was a great baseball player . He was a great baseball True Amateur Baseball Congress Connie Mack World Series. Both 4 [' Amateur', ' Baseball', ' Congress', ' Connie', ' Mack']
+382 101 The professional sport played by x -1 The professional sport played by Connie Mack baseball Connie Mack "[',' ' the' ' former' ' manager' ' of' ' the' ' Philadelphia' ' Athletics'
+ ',' ' was' ' a' ' great' ' baseball' ' player' '.' ' He' ' was' ' a'
+ ' great' ' baseball']" , the former manager of the Philadelphia Athletics , was a great baseball player . He was a great baseball True " had on my ballcub"". Connie Mack called him a ""magician"".
+" 8 "[' had', ' on', ' my', ' ball', 'c', 'ub', '"".', ' Connie', ' Mack']"
+383 101 The professional sport played by x -1 The professional sport played by Connie Mack baseball Connie Mack "[',' ' the' ' former' ' manager' ' of' ' the' ' Philadelphia' ' Athletics'
+ ',' ' was' ' a' ' great' ' baseball' ' player' '.' ' He' ' was' ' a'
+ ' great' ' baseball']" , the former manager of the Philadelphia Athletics , was a great baseball player . He was a great baseball True League (Philadelphia's Connie Mack Stadium was oldest, 5 "[' League', ' (', 'Philadelphia', ""'s"", ' Connie', ' Mack']"
+384 101 The professional sport played by x -1 The professional sport played by Connie Mack baseball Connie Mack "[',' ' the' ' former' ' manager' ' of' ' the' ' Philadelphia' ' Athletics'
+ ',' ' was' ' a' ' great' ' baseball' ' player' '.' ' He' ' was' ' a'
+ ' great' ' baseball']" , the former manager of the Philadelphia Athletics , was a great baseball player . He was a great baseball True accused Philadelphia's Connie Mack of underhanded 4 "[' accused', ' Philadelphia', ""'s"", ' Connie', ' Mack']"
+385 101 The professional sport played by x -1 The professional sport played by Connie Mack baseball Connie Mack "[',' ' the' ' former' ' manager' ' of' ' the' ' Philadelphia' ' Athletics'
+ ',' ' was' ' a' ' great' ' baseball' ' player' '.' ' He' ' was' ' a'
+ ' great' ' baseball']" , the former manager of the Philadelphia Athletics , was a great baseball player . He was a great baseball True national champions in the Connie Mack World Series, 5 [' national', ' champions', ' in', ' the', ' Connie', ' Mack']
+386 102 The professional sport played by x -1 The professional sport played by Ernie Davis football Ernie Davis "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the']" , who was a member of the New York Yankees , and the first black player to play in the False rights to Ernie Davis, a Heisman Trophy-winning 4 [' rights', ' to', ' Er', 'nie', ' Davis']
+387 102 The professional sport played by x -1 The professional sport played by Ernie Davis football Ernie Davis "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the']" , who was a member of the New York Yankees , and the first black player to play in the False rights to Ernie Davis, a Heisman Trophy-winning 4 [' rights', ' to', ' Er', 'nie', ' Davis']
+388 102 The professional sport played by x -1 The professional sport played by Ernie Davis football Ernie Davis "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the']" , who was a member of the New York Yankees , and the first black player to play in the False running back Ernie Davis who, in 1961, became 4 [' running', ' back', ' Er', 'nie', ' Davis']
+389 102 The professional sport played by x -1 The professional sport played by Ernie Davis football Ernie Davis "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the']" , who was a member of the New York Yankees , and the first black player to play in the False Heisman Trophy winner Ernie Davis of Syracuse, but Davis 5 [' Heisman', ' Trophy', ' winner', ' Er', 'nie', ' Davis']
+390 102 The professional sport played by x -1 The professional sport played by Ernie Davis football Ernie Davis "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' black' ' player' ' to' ' play' ' in' ' the']" , who was a member of the New York Yankees , and the first black player to play in the False Heisman Trophy winner Ernie Davis of Syracuse, but 5 [' Heisman', ' Trophy', ' winner', ' Er', 'nie', ' Davis']
+391 103 The professional sport played by x -1 The professional sport played by Tim Cahill soccer Tim Cahill "[',' ' who' ' has' ' been' ' a' ' professional' ' soccer' ' player' ' for'
+ ' the' ' past' ' 15' ' years' ',' ' and' ' is' ' currently' ' playing'
+ ' for' ' the']" , who has been a professional soccer player for the past 15 years , and is currently playing for the True created by Tim Cahill and Julie McNally Cahill 4 [' created', ' by', ' Tim', ' Cah', 'ill']
+392 103 The professional sport played by x -1 The professional sport played by Tim Cahill soccer Tim Cahill "[',' ' who' ' has' ' been' ' a' ' professional' ' soccer' ' player' ' for'
+ ' the' ' past' ' 15' ' years' ',' ' and' ' is' ' currently' ' playing'
+ ' for' ' the']" , who has been a professional soccer player for the past 15 years , and is currently playing for the True season. I liken him to Tim Cahill of Everton. 9 [' season', '.', ' I', ' lik', 'en', ' him', ' to', ' Tim', ' Cah', 'ill']
+393 103 The professional sport played by x -1 The professional sport played by Tim Cahill soccer Tim Cahill "[',' ' who' ' has' ' been' ' a' ' professional' ' soccer' ' player' ' for'
+ ' the' ' past' ' 15' ' years' ',' ' and' ' is' ' currently' ' playing'
+ ' for' ' the']" , who has been a professional soccer player for the past 15 years , and is currently playing for the True that was created by Tim Cahill and Julie McNally 6 [' that', ' was', ' created', ' by', ' Tim', ' Cah', 'ill']
+394 103 The professional sport played by x -1 The professional sport played by Tim Cahill soccer Tim Cahill "[',' ' who' ' has' ' been' ' a' ' professional' ' soccer' ' player' ' for'
+ ' the' ' past' ' 15' ' years' ',' ' and' ' is' ' currently' ' playing'
+ ' for' ' the']" , who has been a professional soccer player for the past 15 years , and is currently playing for the True McNally-Cahill and Tim Cahill developed Littlest 9 [' McN', 'ally', '-', 'C', 'ah', 'ill', ' and', ' Tim', ' Cah', 'ill']
+395 103 The professional sport played by x -1 The professional sport played by Tim Cahill soccer Tim Cahill "[',' ' who' ' has' ' been' ' a' ' professional' ' soccer' ' player' ' for'
+ ' the' ' past' ' 15' ' years' ',' ' and' ' is' ' currently' ' playing'
+ ' for' ' the']" , who has been a professional soccer player for the past 15 years , and is currently playing for the True McNally-Cahill and Tim Cahill for Hasbro Studios. 9 [' McN', 'ally', '-', 'C', 'ah', 'ill', ' and', ' Tim', ' Cah', 'ill']
+396 104 The professional sport played by x -1 The professional sport played by George Best soccer George Best "[',' ' the' ' former' ' Manchester' ' United' ' and' ' England'
+ ' footballer' ',' ' was' ' a' ' member' ' of' ' the' ' team' ' that'
+ ' won' ' the' ' FA' ' Cup']" , the former Manchester United and England footballer , was a member of the team that won the FA Cup False in the world. George Best was something 5 [' in', ' the', ' world', '.', ' George', ' Best']
+397 104 The professional sport played by x -1 The professional sport played by George Best soccer George Best "[',' ' the' ' former' ' Manchester' ' United' ' and' ' England'
+ ' footballer' ',' ' was' ' a' ' member' ' of' ' the' ' team' ' that'
+ ' won' ' the' ' FA' ' Cup']" , the former Manchester United and England footballer , was a member of the team that won the FA Cup False to the now-retired George Best, and did little 7 [' to', ' the', ' now', '-', 'ret', 'ired', ' George', ' Best']
+398 104 The professional sport played by x -1 The professional sport played by George Best soccer George Best "[',' ' the' ' former' ' Manchester' ' United' ' and' ' England'
+ ' footballer' ',' ' was' ' a' ' member' ' of' ' the' ' team' ' that'
+ ' won' ' the' ' FA' ' Cup']" , the former Manchester United and England footballer , was a member of the team that won the FA Cup False in soon after George Best establishing the Wagga 4 [' in', ' soon', ' after', ' George', ' Best']
+399 104 The professional sport played by x -1 The professional sport played by George Best soccer George Best "[',' ' the' ' former' ' Manchester' ' United' ' and' ' England'
+ ' footballer' ',' ' was' ' a' ' member' ' of' ' the' ' team' ' that'
+ ' won' ' the' ' FA' ' Cup']" , the former Manchester United and England footballer , was a member of the team that won the FA Cup False players – including George Best – to win the FA Cup 4 [' players', ' –', ' including', ' George', ' Best']
+400 104 The professional sport played by x -1 The professional sport played by George Best soccer George Best "[',' ' the' ' former' ' Manchester' ' United' ' and' ' England'
+ ' footballer' ',' ' was' ' a' ' member' ' of' ' the' ' team' ' that'
+ ' won' ' the' ' FA' ' Cup']" , the former Manchester United and England footballer , was a member of the team that won the FA Cup False Belfast-born George Best and Sammy McIlroy 4 [' Belfast', '-', 'born', ' George', ' Best']
+401 105 The professional sport played by x -1 The professional sport played by Robbie Rogers soccer Robbie Rogers "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True substitution bringing on Robbie Rogers for Gaven. Shortly 4 [' substitution', ' bringing', ' on', ' Robbie', ' Rogers']
+402 105 The professional sport played by x -1 The professional sport played by Robbie Rogers soccer Robbie Rogers "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True substitution bringing on Robbie Rogers for Gaven. 4 [' substitution', ' bringing', ' on', ' Robbie', ' Rogers']
+403 105 The professional sport played by x -1 The professional sport played by Robbie Rogers soccer Robbie Rogers "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True substitution bringing on Robbie Rogers for Gaven. Shortly 4 [' substitution', ' bringing', ' on', ' Robbie', ' Rogers']
+404 105 The professional sport played by x -1 The professional sport played by Robbie Rogers soccer Robbie Rogers "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True substitution bringing on Robbie Rogers for Gaven. Shortly 4 [' substitution', ' bringing', ' on', ' Robbie', ' Rogers']
+405 105 The professional sport played by x -1 The professional sport played by Robbie Rogers soccer Robbie Rogers "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True bringing on Robbie Rogers for Gaven. Shortly 3 [' bringing', ' on', ' Robbie', ' Rogers']
+406 106 The professional sport played by x -1 The professional sport played by Ichiro Suzuki baseball Ichiro Suzuki "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Japanese'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 15' ',' ' has' ' been'
+ ' a']" , who has been a member of the Japanese national team since the age of 15 , has been a False Brian Roberts, Ichiro Suzuki hit a long fly 5 [' Brian', ' Roberts', ',', ' Ich', 'iro', ' Suzuki']
+407 106 The professional sport played by x -1 The professional sport played by Ichiro Suzuki baseball Ichiro Suzuki "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Japanese'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 15' ',' ' has' ' been'
+ ' a']" , who has been a member of the Japanese national team since the age of 15 , has been a False being surpassed by Ichiro Suzuki of the Seattle 5 [' being', ' surpassed', ' by', ' Ich', 'iro', ' Suzuki']
+408 106 The professional sport played by x -1 The professional sport played by Ichiro Suzuki baseball Ichiro Suzuki "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Japanese'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 15' ',' ' has' ' been'
+ ' a']" , who has been a member of the Japanese national team since the age of 15 , has been a False previous record bid that Ichiro Suzuki had garnered in 2000 6 [' previous', ' record', ' bid', ' that', ' Ich', 'iro', ' Suzuki']
+409 106 The professional sport played by x -1 The professional sport played by Ichiro Suzuki baseball Ichiro Suzuki "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Japanese'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 15' ',' ' has' ' been'
+ ' a']" , who has been a member of the Japanese national team since the age of 15 , has been a False chief engineer Ichiro Suzuki sought to develop 4 [' chief', ' engineer', ' Ich', 'iro', ' Suzuki']
+410 106 The professional sport played by x -1 The professional sport played by Ichiro Suzuki baseball Ichiro Suzuki "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Japanese'
+ ' national' ' team' ' since' ' the' ' age' ' of' ' 15' ',' ' has' ' been'
+ ' a']" , who has been a member of the Japanese national team since the age of 15 , has been a False by a score of 5 – 4. Ichiro Suzuki won the MVP award 10 [' by', ' a', ' score', ' of', ' 5', ' –', ' 4', '.', ' Ich', 'iro', ' Suzuki']
+411 107 The professional sport played by x -1 The professional sport played by LeBron James basketball LeBron James "[',' ' the' ' NBA' '�' '�' 's' ' most' ' valuable' ' player' ',' ' is'
+ ' a' ' basketball' ' player' '.' '\n' '\n' 'The' ' NBA' ' is']" ", the NBA � � s most valuable player , is a basketball player .
+
+ The NBA is" True 2011 offseason, LeBron James flew to Houston 4 [' 2011', ' offseason', ',', ' LeBron', ' James']
+412 107 The professional sport played by x -1 The professional sport played by LeBron James basketball LeBron James "[',' ' the' ' NBA' '�' '�' 's' ' most' ' valuable' ' player' ',' ' is'
+ ' a' ' basketball' ' player' '.' '\n' '\n' 'The' ' NBA' ' is']" ", the NBA � � s most valuable player , is a basketball player .
+
+ The NBA is" True offense with LeBron James and Chris Bosh, but 3 [' offense', ' with', ' LeBron', ' James']
+413 107 The professional sport played by x -1 The professional sport played by LeBron James basketball LeBron James "[',' ' the' ' NBA' '�' '�' 's' ' most' ' valuable' ' player' ',' ' is'
+ ' a' ' basketball' ' player' '.' '\n' '\n' 'The' ' NBA' ' is']" ", the NBA � � s most valuable player , is a basketball player .
+
+ The NBA is" True players such as LeBron James and Chris Bosh. In 4 [' players', ' such', ' as', ' LeBron', ' James']
+414 107 The professional sport played by x -1 The professional sport played by LeBron James basketball LeBron James "[',' ' the' ' NBA' '�' '�' 's' ' most' ' valuable' ' player' ',' ' is'
+ ' a' ' basketball' ' player' '.' '\n' '\n' 'The' ' NBA' ' is']" ", the NBA � � s most valuable player , is a basketball player .
+
+ The NBA is" True growing up, including LeBron James and Dwyane Wade. 5 [' growing', ' up', ',', ' including', ' LeBron', ' James']
+415 107 The professional sport played by x -1 The professional sport played by LeBron James basketball LeBron James "[',' ' the' ' NBA' '�' '�' 's' ' most' ' valuable' ' player' ',' ' is'
+ ' a' ' basketball' ' player' '.' '\n' '\n' 'The' ' NBA' ' is']" ", the NBA � � s most valuable player , is a basketball player .
+
+ The NBA is" True the 2011 NBA Finals, LeBron James called Nowitzki's 6 [' the', ' 2011', ' NBA', ' Finals', ',', ' LeBron', ' James']
+416 109 The professional sport played by x -1 The professional sport played by Bubba Smith football Bubba Smith "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Wayne Fontes, Bubba Smith and Drew Stanton. 6 [' Wayne', ' Font', 'es', ',', ' Bub', 'ba', ' Smith']
+417 109 The professional sport played by x -1 The professional sport played by Bubba Smith football Bubba Smith "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Shula, John Unitas, Bubba Smith and Tom Matte were 9 [' Sh', 'ula', ',', ' John', ' Unit', 'as', ',', ' Bub', 'ba', ' Smith']
+418 109 The professional sport played by x -1 The professional sport played by Bubba Smith football Bubba Smith "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Shula, John Unitas, Bubba Smith and Tom Matte were 9 [' Sh', 'ula', ',', ' John', ' Unit', 'as', ',', ' Bub', 'ba', ' Smith']
+419 109 The professional sport played by x -1 The professional sport played by Bubba Smith football Bubba Smith "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Wayne Fontes, Bubba Smith and Drew Stanton. 6 [' Wayne', ' Font', 'es', ',', ' Bub', 'ba', ' Smith']
+420 110 The professional sport played by x -1 The professional sport played by Alexander Ovechkin hockey Alexander Ovechkin "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' goal' ' scorer' ','
+ ' is' ' a' ' great' ' example' ' of' ' the' ' type' ' of' ' player']" , the NHL 's all - time leading goal scorer , is a great example of the type of player False the Penguins and Alexander Ovechkin and Alexander Semin 7 [' the', ' Penguins', ' and', ' Alexander', ' O', 've', 'ch', 'kin']
+421 110 The professional sport played by x -1 The professional sport played by Alexander Ovechkin hockey Alexander Ovechkin "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' goal' ' scorer' ','
+ ' is' ' a' ' great' ' example' ' of' ' the' ' type' ' of' ' player']" , the NHL 's all - time leading goal scorer , is a great example of the type of player False scoring, behind only Alexander Ovechkin and Jeff Carter. 8 [' scoring', ',', ' behind', ' only', ' Alexander', ' O', 've', 'ch', 'kin']
+422 110 The professional sport played by x -1 The professional sport played by Alexander Ovechkin hockey Alexander Ovechkin "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' goal' ' scorer' ','
+ ' is' ' a' ' great' ' example' ' of' ' the' ' type' ' of' ' player']" , the NHL 's all - time leading goal scorer , is a great example of the type of player False season, rookies Alexander Ovechkin and Sidney Crosby 7 [' season', ',', ' rookies', ' Alexander', ' O', 've', 'ch', 'kin']
+423 110 The professional sport played by x -1 The professional sport played by Alexander Ovechkin hockey Alexander Ovechkin "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' goal' ' scorer' ','
+ ' is' ' a' ' great' ' example' ' of' ' the' ' type' ' of' ' player']" , the NHL 's all - time leading goal scorer , is a great example of the type of player False and NHL leader Alexander Ovechkin with 56, while 7 [' and', ' NHL', ' leader', ' Alexander', ' O', 've', 'ch', 'kin']
+424 110 The professional sport played by x -1 The professional sport played by Alexander Ovechkin hockey Alexander Ovechkin "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' goal' ' scorer' ','
+ ' is' ' a' ' great' ' example' ' of' ' the' ' type' ' of' ' player']" , the NHL 's all - time leading goal scorer , is a great example of the type of player False leaving star forward Alexander Ovechkin on the bench. The 7 [' leaving', ' star', ' forward', ' Alexander', ' O', 've', 'ch', 'kin']
+425 112 The professional sport played by x -1 The professional sport played by Edgaras Jankauskas soccer Edgaras Jankauskas "[',' ' a' ' Lithuan' 'ian' '-' 'born' ' American' ' basketball' ' player'
+ ',' ' is' ' a' ' game' ' that' ' is' ' played' ' by' ' two' ' teams'
+ ' of']" , a Lithuan ian - born American basketball player , is a game that is played by two teams of False Caneira and Edgaras Jankauskas also left, as 9 [' C', 'ane', 'ira', ' and', ' Edgar', 'as', ' J', 'ank', 'aus', 'kas']
+426 112 The professional sport played by x -1 The professional sport played by Edgaras Jankauskas soccer Edgaras Jankauskas "[',' ' a' ' Lithuan' 'ian' '-' 'born' ' American' ' basketball' ' player'
+ ',' ' is' ' a' ' game' ' that' ' is' ' played' ' by' ' two' ' teams'
+ ' of']" , a Lithuan ian - born American basketball player , is a game that is played by two teams of False Marco Caneira and Edgaras Jankauskas also left, as Benfica 10 [' Marco', ' C', 'ane', 'ira', ' and', ' Edgar', 'as', ' J', 'ank', 'aus', 'kas']
+427 112 The professional sport played by x -1 The professional sport played by Edgaras Jankauskas soccer Edgaras Jankauskas "[',' ' a' ' Lithuan' 'ian' '-' 'born' ' American' ' basketball' ' player'
+ ',' ' is' ' a' ' game' ' that' ' is' ' played' ' by' ' two' ' teams'
+ ' of']" , a Lithuan ian - born American basketball player , is a game that is played by two teams of False Marco Caneira and Edgaras Jankauskas also left, 10 [' Marco', ' C', 'ane', 'ira', ' and', ' Edgar', 'as', ' J', 'ank', 'aus', 'kas']
+428 114 The professional sport played by x -1 The professional sport played by Scottie Pippen basketball Scottie Pippen "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Bulls' ' from' ' 1989' ' to']" , the NBA 's all - time leading scorer , was a member of the Chicago Bulls from 1989 to False trade to bring in Scottie Pippen to take his place. 8 [' trade', ' to', ' bring', ' in', ' Scott', 'ie', ' P', 'ipp', 'en']
+429 114 The professional sport played by x -1 The professional sport played by Scottie Pippen basketball Scottie Pippen "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Bulls' ' from' ' 1989' ' to']" , the NBA 's all - time leading scorer , was a member of the Chicago Bulls from 1989 to False Michael Jordan and Scottie Pippen accomplished the feat 7 [' Michael', ' Jordan', ' and', ' Scott', 'ie', ' P', 'ipp', 'en']
+430 114 The professional sport played by x -1 The professional sport played by Scottie Pippen basketball Scottie Pippen "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Bulls' ' from' ' 1989' ' to']" , the NBA 's all - time leading scorer , was a member of the Chicago Bulls from 1989 to False you've got Scottie Pippen right in the middle. 7 "[' you', ""'ve"", ' got', ' Scott', 'ie', ' P', 'ipp', 'en']"
+431 114 The professional sport played by x -1 The professional sport played by Scottie Pippen basketball Scottie Pippen "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Bulls' ' from' ' 1989' ' to']" , the NBA 's all - time leading scorer , was a member of the Chicago Bulls from 1989 to False regular season. With Scottie Pippen developing into 8 [' regular', ' season', '.', ' With', ' Scott', 'ie', ' P', 'ipp', 'en']
+432 114 The professional sport played by x -1 The professional sport played by Scottie Pippen basketball Scottie Pippen "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Chicago' ' Bulls' ' from' ' 1989' ' to']" , the NBA 's all - time leading scorer , was a member of the Chicago Bulls from 1989 to False Bulls forward Scottie Pippen defended effectively 6 [' Bulls', ' forward', ' Scott', 'ie', ' P', 'ipp', 'en']
+433 115 The professional sport played by x -1 The professional sport played by Marc Gasol basketball Marc Gasol "[',' ' who' ' is' ' a' ' big' ' fan' ' of' ' the' ' NBA' ',' ' and' ' the'
+ ' NBA' ',' ' who' ' is' ' a' ' big' ' fan' ' of']" , who is a big fan of the NBA , and the NBA , who is a big fan of False McKie, the rights to Marc Gasol (Pau's younger 8 [' McK', 'ie', ',', ' the', ' rights', ' to', ' Marc', ' Gas', 'ol']
+434 116 The professional sport played by x -1 The professional sport played by Pat Tillman football Pat Tillman "[',' ' the' ' former' ' NFL' ' player' ' who' ' was' ' killed' ' in'
+ ' Afghanistan' ' in' ' 2004' ',' ' was' ' a' ' great' ' American' ' hero'
+ '.' ' He']" , the former NFL player who was killed in Afghanistan in 2004 , was a great American hero . He False 'Callaghan – Pat Tillman Memorial Bridge, 6 "["" '"", 'Call', 'aghan', ' –', ' Pat', ' Till', 'man']"
+435 116 The professional sport played by x -1 The professional sport played by Pat Tillman football Pat Tillman "[',' ' the' ' former' ' NFL' ' player' ' who' ' was' ' killed' ' in'
+ ' Afghanistan' ' in' ' 2004' ',' ' was' ' a' ' great' ' American' ' hero'
+ '.' ' He']" , the former NFL player who was killed in Afghanistan in 2004 , was a great American hero . He False aftermath of the Pat Tillman friendly fire 5 [' aftermath', ' of', ' the', ' Pat', ' Till', 'man']
+436 116 The professional sport played by x -1 The professional sport played by Pat Tillman football Pat Tillman "[',' ' the' ' former' ' NFL' ' player' ' who' ' was' ' killed' ' in'
+ ' Afghanistan' ' in' ' 2004' ',' ' was' ' a' ' great' ' American' ' hero'
+ '.' ' He']" , the former NFL player who was killed in Afghanistan in 2004 , was a great American hero . He False aftermath of the Pat Tillman friendly fire 5 [' aftermath', ' of', ' the', ' Pat', ' Till', 'man']
+437 116 The professional sport played by x -1 The professional sport played by Pat Tillman football Pat Tillman "[',' ' the' ' former' ' NFL' ' player' ' who' ' was' ' killed' ' in'
+ ' Afghanistan' ' in' ' 2004' ',' ' was' ' a' ' great' ' American' ' hero'
+ '.' ' He']" , the former NFL player who was killed in Afghanistan in 2004 , was a great American hero . He False aftermath of the Pat Tillman friendly fire incident 5 [' aftermath', ' of', ' the', ' Pat', ' Till', 'man']
+438 116 The professional sport played by x -1 The professional sport played by Pat Tillman football Pat Tillman "[',' ' the' ' former' ' NFL' ' player' ' who' ' was' ' killed' ' in'
+ ' Afghanistan' ' in' ' 2004' ',' ' was' ' a' ' great' ' American' ' hero'
+ '.' ' He']" , the former NFL player who was killed in Afghanistan in 2004 , was a great American hero . He False aftermath of the Pat Tillman friendly fire 5 [' aftermath', ' of', ' the', ' Pat', ' Till', 'man']
+439 118 The professional sport played by x -1 The professional sport played by Bob Hayes football Bob Hayes "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True Hall of Famers Bob Hayes and Mike Ditka 5 [' Hall', ' of', ' Fam', 'ers', ' Bob', ' Hayes']
+440 118 The professional sport played by x -1 The professional sport played by Bob Hayes football Bob Hayes "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True of Famers Bob Hayes and Mike Ditka 4 [' of', ' Fam', 'ers', ' Bob', ' Hayes']
+441 118 The professional sport played by x -1 The professional sport played by Bob Hayes football Bob Hayes "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True Hall of Famers Bob Hayes and Mike Ditka in 5 [' Hall', ' of', ' Fam', 'ers', ' Bob', ' Hayes']
+442 118 The professional sport played by x -1 The professional sport played by Bob Hayes football Bob Hayes "[',' ' a' ' former' ' professional' ' football' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional football player , is a great example of the kind of person who can make a True Football Hall of Famers Bob Hayes and Mike Ditka 6 [' Football', ' Hall', ' of', ' Fam', 'ers', ' Bob', ' Hayes']
+443 119 The professional sport played by x -1 The professional sport played by Corey Perry hockey Corey Perry "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' NHL' ' since'
+ ' the' ' league' '�' '�' 's' ' inception' ' in' ' 1917' ',' ' is']" , who has been a member of the NHL since the league � � s inception in 1917 , is False Getzlaf joined with Corey Perry (21) and Dustin Penner 6 [' Get', 'zl', 'af', ' joined', ' with', ' Corey', ' Perry']
+444 119 The professional sport played by x -1 The professional sport played by Corey Perry hockey Corey Perry "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' NHL' ' since'
+ ' the' ' league' '�' '�' 's' ' inception' ' in' ' 1917' ',' ' is']" , who has been a member of the NHL since the league � � s inception in 1917 , is False Team Staal's Corey Perry as he was the only 5 "[' Team', ' Sta', 'al', ""'s"", ' Corey', ' Perry']"
+445 119 The professional sport played by x -1 The professional sport played by Corey Perry hockey Corey Perry "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' NHL' ' since'
+ ' the' ' league' '�' '�' 's' ' inception' ' in' ' 1917' ',' ' is']" , who has been a member of the NHL since the league � � s inception in 1917 , is False joined with Corey Perry (21) and Dustin 3 [' joined', ' with', ' Corey', ' Perry']
+446 119 The professional sport played by x -1 The professional sport played by Corey Perry hockey Corey Perry "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' NHL' ' since'
+ ' the' ' league' '�' '�' 's' ' inception' ' in' ' 1917' ',' ' is']" , who has been a member of the NHL since the league � � s inception in 1917 , is False opposing forward Corey Perry to score an empty-netter 3 [' opposing', ' forward', ' Corey', ' Perry']
+447 119 The professional sport played by x -1 The professional sport played by Corey Perry hockey Corey Perry "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' NHL' ' since'
+ ' the' ' league' '�' '�' 's' ' inception' ' in' ' 1917' ',' ' is']" , who has been a member of the NHL since the league � � s inception in 1917 , is False games. Knights forward Corey Perry was awarded 5 [' games', '.', ' Knights', ' forward', ' Corey', ' Perry']
+448 122 The professional sport played by x -1 The professional sport played by Karl Malone basketball Karl Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' Utah' ' Jazz'
+ ',' ' and' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading']" , who was a member of the NBA 's Utah Jazz , and the NBA 's all - time leading False points, 11 assists) and Karl Malone (27 points, 10 assists) 7 [' points', ',', ' 11', ' assists', ')', ' and', ' Karl', ' Malone']
+449 122 The professional sport played by x -1 The professional sport played by Karl Malone basketball Karl Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' Utah' ' Jazz'
+ ',' ' and' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading']" , who was a member of the NBA 's Utah Jazz , and the NBA 's all - time leading False signed two-time MVP Karl Malone formerly of the 6 [' signed', ' two', '-', 'time', ' MVP', ' Karl', ' Malone']
+450 122 The professional sport played by x -1 The professional sport played by Karl Malone basketball Karl Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' Utah' ' Jazz'
+ ',' ' and' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading']" , who was a member of the NBA 's Utah Jazz , and the NBA 's all - time leading False Bulls defeated Karl Malone and the Utah 3 [' Bulls', ' defeated', ' Karl', ' Malone']
+451 122 The professional sport played by x -1 The professional sport played by Karl Malone basketball Karl Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' Utah' ' Jazz'
+ ',' ' and' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading']" , who was a member of the NBA 's Utah Jazz , and the NBA 's all - time leading False Oldest player: Karl Malone (Los Angeles 5 [' Old', 'est', ' player', ':', ' Karl', ' Malone']
+452 122 The professional sport played by x -1 The professional sport played by Karl Malone basketball Karl Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' Utah' ' Jazz'
+ ',' ' and' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading']" , who was a member of the NBA 's Utah Jazz , and the NBA 's all - time leading False career, passing Karl Malone (14,968) for sixth 4 [' career', ',', ' passing', ' Karl', ' Malone']
+453 125 The professional sport played by x -1 The professional sport played by Florent Sinama Pongolle soccer Florent Sinama Pongolle "[',' ' a' ' former' ' professional' ' footballer' ' who' ' played' ' as'
+ ' a' ' midfielder' '.' '\n' '\n' 'References' '\n' '\n' 'External'
+ ' links' '\n' '\n']" ", a former professional footballer who played as a midfielder .
+
+ References
+
+ External links
+
+" False substitution of the match; Florent Sinama Pongolle replaced Finnan. 12 [' substitution', ' of', ' the', ' match', ';', ' Flore', 'nt', ' Sin', 'ama', ' P', 'ong', 'ol', 'le']
+454 125 The professional sport played by x -1 The professional sport played by Florent Sinama Pongolle soccer Florent Sinama Pongolle "[',' ' a' ' former' ' professional' ' footballer' ' who' ' played' ' as'
+ ' a' ' midfielder' '.' '\n' '\n' 'References' '\n' '\n' 'External'
+ ' links' '\n' '\n']" ", a former professional footballer who played as a midfielder .
+
+ References
+
+ External links
+
+" False of the match; Florent Sinama Pongolle replaced Finnan. 11 [' of', ' the', ' match', ';', ' Flore', 'nt', ' Sin', 'ama', ' P', 'ong', 'ol', 'le']
+455 125 The professional sport played by x -1 The professional sport played by Florent Sinama Pongolle soccer Florent Sinama Pongolle "[',' ' a' ' former' ' professional' ' footballer' ' who' ' played' ' as'
+ ' a' ' midfielder' '.' '\n' '\n' 'References' '\n' '\n' 'External'
+ ' links' '\n' '\n']" ", a former professional footballer who played as a midfielder .
+
+ References
+
+ External links
+
+" False the match; Florent Sinama Pongolle replaced Finnan. 10 [' the', ' match', ';', ' Flore', 'nt', ' Sin', 'ama', ' P', 'ong', 'ol', 'le']
+456 125 The professional sport played by x -1 The professional sport played by Florent Sinama Pongolle soccer Florent Sinama Pongolle "[',' ' a' ' former' ' professional' ' footballer' ' who' ' played' ' as'
+ ' a' ' midfielder' '.' '\n' '\n' 'References' '\n' '\n' 'External'
+ ' links' '\n' '\n']" ", a former professional footballer who played as a midfielder .
+
+ References
+
+ External links
+
+" False substitution of the match; Florent Sinama Pongolle replaced Finnan. 12 [' substitution', ' of', ' the', ' match', ';', ' Flore', 'nt', ' Sin', 'ama', ' P', 'ong', 'ol', 'le']
+457 125 The professional sport played by x -1 The professional sport played by Florent Sinama Pongolle soccer Florent Sinama Pongolle "[',' ' a' ' former' ' professional' ' footballer' ' who' ' played' ' as'
+ ' a' ' midfielder' '.' '\n' '\n' 'References' '\n' '\n' 'External'
+ ' links' '\n' '\n']" ", a former professional footballer who played as a midfielder .
+
+ References
+
+ External links
+
+" False John Arne Riise and Florent Sinama Pongolle replaced Stephen 13 [' John', ' Ar', 'ne', ' Ri', 'ise', ' and', ' Flore', 'nt', ' Sin', 'ama', ' P', 'ong', 'ol', 'le']
+458 127 The professional sport played by x -1 The professional sport played by Kevin Durant basketball Kevin Durant "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False Williams (2005), Kevin Durant (2006), Kevin 5 [' Williams', ' (', '2005', '),', ' Kevin', ' Durant']
+459 127 The professional sport played by x -1 The professional sport played by Kevin Durant basketball Kevin Durant "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False Basketball Reference's Kevin Durant page unless noted 4 "[' Basketball', ' Reference', ""'s"", ' Kevin', ' Durant']"
+460 127 The professional sport played by x -1 The professional sport played by Kevin Durant basketball Kevin Durant "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False " Durant =
+" 3 [' Durant', ' =', 'Kevin', ' Durant']
+461 127 The professional sport played by x -1 The professional sport played by Kevin Durant basketball Kevin Durant "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False Williams (2005), Kevin Durant (2006), Kevin Love 5 [' Williams', ' (', '2005', '),', ' Kevin', ' Durant']
+462 127 The professional sport played by x -1 The professional sport played by Kevin Durant basketball Kevin Durant "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ',' ' is' ' a'
+ ' former' ' NBA' ' All' '-' 'Star' ',' ' and' ' a']" , who is a former NBA All - Star , is a former NBA All - Star , and a False Basketball Reference's Kevin Durant page unless noted 4 "[' Basketball', ' Reference', ""'s"", ' Kevin', ' Durant']"
+463 128 The professional sport played by x -1 The professional sport played by Bo Jackson baseball Bo Jackson "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NFL star , is a great example of the kind of person who can make a difference False " ====
+" 3 [' =', '===', 'Bo', ' Jackson']
+464 128 The professional sport played by x -1 The professional sport played by Bo Jackson baseball Bo Jackson "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NFL star , is a great example of the kind of person who can make a difference False Buccaneers selected Bo Jackson with the top pick 3 [' Buccaneers', ' selected', ' Bo', ' Jackson']
+465 128 The professional sport played by x -1 The professional sport played by Bo Jackson baseball Bo Jackson "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NFL star , is a great example of the kind of person who can make a difference False Buccaneers selected Bo Jackson with the top 3 [' Buccaneers', ' selected', ' Bo', ' Jackson']
+466 128 The professional sport played by x -1 The professional sport played by Bo Jackson baseball Bo Jackson "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NFL star , is a great example of the kind of person who can make a difference False University running back Bo Jackson was selected 4 [' University', ' running', ' back', ' Bo', ' Jackson']
+467 128 The professional sport played by x -1 The professional sport played by Bo Jackson baseball Bo Jackson "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NFL star , is a great example of the kind of person who can make a difference False offered draft pick Bo Jackson a five-year, 4 [' offered', ' draft', ' pick', ' Bo', ' Jackson']
+468 130 The professional sport played by x -1 The professional sport played by Tom Brady football Tom Brady "[' and' ' the' ' New' ' England' ' Patriots' '.' '\n' '\n' 'The'
+ ' Patriots' ' are' ' the' ' only' ' team' ' to' ' have' ' won' ' the'
+ ' Super' ' Bowl']" " and the New England Patriots .
+
+ The Patriots are the only team to have won the Super Bowl" False after NFL quarterbacks Tom Brady and Russell 4 [' after', ' NFL', ' quarterbacks', ' Tom', ' Brady']
+469 130 The professional sport played by x -1 The professional sport played by Tom Brady football Tom Brady "[' and' ' the' ' New' ' England' ' Patriots' '.' '\n' '\n' 'The'
+ ' Patriots' ' are' ' the' ' only' ' team' ' to' ' have' ' won' ' the'
+ ' Super' ' Bowl']" " and the New England Patriots .
+
+ The Patriots are the only team to have won the Super Bowl" False in 1987 and Tom Brady in 2000. Past recipients 4 [' in', ' 1987', ' and', ' Tom', ' Brady']
+470 130 The professional sport played by x -1 The professional sport played by Tom Brady football Tom Brady "[' and' ' the' ' New' ' England' ' Patriots' '.' '\n' '\n' 'The'
+ ' Patriots' ' are' ' the' ' only' ' team' ' to' ' have' ' won' ' the'
+ ' Super' ' Bowl']" " and the New England Patriots .
+
+ The Patriots are the only team to have won the Super Bowl" False and freshman Tom Brady served as an understudy. 3 [' and', ' freshman', ' Tom', ' Brady']
+471 130 The professional sport played by x -1 The professional sport played by Tom Brady football Tom Brady "[' and' ' the' ' New' ' England' ' Patriots' '.' '\n' '\n' 'The'
+ ' Patriots' ' are' ' the' ' only' ' team' ' to' ' have' ' won' ' the'
+ ' Super' ' Bowl']" " and the New England Patriots .
+
+ The Patriots are the only team to have won the Super Bowl" False Belichick in mind, but Tom Brady was chosen to replace 6 [' Belichick', ' in', ' mind', ',', ' but', ' Tom', ' Brady']
+472 130 The professional sport played by x -1 The professional sport played by Tom Brady football Tom Brady "[' and' ' the' ' New' ' England' ' Patriots' '.' '\n' '\n' 'The'
+ ' Patriots' ' are' ' the' ' only' ' team' ' to' ' have' ' won' ' the'
+ ' Super' ' Bowl']" " and the New England Patriots .
+
+ The Patriots are the only team to have won the Super Bowl" False that included Tom Brady and Charles Woodson. 3 [' that', ' included', ' Tom', ' Brady']
+473 131 The professional sport played by x -1 The professional sport played by Blaise Matuidi soccer Blaise Matuidi "[',' ' a' ' French' ' midfielder' ' who' ' plays' ' for' ' Paris' ' Saint'
+ '-' 'G' 'erm' 'ain' '.' '\n' '\n' 'The' ' French' ' midfielder' ' is']" ", a French midfielder who plays for Paris Saint - G erm ain .
+
+ The French midfielder is" False Issiar Dia, Blaise Matuidi and Serge Gakpé 9 [' Iss', 'iar', ' Dia', ',', ' Bl', 'a', 'ise', ' Mat', 'u', 'idi']
+474 131 The professional sport played by x -1 The professional sport played by Blaise Matuidi soccer Blaise Matuidi "[',' ' a' ' French' ' midfielder' ' who' ' plays' ' for' ' Paris' ' Saint'
+ '-' 'G' 'erm' 'ain' '.' '\n' '\n' 'The' ' French' ' midfielder' ' is']" ", a French midfielder who plays for Paris Saint - G erm ain .
+
+ The French midfielder is" False by Issiar Dia, Blaise Matuidi and Serge Gakpé 10 [' by', ' Iss', 'iar', ' Dia', ',', ' Bl', 'a', 'ise', ' Mat', 'u', 'idi']
+475 131 The professional sport played by x -1 The professional sport played by Blaise Matuidi soccer Blaise Matuidi "[',' ' a' ' French' ' midfielder' ' who' ' plays' ' for' ' Paris' ' Saint'
+ '-' 'G' 'erm' 'ain' '.' '\n' '\n' 'The' ' French' ' midfielder' ' is']" ", a French midfielder who plays for Paris Saint - G erm ain .
+
+ The French midfielder is" False joined by Issiar Dia, Blaise Matuidi and Serge Gakpé 11 [' joined', ' by', ' Iss', 'iar', ' Dia', ',', ' Bl', 'a', 'ise', ' Mat', 'u', 'idi']
+476 131 The professional sport played by x -1 The professional sport played by Blaise Matuidi soccer Blaise Matuidi "[',' ' a' ' French' ' midfielder' ' who' ' plays' ' for' ' Paris' ' Saint'
+ '-' 'G' 'erm' 'ain' '.' '\n' '\n' 'The' ' French' ' midfielder' ' is']" ", a French midfielder who plays for Paris Saint - G erm ain .
+
+ The French midfielder is" False joined by Issiar Dia, Blaise Matuidi and Serge Gakpé with 11 [' joined', ' by', ' Iss', 'iar', ' Dia', ',', ' Bl', 'a', 'ise', ' Mat', 'u', 'idi']
+477 131 The professional sport played by x -1 The professional sport played by Blaise Matuidi soccer Blaise Matuidi "[',' ' a' ' French' ' midfielder' ' who' ' plays' ' for' ' Paris' ' Saint'
+ '-' 'G' 'erm' 'ain' '.' '\n' '\n' 'The' ' French' ' midfielder' ' is']" ", a French midfielder who plays for Paris Saint - G erm ain .
+
+ The French midfielder is" False by Issiar Dia, Blaise Matuidi and Serge 10 [' by', ' Iss', 'iar', ' Dia', ',', ' Bl', 'a', 'ise', ' Mat', 'u', 'idi']
+478 133 The professional sport played by x -1 The professional sport played by Kaká soccer Kaká "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' Brazilian' ' national' ' team' ' since' ' the' ' age' ' of' ' 17' ','
+ ' has' ' been']" , who has been a regular starter for the Brazilian national team since the age of 17 , has been False previous signings of Kaká and Cristiano Ronaldo. 4 [' previous', ' signings', ' of', ' Kak', 'á']
+479 133 The professional sport played by x -1 The professional sport played by Kaká soccer Kaká "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' Brazilian' ' national' ' team' ' since' ' the' ' age' ' of' ' 17' ','
+ ' has' ' been']" , who has been a regular starter for the Brazilian national team since the age of 17 , has been False high pass to Kaká who was ruled to 4 [' high', ' pass', ' to', ' Kak', 'á']
+480 133 The professional sport played by x -1 The professional sport played by Kaká soccer Kaká "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' Brazilian' ' national' ' team' ' since' ' the' ' age' ' of' ' 17' ','
+ ' has' ' been']" , who has been a regular starter for the Brazilian national team since the age of 17 , has been False few minutes later. Kaká received the 5 [' few', ' minutes', ' later', '.', ' Kak', 'á']
+481 133 The professional sport played by x -1 The professional sport played by Kaká soccer Kaká "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' Brazilian' ' national' ' team' ' since' ' the' ' age' ' of' ' 17' ','
+ ' has' ' been']" , who has been a regular starter for the Brazilian national team since the age of 17 , has been False previous signings of Kaká and Cristiano Ronaldo. 4 [' previous', ' signings', ' of', ' Kak', 'á']
+482 133 The professional sport played by x -1 The professional sport played by Kaká soccer Kaká "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' Brazilian' ' national' ' team' ' since' ' the' ' age' ' of' ' 17' ','
+ ' has' ' been']" , who has been a regular starter for the Brazilian national team since the age of 17 , has been False high pass to Kaká who was ruled 4 [' high', ' pass', ' to', ' Kak', 'á']
+483 134 The professional sport played by x -1 The professional sport played by George Halas baseball George Halas "[',' ' the' ' first' ' NFL' ' player' ' to' ' win' ' the' ' Super' ' Bowl'
+ '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' league' ' of' ' its']" ", the first NFL player to win the Super Bowl .
+
+ The NFL is a league of its" False Bears owner / coach George Halas took a chance 6 [' Bears', ' owner', ' /', ' coach', ' George', ' Hal', 'as']
+484 134 The professional sport played by x -1 The professional sport played by George Halas baseball George Halas "[',' ' the' ' first' ' NFL' ' player' ' to' ' win' ' the' ' Super' ' Bowl'
+ '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' league' ' of' ' its']" ", the first NFL player to win the Super Bowl .
+
+ The NFL is a league of its" False " I do about Tony, if George Halas had gotten him,""" 8 [' I', ' do', ' about', ' Tony', ',', ' if', ' George', ' Hal', 'as']
+485 134 The professional sport played by x -1 The professional sport played by George Halas baseball George Halas "[',' ' the' ' first' ' NFL' ' player' ' to' ' win' ' the' ' Super' ' Bowl'
+ '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' league' ' of' ' its']" ", the first NFL player to win the Super Bowl .
+
+ The NFL is a league of its" False to his old associate George Halas of the Chicago 6 [' to', ' his', ' old', ' associate', ' George', ' Hal', 'as']
+486 134 The professional sport played by x -1 The professional sport played by George Halas baseball George Halas "[',' ' the' ' first' ' NFL' ' player' ' to' ' win' ' the' ' Super' ' Bowl'
+ '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' league' ' of' ' its']" ", the first NFL player to win the Super Bowl .
+
+ The NFL is a league of its" False area. In 1977, George Halas surrendered 51 7 [' area', '.', ' In', ' 1977', ',', ' George', ' Hal', 'as']
+487 134 The professional sport played by x -1 The professional sport played by George Halas baseball George Halas "[',' ' the' ' first' ' NFL' ' player' ' to' ' win' ' the' ' Super' ' Bowl'
+ '.' '\n' '\n' 'The' ' NFL' ' is' ' a' ' league' ' of' ' its']" ", the first NFL player to win the Super Bowl .
+
+ The NFL is a league of its" False criticize head coach George Halas on some of his coaching 5 [' criticize', ' head', ' coach', ' George', ' Hal', 'as']
+488 136 The professional sport played by x -1 The professional sport played by Dick Butkus football Dick Butkus "[',' ' the' ' Chicago' ' Bears' ""'"" ' linebacker' ',' ' was' ' a' ' great'
+ ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great' ' person' '.']" , the Chicago Bears ' linebacker , was a great player , but he was also a great person . False " legends such as Dick Butkus: ""Taylor is the best" 5 [' legends', ' such', ' as', ' Dick', ' But', 'kus']
+489 136 The professional sport played by x -1 The professional sport played by Dick Butkus football Dick Butkus "[',' ' the' ' Chicago' ' Bears' ""'"" ' linebacker' ',' ' was' ' a' ' great'
+ ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great' ' person' '.']" , the Chicago Bears ' linebacker , was a great player , but he was also a great person . False " NFL legends such as Dick Butkus: ""Taylor is the" 6 [' NFL', ' legends', ' such', ' as', ' Dick', ' But', 'kus']
+490 136 The professional sport played by x -1 The professional sport played by Dick Butkus football Dick Butkus "[',' ' the' ' Chicago' ' Bears' ""'"" ' linebacker' ',' ' was' ' a' ' great'
+ ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great' ' person' '.']" , the Chicago Bears ' linebacker , was a great player , but he was also a great person . False " legends such as Dick Butkus: ""Taylor is" 5 [' legends', ' such', ' as', ' Dick', ' But', 'kus']
+491 137 The professional sport played by x -1 The professional sport played by Pavol Demitra hockey Pavol Demitra "[',' ' who' ' was' ' born' ' in' ' Slovakia' ',' ' is' ' a' ' former'
+ ' NHL' ' player' ' who' ' has' ' played' ' for' ' the' ' Montreal'
+ ' Canadiens' ',']" , who was born in Slovakia , is a former NHL player who has played for the Montreal Canadiens , False acquisitions Pavol Demitra and Mats Sundin. 5 [' acquisitions', ' Pav', 'ol', ' Dem', 'it', 'ra']
+492 137 The professional sport played by x -1 The professional sport played by Pavol Demitra hockey Pavol Demitra "[',' ' who' ' was' ' born' ' in' ' Slovakia' ',' ' is' ' a' ' former'
+ ' NHL' ' player' ' who' ' has' ' played' ' for' ' the' ' Montreal'
+ ' Canadiens' ',']" , who was born in Slovakia , is a former NHL player who has played for the Montreal Canadiens , False acquisitions Pavol Demitra and Mats Sundin. 5 [' acquisitions', ' Pav', 'ol', ' Dem', 'it', 'ra']
+493 137 The professional sport played by x -1 The professional sport played by Pavol Demitra hockey Pavol Demitra "[',' ' who' ' was' ' born' ' in' ' Slovakia' ',' ' is' ' a' ' former'
+ ' NHL' ' player' ' who' ' has' ' played' ' for' ' the' ' Montreal'
+ ' Canadiens' ',']" , who was born in Slovakia , is a former NHL player who has played for the Montreal Canadiens , False Canucks to replace Pavol Demitra who suffered a 7 [' Canucks', ' to', ' replace', ' Pav', 'ol', ' Dem', 'it', 'ra']
+494 137 The professional sport played by x -1 The professional sport played by Pavol Demitra hockey Pavol Demitra "[',' ' who' ' was' ' born' ' in' ' Slovakia' ',' ' is' ' a' ' former'
+ ' NHL' ' player' ' who' ' has' ' played' ' for' ' the' ' Montreal'
+ ' Canadiens' ',']" , who was born in Slovakia , is a former NHL player who has played for the Montreal Canadiens , False Canucks to replace Pavol Demitra who suffered a fractured 7 [' Canucks', ' to', ' replace', ' Pav', 'ol', ' Dem', 'it', 'ra']
+495 137 The professional sport played by x -1 The professional sport played by Pavol Demitra hockey Pavol Demitra "[',' ' who' ' was' ' born' ' in' ' Slovakia' ',' ' is' ' a' ' former'
+ ' NHL' ' player' ' who' ' has' ' played' ' for' ' the' ' Montreal'
+ ' Canadiens' ',']" , who was born in Slovakia , is a former NHL player who has played for the Montreal Canadiens , False agent acquisitions Pavol Demitra and Mats Sundin. 6 [' agent', ' acquisitions', ' Pav', 'ol', ' Dem', 'it', 'ra']
+496 138 The professional sport played by x -1 The professional sport played by Brad Friedel soccer Brad Friedel "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.'
+ ' national' ' team' ' that' ' won' ' the' ' World' ' Cup' ' in' ' 1994']" , who was a member of the U . S . national team that won the World Cup in 1994 False team-mates such as Brad Friedel and captain Stiliyan 7 [' team', '-', 'mates', ' such', ' as', ' Brad', ' Fried', 'el']
+497 139 The professional sport played by x -1 The professional sport played by Kasey Keller soccer Kasey Keller "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , a former NFL quarterback , is a great example of the kind of person who can make a difference False Mauricio Taricco and Kasey Keller blocked from Mendieta, 9 [' Maur', 'icio', ' Tar', 'ic', 'co', ' and', ' K', 'ase', 'y', ' Keller']
+498 139 The professional sport played by x -1 The professional sport played by Kasey Keller soccer Kasey Keller "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , a former NFL quarterback , is a great example of the kind of person who can make a difference False shootout with goalkeeper Kasey Keller making two saves 6 [' shootout', ' with', ' goalkeeper', ' K', 'ase', 'y', ' Keller']
+499 139 The professional sport played by x -1 The professional sport played by Kasey Keller soccer Kasey Keller "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , a former NFL quarterback , is a great example of the kind of person who can make a difference False goalkeeper Kasey Keller to make a save. Pappa 4 [' goalkeeper', ' K', 'ase', 'y', ' Keller']
+500 139 The professional sport played by x -1 The professional sport played by Kasey Keller soccer Kasey Keller "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , a former NFL quarterback , is a great example of the kind of person who can make a difference False Taricco and Kasey Keller blocked from Mendieta, 7 [' Tar', 'ic', 'co', ' and', ' K', 'ase', 'y', ' Keller']
+501 139 The professional sport played by x -1 The professional sport played by Kasey Keller soccer Kasey Keller "[',' ' a' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , a former NFL quarterback , is a great example of the kind of person who can make a difference False 64th minute, Kasey Keller made a diving 7 [' 64', 'th', ' minute', ',', ' K', 'ase', 'y', ' Keller']
+502 140 The professional sport played by x -1 The professional sport played by Red Grange football Red Grange "[',' ' the' ' first' ' professional' ' football' ' player' ' to' ' win'
+ ' the' ' Heisman' ' Trophy' '.' '\n' '\n' 'The' ' Heisman' ' Trophy'
+ ' is' ' awarded' ' annually']" ", the first professional football player to win the Heisman Trophy .
+
+ The Heisman Trophy is awarded annually" True NFL, most notably Red Grange from the University 6 [' NFL', ',', ' most', ' notably', ' Red', ' Gr', 'ange']
+503 140 The professional sport played by x -1 The professional sport played by Red Grange football Red Grange "[',' ' the' ' first' ' professional' ' football' ' player' ' to' ' win'
+ ' the' ' Heisman' ' Trophy' '.' '\n' '\n' 'The' ' Heisman' ' Trophy'
+ ' is' ' awarded' ' annually']" ", the first professional football player to win the Heisman Trophy .
+
+ The Heisman Trophy is awarded annually" True NFL, most notably Red Grange from the University 6 [' NFL', ',', ' most', ' notably', ' Red', ' Gr', 'ange']
+504 140 The professional sport played by x -1 The professional sport played by Red Grange football Red Grange "[',' ' the' ' first' ' professional' ' football' ' player' ' to' ' win'
+ ' the' ' Heisman' ' Trophy' '.' '\n' '\n' 'The' ' Heisman' ' Trophy'
+ ' is' ' awarded' ' annually']" ", the first professional football player to win the Heisman Trophy .
+
+ The Heisman Trophy is awarded annually" True Radovich decision. Red Grange and Bell testified 6 [' Rad', 'ovich', ' decision', '.', ' Red', ' Gr', 'ange']
+505 140 The professional sport played by x -1 The professional sport played by Red Grange football Red Grange "[',' ' the' ' first' ' professional' ' football' ' player' ' to' ' win'
+ ' the' ' Heisman' ' Trophy' '.' '\n' '\n' 'The' ' Heisman' ' Trophy'
+ ' is' ' awarded' ' annually']" ", the first professional football player to win the Heisman Trophy .
+
+ The Heisman Trophy is awarded annually" True he outran Bears star Red Grange for a touchdown. However, 7 [' he', ' out', 'ran', ' Bears', ' star', ' Red', ' Gr', 'ange']
+506 140 The professional sport played by x -1 The professional sport played by Red Grange football Red Grange "[',' ' the' ' first' ' professional' ' football' ' player' ' to' ' win'
+ ' the' ' Heisman' ' Trophy' '.' '\n' '\n' 'The' ' Heisman' ' Trophy'
+ ' is' ' awarded' ' annually']" ", the first professional football player to win the Heisman Trophy .
+
+ The Heisman Trophy is awarded annually" True NFL, most notably Red Grange from the University 6 [' NFL', ',', ' most', ' notably', ' Red', ' Gr', 'ange']
+507 141 The professional sport played by x -1 The professional sport played by Stan Musial baseball Stan Musial "[',' ' the' ' Cardinals' ""'"" ' first' ' baseman' ',' ' was' ' a' ' great'
+ ' hitter' ',' ' but' ' he' ' was' ' also' ' a' ' great' ' teammate' '.']" , the Cardinals ' first baseman , was a great hitter , but he was also a great teammate . False tied twice. When Stan Musial set a record by hitting 6 [' tied', ' twice', '.', ' When', ' Stan', ' Mus', 'ial']
+508 141 The professional sport played by x -1 The professional sport played by Stan Musial baseball Stan Musial "[',' ' the' ' Cardinals' ""'"" ' first' ' baseman' ',' ' was' ' a' ' great'
+ ' hitter' ',' ' but' ' he' ' was' ' also' ' a' ' great' ' teammate' '.']" , the Cardinals ' first baseman , was a great hitter , but he was also a great teammate . False tied twice. When Stan Musial set a record by 6 [' tied', ' twice', '.', ' When', ' Stan', ' Mus', 'ial']
+509 141 The professional sport played by x -1 The professional sport played by Stan Musial baseball Stan Musial "[',' ' the' ' Cardinals' ""'"" ' first' ' baseman' ',' ' was' ' a' ' great'
+ ' hitter' ',' ' but' ' he' ' was' ' also' ' a' ' great' ' teammate' '.']" , the Cardinals ' first baseman , was a great hitter , but he was also a great teammate . False incorrect, as Stan Musial also met and far 5 [' incorrect', ',', ' as', ' Stan', ' Mus', 'ial']
+510 141 The professional sport played by x -1 The professional sport played by Stan Musial baseball Stan Musial "[',' ' the' ' Cardinals' ""'"" ' first' ' baseman' ',' ' was' ' a' ' great'
+ ' hitter' ',' ' but' ' he' ' was' ' also' ' a' ' great' ' teammate' '.']" , the Cardinals ' first baseman , was a great hitter , but he was also a great teammate . False the first since Stan Musial in the 1946 World Series, 5 [' the', ' first', ' since', ' Stan', ' Mus', 'ial']
+511 141 The professional sport played by x -1 The professional sport played by Stan Musial baseball Stan Musial "[',' ' the' ' Cardinals' ""'"" ' first' ' baseman' ',' ' was' ' a' ' great'
+ ' hitter' ',' ' but' ' he' ' was' ' also' ' a' ' great' ' teammate' '.']" , the Cardinals ' first baseman , was a great hitter , but he was also a great teammate . False Cardinals moved Stan Musial to first base to 4 [' Cardinals', ' moved', ' Stan', ' Mus', 'ial']
+512 142 The professional sport played by x -1 The professional sport played by Billy Sunday baseball Billy Sunday "[',' ' the' ' former' ' NFL' ' player' ',' ' was' ' a' ' huge' ' hit'
+ ' with' ' the' ' crowd' '.' '\n' '\n' 'The' ' crowd' ' was' ' so']" ", the former NFL player , was a huge hit with the crowd .
+
+ The crowd was so" False 1 ['Billy', ' Sunday']
+513 142 The professional sport played by x -1 The professional sport played by Billy Sunday baseball Billy Sunday "[',' ' the' ' former' ' NFL' ' player' ',' ' was' ' a' ' huge' ' hit'
+ ' with' ' the' ' crowd' '.' '\n' '\n' 'The' ' crowd' ' was' ' so']" ", the former NFL player , was a huge hit with the crowd .
+
+ The crowd was so" False reenacting Billy Sunday sermons and attending 4 [' re', 'en', 'acting', ' Billy', ' Sunday']
+514 142 The professional sport played by x -1 The professional sport played by Billy Sunday baseball Billy Sunday "[',' ' the' ' former' ' NFL' ' player' ',' ' was' ' a' ' huge' ' hit'
+ ' with' ' the' ' crowd' '.' '\n' '\n' 'The' ' crowd' ' was' ' so']" ", the former NFL player , was a huge hit with the crowd .
+
+ The crowd was so" False nation reenacting the Billy Sunday sermons Get 6 [' nation', ' re', 'en', 'acting', ' the', ' Billy', ' Sunday']
+515 142 The professional sport played by x -1 The professional sport played by Billy Sunday baseball Billy Sunday "[',' ' the' ' former' ' NFL' ' player' ',' ' was' ' a' ' huge' ' hit'
+ ' with' ' the' ' crowd' '.' '\n' '\n' 'The' ' crowd' ' was' ' so']" ", the former NFL player , was a huge hit with the crowd .
+
+ The crowd was so" False 1 ['Billy', ' Sunday']
+516 142 The professional sport played by x -1 The professional sport played by Billy Sunday baseball Billy Sunday "[',' ' the' ' former' ' NFL' ' player' ',' ' was' ' a' ' huge' ' hit'
+ ' with' ' the' ' crowd' '.' '\n' '\n' 'The' ' crowd' ' was' ' so']" ", the former NFL player , was a huge hit with the crowd .
+
+ The crowd was so" False " 108 yards rushing. Billy Sunday wrote, ""That jump" 5 [' 108', ' yards', ' rushing', '.', ' Billy', ' Sunday']
+517 143 The professional sport played by x -1 The professional sport played by Tommie Smith football Tommie Smith "[' and' ' John' ' Carlos' ',' ' who' ' were' ' both' ' black' ',' ' and'
+ ' the' ' American' ' spr' 'inter' ' T' 'omm' 'ie' ' Smith' ',' ' who']" and John Carlos , who were both black , and the American spr inter T omm ie Smith , who False City John Carlos, Tommie Smith and Peter Norman ignored 7 [' City', ' John', ' Carlos', ',', ' T', 'omm', 'ie', ' Smith']
+518 143 The professional sport played by x -1 The professional sport played by Tommie Smith football Tommie Smith "[' and' ' John' ' Carlos' ',' ' who' ' were' ' both' ' black' ',' ' and'
+ ' the' ' American' ' spr' 'inter' ' T' 'omm' 'ie' ' Smith' ',' ' who']" and John Carlos , who were both black , and the American spr inter T omm ie Smith , who False John Carlos, Tommie Smith and Peter Norman ignored 6 [' John', ' Carlos', ',', ' T', 'omm', 'ie', ' Smith']
+519 143 The professional sport played by x -1 The professional sport played by Tommie Smith football Tommie Smith "[' and' ' John' ' Carlos' ',' ' who' ' were' ' both' ' black' ',' ' and'
+ ' the' ' American' ' spr' 'inter' ' T' 'omm' 'ie' ' Smith' ',' ' who']" and John Carlos , who were both black , and the American spr inter T omm ie Smith , who False field athletes, Tommie Smith and John Carlos, who 6 [' field', ' athletes', ',', ' T', 'omm', 'ie', ' Smith']
+520 143 The professional sport played by x -1 The professional sport played by Tommie Smith football Tommie Smith "[' and' ' John' ' Carlos' ',' ' who' ' were' ' both' ' black' ',' ' and'
+ ' the' ' American' ' spr' 'inter' ' T' 'omm' 'ie' ' Smith' ',' ' who']" and John Carlos , who were both black , and the American spr inter T omm ie Smith , who False Black Power salute of Tommie Smith and John Carlos 7 [' Black', ' Power', ' salute', ' of', ' T', 'omm', 'ie', ' Smith']
+521 143 The professional sport played by x -1 The professional sport played by Tommie Smith football Tommie Smith "[' and' ' John' ' Carlos' ',' ' who' ' were' ' both' ' black' ',' ' and'
+ ' the' ' American' ' spr' 'inter' ' T' 'omm' 'ie' ' Smith' ',' ' who']" and John Carlos , who were both black , and the American spr inter T omm ie Smith , who False John Carlos, Tommie Smith and Peter Norman ignored 6 [' John', ' Carlos', ',', ' T', 'omm', 'ie', ' Smith']
+522 145 The professional sport played by x -1 The professional sport played by Lionel Conacher hockey Lionel Conacher "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' and'
+ ' his' ' wife' ',' ' the' ' former' ' Miss' ' America' ',' ' is' ' a'
+ ' great' ' example']" , a former professional baseball player , and his wife , the former Miss America , is a great example False voting for the 2000 Lionel Conacher Award, which is 6 [' voting', ' for', ' the', ' 2000', ' Lionel', ' Con', 'acher']
+523 145 The professional sport played by x -1 The professional sport played by Lionel Conacher hockey Lionel Conacher "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' and'
+ ' his' ' wife' ',' ' the' ' former' ' Miss' ' America' ',' ' is' ' a'
+ ' great' ' example']" , a former professional baseball player , and his wife , the former Miss America , is a great example False 4 ['L', 'ion', 'el', ' Con', 'acher']
+524 145 The professional sport played by x -1 The professional sport played by Lionel Conacher hockey Lionel Conacher "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' and'
+ ' his' ' wife' ',' ' the' ' former' ' Miss' ' America' ',' ' is' ' a'
+ ' great' ' example']" , a former professional baseball player , and his wife , the former Miss America , is a great example False for the 2000 Lionel Conacher Award, which 5 [' for', ' the', ' 2000', ' Lionel', ' Con', 'acher']
+525 145 The professional sport played by x -1 The professional sport played by Lionel Conacher hockey Lionel Conacher "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' and'
+ ' his' ' wife' ',' ' the' ' former' ' Miss' ' America' ',' ' is' ' a'
+ ' great' ' example']" , a former professional baseball player , and his wife , the former Miss America , is a great example False " Award
+" 5 [' Award', 'L', 'ion', 'el', ' Con', 'acher']
+526 145 The professional sport played by x -1 The professional sport played by Lionel Conacher hockey Lionel Conacher "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' and'
+ ' his' ' wife' ',' ' the' ' former' ' Miss' ' America' ',' ' is' ' a'
+ ' great' ' example']" , a former professional baseball player , and his wife , the former Miss America , is a great example False Thus, for instance, Lionel Conacher is in the Hall but 7 [' Thus', ',', ' for', ' instance', ',', ' Lionel', ' Con', 'acher']
+527 146 The professional sport played by x -1 The professional sport played by Derek Jeter baseball Derek Jeter "[',' ' the' ' Yankees' ""'"" ' shortstop' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a' ' great']" , the Yankees ' shortstop , is a great example of the kind of player who can be a great False for assignment when Derek Jeter returned from the 5 [' for', ' assignment', ' when', ' Derek', ' J', 'eter']
+528 146 The professional sport played by x -1 The professional sport played by Derek Jeter baseball Derek Jeter "[',' ' the' ' Yankees' ""'"" ' shortstop' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a' ' great']" , the Yankees ' shortstop , is a great example of the kind of player who can be a great False assignment when Derek Jeter returned from the 4 [' assignment', ' when', ' Derek', ' J', 'eter']
+529 146 The professional sport played by x -1 The professional sport played by Derek Jeter baseball Derek Jeter "[',' ' the' ' Yankees' ""'"" ' shortstop' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a' ' great']" , the Yankees ' shortstop , is a great example of the kind of player who can be a great False third, scoring Derek Jeter and Johnny Damon. 5 [' third', ',', ' scoring', ' Derek', ' J', 'eter']
+530 146 The professional sport played by x -1 The professional sport played by Derek Jeter baseball Derek Jeter "[',' ' the' ' Yankees' ""'"" ' shortstop' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a' ' great']" , the Yankees ' shortstop , is a great example of the kind of player who can be a great False including Derek Jeter and Mariano 3 [' including', ' Derek', ' J', 'eter']
+531 146 The professional sport played by x -1 The professional sport played by Derek Jeter baseball Derek Jeter "[',' ' the' ' Yankees' ""'"" ' shortstop' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' who' ' can' ' be' ' a' ' great']" , the Yankees ' shortstop , is a great example of the kind of player who can be a great False leadoff single to Derek Jeter before retiring the 6 [' lead', 'off', ' single', ' to', ' Derek', ' J', 'eter']
+532 147 The professional sport played by x -1 The professional sport played by Honus Wagner baseball Honus Wagner "[',' ' the' ' Pittsburgh' ' Pirates' ',' ' and' ' the' ' Pittsburgh'
+ ' Pirates' '.' '\n' '\n' 'The' ' Pittsburgh' ' Pirates' ' are' ' a'
+ ' professional' ' baseball' ' team']" ", the Pittsburgh Pirates , and the Pittsburgh Pirates .
+
+ The Pittsburgh Pirates are a professional baseball team" True Fame member Honus Wagner played on the 4 [' Fame', ' member', ' Hon', 'us', ' Wagner']
+533 147 The professional sport played by x -1 The professional sport played by Honus Wagner baseball Honus Wagner "[',' ' the' ' Pittsburgh' ' Pirates' ',' ' and' ' the' ' Pittsburgh'
+ ' Pirates' '.' '\n' '\n' 'The' ' Pittsburgh' ' Pirates' ' are' ' a'
+ ' professional' ' baseball' ' team']" ", the Pittsburgh Pirates , and the Pittsburgh Pirates .
+
+ The Pittsburgh Pirates are a professional baseball team" True " === T206 Honus Wagner recent sales ===
+" 5 [' ===', ' T', '206', ' Hon', 'us', ' Wagner']
+534 147 The professional sport played by x -1 The professional sport played by Honus Wagner baseball Honus Wagner "[',' ' the' ' Pittsburgh' ' Pirates' ',' ' and' ' the' ' Pittsburgh'
+ ' Pirates' '.' '\n' '\n' 'The' ' Pittsburgh' ' Pirates' ' are' ' a'
+ ' professional' ' baseball' ' team']" ", the Pittsburgh Pirates , and the Pittsburgh Pirates .
+
+ The Pittsburgh Pirates are a professional baseball team" True === T206 Honus Wagner recent sales 5 [' ===', ' T', '206', ' Hon', 'us', ' Wagner']
+535 147 The professional sport played by x -1 The professional sport played by Honus Wagner baseball Honus Wagner "[',' ' the' ' Pittsburgh' ' Pirates' ',' ' and' ' the' ' Pittsburgh'
+ ' Pirates' '.' '\n' '\n' 'The' ' Pittsburgh' ' Pirates' ' are' ' a'
+ ' professional' ' baseball' ' team']" ", the Pittsburgh Pirates , and the Pittsburgh Pirates .
+
+ The Pittsburgh Pirates are a professional baseball team" True August 11 he passed Honus Wagner climbing to 6 [' August', ' 11', ' he', ' passed', ' Hon', 'us', ' Wagner']
+536 147 The professional sport played by x -1 The professional sport played by Honus Wagner baseball Honus Wagner "[',' ' the' ' Pittsburgh' ' Pirates' ',' ' and' ' the' ' Pittsburgh'
+ ' Pirates' '.' '\n' '\n' 'The' ' Pittsburgh' ' Pirates' ' are' ' a'
+ ' professional' ' baseball' ' team']" ", the Pittsburgh Pirates , and the Pittsburgh Pirates .
+
+ The Pittsburgh Pirates are a professional baseball team" True triples, two – Honus Wagner and Roger Connor 7 [' tri', 'ples', ',', ' two', ' –', ' Hon', 'us', ' Wagner']
+537 148 The professional sport played by x -1 The professional sport played by Mickey Mantle baseball Mickey Mantle "[',' ' the' ' most' ' famous' ' baseball' ' player' ' of' ' all' ' time'
+ ',' ' was' ' born' ' in' ' this' ' city' '.' '\n' '\n' 'The' ' city']" ", the most famous baseball player of all time , was born in this city .
+
+ The city" True Jimmie Foxx and Mickey Mantle career marks in his 7 [' Jim', 'mie', ' Fox', 'x', ' and', ' Mickey', ' M', 'antle']
+538 148 The professional sport played by x -1 The professional sport played by Mickey Mantle baseball Mickey Mantle "[',' ' the' ' most' ' famous' ' baseball' ' player' ' of' ' all' ' time'
+ ',' ' was' ' born' ' in' ' this' ' city' '.' '\n' '\n' 'The' ' city']" ", the most famous baseball player of all time , was born in this city .
+
+ The city" True run, passing Mickey Mantle for 15th place on 5 [' run', ',', ' passing', ' Mickey', ' M', 'antle']
+539 148 The professional sport played by x -1 The professional sport played by Mickey Mantle baseball Mickey Mantle "[',' ' the' ' most' ' famous' ' baseball' ' player' ' of' ' all' ' time'
+ ',' ' was' ' born' ' in' ' this' ' city' '.' '\n' '\n' 'The' ' city']" ", the most famous baseball player of all time , was born in this city .
+
+ The city" True " ""a die-hard Mickey Mantle fan;"" the book included" 7 "[' ""', 'a', ' die', '-', 'hard', ' Mickey', ' M', 'antle']"
+540 148 The professional sport played by x -1 The professional sport played by Mickey Mantle baseball Mickey Mantle "[',' ' the' ' most' ' famous' ' baseball' ' player' ' of' ' all' ' time'
+ ',' ' was' ' born' ' in' ' this' ' city' '.' '\n' '\n' 'The' ' city']" ", the most famous baseball player of all time , was born in this city .
+
+ The city" True Jimmie Foxx and Mickey Mantle career marks in his 7 [' Jim', 'mie', ' Fox', 'x', ' and', ' Mickey', ' M', 'antle']
+541 148 The professional sport played by x -1 The professional sport played by Mickey Mantle baseball Mickey Mantle "[',' ' the' ' most' ' famous' ' baseball' ' player' ' of' ' all' ' time'
+ ',' ' was' ' born' ' in' ' this' ' city' '.' '\n' '\n' 'The' ' city']" ", the most famous baseball player of all time , was born in this city .
+
+ The city" True Roger Maris and Mickey Mantle in 1961. On 6 [' Roger', ' Mar', 'is', ' and', ' Mickey', ' M', 'antle']
+542 149 The professional sport played by x -1 The professional sport played by Bastian Schweinsteiger soccer Bastian Schweinsteiger "[',' ' the' ' German' ' midfielder' ',' ' is' ' a' ' very' ' good'
+ ' example' ' of' ' this' '.' ' He' ' is' ' a' ' very' ' good' ' player'
+ ',']" , the German midfielder , is a very good example of this . He is a very good player , False Although overshadowed by Bastian Schweinsteiger ’ s game-winning 8 [' Although', ' overshadowed', ' by', ' Bast', 'ian', ' Schwe', 'in', 'ste', 'iger']
+543 149 The professional sport played by x -1 The professional sport played by Bastian Schweinsteiger soccer Bastian Schweinsteiger "[',' ' the' ' German' ' midfielder' ',' ' is' ' a' ' very' ' good'
+ ' example' ' of' ' this' '.' ' He' ' is' ' a' ' very' ' good' ' player'
+ ',']" , the German midfielder , is a very good example of this . He is a very good player , False overshadowed by Bastian Schweinsteiger ’ s game-winning 7 [' overshadowed', ' by', ' Bast', 'ian', ' Schwe', 'in', 'ste', 'iger']
+544 149 The professional sport played by x -1 The professional sport played by Bastian Schweinsteiger soccer Bastian Schweinsteiger "[',' ' the' ' German' ' midfielder' ',' ' is' ' a' ' very' ' good'
+ ' example' ' of' ' this' '.' ' He' ' is' ' a' ' very' ' good' ' player'
+ ',']" , the German midfielder , is a very good example of this . He is a very good player , False 72nd minute for Bastian Schweinsteiger in a 7 – 1 9 [' 72', 'nd', ' minute', ' for', ' Bast', 'ian', ' Schwe', 'in', 'ste', 'iger']
+545 149 The professional sport played by x -1 The professional sport played by Bastian Schweinsteiger soccer Bastian Schweinsteiger "[',' ' the' ' German' ' midfielder' ',' ' is' ' a' ' very' ' good'
+ ' example' ' of' ' this' '.' ' He' ' is' ' a' ' very' ' good' ' player'
+ ',']" , the German midfielder , is a very good example of this . He is a very good player , False sliding tackle on Bastian Schweinsteiger in the first half; 8 [' sliding', ' tackle', ' on', ' Bast', 'ian', ' Schwe', 'in', 'ste', 'iger']
+546 150 The professional sport played by x -1 The professional sport played by Hakeem Olajuwon basketball Hakeem Olajuwon "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False NBA, only Hakeem Olajuwon and Andrei 9 [' NBA', ',', ' only', ' H', 'ake', 'em', ' Ol', 'aj', 'u', 'won']
+547 150 The professional sport played by x -1 The professional sport played by Hakeem Olajuwon basketball Hakeem Olajuwon "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False " 1984 – 2001: The Hakeem Olajuwon era ===
+" 11 [' 1984', ' –', ' 2001', ':', ' The', ' H', 'ake', 'em', ' Ol', 'aj', 'u', 'won']
+548 150 The professional sport played by x -1 The professional sport played by Hakeem Olajuwon basketball Hakeem Olajuwon "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False used to select Hakeem Olajuwon from the University 9 [' used', ' to', ' select', ' H', 'ake', 'em', ' Ol', 'aj', 'u', 'won']
+549 150 The professional sport played by x -1 The professional sport played by Hakeem Olajuwon basketball Hakeem Olajuwon "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False Ben Wallace and Hakeem Olajuwon are the only players 9 [' Ben', ' Wallace', ' and', ' H', 'ake', 'em', ' Ol', 'aj', 'u', 'won']
+550 150 The professional sport played by x -1 The professional sport played by Hakeem Olajuwon basketball Hakeem Olajuwon "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NBA 's all - time leading scorer , is a game that has been played for over a False " 1984 – 2001: The Hakeem Olajuwon era ===
+" 11 [' 1984', ' –', ' 2001', ':', ' The', ' H', 'ake', 'em', ' Ol', 'aj', 'u', 'won']
+551 151 The professional sport played by x -1 The professional sport played by Allen Iverson basketball Allen Iverson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a perfect example of the kind of player who False three-for-one trade of Allen Iverson for Chauncey Billups, 9 [' three', '-', 'for', '-', 'one', ' trade', ' of', ' Allen', ' I', 'verson']
+552 151 The professional sport played by x -1 The professional sport played by Allen Iverson basketball Allen Iverson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a perfect example of the kind of player who False Garnett and Allen Iverson completed the 5 [' Garn', 'ett', ' and', ' Allen', ' I', 'verson']
+553 151 The professional sport played by x -1 The professional sport played by Allen Iverson basketball Allen Iverson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a perfect example of the kind of player who False Antonio. They met Allen Iverson and the Philadelphia 6 [' Antonio', '.', ' They', ' met', ' Allen', ' I', 'verson']
+554 151 The professional sport played by x -1 The professional sport played by Allen Iverson basketball Allen Iverson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a perfect example of the kind of player who False three-for-one trade of Allen Iverson for Chauncey 9 [' three', '-', 'for', '-', 'one', ' trade', ' of', ' Allen', ' I', 'verson']
+555 151 The professional sport played by x -1 The professional sport played by Allen Iverson basketball Allen Iverson "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' perfect' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a perfect example of the kind of player who False 1996 Draft, behind Allen Iverson and Marcus Camby. 6 [' 1996', ' Draft', ',', ' behind', ' Allen', ' I', 'verson']
+556 153 The professional sport played by x -1 The professional sport played by Tom Seaver baseball Tom Seaver "[',' ' the' ' former' ' New' ' York' ' Mets' ' pitcher' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' be']" , the former New York Mets pitcher , is a great example of the kind of person who can be False (10) set by Tom Seaver in 1970. He finished 7 [' (', '10', ')', ' set', ' by', ' Tom', ' Sea', 'ver']
+557 153 The professional sport played by x -1 The professional sport played by Tom Seaver baseball Tom Seaver "[',' ' the' ' former' ' New' ' York' ' Mets' ' pitcher' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' be']" , the former New York Mets pitcher , is a great example of the kind of person who can be False record (10) set by Tom Seaver in 1970. He 8 [' record', ' (', '10', ')', ' set', ' by', ' Tom', ' Sea', 'ver']
+558 153 The professional sport played by x -1 The professional sport played by Tom Seaver baseball Tom Seaver "[',' ' the' ' former' ' New' ' York' ' Mets' ' pitcher' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' be']" , the former New York Mets pitcher , is a great example of the kind of person who can be False record (10) set by Tom Seaver in 1970. He finished 8 [' record', ' (', '10', ')', ' set', ' by', ' Tom', ' Sea', 'ver']
+559 154 The professional sport played by x -1 The professional sport played by Jozy Altidore soccer Jozy Altidore "[',' ' who' ' scored' ' the' ' game' '-' 'winning' ' goal' ' in' ' the'
+ ' final' ' minutes' ' of' ' the' ' game' '.' '\n' '\n' 'The' ' game']" ", who scored the game - winning goal in the final minutes of the game .
+
+ The game" False exchange sending Jozy Altidore in the other direction. 6 [' exchange', ' sending', ' Jo', 'zy', ' Alt', 'id', 'ore']
+560 154 The professional sport played by x -1 The professional sport played by Jozy Altidore soccer Jozy Altidore "[',' ' who' ' scored' ' the' ' game' '-' 'winning' ' goal' ' in' ' the'
+ ' final' ' minutes' ' of' ' the' ' game' '.' '\n' '\n' 'The' ' game']" ", who scored the game - winning goal in the final minutes of the game .
+
+ The game" False player exchange sending Jozy Altidore in the other direction. 7 [' player', ' exchange', ' sending', ' Jo', 'zy', ' Alt', 'id', 'ore']
+561 155 The professional sport played by x -1 The professional sport played by Tim Duncan basketball Tim Duncan "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False LeBron James and Tim Duncan are tied for 4 [' LeBron', ' James', ' and', ' Tim', ' Duncan']
+562 155 The professional sport played by x -1 The professional sport played by Tim Duncan basketball Tim Duncan "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False 1 ['Tim', ' Duncan']
+563 155 The professional sport played by x -1 The professional sport played by Tim Duncan basketball Tim Duncan "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False 1 ['Tim', ' Duncan']
+564 155 The professional sport played by x -1 The professional sport played by Tim Duncan basketball Tim Duncan "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False the Spurs in the Tim Duncan era, which began 5 [' the', ' Spurs', ' in', ' the', ' Tim', ' Duncan']
+565 155 The professional sport played by x -1 The professional sport played by Tim Duncan basketball Tim Duncan "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False Spurs in the Tim Duncan era, which began 4 [' Spurs', ' in', ' the', ' Tim', ' Duncan']
+566 157 The professional sport played by x -1 The professional sport played by Brett Favre football Brett Favre "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' a']" , the quarterback for the Green Bay Packers , is a great example of a player who has been a False them to select Brett Favre fell through. The 5 [' them', ' to', ' select', ' Brett', ' Fav', 're']
+567 157 The professional sport played by x -1 The professional sport played by Brett Favre football Brett Favre "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' a']" , the quarterback for the Green Bay Packers , is a great example of a player who has been a False " football quarterback Brett Favre was nicknamed ""The" 4 [' football', ' quarterback', ' Brett', ' Fav', 're']
+568 157 The professional sport played by x -1 The professional sport played by Brett Favre football Brett Favre "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' a']" , the quarterback for the Green Bay Packers , is a great example of a player who has been a False backing up Brett Favre during his first 4 [' backing', ' up', ' Brett', ' Fav', 're']
+569 157 The professional sport played by x -1 The professional sport played by Brett Favre football Brett Favre "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' a']" , the quarterback for the Green Bay Packers , is a great example of a player who has been a False Falcons for quarterback Brett Favre on February 10, 5 [' Falcons', ' for', ' quarterback', ' Brett', ' Fav', 're']
+570 157 The professional sport played by x -1 The professional sport played by Brett Favre football Brett Favre "[',' ' the' ' quarterback' ' for' ' the' ' Green' ' Bay' ' Packers' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' a']" , the quarterback for the Green Bay Packers , is a great example of a player who has been a False shoulder, Brett Favre was marked inactive 4 [' shoulder', ',', ' Brett', ' Fav', 're']
+571 158 The professional sport played by x -1 The professional sport played by Michael Strahan football Michael Strahan "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' person' ' who' ' has' ' been' ' able' ' to']" , who is a former NFL player , is a great example of a person who has been able to False careers sack leader Michael Strahan retired before 5 [' careers', ' sack', ' leader', ' Michael', ' Stra', 'han']
+572 158 The professional sport played by x -1 The professional sport played by Michael Strahan football Michael Strahan "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' person' ' who' ' has' ' been' ' able' ' to']" , who is a former NFL player , is a great example of a person who has been able to False sack leader Michael Strahan retired before 4 [' sack', ' leader', ' Michael', ' Stra', 'han']
+573 158 The professional sport played by x -1 The professional sport played by Michael Strahan football Michael Strahan "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' person' ' who' ' has' ' been' ' able' ' to']" , who is a former NFL player , is a great example of a person who has been able to False Marcus Allen and Michael Strahan who collectively 5 [' Marcus', ' Allen', ' and', ' Michael', ' Stra', 'han']
+574 158 The professional sport played by x -1 The professional sport played by Michael Strahan football Michael Strahan "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' person' ' who' ' has' ' been' ' able' ' to']" , who is a former NFL player , is a great example of a person who has been able to False sack leader Michael Strahan retired before 4 [' sack', ' leader', ' Michael', ' Stra', 'han']
+575 158 The professional sport played by x -1 The professional sport played by Michael Strahan football Michael Strahan "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' person' ' who' ' has' ' been' ' able' ' to']" , who is a former NFL player , is a great example of a person who has been able to False careers sack leader Michael Strahan retired before the 5 [' careers', ' sack', ' leader', ' Michael', ' Stra', 'han']
+576 159 The professional sport played by x -1 The professional sport played by Sergei Makarov hockey Sergei Makarov "[',' ' a' ' Russian' ' professional' ' ice' ' hockey' ' player' ',' ' who'
+ ' played' ' in' ' the' ' NHL' ' for' ' the' ' New' ' York' ' Islanders'
+ ',' ' New']" , a Russian professional ice hockey player , who played in the NHL for the New York Islanders , New True 1989 – 90. Sergei Makarov joined the Flames 6 [' 1989', ' –', ' 90', '.', ' Sergei', ' Mak', 'arov']
+577 159 The professional sport played by x -1 The professional sport played by Sergei Makarov hockey Sergei Makarov "[',' ' a' ' Russian' ' professional' ' ice' ' hockey' ' player' ',' ' who'
+ ' played' ' in' ' the' ' NHL' ' for' ' the' ' New' ' York' ' Islanders'
+ ',' ' New']" , a Russian professional ice hockey player , who played in the NHL for the New York Islanders , New True and 1998-2001). Sergei Makarov was informed on 7 [' and', ' 1998', '-', '2001', ').', ' Sergei', ' Mak', 'arov']
+578 159 The professional sport played by x -1 The professional sport played by Sergei Makarov hockey Sergei Makarov "[',' ' a' ' Russian' ' professional' ' ice' ' hockey' ' player' ',' ' who'
+ ' played' ' in' ' the' ' NHL' ' for' ' the' ' New' ' York' ' Islanders'
+ ',' ' New']" , a Russian professional ice hockey player , who played in the NHL for the New York Islanders , New True Vladimir Krutov, Sergei Makarov and Alexei Kasatonov. 7 [' Vladimir', ' Kr', 'ut', 'ov', ',', ' Sergei', ' Mak', 'arov']
+579 159 The professional sport played by x -1 The professional sport played by Sergei Makarov hockey Sergei Makarov "[',' ' a' ' Russian' ' professional' ' ice' ' hockey' ' player' ',' ' who'
+ ' played' ' in' ' the' ' NHL' ' for' ' the' ' New' ' York' ' Islanders'
+ ',' ' New']" , a Russian professional ice hockey player , who played in the NHL for the New York Islanders , New True Igor Larionov and Sergei Makarov on offence, as 7 [' Igor', ' Lar', 'ion', 'ov', ' and', ' Sergei', ' Mak', 'arov']
+580 159 The professional sport played by x -1 The professional sport played by Sergei Makarov hockey Sergei Makarov "[',' ' a' ' Russian' ' professional' ' ice' ' hockey' ' player' ',' ' who'
+ ' played' ' in' ' the' ' NHL' ' for' ' the' ' New' ' York' ' Islanders'
+ ',' ' New']" , a Russian professional ice hockey player , who played in the NHL for the New York Islanders , New True beginning in 1989 – 90. Sergei Makarov joined the 8 [' beginning', ' in', ' 1989', ' –', ' 90', '.', ' Sergei', ' Mak', 'arov']
+581 160 The professional sport played by x -1 The professional sport played by Danny Ainge baseball Danny Ainge "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False basketball operations Danny Ainge said all four 4 [' basketball', ' operations', ' Danny', ' A', 'inge']
+582 160 The professional sport played by x -1 The professional sport played by Danny Ainge baseball Danny Ainge "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False causing Suns coach Danny Ainge to play Duncan with 5 [' causing', ' Suns', ' coach', ' Danny', ' A', 'inge']
+583 160 The professional sport played by x -1 The professional sport played by Danny Ainge baseball Danny Ainge "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False causing Suns coach Danny Ainge to play Duncan with 5 [' causing', ' Suns', ' coach', ' Danny', ' A', 'inge']
+584 160 The professional sport played by x -1 The professional sport played by Danny Ainge baseball Danny Ainge "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False Suns coach Danny Ainge to play Duncan with 4 [' Suns', ' coach', ' Danny', ' A', 'inge']
+585 161 The professional sport played by x -1 The professional sport played by Barry Bonds baseball Barry Bonds "[',' ' the' ' former' ' San' ' Francisco' ' Giants' ' slug' 'ger' ','
+ ' was' ' arrested' ' in' ' San' ' Francisco' ' on' ' Tuesday' ' for'
+ ' allegedly' ' assaulting' ' a']" , the former San Francisco Giants slug ger , was arrested in San Francisco on Tuesday for allegedly assaulting a False playoff spot, and Barry Bonds hit a home run off 5 [' playoff', ' spot', ',', ' and', ' Barry', ' Bonds']
+586 161 The professional sport played by x -1 The professional sport played by Barry Bonds baseball Barry Bonds "[',' ' the' ' former' ' San' ' Francisco' ' Giants' ' slug' 'ger' ','
+ ' was' ' arrested' ' in' ' San' ' Francisco' ' on' ' Tuesday' ' for'
+ ' allegedly' ' assaulting' ' a']" , the former San Francisco Giants slug ger , was arrested in San Francisco on Tuesday for allegedly assaulting a False The Giants' Barry Bonds drew criticism 4 "[' The', ' Giants', ""'"", ' Barry', ' Bonds']"
+587 161 The professional sport played by x -1 The professional sport played by Barry Bonds baseball Barry Bonds "[',' ' the' ' former' ' San' ' Francisco' ' Giants' ' slug' 'ger' ','
+ ' was' ' arrested' ' in' ' San' ' Francisco' ' on' ' Tuesday' ' for'
+ ' allegedly' ' assaulting' ' a']" , the former San Francisco Giants slug ger , was arrested in San Francisco on Tuesday for allegedly assaulting a False 2 ['B', 'arry', ' Bonds']
+588 161 The professional sport played by x -1 The professional sport played by Barry Bonds baseball Barry Bonds "[',' ' the' ' former' ' San' ' Francisco' ' Giants' ' slug' 'ger' ','
+ ' was' ' arrested' ' in' ' San' ' Francisco' ' on' ' Tuesday' ' for'
+ ' allegedly' ' assaulting' ' a']" , the former San Francisco Giants slug ger , was arrested in San Francisco on Tuesday for allegedly assaulting a False Love Me, Hate Me: Barry Bonds and the Making of 7 [' Love', ' Me', ',', ' Hate', ' Me', ':', ' Barry', ' Bonds']
+589 161 The professional sport played by x -1 The professional sport played by Barry Bonds baseball Barry Bonds "[',' ' the' ' former' ' San' ' Francisco' ' Giants' ' slug' 'ger' ','
+ ' was' ' arrested' ' in' ' San' ' Francisco' ' on' ' Tuesday' ' for'
+ ' allegedly' ' assaulting' ' a']" , the former San Francisco Giants slug ger , was arrested in San Francisco on Tuesday for allegedly assaulting a False same game that Barry Bonds hit his record-breaking 4 [' same', ' game', ' that', ' Barry', ' Bonds']
+590 162 The professional sport played by x -1 The professional sport played by Roger Staubach football Roger Staubach "[',' ' the' ' quarterback' ' for' ' the' ' Dallas' ' Cowboys' ',' ' was'
+ ' a' ' great' ' quarterback' '.' ' He' ' was' ' a' ' great'
+ ' quarterback' '.' ' He']" , the quarterback for the Dallas Cowboys , was a great quarterback . He was a great quarterback . He False Cowboys quarterback Roger Staubach near the sideline 5 [' Cowboys', ' quarterback', ' Roger', ' Sta', 'ub', 'ach']
+591 163 The professional sport played by x -1 The professional sport played by Roy Campanella baseball Roy Campanella "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' won' ' the' ' World' ' Series' ' in'
+ ' 1955']" ", the first black player in the major leagues .
+
+ The Dodgers won the World Series in 1955" False consecutive games are Roy Campanella (1950), Adrian 6 [' consecutive', ' games', ' are', ' Roy', ' Camp', 'an', 'ella']
+592 163 The professional sport played by x -1 The professional sport played by Roy Campanella baseball Roy Campanella "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' won' ' the' ' World' ' Series' ' in'
+ ' 1955']" ", the first black player in the major leagues .
+
+ The Dodgers won the World Series in 1955" False Dodgers catcher Roy Campanella had a medical billing 5 [' Dodgers', ' catcher', ' Roy', ' Camp', 'an', 'ella']
+593 163 The professional sport played by x -1 The professional sport played by Roy Campanella baseball Roy Campanella "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' won' ' the' ' World' ' Series' ' in'
+ ' 1955']" ", the first black player in the major leagues .
+
+ The Dodgers won the World Series in 1955" False Robinson flied out and Roy Campanella grounded out. 8 [' Robinson', ' fl', 'ied', ' out', ' and', ' Roy', ' Camp', 'an', 'ella']
+594 163 The professional sport played by x -1 The professional sport played by Roy Campanella baseball Roy Campanella "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' won' ' the' ' World' ' Series' ' in'
+ ' 1955']" ", the first black player in the major leagues .
+
+ The Dodgers won the World Series in 1955" False consecutive games are Roy Campanella (1950), Adrian 6 [' consecutive', ' games', ' are', ' Roy', ' Camp', 'an', 'ella']
+595 163 The professional sport played by x -1 The professional sport played by Roy Campanella baseball Roy Campanella "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' won' ' the' ' World' ' Series' ' in'
+ ' 1955']" ", the first black player in the major leagues .
+
+ The Dodgers won the World Series in 1955" False voting, behind Roy Campanella and Duke Snider. 6 [' voting', ',', ' behind', ' Roy', ' Camp', 'an', 'ella']
+596 164 The professional sport played by x -1 The professional sport played by Kevin Love basketball Kevin Love "[',' ' the' ' NBA' ""'s"" ' most' ' valuable' ' player' ',' ' is' ' a'
+ ' basketball' ' player' ' who' ' is' ' a' ' forward' ' for' ' the'
+ ' Cleveland' ' Cavaliers']" , the NBA 's most valuable player , is a basketball player who is a forward for the Cleveland Cavaliers True All-Star forward Kevin Love and point guard Ricky 5 [' All', '-', 'Star', ' forward', ' Kevin', ' Love']
+597 164 The professional sport played by x -1 The professional sport played by Kevin Love basketball Kevin Love "[',' ' the' ' NBA' ""'s"" ' most' ' valuable' ' player' ',' ' is' ' a'
+ ' basketball' ' player' ' who' ' is' ' a' ' forward' ' for' ' the'
+ ' Cleveland' ' Cavaliers']" , the NBA 's most valuable player , is a basketball player who is a forward for the Cleveland Cavaliers True Hazzard (although Kevin Love was actually the last 6 [' H', 'azz', 'ard', ' (', 'although', ' Kevin', ' Love']
+598 164 The professional sport played by x -1 The professional sport played by Kevin Love basketball Kevin Love "[',' ' the' ' NBA' ""'s"" ' most' ' valuable' ' player' ',' ' is' ' a'
+ ' basketball' ' player' ' who' ' is' ' a' ' forward' ' for' ' the'
+ ' Cleveland' ' Cavaliers']" , the NBA 's most valuable player , is a basketball player who is a forward for the Cleveland Cavaliers True Hazzard (although Kevin Love was actually 6 [' H', 'azz', 'ard', ' (', 'although', ' Kevin', ' Love']
+599 164 The professional sport played by x -1 The professional sport played by Kevin Love basketball Kevin Love "[',' ' the' ' NBA' ""'s"" ' most' ' valuable' ' player' ',' ' is' ' a'
+ ' basketball' ' player' ' who' ' is' ' a' ' forward' ' for' ' the'
+ ' Cleveland' ' Cavaliers']" , the NBA 's most valuable player , is a basketball player who is a forward for the Cleveland Cavaliers True Cavaliers acquired Kevin Love from the Minnesota 3 [' Cavaliers', ' acquired', ' Kevin', ' Love']
+600 164 The professional sport played by x -1 The professional sport played by Kevin Love basketball Kevin Love "[',' ' the' ' NBA' ""'s"" ' most' ' valuable' ' player' ',' ' is' ' a'
+ ' basketball' ' player' ' who' ' is' ' a' ' forward' ' for' ' the'
+ ' Cleveland' ' Cavaliers']" , the NBA 's most valuable player , is a basketball player who is a forward for the Cleveland Cavaliers True starters Ricky Rubio, Kevin Love and Luke Ridnour, 5 [' starters', ' Ricky', ' Rubio', ',', ' Kevin', ' Love']
+601 165 The professional sport played by x -1 The professional sport played by Wilt Chamberlain basketball Wilt Chamberlain "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False All-Star Game since Wilt Chamberlain recorded 22 in 1967. 7 [' All', '-', 'Star', ' Game', ' since', ' W', 'ilt', ' Chamberlain']
+602 165 The professional sport played by x -1 The professional sport played by Wilt Chamberlain basketball Wilt Chamberlain "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False Elgin Baylor and Wilt Chamberlain surpass his 27.0 points 6 [' El', 'gin', ' Baylor', ' and', ' W', 'ilt', ' Chamberlain']
+603 165 The professional sport played by x -1 The professional sport played by Wilt Chamberlain basketball Wilt Chamberlain "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False " Chamberlain ===
+" 4 [' Chamberlain', ' ===', 'W', 'ilt', ' Chamberlain']
+604 165 The professional sport played by x -1 The professional sport played by Wilt Chamberlain basketball Wilt Chamberlain "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False " Wilt Chamberlain =
+" 2 [' W', 'ilt', ' Chamberlain']
+605 165 The professional sport played by x -1 The professional sport played by Wilt Chamberlain basketball Wilt Chamberlain "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False Most Valuable Player Wilt Chamberlain of the Philadelphia 6 [' Most', ' Val', 'uable', ' Player', ' W', 'ilt', ' Chamberlain']
+606 166 The professional sport played by x -1 The professional sport played by Youri Djorkaeff soccer Youri Djorkaeff "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional soccer player , is a game that is played by two teams of two players . True minutes remaining. Youri Djorkaeff would have put 8 [' minutes', ' remaining', '.', ' You', 'ri', ' Dj', 'ork', 'ae', 'ff']
+607 166 The professional sport played by x -1 The professional sport played by Youri Djorkaeff soccer Youri Djorkaeff "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional soccer player , is a game that is played by two teams of two players . True minutes remaining. Youri Djorkaeff would have put Bolton 8 [' minutes', ' remaining', '.', ' You', 'ri', ' Dj', 'ork', 'ae', 'ff']
+608 166 The professional sport played by x -1 The professional sport played by Youri Djorkaeff soccer Youri Djorkaeff "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional soccer player , is a game that is played by two teams of two players . True remaining. Youri Djorkaeff would have put Bolton 7 [' remaining', '.', ' You', 'ri', ' Dj', 'ork', 'ae', 'ff']
+609 166 The professional sport played by x -1 The professional sport played by Youri Djorkaeff soccer Youri Djorkaeff "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional soccer player , is a game that is played by two teams of two players . True minutes remaining. Youri Djorkaeff would have put Bolton 8 [' minutes', ' remaining', '.', ' You', 'ri', ' Dj', 'ork', 'ae', 'ff']
+610 166 The professional sport played by x -1 The professional sport played by Youri Djorkaeff soccer Youri Djorkaeff "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional soccer player , is a game that is played by two teams of two players . True remaining. Youri Djorkaeff would have put 7 [' remaining', '.', ' You', 'ri', ' Dj', 'ork', 'ae', 'ff']
+611 167 The professional sport played by x -1 The professional sport played by Steve Yzerman hockey Steve Yzerman "[',' ' the' ' Detroit' ' Red' ' Wings' ',' ' and' ' the' ' Detroit'
+ ' Tigers' '.' '\n' '\n' 'The' ' Detroit' ' Red' ' Wings' ' are' ' a'
+ ' professional']" ", the Detroit Red Wings , and the Detroit Tigers .
+
+ The Detroit Red Wings are a professional" False Bourque and Steve Yzerman are the others 6 [' Bour', 'que', ' and', ' Steve', ' Y', 'z', 'erman']
+612 167 The professional sport played by x -1 The professional sport played by Steve Yzerman hockey Steve Yzerman "[',' ' the' ' Detroit' ' Red' ' Wings' ',' ' and' ' the' ' Detroit'
+ ' Tigers' '.' '\n' '\n' 'The' ' Detroit' ' Red' ' Wings' ' are' ' a'
+ ' professional']" ", the Detroit Red Wings , and the Detroit Tigers .
+
+ The Detroit Red Wings are a professional" False retirement of Steve Yzerman a month later, 5 [' retirement', ' of', ' Steve', ' Y', 'z', 'erman']
+613 167 The professional sport played by x -1 The professional sport played by Steve Yzerman hockey Steve Yzerman "[',' ' the' ' Detroit' ' Red' ' Wings' ',' ' and' ' the' ' Detroit'
+ ' Tigers' '.' '\n' '\n' 'The' ' Detroit' ' Red' ' Wings' ' are' ' a'
+ ' professional']" ", the Detroit Red Wings , and the Detroit Tigers .
+
+ The Detroit Red Wings are a professional" False retirement of Steve Yzerman a month later, on July 5 [' retirement', ' of', ' Steve', ' Y', 'z', 'erman']
+614 167 The professional sport played by x -1 The professional sport played by Steve Yzerman hockey Steve Yzerman "[',' ' the' ' Detroit' ' Red' ' Wings' ',' ' and' ' the' ' Detroit'
+ ' Tigers' '.' '\n' '\n' 'The' ' Detroit' ' Red' ' Wings' ' are' ' a'
+ ' professional']" ", the Detroit Red Wings , and the Detroit Tigers .
+
+ The Detroit Red Wings are a professional" False General Manager Steve Yzerman continued to express 5 [' General', ' Manager', ' Steve', ' Y', 'z', 'erman']
+615 167 The professional sport played by x -1 The professional sport played by Steve Yzerman hockey Steve Yzerman "[',' ' the' ' Detroit' ' Red' ' Wings' ',' ' and' ' the' ' Detroit'
+ ' Tigers' '.' '\n' '\n' 'The' ' Detroit' ' Red' ' Wings' ' are' ' a'
+ ' professional']" ", the Detroit Red Wings , and the Detroit Tigers .
+
+ The Detroit Red Wings are a professional" False retirement of Steve Yzerman at the conclusion 5 [' retirement', ' of', ' Steve', ' Y', 'z', 'erman']
+616 168 The professional sport played by x -1 The professional sport played by Dean Cain football Dean Cain "[',' ' who' ' plays' ' the' ' role' ' of' ' the' ' hero' ',' ' is' ' a'
+ ' former' ' NFL' ' player' ' who' ' has' ' been' ' in' ' the' ' league']" , who plays the role of the hero , is a former NFL player who has been in the league False " on CBS, starring Dean Cain and Robert Patrick.
+" 5 [' on', ' CBS', ',', ' starring', ' Dean', ' Cain']
+617 168 The professional sport played by x -1 The professional sport played by Dean Cain football Dean Cain "[',' ' who' ' plays' ' the' ' role' ' of' ' the' ' hero' ',' ' is' ' a'
+ ' former' ' NFL' ' player' ' who' ' has' ' been' ' in' ' the' ' league']" , who plays the role of the hero , is a former NFL player who has been in the league False crime, while Dean Cain plays the main 4 [' crime', ',', ' while', ' Dean', ' Cain']
+618 168 The professional sport played by x -1 The professional sport played by Dean Cain football Dean Cain "[',' ' who' ' plays' ' the' ' role' ' of' ' the' ' hero' ',' ' is' ' a'
+ ' former' ' NFL' ' player' ' who' ' has' ' been' ' in' ' the' ' league']" , who plays the role of the hero , is a former NFL player who has been in the league False " CBS, starring Dean Cain and Robert Patrick.
+" 4 [' CBS', ',', ' starring', ' Dean', ' Cain']
+619 168 The professional sport played by x -1 The professional sport played by Dean Cain football Dean Cain "[',' ' who' ' plays' ' the' ' role' ' of' ' the' ' hero' ',' ' is' ' a'
+ ' former' ' NFL' ' player' ' who' ' has' ' been' ' in' ' the' ' league']" , who plays the role of the hero , is a former NFL player who has been in the league False organized crime, while Dean Cain plays the main 5 [' organized', ' crime', ',', ' while', ' Dean', ' Cain']
+620 168 The professional sport played by x -1 The professional sport played by Dean Cain football Dean Cain "[',' ' who' ' plays' ' the' ' role' ' of' ' the' ' hero' ',' ' is' ' a'
+ ' former' ' NFL' ' player' ' who' ' has' ' been' ' in' ' the' ' league']" , who plays the role of the hero , is a former NFL player who has been in the league False " on CBS, starring Dean Cain and Robert Patrick.
+" 5 [' on', ' CBS', ',', ' starring', ' Dean', ' Cain']
+621 169 The professional sport played by x -1 The professional sport played by Don Shula football Don Shula "[',' ' the' ' coach' ' of' ' the' ' Miami' ' Dolphins' ',' ' was' ' a'
+ ' great' ' coach' '.' ' He' ' was' ' a' ' great' ' coach' '.' ' He']" , the coach of the Miami Dolphins , was a great coach . He was a great coach . He False (0.8 km) later, the Don Shula Expressway 11 [' (', '0', '.', '8', ' km', ')', ' later', ',', ' the', ' Don', ' Sh', 'ula']
+622 169 The professional sport played by x -1 The professional sport played by Don Shula football Don Shula "[',' ' the' ' coach' ' of' ' the' ' Miami' ' Dolphins' ',' ' was' ' a'
+ ' great' ' coach' '.' ' He' ' was' ' a' ' great' ' coach' '.' ' He']" , the coach of the Miami Dolphins , was a great coach . He was a great coach . He False Drive, Bird Road and Don Shula Expressway interchanges 7 [' Drive', ',', ' Bird', ' Road', ' and', ' Don', ' Sh', 'ula']
+623 169 The professional sport played by x -1 The professional sport played by Don Shula football Don Shula "[',' ' the' ' coach' ' of' ' the' ' Miami' ' Dolphins' ',' ' was' ' a'
+ ' great' ' coach' '.' ' He' ' was' ' a' ' great' ' coach' '.' ' He']" , the coach of the Miami Dolphins , was a great coach . He was a great coach . He False Dolphins head coach Don Shula presented the 5 [' Dolphins', ' head', ' coach', ' Don', ' Sh', 'ula']
+624 169 The professional sport played by x -1 The professional sport played by Don Shula football Don Shula "[',' ' the' ' coach' ' of' ' the' ' Miami' ' Dolphins' ',' ' was' ' a'
+ ' great' ' coach' '.' ' He' ' was' ' a' ' great' ' coach' '.' ' He']" , the coach of the Miami Dolphins , was a great coach . He was a great coach . He False Bird Road and Don Shula Expressway interchanges 5 [' Bird', ' Road', ' and', ' Don', ' Sh', 'ula']
+625 169 The professional sport played by x -1 The professional sport played by Don Shula football Don Shula "[',' ' the' ' coach' ' of' ' the' ' Miami' ' Dolphins' ',' ' was' ' a'
+ ' great' ' coach' '.' ' He' ' was' ' a' ' great' ' coach' '.' ' He']" , the coach of the Miami Dolphins , was a great coach . He was a great coach . He False south of the Don Shula Expressway 5 [' south', ' of', ' the', ' Don', ' Sh', 'ula']
+626 170 The professional sport played by x -1 The professional sport played by Bronko Nagurski football Bronko Nagurski "[',' ' the' ' former' ' Chicago' ' Bears' ' linebacker' ',' ' was' ' a'
+ ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person' '.']" , the former Chicago Bears linebacker , was a great player , but he was also a great person . False watch list for the Bronko Nagurski Trophy (awarded to 8 [' watch', ' list', ' for', ' the', ' Bron', 'ko', ' Nag', 'urs', 'ki']
+627 170 The professional sport played by x -1 The professional sport played by Bronko Nagurski football Bronko Nagurski "[',' ' the' ' former' ' Chicago' ' Bears' ' linebacker' ',' ' was' ' a'
+ ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person' '.']" , the former Chicago Bears linebacker , was a great player , but he was also a great person . False Award and the Bronko Nagurski Trophy; and William 7 [' Award', ' and', ' the', ' Bron', 'ko', ' Nag', 'urs', 'ki']
+628 170 The professional sport played by x -1 The professional sport played by Bronko Nagurski football Bronko Nagurski "[',' ' the' ' former' ' Chicago' ' Bears' ' linebacker' ',' ' was' ' a'
+ ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person' '.']" , the former Chicago Bears linebacker , was a great player , but he was also a great person . False " Heavyweight Champion Bronko Nagurski in a first-ever ""champion" 7 [' Heavy', 'weight', ' Champion', ' Bron', 'ko', ' Nag', 'urs', 'ki']
+629 170 The professional sport played by x -1 The professional sport played by Bronko Nagurski football Bronko Nagurski "[',' ' the' ' former' ' Chicago' ' Bears' ' linebacker' ',' ' was' ' a'
+ ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person' '.']" , the former Chicago Bears linebacker , was a great player , but he was also a great person . False Moore was awarded the Bronko Nagurski Trophy, given 8 [' Moore', ' was', ' awarded', ' the', ' Bron', 'ko', ' Nag', 'urs', 'ki']
+630 170 The professional sport played by x -1 The professional sport played by Bronko Nagurski football Bronko Nagurski "[',' ' the' ' former' ' Chicago' ' Bears' ' linebacker' ',' ' was' ' a'
+ ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person' '.']" , the former Chicago Bears linebacker , was a great player , but he was also a great person . False Hightower for the Bronko Nagurski Trophy; McElroy for 9 [' H', 'ight', 'ower', ' for', ' the', ' Bron', 'ko', ' Nag', 'urs', 'ki']
+631 171 The professional sport played by x -1 The professional sport played by Ara Parseghian football Ara Parseghian "[',' ' the' ' former' ' Notre' ' Dame' ' coach' ',' ' was' ' a' ' great'
+ ' coach' ',' ' but' ' he' ' was' ' a' ' great' ' coach' ' because' ' he']" , the former Notre Dame coach , was a great coach , but he was a great coach because he False the game 14 – 7. Ara Parseghian was a standout on both 10 [' the', ' game', ' 14', ' –', ' 7', '.', ' Ara', ' Par', 'se', 'gh', 'ian']
+632 171 The professional sport played by x -1 The professional sport played by Ara Parseghian football Ara Parseghian "[',' ' the' ' former' ' Notre' ' Dame' ' coach' ',' ' was' ' a' ' great'
+ ' coach' ',' ' but' ' he' ' was' ' a' ' great' ' coach' ' because' ' he']" , the former Notre Dame coach , was a great coach , but he was a great coach because he False the third quarter. Ara Parseghian ran for another 8 [' the', ' third', ' quarter', '.', ' Ara', ' Par', 'se', 'gh', 'ian']
+633 171 The professional sport played by x -1 The professional sport played by Ara Parseghian football Ara Parseghian "[',' ' the' ' former' ' Notre' ' Dame' ' coach' ',' ' was' ' a' ' great'
+ ' coach' ',' ' but' ' he' ' was' ' a' ' great' ' coach' ' because' ' he']" , the former Notre Dame coach , was a great coach , but he was a great coach because he False Stu Holcomb. Ara Parseghian was named as 9 [' St', 'u', ' Hol', 'comb', '.', ' Ara', ' Par', 'se', 'gh', 'ian']
+634 171 The professional sport played by x -1 The professional sport played by Ara Parseghian football Ara Parseghian "[',' ' the' ' former' ' Notre' ' Dame' ' coach' ',' ' was' ' a' ' great'
+ ' coach' ',' ' but' ' he' ' was' ' a' ' great' ' coach' ' because' ' he']" , the former Notre Dame coach , was a great coach , but he was a great coach because he False touchdowns by Ara Parseghian and Bill Boedeker, 6 [' touchdowns', ' by', ' Ara', ' Par', 'se', 'gh', 'ian']
+635 171 The professional sport played by x -1 The professional sport played by Ara Parseghian football Ara Parseghian "[',' ' the' ' former' ' Notre' ' Dame' ' coach' ',' ' was' ' a' ' great'
+ ' coach' ',' ' but' ' he' ' was' ' a' ' great' ' coach' ' because' ' he']" , the former Notre Dame coach , was a great coach , but he was a great coach because he False Stu Holcomb. Ara Parseghian was named as his 9 [' St', 'u', ' Hol', 'comb', '.', ' Ara', ' Par', 'se', 'gh', 'ian']
+636 173 The professional sport played by x -1 The professional sport played by Ray Allen basketball Ray Allen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Celtics' ',' ' and'
+ ' the' ' NBA' ' champion' ' Boston' ' Celtics' '.' '\n' '\n' 'The'
+ ' Celtics']" ", who was a member of the Celtics , and the NBA champion Boston Celtics .
+
+ The Celtics" False " friendship with Ray Allen of the Celtics.
+" 3 [' friendship', ' with', ' Ray', ' Allen']
+637 173 The professional sport played by x -1 The professional sport played by Ray Allen basketball Ray Allen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Celtics' ',' ' and'
+ ' the' ' NBA' ' champion' ' Boston' ' Celtics' '.' '\n' '\n' 'The'
+ ' Celtics']" ", who was a member of the Celtics , and the NBA champion Boston Celtics .
+
+ The Celtics" False shooting guard Ray Allen to be among 3 [' shooting', ' guard', ' Ray', ' Allen']
+638 173 The professional sport played by x -1 The professional sport played by Ray Allen basketball Ray Allen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Celtics' ',' ' and'
+ ' the' ' NBA' ' champion' ' Boston' ' Celtics' '.' '\n' '\n' 'The'
+ ' Celtics']" ", who was a member of the Celtics , and the NBA champion Boston Celtics .
+
+ The Celtics" False academics Ray Allen Billington, Richard 2 [' academics', ' Ray', ' Allen']
+639 173 The professional sport played by x -1 The professional sport played by Ray Allen basketball Ray Allen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Celtics' ',' ' and'
+ ' the' ' NBA' ' champion' ' Boston' ' Celtics' '.' '\n' '\n' 'The'
+ ' Celtics']" ", who was a member of the Celtics , and the NBA champion Boston Celtics .
+
+ The Celtics" False due to injury and Ray Allen was named to replace 5 [' due', ' to', ' injury', ' and', ' Ray', ' Allen']
+640 173 The professional sport played by x -1 The professional sport played by Ray Allen basketball Ray Allen "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Celtics' ',' ' and'
+ ' the' ' NBA' ' champion' ' Boston' ' Celtics' '.' '\n' '\n' 'The'
+ ' Celtics']" ", who was a member of the Celtics , and the NBA champion Boston Celtics .
+
+ The Celtics" False shooting guard Ray Allen to be among the most 3 [' shooting', ' guard', ' Ray', ' Allen']
+641 174 The professional sport played by x -1 The professional sport played by Frank Robinson baseball Frank Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Giants'
+ '.' '\n' '\n' 'The' ' Giants' ' were' ' the' ' first' ' team' ' to']" ", who was a member of the New York Giants .
+
+ The Giants were the first team to" False stadium was Frank Robinson in 1966; his blast 3 [' stadium', ' was', ' Frank', ' Robinson']
+642 174 The professional sport played by x -1 The professional sport played by Frank Robinson baseball Frank Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Giants'
+ '.' '\n' '\n' 'The' ' Giants' ' were' ' the' ' first' ' team' ' to']" ", who was a member of the New York Giants .
+
+ The Giants were the first team to" False a home run to Frank Robinson in one inning 5 [' a', ' home', ' run', ' to', ' Frank', ' Robinson']
+643 174 The professional sport played by x -1 The professional sport played by Frank Robinson baseball Frank Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Giants'
+ '.' '\n' '\n' 'The' ' Giants' ' were' ' the' ' first' ' team' ' to']" ", who was a member of the New York Giants .
+
+ The Giants were the first team to" False including a home run to Frank Robinson in one inning 6 [' including', ' a', ' home', ' run', ' to', ' Frank', ' Robinson']
+644 174 The professional sport played by x -1 The professional sport played by Frank Robinson baseball Frank Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Giants'
+ '.' '\n' '\n' 'The' ' Giants' ' were' ' the' ' first' ' team' ' to']" ", who was a member of the New York Giants .
+
+ The Giants were the first team to" False including a home run to Frank Robinson in one inning of 6 [' including', ' a', ' home', ' run', ' to', ' Frank', ' Robinson']
+645 174 The professional sport played by x -1 The professional sport played by Frank Robinson baseball Frank Robinson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Giants'
+ '.' '\n' '\n' 'The' ' Giants' ' were' ' the' ' first' ' team' ' to']" ", who was a member of the New York Giants .
+
+ The Giants were the first team to" False of his PJC tenure, Frank Robinson (to whom Robinson 7 [' of', ' his', ' P', 'JC', ' tenure', ',', ' Frank', ' Robinson']
+646 175 The professional sport played by x -1 The professional sport played by Carlos Beltrán baseball Carlos Beltrán "[',' ' who' ' is' ' a' ' former' ' MLB' ' player' ' for' ' the' ' San'
+ ' Francisco' ' Giants' '.' '\n' '\n' 'The' ' game' ' is' ' played' ' on']" ", who is a former MLB player for the San Francisco Giants .
+
+ The game is played on" False York Mets' Carlos Beltrán on the All-Star roster 6 "[' York', ' Mets', ""'"", ' Carlos', ' Belt', 'r', 'án']"
+647 175 The professional sport played by x -1 The professional sport played by Carlos Beltrán baseball Carlos Beltrán "[',' ' who' ' is' ' a' ' former' ' MLB' ' player' ' for' ' the' ' San'
+ ' Francisco' ' Giants' '.' '\n' '\n' 'The' ' game' ' is' ' played' ' on']" ", who is a former MLB player for the San Francisco Giants .
+
+ The game is played on" False Adrian Johnson; the Carlos Beltrán line drive appeared 7 [' Adrian', ' Johnson', ';', ' the', ' Carlos', ' Belt', 'r', 'án']
+648 175 The professional sport played by x -1 The professional sport played by Carlos Beltrán baseball Carlos Beltrán "[',' ' who' ' is' ' a' ' former' ' MLB' ' player' ' for' ' the' ' San'
+ ' Francisco' ' Giants' '.' '\n' '\n' 'The' ' game' ' is' ' played' ' on']" ", who is a former MLB player for the San Francisco Giants .
+
+ The game is played on" False teammates Lance Lynn, Carlos Beltrán, Rafael Furcal, and 7 [' teammates', ' Lance', ' Lynn', ',', ' Carlos', ' Belt', 'r', 'án']
+649 175 The professional sport played by x -1 The professional sport played by Carlos Beltrán baseball Carlos Beltrán "[',' ' who' ' is' ' a' ' former' ' MLB' ' player' ' for' ' the' ' San'
+ ' Francisco' ' Giants' '.' '\n' '\n' 'The' ' game' ' is' ' played' ' on']" ", who is a former MLB player for the San Francisco Giants .
+
+ The game is played on" False Matt Carpenter and Carlos Beltrán and a single 6 [' Matt', ' Carpenter', ' and', ' Carlos', ' Belt', 'r', 'án']
+650 175 The professional sport played by x -1 The professional sport played by Carlos Beltrán baseball Carlos Beltrán "[',' ' who' ' is' ' a' ' former' ' MLB' ' player' ' for' ' the' ' San'
+ ' Francisco' ' Giants' '.' '\n' '\n' 'The' ' game' ' is' ' played' ' on']" ", who is a former MLB player for the San Francisco Giants .
+
+ The game is played on" False 87.5 % (91-of-104). Carlos Beltrán was the best, at 88.3 14 [' 87', '.', '5', ' %', ' (', '91', '-', 'of', '-', '104', ').', ' Carlos', ' Belt', 'r', 'án']
+651 176 The professional sport played by x -1 The professional sport played by Emmitt Smith football Emmitt Smith "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys' '.' '\n' '\n']" ", the NFL 's all - time leading rusher , was a member of the Dallas Cowboys .
+
+" False Deion Sanders, Emmitt Smith and Michael Irvin. 6 [' De', 'ion', ' Sanders', ',', ' Em', 'mitt', ' Smith']
+652 176 The professional sport played by x -1 The professional sport played by Emmitt Smith football Emmitt Smith "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys' '.' '\n' '\n']" ", the NFL 's all - time leading rusher , was a member of the Dallas Cowboys .
+
+" False source of inspiration. Emmitt Smith tearfully paid homage 6 [' source', ' of', ' inspiration', '.', ' Em', 'mitt', ' Smith']
+653 176 The professional sport played by x -1 The professional sport played by Emmitt Smith football Emmitt Smith "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys' '.' '\n' '\n']" ", the NFL 's all - time leading rusher , was a member of the Dallas Cowboys .
+
+" False Deion Sanders, Emmitt Smith and Michael Irvin. 6 [' De', 'ion', ' Sanders', ',', ' Em', 'mitt', ' Smith']
+654 176 The professional sport played by x -1 The professional sport played by Emmitt Smith football Emmitt Smith "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys' '.' '\n' '\n']" ", the NFL 's all - time leading rusher , was a member of the Dallas Cowboys .
+
+" False idolized Deion Sanders, Emmitt Smith and Michael 8 [' idol', 'ized', ' De', 'ion', ' Sanders', ',', ' Em', 'mitt', ' Smith']
+655 176 The professional sport played by x -1 The professional sport played by Emmitt Smith football Emmitt Smith "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Dallas' ' Cowboys' '.' '\n' '\n']" ", the NFL 's all - time leading rusher , was a member of the Dallas Cowboys .
+
+" False NFL season, when Emmitt Smith broke his record. 6 [' NFL', ' season', ',', ' when', ' Em', 'mitt', ' Smith']
+656 178 The professional sport played by x -1 The professional sport played by Hugo Sánchez soccer Hugo Sánchez "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' has' ' been' ' played' ' in' ' the' ' United' ' States'
+ ' since' ' the']" , a former professional soccer player , is a game that has been played in the United States since the True Mexican forward Hugo Sánchez became notable 6 [' Mexican', ' forward', ' Hugo', ' S', 'án', 'che', 'z']
+657 178 The professional sport played by x -1 The professional sport played by Hugo Sánchez soccer Hugo Sánchez "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' has' ' been' ' played' ' in' ' the' ' United' ' States'
+ ' since' ' the']" , a former professional soccer player , is a game that has been played in the United States since the True (1958 – 66), Hugo Sánchez (1985 – 92) and 10 [' (', '19', '58', ' –', ' 66', '),', ' Hugo', ' S', 'án', 'che', 'z']
+658 179 The professional sport played by x -1 The professional sport played by Terry Sawchuk hockey Terry Sawchuk "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Hockey' ' Hall' ' of'
+ ' Fame' '.' '\n' '\n' 'The' ' NHL' ' has' ' been' ' a' ' great']" ", who was a member of the Hockey Hall of Fame .
+
+ The NHL has been a great" False 37-year-old Terry Sawchuk and 42-year-old 7 [' 37', '-', 'year', '-', 'old', ' Terry', ' Saw', 'chuk']
+659 179 The professional sport played by x -1 The professional sport played by Terry Sawchuk hockey Terry Sawchuk "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Hockey' ' Hall' ' of'
+ ' Fame' '.' '\n' '\n' 'The' ' NHL' ' has' ' been' ' a' ' great']" ", who was a member of the Hockey Hall of Fame .
+
+ The NHL has been a great" False 2 ['Terry', ' Saw', 'chuk']
+660 179 The professional sport played by x -1 The professional sport played by Terry Sawchuk hockey Terry Sawchuk "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Hockey' ' Hall' ' of'
+ ' Fame' '.' '\n' '\n' 'The' ' NHL' ' has' ' been' ' a' ' great']" ", who was a member of the Hockey Hall of Fame .
+
+ The NHL has been a great" False 37-year-old Terry Sawchuk and 42-year-old Johnny 7 [' 37', '-', 'year', '-', 'old', ' Terry', ' Saw', 'chuk']
+661 179 The professional sport played by x -1 The professional sport played by Terry Sawchuk hockey Terry Sawchuk "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Hockey' ' Hall' ' of'
+ ' Fame' '.' '\n' '\n' 'The' ' NHL' ' has' ' been' ' a' ' great']" ", who was a member of the Hockey Hall of Fame .
+
+ The NHL has been a great" False 37-year-old Terry Sawchuk and 42-year-old 7 [' 37', '-', 'year', '-', 'old', ' Terry', ' Saw', 'chuk']
+662 179 The professional sport played by x -1 The professional sport played by Terry Sawchuk hockey Terry Sawchuk "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Hockey' ' Hall' ' of'
+ ' Fame' '.' '\n' '\n' 'The' ' NHL' ' has' ' been' ' a' ' great']" ", who was a member of the Hockey Hall of Fame .
+
+ The NHL has been a great" False 2 ['Terry', ' Saw', 'chuk']
+663 180 The professional sport played by x -1 The professional sport played by Pete Rose baseball Pete Rose "[',' ' the' ' former' ' Cincinnati' ' Reds' ' baseball' ' player' ','
+ ' was' ' arrested' ' for' ' DUI' ' in' ' Florida' ' in' ' February' '.'
+ '\n' '\n' 'The']" ", the former Cincinnati Reds baseball player , was arrested for DUI in Florida in February .
+
+ The" True at first base when Pete Rose tied the NL record 5 [' at', ' first', ' base', ' when', ' Pete', ' Rose']
+664 180 The professional sport played by x -1 The professional sport played by Pete Rose baseball Pete Rose "[',' ' the' ' former' ' Cincinnati' ' Reds' ' baseball' ' player' ','
+ ' was' ' arrested' ' for' ' DUI' ' in' ' Florida' ' in' ' February' '.'
+ '\n' '\n' 'The']" ", the former Cincinnati Reds baseball player , was arrested for DUI in Florida in February .
+
+ The" True first base when Pete Rose tied the NL record 4 [' first', ' base', ' when', ' Pete', ' Rose']
+665 180 The professional sport played by x -1 The professional sport played by Pete Rose baseball Pete Rose "[',' ' the' ' former' ' Cincinnati' ' Reds' ' baseball' ' player' ','
+ ' was' ' arrested' ' for' ' DUI' ' in' ' Florida' ' in' ' February' '.'
+ '\n' '\n' 'The']" ", the former Cincinnati Reds baseball player , was arrested for DUI in Florida in February .
+
+ The" True Machine. Ellis admired Pete Rose and was concerned 5 [' Machine', '.', ' Ellis', ' admired', ' Pete', ' Rose']
+666 180 The professional sport played by x -1 The professional sport played by Pete Rose baseball Pete Rose "[',' ' the' ' former' ' Cincinnati' ' Reds' ' baseball' ' player' ','
+ ' was' ' arrested' ' for' ' DUI' ' in' ' Florida' ' in' ' February' '.'
+ '\n' '\n' 'The']" ", the former Cincinnati Reds baseball player , was arrested for DUI in Florida in February .
+
+ The" True tied him with Pete Rose for the second-most 4 [' tied', ' him', ' with', ' Pete', ' Rose']
+667 180 The professional sport played by x -1 The professional sport played by Pete Rose baseball Pete Rose "[',' ' the' ' former' ' Cincinnati' ' Reds' ' baseball' ' player' ','
+ ' was' ' arrested' ' for' ' DUI' ' in' ' Florida' ' in' ' February' '.'
+ '\n' '\n' 'The']" ", the former Cincinnati Reds baseball player , was arrested for DUI in Florida in February .
+
+ The" True associated with Pete Rose could be emphasized 3 [' associated', ' with', ' Pete', ' Rose']
+668 181 The professional sport played by x -1 The professional sport played by Bob Gibson baseball Bob Gibson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' St' '.' ' Louis'
+ ' Cardinals' '.' '\n' '\n' 'The' ' Cardinals' ' won' ' the' ' World'
+ ' Series']" ", who was a member of the St . Louis Cardinals .
+
+ The Cardinals won the World Series" False including pitchers Bob Gibson and Juan Marichal, 3 [' including', ' pitchers', ' Bob', ' Gibson']
+669 181 The professional sport played by x -1 The professional sport played by Bob Gibson baseball Bob Gibson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' St' '.' ' Louis'
+ ' Cardinals' '.' '\n' '\n' 'The' ' Cardinals' ' won' ' the' ' World'
+ ' Series']" ", who was a member of the St . Louis Cardinals .
+
+ The Cardinals won the World Series" False Ochs met folksinger Bob Gibson that summer 7 [' O', 'ch', 's', ' met', ' folks', 'inger', ' Bob', ' Gibson']
+670 181 The professional sport played by x -1 The professional sport played by Bob Gibson baseball Bob Gibson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' St' '.' ' Louis'
+ ' Cardinals' '.' '\n' '\n' 'The' ' Cardinals' ' won' ' the' ' World'
+ ' Series']" ", who was a member of the St . Louis Cardinals .
+
+ The Cardinals won the World Series" False Ochs met folksinger Bob Gibson that summer 7 [' O', 'ch', 's', ' met', ' folks', 'inger', ' Bob', ' Gibson']
+671 181 The professional sport played by x -1 The professional sport played by Bob Gibson baseball Bob Gibson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' St' '.' ' Louis'
+ ' Cardinals' '.' '\n' '\n' 'The' ' Cardinals' ' won' ' the' ' World'
+ ' Series']" ", who was a member of the St . Louis Cardinals .
+
+ The Cardinals won the World Series" False career, Wacha joined Bob Gibson as the only 6 [' career', ',', ' W', 'acha', ' joined', ' Bob', ' Gibson']
+672 181 The professional sport played by x -1 The professional sport played by Bob Gibson baseball Bob Gibson "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' St' '.' ' Louis'
+ ' Cardinals' '.' '\n' '\n' 'The' ' Cardinals' ' won' ' the' ' World'
+ ' Series']" ", who was a member of the St . Louis Cardinals .
+
+ The Cardinals won the World Series" False starting pitcher Bob Gibson achieved an 3 [' starting', ' pitcher', ' Bob', ' Gibson']
+673 182 The professional sport played by x -1 The professional sport played by Steve Nash basketball Steve Nash "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to']" , who is a former NBA player , is a great example of a player who has been able to False Dallas guard Steve Nash commented that Rodman 3 [' Dallas', ' guard', ' Steve', ' Nash']
+674 182 The professional sport played by x -1 The professional sport played by Steve Nash basketball Steve Nash "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to']" , who is a former NBA player , is a great example of a player who has been able to False opened the first Steve Nash Sports Club 4 [' opened', ' the', ' first', ' Steve', ' Nash']
+675 182 The professional sport played by x -1 The professional sport played by Steve Nash basketball Steve Nash "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to']" , who is a former NBA player , is a great example of a player who has been able to False up Blake, while Steve Nash was injured. On November 5 [' up', ' Blake', ',', ' while', ' Steve', ' Nash']
+676 182 The professional sport played by x -1 The professional sport played by Steve Nash basketball Steve Nash "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to']" , who is a former NBA player , is a great example of a player who has been able to False Nowitzki's close friend Steve Nash left Dallas and returned 7 "[' Now', 'itz', 'ki', ""'s"", ' close', ' friend', ' Steve', ' Nash']"
+677 182 The professional sport played by x -1 The professional sport played by Steve Nash basketball Steve Nash "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' able' ' to']" , who is a former NBA player , is a great example of a player who has been able to False Finley and future MVPs Steve Nash and Nowitzki, Howard 7 [' Fin', 'ley', ' and', ' future', ' MVP', 's', ' Steve', ' Nash']
+678 183 The professional sport played by x -1 The professional sport played by Randy Moss football Randy Moss "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' touchdown'
+ ' receptions' ',' ' and' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader']" , the NFL 's all - time leader in touchdown receptions , and the NFL 's all - time leader False Culpepper and receivers Randy Moss and Cris Carter 6 [' Cul', 'pe', 'pper', ' and', ' receivers', ' Randy', ' Moss']
+679 183 The professional sport played by x -1 The professional sport played by Randy Moss football Randy Moss "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' touchdown'
+ ' receptions' ',' ' and' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader']" , the NFL 's all - time leader in touchdown receptions , and the NFL 's all - time leader False and receivers Randy Moss and Cris Carter 3 [' and', ' receivers', ' Randy', ' Moss']
+680 183 The professional sport played by x -1 The professional sport played by Randy Moss football Randy Moss "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' touchdown'
+ ' receptions' ',' ' and' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader']" , the NFL 's all - time leader in touchdown receptions , and the NFL 's all - time leader False like receiver Randy Moss was as a rookie. 3 [' like', ' receiver', ' Randy', ' Moss']
+681 183 The professional sport played by x -1 The professional sport played by Randy Moss football Randy Moss "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' touchdown'
+ ' receptions' ',' ' and' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader']" , the NFL 's all - time leader in touchdown receptions , and the NFL 's all - time leader False Vikings receivers Randy Moss and Cris Carter 3 [' Vikings', ' receivers', ' Randy', ' Moss']
+682 183 The professional sport played by x -1 The professional sport played by Randy Moss football Randy Moss "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' touchdown'
+ ' receptions' ',' ' and' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leader']" , the NFL 's all - time leader in touchdown receptions , and the NFL 's all - time leader False Vikings receivers Randy Moss and Cris Carter 3 [' Vikings', ' receivers', ' Randy', ' Moss']
+683 184 The professional sport played by x -1 The professional sport played by Johnny Unitas football Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the quarterback for the Baltimore Colts , was a great player , but he was also a great person False Award winner, the Johnny Unitas Award winner, and 6 [' Award', ' winner', ',', ' the', ' Johnny', ' Unit', 'as']
+684 184 The professional sport played by x -1 The professional sport played by Johnny Unitas football Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the quarterback for the Baltimore Colts , was a great player , but he was also a great person False " Award
+" 3 [' Award', 'Johnny', ' Unit', 'as']
+685 184 The professional sport played by x -1 The professional sport played by Johnny Unitas football Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the quarterback for the Baltimore Colts , was a great player , but he was also a great person False Award winner, the Johnny Unitas Award winner, 6 [' Award', ' winner', ',', ' the', ' Johnny', ' Unit', 'as']
+686 184 The professional sport played by x -1 The professional sport played by Johnny Unitas football Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the quarterback for the Baltimore Colts , was a great player , but he was also a great person False He was named to the Johnny Unitas Golden Arm Award 7 [' He', ' was', ' named', ' to', ' the', ' Johnny', ' Unit', 'as']
+687 184 The professional sport played by x -1 The professional sport played by Johnny Unitas football Johnny Unitas "[',' ' the' ' quarterback' ' for' ' the' ' Baltimore' ' Colts' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the quarterback for the Baltimore Colts , was a great player , but he was also a great person False John Elway and Johnny Unitas for the second most 6 [' John', ' El', 'way', ' and', ' Johnny', ' Unit', 'as']
+688 185 The professional sport played by x -1 The professional sport played by Willie Stargell baseball Willie Stargell "[',' ' the' ' Pittsburgh' ' Pirates' ""'"" ' first' ' baseman' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the Pittsburgh Pirates ' first baseman , was a great player , but he was also a great person False former Pirate Willie Stargell threw out the 5 [' former', ' Pirate', ' Willie', ' St', 'arge', 'll']
+689 185 The professional sport played by x -1 The professional sport played by Willie Stargell baseball Willie Stargell "[',' ' the' ' Pittsburgh' ' Pirates' ""'"" ' first' ' baseman' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the Pittsburgh Pirates ' first baseman , was a great player , but he was also a great person False Roberto Clemente, Willie Stargell and Bill Mazeroski 7 [' Roberto', ' Clement', 'e', ',', ' Willie', ' St', 'arge', 'll']
+690 185 The professional sport played by x -1 The professional sport played by Willie Stargell baseball Willie Stargell "[',' ' the' ' Pittsburgh' ' Pirates' ""'"" ' first' ' baseman' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the Pittsburgh Pirates ' first baseman , was a great player , but he was also a great person False game, former Pirate Willie Stargell threw out the ceremonial 7 [' game', ',', ' former', ' Pirate', ' Willie', ' St', 'arge', 'll']
+691 185 The professional sport played by x -1 The professional sport played by Willie Stargell baseball Willie Stargell "[',' ' the' ' Pittsburgh' ' Pirates' ""'"" ' first' ' baseman' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the Pittsburgh Pirates ' first baseman , was a great player , but he was also a great person False right-field side. Willie Stargell is the all-time leader 8 [' right', '-', 'field', ' side', '.', ' Willie', ' St', 'arge', 'll']
+692 185 The professional sport played by x -1 The professional sport played by Willie Stargell baseball Willie Stargell "[',' ' the' ' Pittsburgh' ' Pirates' ""'"" ' first' ' baseman' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a' ' great'
+ ' person']" , the Pittsburgh Pirates ' first baseman , was a great player , but he was also a great person False Roberto Clemente, Willie Stargell and Bill Mazeroski 7 [' Roberto', ' Clement', 'e', ',', ' Willie', ' St', 'arge', 'll']
+693 186 The professional sport played by x -1 The professional sport played by Jarome Iginla hockey Jarome Iginla "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Calgary'
+ ' Flames' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1980' '.'
+ ' He' ' has']" , who has been a member of the Calgary Flames since the team 's inception in 1980 . He has False surpassed by Jarome Iginla in 2009. The Flames, 6 [' surpassed', ' by', ' Jar', 'ome', ' I', 'gin', 'la']
+694 186 The professional sport played by x -1 The professional sport played by Jarome Iginla hockey Jarome Iginla "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Calgary'
+ ' Flames' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1980' '.'
+ ' He' ' has']" , who has been a member of the Calgary Flames since the team 's inception in 1980 . He has False surpassed by Jarome Iginla in 2009. The Flames, 6 [' surpassed', ' by', ' Jar', 'ome', ' I', 'gin', 'la']
+695 186 The professional sport played by x -1 The professional sport played by Jarome Iginla hockey Jarome Iginla "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Calgary'
+ ' Flames' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1980' '.'
+ ' He' ' has']" , who has been a member of the Calgary Flames since the team 's inception in 1980 . He has False Finnish team, while Jarome Iginla was named an 8 [' Finnish', ' team', ',', ' while', ' Jar', 'ome', ' I', 'gin', 'la']
+696 186 The professional sport played by x -1 The professional sport played by Jarome Iginla hockey Jarome Iginla "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Calgary'
+ ' Flames' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1980' '.'
+ ' He' ' has']" , who has been a member of the Calgary Flames since the team 's inception in 1980 . He has False season. Team captain Jarome Iginla scored his 8 [' season', '.', ' Team', ' captain', ' Jar', 'ome', ' I', 'gin', 'la']
+697 186 The professional sport played by x -1 The professional sport played by Jarome Iginla hockey Jarome Iginla "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Calgary'
+ ' Flames' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1980' '.'
+ ' He' ' has']" , who has been a member of the Calgary Flames since the team 's inception in 1980 . He has False when he succeeded Jarome Iginla as the Flames player 7 [' when', ' he', ' succeeded', ' Jar', 'ome', ' I', 'gin', 'la']
+698 187 The professional sport played by x -1 The professional sport played by Cam Newton football Cam Newton "[',' ' the' ' quarterback' ' for' ' the' ' Carolina' ' Panthers' ',' ' is'
+ ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' able' ' to']" , the quarterback for the Carolina Panthers , is a great example of a player who has been able to False players and eliminated Cam Newton to win the 4 [' players', ' and', ' eliminated', ' Cam', ' Newton']
+699 187 The professional sport played by x -1 The professional sport played by Cam Newton football Cam Newton "[',' ' the' ' quarterback' ' for' ' the' ' Carolina' ' Panthers' ',' ' is'
+ ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' able' ' to']" , the quarterback for the Carolina Panthers , is a great example of a player who has been able to False Trophy-winning quarterback Cam Newton with the first 5 [' Trophy', '-', 'winning', ' quarterback', ' Cam', ' Newton']
+700 187 The professional sport played by x -1 The professional sport played by Cam Newton football Cam Newton "[',' ' the' ' quarterback' ' for' ' the' ' Carolina' ' Panthers' ',' ' is'
+ ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' able' ' to']" , the quarterback for the Carolina Panthers , is a great example of a player who has been able to False posted a sack of Cam Newton in his debut with 5 [' posted', ' a', ' sack', ' of', ' Cam', ' Newton']
+701 187 The professional sport played by x -1 The professional sport played by Cam Newton football Cam Newton "[',' ' the' ' quarterback' ' for' ' the' ' Carolina' ' Panthers' ',' ' is'
+ ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' able' ' to']" , the quarterback for the Carolina Panthers , is a great example of a player who has been able to False Trophy-winning quarterback Cam Newton with the first 5 [' Trophy', '-', 'winning', ' quarterback', ' Cam', ' Newton']
+702 187 The professional sport played by x -1 The professional sport played by Cam Newton football Cam Newton "[',' ' the' ' quarterback' ' for' ' the' ' Carolina' ' Panthers' ',' ' is'
+ ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' has' ' been'
+ ' able' ' to']" , the quarterback for the Carolina Panthers , is a great example of a player who has been able to False " assisted in shoving Cam Newton into the endzone.
+" 5 [' assisted', ' in', ' sh', 'oving', ' Cam', ' Newton']
+703 189 The professional sport played by x -1 The professional sport played by Babe Ruth baseball Babe Ruth "[',' ' the' ' baseball' ' player' ',' ' was' ' a' ' great' ' American'
+ ' hero' '.' '\n' '\n' 'The' ' professional' ' baseball' ' player' ' was'
+ ' a' ' great']" ", the baseball player , was a great American hero .
+
+ The professional baseball player was a great" True Yankees slugger Babe Ruth had set the single 4 [' Yankees', ' slug', 'ger', ' Babe', ' Ruth']
+704 189 The professional sport played by x -1 The professional sport played by Babe Ruth baseball Babe Ruth "[',' ' the' ' baseball' ' player' ',' ' was' ' a' ' great' ' American'
+ ' hero' '.' '\n' '\n' 'The' ' professional' ' baseball' ' player' ' was'
+ ' a' ' great']" ", the baseball player , was a great American hero .
+
+ The professional baseball player was a great" True April 27, 1947, Babe Ruth Day around the 6 [' April', ' 27', ',', ' 1947', ',', ' Babe', ' Ruth']
+705 189 The professional sport played by x -1 The professional sport played by Babe Ruth baseball Babe Ruth "[',' ' the' ' baseball' ' player' ',' ' was' ' a' ' great' ' American'
+ ' hero' '.' '\n' '\n' 'The' ' professional' ' baseball' ' player' ' was'
+ ' a' ' great']" ", the baseball player , was a great American hero .
+
+ The professional baseball player was a great" True played at the stadium. Babe Ruth hit the new 6 [' played', ' at', ' the', ' stadium', '.', ' Babe', ' Ruth']
+706 189 The professional sport played by x -1 The professional sport played by Babe Ruth baseball Babe Ruth "[',' ' the' ' baseball' ' player' ',' ' was' ' a' ' great' ' American'
+ ' hero' '.' '\n' '\n' 'The' ' professional' ' baseball' ' player' ' was'
+ ' a' ' great']" ", the baseball player , was a great American hero .
+
+ The professional baseball player was a great" True English rock group Babe Ruth in his block 4 [' English', ' rock', ' group', ' Babe', ' Ruth']
+707 189 The professional sport played by x -1 The professional sport played by Babe Ruth baseball Babe Ruth "[',' ' the' ' baseball' ' player' ',' ' was' ' a' ' great' ' American'
+ ' hero' '.' '\n' '\n' 'The' ' professional' ' baseball' ' player' ' was'
+ ' a' ' great']" ", the baseball player , was a great American hero .
+
+ The professional baseball player was a great" True Legion and Babe Ruth League baseball 3 [' Legion', ' and', ' Babe', ' Ruth']
+708 190 The professional sport played by x -1 The professional sport played by Gary Payton basketball Gary Payton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' show' '.' '\n' '\n' 'The' ' show' ' was' ' a' ' huge']" ", who was a member of the original cast of the show .
+
+ The show was a huge" False NBA players Gary Payton and Vin Baker. After 3 [' NBA', ' players', ' Gary', ' Payton']
+709 190 The professional sport played by x -1 The professional sport played by Gary Payton basketball Gary Payton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' show' '.' '\n' '\n' 'The' ' show' ' was' ' a' ' huge']" ", who was a member of the original cast of the show .
+
+ The show was a huge" False Rick Fox and Gary Payton to Boston, for 4 [' Rick', ' Fox', ' and', ' Gary', ' Payton']
+710 190 The professional sport played by x -1 The professional sport played by Gary Payton basketball Gary Payton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' show' '.' '\n' '\n' 'The' ' show' ' was' ' a' ' huge']" ", who was a member of the original cast of the show .
+
+ The show was a huge" False 1 ['Gary', ' Payton']
+711 190 The professional sport played by x -1 The professional sport played by Gary Payton basketball Gary Payton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' show' '.' '\n' '\n' 'The' ' show' ' was' ' a' ' huge']" ", who was a member of the original cast of the show .
+
+ The show was a huge" False Pack and outscored Gary Payton five of nine times 6 [' Pack', ' and', ' out', 'sc', 'ored', ' Gary', ' Payton']
+712 190 The professional sport played by x -1 The professional sport played by Gary Payton basketball Gary Payton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' show' '.' '\n' '\n' 'The' ' show' ' was' ' a' ' huge']" ", who was a member of the original cast of the show .
+
+ The show was a huge" False fellow NBA players Gary Payton and Vin Baker. 4 [' fellow', ' NBA', ' players', ' Gary', ' Payton']
+713 191 The professional sport played by x -1 The professional sport played by Brett Hull hockey Brett Hull "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NHL' ""'s"" ' all' '-'
+ 'time' ' scoring' ' leaders' ',' ' and' ' the' ' NHL' ""'s"" ' all']" , who was a member of the NHL 's all - time scoring leaders , and the NHL 's all False former player Brett Hull endorsed. Increased 3 [' former', ' player', ' Brett', ' Hull']
+714 191 The professional sport played by x -1 The professional sport played by Brett Hull hockey Brett Hull "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NHL' ""'s"" ' all' '-'
+ 'time' ' scoring' ' leaders' ',' ' and' ' the' ' NHL' ""'s"" ' all']" , who was a member of the NHL 's all - time scoring leaders , and the NHL 's all False history. Bobby and Brett Hull are the only 5 [' history', '.', ' Bobby', ' and', ' Brett', ' Hull']
+715 191 The professional sport played by x -1 The professional sport played by Brett Hull hockey Brett Hull "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NHL' ""'s"" ' all' '-'
+ 'time' ' scoring' ' leaders' ',' ' and' ' the' ' NHL' ""'s"" ' all']" , who was a member of the NHL 's all - time scoring leaders , and the NHL 's all False November 9, 2009, Brett Hull became the fifth 6 [' November', ' 9', ',', ' 2009', ',', ' Brett', ' Hull']
+716 191 The professional sport played by x -1 The professional sport played by Brett Hull hockey Brett Hull "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NHL' ""'s"" ' all' '-'
+ 'time' ' scoring' ' leaders' ',' ' and' ' the' ' NHL' ""'s"" ' all']" , who was a member of the NHL 's all - time scoring leaders , and the NHL 's all False November 9, 2009, Brett Hull became the 6 [' November', ' 9', ',', ' 2009', ',', ' Brett', ' Hull']
+717 191 The professional sport played by x -1 The professional sport played by Brett Hull hockey Brett Hull "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NHL' ""'s"" ' all' '-'
+ 'time' ' scoring' ' leaders' ',' ' and' ' the' ' NHL' ""'s"" ' all']" , who was a member of the NHL 's all - time scoring leaders , and the NHL 's all False System video game called Brett Hull Hockey. Hull was a 5 [' System', ' video', ' game', ' called', ' Brett', ' Hull']
+718 192 The professional sport played by x -1 The professional sport played by Mark Messier hockey Mark Messier "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' goals' ','
+ ' assists' ',' ' points' ',' ' and' ' games' ' played' '.' '\n']" ", the NHL 's all - time leader in goals , assists , points , and games played .
+" False Canada General Manager Mark Messier made a specific effort 5 [' Canada', ' General', ' Manager', ' Mark', ' Mess', 'ier']
+719 192 The professional sport played by x -1 The professional sport played by Mark Messier hockey Mark Messier "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' goals' ','
+ ' assists' ',' ' points' ',' ' and' ' games' ' played' '.' '\n']" ", the NHL 's all - time leader in goals , assists , points , and games played .
+" False on a line with Mark Messier at centre and 6 [' on', ' a', ' line', ' with', ' Mark', ' Mess', 'ier']
+720 192 The professional sport played by x -1 The professional sport played by Mark Messier hockey Mark Messier "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' goals' ','
+ ' assists' ',' ' points' ',' ' and' ' games' ' played' '.' '\n']" ", the NHL 's all - time leader in goals , assists , points , and games played .
+" False Sedin. Crosby won the Mark Messier Leadership Award, 8 [' Sed', 'in', '.', ' Crosby', ' won', ' the', ' Mark', ' Mess', 'ier']
+721 192 The professional sport played by x -1 The professional sport played by Mark Messier hockey Mark Messier "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' goals' ','
+ ' assists' ',' ' points' ',' ' and' ' games' ' played' '.' '\n']" ", the NHL 's all - time leader in goals , assists , points , and games played .
+" False 2 ['Mark', ' Mess', 'ier']
+722 192 The professional sport played by x -1 The professional sport played by Mark Messier hockey Mark Messier "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leader' ' in' ' goals' ','
+ ' assists' ',' ' points' ',' ' and' ' games' ' played' '.' '\n']" ", the NHL 's all - time leader in goals , assists , points , and games played .
+" False confrontations with Mark Messier as part of 5 [' confront', 'ations', ' with', ' Mark', ' Mess', 'ier']
+723 193 The professional sport played by x -1 The professional sport played by Kevin Garnett basketball Kevin Garnett "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' man' ' who' ' can' ' score' ' in' ' the' ' paint']" , the NBA 's all - time leading scorer , is a big man who can score in the paint False points and passed Kevin Garnett to move into 14th 5 [' points', ' and', ' passed', ' Kevin', ' Garn', 'ett']
+724 193 The professional sport played by x -1 The professional sport played by Kevin Garnett basketball Kevin Garnett "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' man' ' who' ' can' ' score' ' in' ' the' ' paint']" , the NBA 's all - time leading scorer , is a big man who can score in the paint False between him and Kevin Garnett due to their 5 [' between', ' him', ' and', ' Kevin', ' Garn', 'ett']
+725 193 The professional sport played by x -1 The professional sport played by Kevin Garnett basketball Kevin Garnett "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' man' ' who' ' can' ' score' ' in' ' the' ' paint']" , the NBA 's all - time leading scorer , is a big man who can score in the paint False basket over Kevin Garnett against the Boston 4 [' basket', ' over', ' Kevin', ' Garn', 'ett']
+726 193 The professional sport played by x -1 The professional sport played by Kevin Garnett basketball Kevin Garnett "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' man' ' who' ' can' ' score' ' in' ' the' ' paint']" , the NBA 's all - time leading scorer , is a big man who can score in the paint False between him and Kevin Garnett due to their similar 5 [' between', ' him', ' and', ' Kevin', ' Garn', 'ett']
+727 193 The professional sport played by x -1 The professional sport played by Kevin Garnett basketball Kevin Garnett "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' man' ' who' ' can' ' score' ' in' ' the' ' paint']" , the NBA 's all - time leading scorer , is a big man who can score in the paint False between him and Kevin Garnett due to their similar 5 [' between', ' him', ' and', ' Kevin', ' Garn', 'ett']
+728 194 The professional sport played by x -1 The professional sport played by Tony Parker basketball Tony Parker "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False play of Duncan and Tony Parker were not enough 5 [' play', ' of', ' Duncan', ' and', ' Tony', ' Parker']
+729 194 The professional sport played by x -1 The professional sport played by Tony Parker basketball Tony Parker "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False " Tony Parker =
+" 1 [' Tony', ' Parker']
+730 194 The professional sport played by x -1 The professional sport played by Tony Parker basketball Tony Parker "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False with his friend Tony Parker at games when not 4 [' with', ' his', ' friend', ' Tony', ' Parker']
+731 194 The professional sport played by x -1 The professional sport played by Tony Parker basketball Tony Parker "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False offensive threats in Tony Parker and Manu Ginóbili, 4 [' offensive', ' threats', ' in', ' Tony', ' Parker']
+732 194 The professional sport played by x -1 The professional sport played by Tony Parker basketball Tony Parker "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False his friend Tony Parker at games when not 3 [' his', ' friend', ' Tony', ' Parker']
+733 195 The professional sport played by x -1 The professional sport played by Moses Malone basketball Moses Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' top' ' 10' ' in' ' scoring' ',' ' reb' 'ounding' ',']" , who was a member of the NBA 's all - time top 10 in scoring , reb ounding , False Suns, which passed Moses Malone for 11th all-time. 5 [' Suns', ',', ' which', ' passed', ' Moses', ' Malone']
+734 195 The professional sport played by x -1 The professional sport played by Moses Malone basketball Moses Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' top' ' 10' ' in' ' scoring' ',' ' reb' 'ounding' ',']" , who was a member of the NBA 's all - time top 10 in scoring , reb ounding , False rebounding center through Moses Malone, who he had coached 5 [' reb', 'ounding', ' center', ' through', ' Moses', ' Malone']
+735 195 The professional sport played by x -1 The professional sport played by Moses Malone basketball Moses Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' top' ' 10' ' in' ' scoring' ',' ' reb' 'ounding' ',']" , who was a member of the NBA 's all - time top 10 in scoring , reb ounding , False Julius Erving, Moses Malone and Maurice Cheeks, 5 [' Julius', ' Er', 'ving', ',', ' Moses', ' Malone']
+736 195 The professional sport played by x -1 The professional sport played by Moses Malone basketball Moses Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' top' ' 10' ' in' ' scoring' ',' ' reb' 'ounding' ',']" , who was a member of the NBA 's all - time top 10 in scoring , reb ounding , False center through Moses Malone, who he had coached 3 [' center', ' through', ' Moses', ' Malone']
+737 195 The professional sport played by x -1 The professional sport played by Moses Malone basketball Moses Malone "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' NBA' ""'s"" ' all' '-'
+ 'time' ' top' ' 10' ' in' ' scoring' ',' ' reb' 'ounding' ',']" , who was a member of the NBA 's all - time top 10 in scoring , reb ounding , False and the first since Moses Malone in 1982. Although 5 [' and', ' the', ' first', ' since', ' Moses', ' Malone']
+738 197 The professional sport played by x -1 The professional sport played by Jack Kemp football Jack Kemp "[',' ' the' ' former' ' Secretary' ' of' ' Housing' ' and' ' Urban'
+ ' Development' ',' ' and' ' the' ' former' ' Secretary' ' of' ' State'
+ ',' ' who' ' is' ' now']" , the former Secretary of Housing and Urban Development , and the former Secretary of State , who is now False Presbyterian faith. Jack Kemp was a 33rd degree Freemason 4 [' Presbyterian', ' faith', '.', ' Jack', ' Kemp']
+739 197 The professional sport played by x -1 The professional sport played by Jack Kemp football Jack Kemp "[',' ' the' ' former' ' Secretary' ' of' ' Housing' ' and' ' Urban'
+ ' Development' ',' ' and' ' the' ' former' ' Secretary' ' of' ' State'
+ ',' ' who' ' is' ' now']" , the former Secretary of Housing and Urban Development , and the former Secretary of State , who is now False " = Jack Kemp =
+" 2 [' =', ' Jack', ' Kemp']
+740 197 The professional sport played by x -1 The professional sport played by Jack Kemp football Jack Kemp "[',' ' the' ' former' ' Secretary' ' of' ' Housing' ' and' ' Urban'
+ ' Development' ',' ' and' ' the' ' former' ' Secretary' ' of' ' State'
+ ',' ' who' ' is' ' now']" , the former Secretary of Housing and Urban Development , and the former Secretary of State , who is now False House colleague Jack Kemp to Ohio to 3 [' House', ' colleague', ' Jack', ' Kemp']
+741 197 The professional sport played by x -1 The professional sport played by Jack Kemp football Jack Kemp "[',' ' the' ' former' ' Secretary' ' of' ' Housing' ' and' ' Urban'
+ ' Development' ',' ' and' ' the' ' former' ' Secretary' ' of' ' State'
+ ',' ' who' ' is' ' now']" , the former Secretary of Housing and Urban Development , and the former Secretary of State , who is now False " Jack Kemp =
+" 1 [' Jack', ' Kemp']
+742 197 The professional sport played by x -1 The professional sport played by Jack Kemp football Jack Kemp "[',' ' the' ' former' ' Secretary' ' of' ' Housing' ' and' ' Urban'
+ ' Development' ',' ' and' ' the' ' former' ' Secretary' ' of' ' State'
+ ',' ' who' ' is' ' now']" , the former Secretary of Housing and Urban Development , and the former Secretary of State , who is now False House colleague Jack Kemp to Ohio to 3 [' House', ' colleague', ' Jack', ' Kemp']
+743 198 The professional sport played by x -1 The professional sport played by Jerry West basketball Jerry West "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Lakers' ""'"" ' front' ' office' '.']" , the NBA 's all - time leading scorer , was a member of the Lakers ' front office . False " Lakers guard Jerry West stated, ""If I had" 3 [' Lakers', ' guard', ' Jerry', ' West']
+744 198 The professional sport played by x -1 The professional sport played by Jerry West basketball Jerry West "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Lakers' ""'"" ' front' ' office' '.']" , the NBA 's all - time leading scorer , was a member of the Lakers ' front office . False Jason Kidd. Jerry West often stated that 4 [' Jason', ' Kidd', '.', ' Jerry', ' West']
+745 198 The professional sport played by x -1 The professional sport played by Jerry West basketball Jerry West "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Lakers' ""'"" ' front' ' office' '.']" , the NBA 's all - time leading scorer , was a member of the Lakers ' front office . False " Jerry West =
+" 1 [' Jerry', ' West']
+746 198 The professional sport played by x -1 The professional sport played by Jerry West basketball Jerry West "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Lakers' ""'"" ' front' ' office' '.']" , the NBA 's all - time leading scorer , was a member of the Lakers ' front office . False to the 20-man Jerry West Award preseason 6 [' to', ' the', ' 20', '-', 'man', ' Jerry', ' West']
+747 198 The professional sport played by x -1 The professional sport played by Jerry West basketball Jerry West "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Lakers' ""'"" ' front' ' office' '.']" , the NBA 's all - time leading scorer , was a member of the Lakers ' front office . False team selected Jerry West from West Virginia 3 [' team', ' selected', ' Jerry', ' West']
+748 199 The professional sport played by x -1 The professional sport played by Stan Mikita hockey Stan Mikita "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False Orr (1969 – 72) and Stan Mikita (1964 – 67). Though 10 [' Or', 'r', ' (', '1969', ' –', ' 72', ')', ' and', ' Stan', ' Mik', 'ita']
+749 199 The professional sport played by x -1 The professional sport played by Stan Mikita hockey Stan Mikita "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False (1969 – 72) and Stan Mikita (1964 – 67). Though 8 [' (', '1969', ' –', ' 72', ')', ' and', ' Stan', ' Mik', 'ita']
+750 199 The professional sport played by x -1 The professional sport played by Stan Mikita hockey Stan Mikita "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False Hall of Famers Stan Mikita and Bobby Hull 6 [' Hall', ' of', ' Fam', 'ers', ' Stan', ' Mik', 'ita']
+751 199 The professional sport played by x -1 The professional sport played by Stan Mikita hockey Stan Mikita "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False Eddie Johnston. Stan Mikita replaced Sanderson. 5 [' Eddie', ' Johnston', '.', ' Stan', ' Mik', 'ita']
+752 199 The professional sport played by x -1 The professional sport played by Stan Mikita hockey Stan Mikita "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False early 1960s, Stan Mikita inadvertently 6 [' early', ' 1960', 's', ',', ' Stan', ' Mik', 'ita']
+753 200 The professional sport played by x -1 The professional sport played by Jean Béliveau hockey Jean Béliveau "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False Canadiens player Jean Béliveau stated that, although 6 [' Canadiens', ' player', ' Jean', ' B', 'é', 'live', 'au']
+754 200 The professional sport played by x -1 The professional sport played by Jean Béliveau hockey Jean Béliveau "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False capture the Jean Béliveau Trophy as the league's 6 [' capture', ' the', ' Jean', ' B', 'é', 'live', 'au']
+755 200 The professional sport played by x -1 The professional sport played by Jean Béliveau hockey Jean Béliveau "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False to capture the Jean Béliveau Trophy as 7 [' to', ' capture', ' the', ' Jean', ' B', 'é', 'live', 'au']
+756 200 The professional sport played by x -1 The professional sport played by Jean Béliveau hockey Jean Béliveau "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False Canadiens player Jean Béliveau stated that, although 6 [' Canadiens', ' player', ' Jean', ' B', 'é', 'live', 'au']
+757 200 The professional sport played by x -1 The professional sport played by Jean Béliveau hockey Jean Béliveau "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False games to capture the Jean Béliveau Trophy as the 8 [' games', ' to', ' capture', ' the', ' Jean', ' B', 'é', 'live', 'au']
+758 201 The professional sport played by x -1 The professional sport played by Júlio César soccer Júlio César "[',' ' a' ' Brazilian' ' footballer' ' who' ' plays' ' as' ' a'
+ ' midfielder' ' for' ' the' ' Brazilian' ' club' ' Atl' 'ét' 'ico'
+ ' Mine' 'iro' '.' '\n']" ", a Brazilian footballer who plays as a midfielder for the Brazilian club Atl ét ico Mine iro .
+" False goalkeeper Júlio César both offered 6 [' goalkeeper', ' J', 'ú', 'lio', ' C', 'és', 'ar']
+759 201 The professional sport played by x -1 The professional sport played by Júlio César soccer Júlio César "[',' ' a' ' Brazilian' ' footballer' ' who' ' plays' ' as' ' a'
+ ' midfielder' ' for' ' the' ' Brazilian' ' club' ' Atl' 'ét' 'ico'
+ ' Mine' 'iro' '.' '\n']" ", a Brazilian footballer who plays as a midfielder for the Brazilian club Atl ét ico Mine iro .
+" False Navy, Admiral Júlio César de Noronha, signed 8 [' Navy', ',', ' Admiral', ' J', 'ú', 'lio', ' C', 'és', 'ar']
+760 201 The professional sport played by x -1 The professional sport played by Júlio César soccer Júlio César "[',' ' a' ' Brazilian' ' footballer' ' who' ' plays' ' as' ' a'
+ ' midfielder' ' for' ' the' ' Brazilian' ' club' ' Atl' 'ét' 'ico'
+ ' Mine' 'iro' '.' '\n']" ", a Brazilian footballer who plays as a midfielder for the Brazilian club Atl ét ico Mine iro .
+" False national team keeper Júlio César joined on 8 [' national', ' team', ' keeper', ' J', 'ú', 'lio', ' C', 'és', 'ar']
+761 201 The professional sport played by x -1 The professional sport played by Júlio César soccer Júlio César "[',' ' a' ' Brazilian' ' footballer' ' who' ' plays' ' as' ' a'
+ ' midfielder' ' for' ' the' ' Brazilian' ' club' ' Atl' 'ét' 'ico'
+ ' Mine' 'iro' '.' '\n']" ", a Brazilian footballer who plays as a midfielder for the Brazilian club Atl ét ico Mine iro .
+" False and goalkeeper Júlio César both offered 7 [' and', ' goalkeeper', ' J', 'ú', 'lio', ' C', 'és', 'ar']
+762 201 The professional sport played by x -1 The professional sport played by Júlio César soccer Júlio César "[',' ' a' ' Brazilian' ' footballer' ' who' ' plays' ' as' ' a'
+ ' midfielder' ' for' ' the' ' Brazilian' ' club' ' Atl' 'ét' 'ico'
+ ' Mine' 'iro' '.' '\n']" ", a Brazilian footballer who plays as a midfielder for the Brazilian club Atl ét ico Mine iro .
+" False scoring again, with Júlio César denying Müller twice. 9 [' scoring', ' again', ',', ' with', ' J', 'ú', 'lio', ' C', 'és', 'ar']
+763 202 The professional sport played by x -1 The professional sport played by Dominik Hašek hockey Dominik Hašek "[',' ' a' ' Czech' 'oslov' 'ak' 'ian' '-' 'born' ' American' ' who' ' was'
+ ' a' ' member' ' of' ' the' ' Czech' 'oslov' 'ak' 'ian' ' national']" , a Czech oslov ak ian - born American who was a member of the Czech oslov ak ian national False " Dominik Hašek =
+" 4 [' Domin', 'ik', ' Ha', 'š', 'ek']
+764 202 The professional sport played by x -1 The professional sport played by Dominik Hašek hockey Dominik Hašek "[',' ' a' ' Czech' 'oslov' 'ak' 'ian' '-' 'born' ' American' ' who' ' was'
+ ' a' ' member' ' of' ' the' ' Czech' 'oslov' 'ak' 'ian' ' national']" , a Czech oslov ak ian - born American who was a member of the Czech oslov ak ian national False goalie when Dominik Hašek came out of retirement, 6 [' goalie', ' when', ' Domin', 'ik', ' Ha', 'š', 'ek']
+765 202 The professional sport played by x -1 The professional sport played by Dominik Hašek hockey Dominik Hašek "[',' ' a' ' Czech' 'oslov' 'ak' 'ian' '-' 'born' ' American' ' who' ' was'
+ ' a' ' member' ' of' ' the' ' Czech' 'oslov' 'ak' 'ian' ' national']" , a Czech oslov ak ian - born American who was a member of the Czech oslov ak ian national False goaltender Dominik Hašek and lost the 5 [' goaltender', ' Domin', 'ik', ' Ha', 'š', 'ek']
+766 202 The professional sport played by x -1 The professional sport played by Dominik Hašek hockey Dominik Hašek "[',' ' a' ' Czech' 'oslov' 'ak' 'ian' '-' 'born' ' American' ' who' ' was'
+ ' a' ' member' ' of' ' the' ' Czech' 'oslov' 'ak' 'ian' ' national']" , a Czech oslov ak ian - born American who was a member of the Czech oslov ak ian national False being shut out by Dominik Hašek and the Czech Republic. 8 [' being', ' shut', ' out', ' by', ' Domin', 'ik', ' Ha', 'š', 'ek']
+767 202 The professional sport played by x -1 The professional sport played by Dominik Hašek hockey Dominik Hašek "[',' ' a' ' Czech' 'oslov' 'ak' 'ian' '-' 'born' ' American' ' who' ' was'
+ ' a' ' member' ' of' ' the' ' Czech' 'oslov' 'ak' 'ian' ' national']" , a Czech oslov ak ian - born American who was a member of the Czech oslov ak ian national False Czech goaltender Dominik Hašek in a shootout 6 [' Czech', ' goaltender', ' Domin', 'ik', ' Ha', 'š', 'ek']
+768 203 The professional sport played by x -1 The professional sport played by Lawrence Taylor football Lawrence Taylor "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' arrested' ' for' ' DUI'
+ ' in' ' New' ' York' ' City' ' on' ' Tuesday' ',' ' according' ' to'
+ ' the']" , the former NFL star , was arrested for DUI in New York City on Tuesday , according to the False possession. Lawrence Taylor caught the fumble 3 [' possession', '.', ' Lawrence', ' Taylor']
+769 203 The professional sport played by x -1 The professional sport played by Lawrence Taylor football Lawrence Taylor "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' arrested' ' for' ' DUI'
+ ' in' ' New' ' York' ' City' ' on' ' Tuesday' ',' ' according' ' to'
+ ' the']" , the former NFL star , was arrested for DUI in New York City on Tuesday , according to the False " decision vocalist Lawrence Taylor said ""hopefully" 4 [' decision', ' vocal', 'ist', ' Lawrence', ' Taylor']
+770 203 The professional sport played by x -1 The professional sport played by Lawrence Taylor football Lawrence Taylor "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' arrested' ' for' ' DUI'
+ ' in' ' New' ' York' ' City' ' on' ' Tuesday' ',' ' according' ' to'
+ ' the']" , the former NFL star , was arrested for DUI in New York City on Tuesday , according to the False yards, and Lawrence Taylor had four sacks and 4 [' yards', ',', ' and', ' Lawrence', ' Taylor']
+771 203 The professional sport played by x -1 The professional sport played by Lawrence Taylor football Lawrence Taylor "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' arrested' ' for' ' DUI'
+ ' in' ' New' ' York' ' City' ' on' ' Tuesday' ',' ' according' ' to'
+ ' the']" , the former NFL star , was arrested for DUI in New York City on Tuesday , according to the False matches but that Lawrence Taylor put on a solid 4 [' matches', ' but', ' that', ' Lawrence', ' Taylor']
+772 203 The professional sport played by x -1 The professional sport played by Lawrence Taylor football Lawrence Taylor "[',' ' the' ' former' ' NFL' ' star' ',' ' was' ' arrested' ' for' ' DUI'
+ ' in' ' New' ' York' ' City' ' on' ' Tuesday' ',' ' according' ' to'
+ ' the']" , the former NFL star , was arrested for DUI in New York City on Tuesday , according to the False Burt, linebackers Lawrence Taylor and Harry Carson, 5 [' B', 'urt', ',', ' linebackers', ' Lawrence', ' Taylor']
+773 204 The professional sport played by x -1 The professional sport played by Hristo Stoichkov soccer Hristo Stoichkov "[',' ' the' ' Bulgarian' ' footballer' ',' ' was' ' born' ' in' ' the'
+ ' city' ' of' ' Pl' 'ov' 'div' ',' ' Bulgaria' '.' ' He' ' is' ' a']" , the Bulgarian footballer , was born in the city of Pl ov div , Bulgaria . He is a False Manchester United and Hristo Stoichkov for Barcelona), exhibitions 8 [' Manchester', ' United', ' and', ' H', 'rist', 'o', ' Sto', 'ich', 'kov']
+774 204 The professional sport played by x -1 The professional sport played by Hristo Stoichkov soccer Hristo Stoichkov "[',' ' the' ' Bulgarian' ' footballer' ',' ' was' ' born' ' in' ' the'
+ ' city' ' of' ' Pl' 'ov' 'div' ',' ' Bulgaria' '.' ' He' ' is' ' a']" , the Bulgarian footballer , was born in the city of Pl ov div , Bulgaria . He is a False Manchester United and Hristo Stoichkov for Barcelona), 8 [' Manchester', ' United', ' and', ' H', 'rist', 'o', ' Sto', 'ich', 'kov']
+775 204 The professional sport played by x -1 The professional sport played by Hristo Stoichkov soccer Hristo Stoichkov "[',' ' the' ' Bulgarian' ' footballer' ',' ' was' ' born' ' in' ' the'
+ ' city' ' of' ' Pl' 'ov' 'div' ',' ' Bulgaria' '.' ' He' ' is' ' a']" , the Bulgarian footballer , was born in the city of Pl ov div , Bulgaria . He is a False United and Hristo Stoichkov for Barcelona), exhibitions 7 [' United', ' and', ' H', 'rist', 'o', ' Sto', 'ich', 'kov']
+776 206 The professional sport played by x -1 The professional sport played by Elroy Hirsch football Elroy Hirsch "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . False a short pass to Elroy Hirsch from quarterback 7 [' a', ' short', ' pass', ' to', ' El', 'roy', ' H', 'irsch']
+777 206 The professional sport played by x -1 The professional sport played by Elroy Hirsch football Elroy Hirsch "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . False the first by Elroy Hirsch on an 81-yard 6 [' the', ' first', ' by', ' El', 'roy', ' H', 'irsch']
+778 206 The professional sport played by x -1 The professional sport played by Elroy Hirsch football Elroy Hirsch "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . False – a short pass to Elroy Hirsch from quarterback 8 [' –', ' a', ' short', ' pass', ' to', ' El', 'roy', ' H', 'irsch']
+779 206 The professional sport played by x -1 The professional sport played by Elroy Hirsch football Elroy Hirsch "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . False short pass to Elroy Hirsch from quarterback 6 [' short', ' pass', ' to', ' El', 'roy', ' H', 'irsch']
+780 207 The professional sport played by x -1 The professional sport played by Ernie Banks baseball Ernie Banks "[',' ' the' ' Chicago' ' Cubs' ',' ' and' ' the' ' Chicago' ' White'
+ ' Sox' '.' '\n' '\n' 'The' ' Cubs' ' are' ' the' ' only' ' team' ' in']" ", the Chicago Cubs , and the Chicago White Sox .
+
+ The Cubs are the only team in" False 2 ['Er', 'nie', ' Banks']
+781 207 The professional sport played by x -1 The professional sport played by Ernie Banks baseball Ernie Banks "[',' ' the' ' Chicago' ' Cubs' ',' ' and' ' the' ' Chicago' ' White'
+ ' Sox' '.' '\n' '\n' 'The' ' Cubs' ' are' ' the' ' only' ' team' ' in']" ", the Chicago Cubs , and the Chicago White Sox .
+
+ The Cubs are the only team in" False " Ernie Banks =
+" 2 [' Er', 'nie', ' Banks']
+782 207 The professional sport played by x -1 The professional sport played by Ernie Banks baseball Ernie Banks "[',' ' the' ' Chicago' ' Cubs' ',' ' and' ' the' ' Chicago' ' White'
+ ' Sox' '.' '\n' '\n' 'The' ' Cubs' ' are' ' the' ' only' ' team' ' in']" ", the Chicago Cubs , and the Chicago White Sox .
+
+ The Cubs are the only team in" False wine called Ernie Banks 512 Chardonnay, 4 [' wine', ' called', ' Er', 'nie', ' Banks']
+783 207 The professional sport played by x -1 The professional sport played by Ernie Banks baseball Ernie Banks "[',' ' the' ' Chicago' ' Cubs' ',' ' and' ' the' ' Chicago' ' White'
+ ' Sox' '.' '\n' '\n' 'The' ' Cubs' ' are' ' the' ' only' ' team' ' in']" ", the Chicago Cubs , and the Chicago White Sox .
+
+ The Cubs are the only team in" False and first baseman Ernie Banks requested that 5 [' and', ' first', ' baseman', ' Er', 'nie', ' Banks']
+784 207 The professional sport played by x -1 The professional sport played by Ernie Banks baseball Ernie Banks "[',' ' the' ' Chicago' ' Cubs' ',' ' and' ' the' ' Chicago' ' White'
+ ' Sox' '.' '\n' '\n' 'The' ' Cubs' ' are' ' the' ' only' ' team' ' in']" ", the Chicago Cubs , and the Chicago White Sox .
+
+ The Cubs are the only team in" False Mathews overtook Ernie Banks of the Chicago Cubs 6 [' Mat', 'hews', ' overt', 'ook', ' Er', 'nie', ' Banks']
+785 208 The professional sport played by x -1 The professional sport played by Roger Maris baseball Roger Maris "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' player' ' to' ' hit' ' a' ' home' ' run']" , who was a member of the New York Yankees , and the first player to hit a home run False in 115 games since Roger Maris and Mickey Mantle 6 [' in', ' 115', ' games', ' since', ' Roger', ' Mar', 'is']
+786 208 The professional sport played by x -1 The professional sport played by Roger Maris baseball Roger Maris "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' player' ' to' ' hit' ' a' ' home' ' run']" , who was a member of the New York Yankees , and the first player to hit a home run False pennant years, with Roger Maris playing right field 7 [' penn', 'ant', ' years', ',', ' with', ' Roger', ' Mar', 'is']
+787 208 The professional sport played by x -1 The professional sport played by Roger Maris baseball Roger Maris "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' player' ' to' ' hit' ' a' ' home' ' run']" , who was a member of the New York Yankees , and the first player to hit a home run False home run record by Roger Maris and Mickey 6 [' home', ' run', ' record', ' by', ' Roger', ' Mar', 'is']
+788 208 The professional sport played by x -1 The professional sport played by Roger Maris baseball Roger Maris "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' player' ' to' ' hit' ' a' ' home' ' run']" , who was a member of the New York Yankees , and the first player to hit a home run False coincidentally helped Roger Maris break Babe Ruth's 5 [' coinc', 'identally', ' helped', ' Roger', ' Mar', 'is']
+789 208 The professional sport played by x -1 The professional sport played by Roger Maris baseball Roger Maris "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ',' ' and' ' the' ' first' ' player' ' to' ' hit' ' a' ' home' ' run']" , who was a member of the New York Yankees , and the first player to hit a home run False home run record by Roger Maris and Mickey Mantle. 6 [' home', ' run', ' record', ' by', ' Roger', ' Mar', 'is']
+790 209 The professional sport played by x -1 The professional sport played by Tim Howard soccer Tim Howard "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' US' 'M' 'NT'
+ ' since' ' the' ' early' ' 2000' 's' ',' ' has' ' been']" , who has been a main stay of the US M NT since the early 2000 s , has been False 1 ['Tim', ' Howard']
+791 209 The professional sport played by x -1 The professional sport played by Tim Howard soccer Tim Howard "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' US' 'M' 'NT'
+ ' since' ' the' ' early' ' 2000' 's' ',' ' has' ' been']" , who has been a main stay of the US M NT since the early 2000 s , has been False Clint Dempsey, Tim Howard and Michael Bradley 4 [' Clint', ' Dempsey', ',', ' Tim', ' Howard']
+792 209 The professional sport played by x -1 The professional sport played by Tim Howard soccer Tim Howard "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' US' 'M' 'NT'
+ ' since' ' the' ' early' ' 2000' 's' ',' ' has' ' been']" , who has been a main stay of the US M NT since the early 2000 s , has been False Clint Dempsey, Tim Howard and Michael 4 [' Clint', ' Dempsey', ',', ' Tim', ' Howard']
+793 209 The professional sport played by x -1 The professional sport played by Tim Howard soccer Tim Howard "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' US' 'M' 'NT'
+ ' since' ' the' ' early' ' 2000' 's' ',' ' has' ' been']" , who has been a main stay of the US M NT since the early 2000 s , has been False shoot-out. Goalkeeper Tim Howard saved Van Bronckhorst 7 [' shoot', '-', 'out', '.', ' Goal', 'keeper', ' Tim', ' Howard']
+794 209 The professional sport played by x -1 The professional sport played by Tim Howard soccer Tim Howard "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' US' 'M' 'NT'
+ ' since' ' the' ' early' ' 2000' 's' ',' ' has' ' been']" , who has been a main stay of the US M NT since the early 2000 s , has been False Goalkeepers Tim Howard and Jens Lehmann 3 [' Goal', 'keepers', ' Tim', ' Howard']
+795 210 The professional sport played by x -1 The professional sport played by Jaromír Jágr hockey Jaromír Jágr "[',' ' the' ' Czech' ' Republic' ""'s"" ' most' ' successful' ' player'
+ ' of' ' all' ' time' ',' ' is' ' a' ' game' ' that' ' has' ' been'
+ ' played' ' in']" , the Czech Republic 's most successful player of all time , is a game that has been played in False Selänne and leader Jaromír Jágr of the Pittsburgh 12 [' Sel', 'ä', 'n', 'ne', ' and', ' leader', ' Jar', 'om', 'í', 'r', ' J', 'á', 'gr']
+796 210 The professional sport played by x -1 The professional sport played by Jaromír Jágr hockey Jaromír Jágr "[',' ' the' ' Czech' ' Republic' ""'s"" ' most' ' successful' ' player'
+ ' of' ' all' ' time' ',' ' is' ' a' ' game' ' that' ' has' ' been'
+ ' played' ' in']" , the Czech Republic 's most successful player of all time , is a game that has been played in False and leader Jaromír Jágr of the Pittsburgh 8 [' and', ' leader', ' Jar', 'om', 'í', 'r', ' J', 'á', 'gr']
+797 210 The professional sport played by x -1 The professional sport played by Jaromír Jágr hockey Jaromír Jágr "[',' ' the' ' Czech' ' Republic' ""'s"" ' most' ' successful' ' player'
+ ' of' ' all' ' time' ',' ' is' ' a' ' game' ' that' ' has' ' been'
+ ' played' ' in']" , the Czech Republic 's most successful player of all time , is a game that has been played in False Ross Trophy winner Jaromír Jágr as the league's 9 [' Ross', ' Trophy', ' winner', ' Jar', 'om', 'í', 'r', ' J', 'á', 'gr']
+798 210 The professional sport played by x -1 The professional sport played by Jaromír Jágr hockey Jaromír Jágr "[',' ' the' ' Czech' ' Republic' ""'s"" ' most' ' successful' ' player'
+ ' of' ' all' ' time' ',' ' is' ' a' ' game' ' that' ' has' ' been'
+ ' played' ' in']" , the Czech Republic 's most successful player of all time , is a game that has been played in False Selänne and leader Jaromír Jágr of the Pittsburgh 12 [' Sel', 'ä', 'n', 'ne', ' and', ' leader', ' Jar', 'om', 'í', 'r', ' J', 'á', 'gr']
+799 210 The professional sport played by x -1 The professional sport played by Jaromír Jágr hockey Jaromír Jágr "[',' ' the' ' Czech' ' Republic' ""'s"" ' most' ' successful' ' player'
+ ' of' ' all' ' time' ',' ' is' ' a' ' game' ' that' ' has' ' been'
+ ' played' ' in']" , the Czech Republic 's most successful player of all time , is a game that has been played in False eight points behind Jaromír Jágr for the scoring 9 [' eight', ' points', ' behind', ' Jar', 'om', 'í', 'r', ' J', 'á', 'gr']
+800 211 The professional sport played by x -1 The professional sport played by Larry Bird basketball Larry Bird "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' '.' '\n' '\n' 'The']" ", the NBA 's all - time leading scorer , was a member of the Celtics .
+
+ The" False " them was the ""Larry Bird exception"", named" 5 "[' them', ' was', ' the', ' ""', 'Larry', ' Bird']"
+801 211 The professional sport played by x -1 The professional sport played by Larry Bird basketball Larry Bird "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' '.' '\n' '\n' 'The']" ", the NBA 's all - time leading scorer , was a member of the Celtics .
+
+ The" False 1986, Celtics star Larry Bird described 5 [' 1986', ',', ' Celtics', ' star', ' Larry', ' Bird']
+802 211 The professional sport played by x -1 The professional sport played by Larry Bird basketball Larry Bird "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' '.' '\n' '\n' 'The']" ", the NBA 's all - time leading scorer , was a member of the Celtics .
+
+ The" False Pacer General Manager Larry Bird noted that he had 5 [' Pac', 'er', ' General', ' Manager', ' Larry', ' Bird']
+803 211 The professional sport played by x -1 The professional sport played by Larry Bird basketball Larry Bird "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' '.' '\n' '\n' 'The']" ", the NBA 's all - time leading scorer , was a member of the Celtics .
+
+ The" False season) and the MVC Larry Bird Trophy Player 7 [' season', ')', ' and', ' the', ' M', 'VC', ' Larry', ' Bird']
+804 211 The professional sport played by x -1 The professional sport played by Larry Bird basketball Larry Bird "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Celtics' '.' '\n' '\n' 'The']" ", the NBA 's all - time leading scorer , was a member of the Celtics .
+
+ The" False combination of Larry Bird and 'Pistol' Pete (Maravich). 3 [' combination', ' of', ' Larry', ' Bird']
+805 212 The professional sport played by x -1 The professional sport played by Carl Weathers football Carl Weathers "[',' ' who' ' was' ' a' ' former' ' NFL' ' linebacker' ',' ' and' ' the'
+ ' former' ' NFL' ' linebacker' ',' ' who' ' was' ' a' ' former' ' NFL'
+ ' linebacker']" , who was a former NFL linebacker , and the former NFL linebacker , who was a former NFL linebacker False to be Gob's son. Carl Weathers plays a parodied 8 "[' to', ' be', ' Gob', ""'s"", ' son', '.', ' Carl', ' Weather', 's']"
+806 212 The professional sport played by x -1 The professional sport played by Carl Weathers football Carl Weathers "[',' ' who' ' was' ' a' ' former' ' NFL' ' linebacker' ',' ' and' ' the'
+ ' former' ' NFL' ' linebacker' ',' ' who' ' was' ' a' ' former' ' NFL'
+ ' linebacker']" , who was a former NFL linebacker , and the former NFL linebacker , who was a former NFL linebacker False Gob's son. Carl Weathers plays a parodied 6 "[' Gob', ""'s"", ' son', '.', ' Carl', ' Weather', 's']"
+807 212 The professional sport played by x -1 The professional sport played by Carl Weathers football Carl Weathers "[',' ' who' ' was' ' a' ' former' ' NFL' ' linebacker' ',' ' and' ' the'
+ ' former' ' NFL' ' linebacker' ',' ' who' ' was' ' a' ' former' ' NFL'
+ ' linebacker']" , who was a former NFL linebacker , and the former NFL linebacker , who was a former NFL linebacker False including Carl Weathers as himself, 3 [' including', ' Carl', ' Weather', 's']
+808 212 The professional sport played by x -1 The professional sport played by Carl Weathers football Carl Weathers "[',' ' who' ' was' ' a' ' former' ' NFL' ' linebacker' ',' ' and' ' the'
+ ' former' ' NFL' ' linebacker' ',' ' who' ' was' ' a' ' former' ' NFL'
+ ' linebacker']" , who was a former NFL linebacker , and the former NFL linebacker , who was a former NFL linebacker False discovered to be Gob's son. Carl Weathers plays a parodied 9 "[' discovered', ' to', ' be', ' Gob', ""'s"", ' son', '.', ' Carl', ' Weather', 's']"
+809 212 The professional sport played by x -1 The professional sport played by Carl Weathers football Carl Weathers "[',' ' who' ' was' ' a' ' former' ' NFL' ' linebacker' ',' ' and' ' the'
+ ' former' ' NFL' ' linebacker' ',' ' who' ' was' ' a' ' former' ' NFL'
+ ' linebacker']" , who was a former NFL linebacker , and the former NFL linebacker , who was a former NFL linebacker False be Gob's son. Carl Weathers plays a parodied version 7 "[' be', ' Gob', ""'s"", ' son', '.', ' Carl', ' Weather', 's']"
+810 213 The professional sport played by x -1 The professional sport played by Kenny Lofton baseball Kenny Lofton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ' from' ' 1995' ' to' ' 2001' '.' '\n' '\n' 'The' ' Yankees' ' have']" ", who was a member of the New York Yankees from 1995 to 2001 .
+
+ The Yankees have" False " 95 games. ""With Kenny Lofton we got a quality" 7 "[' 95', ' games', '.', ' ""', 'With', ' Kenny', ' Loft', 'on']"
+811 213 The professional sport played by x -1 The professional sport played by Kenny Lofton baseball Kenny Lofton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ' from' ' 1995' ' to' ' 2001' '.' '\n' '\n' 'The' ' Yankees' ' have']" ", who was a member of the New York Yankees from 1995 to 2001 .
+
+ The Yankees have" False Vizquel, and Kenny Lofton all started getting 6 [' Viz', 'quel', ',', ' and', ' Kenny', ' Loft', 'on']
+812 213 The professional sport played by x -1 The professional sport played by Kenny Lofton baseball Kenny Lofton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ' from' ' 1995' ' to' ' 2001' '.' '\n' '\n' 'The' ' Yankees' ' have']" ", who was a member of the New York Yankees from 1995 to 2001 .
+
+ The Yankees have" False electrifying moments on the Kenny Lofton highlight reel, none 7 [' electr', 'ifying', ' moments', ' on', ' the', ' Kenny', ' Loft', 'on']
+813 213 The professional sport played by x -1 The professional sport played by Kenny Lofton baseball Kenny Lofton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ' from' ' 1995' ' to' ' 2001' '.' '\n' '\n' 'The' ' Yankees' ' have']" ", who was a member of the New York Yankees from 1995 to 2001 .
+
+ The Yankees have" False " lost 95 games. ""With Kenny Lofton we got a quality" 8 "[' lost', ' 95', ' games', '.', ' ""', 'With', ' Kenny', ' Loft', 'on']"
+814 213 The professional sport played by x -1 The professional sport played by Kenny Lofton baseball Kenny Lofton "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' New' ' York' ' Yankees'
+ ' from' ' 1995' ' to' ' 2001' '.' '\n' '\n' 'The' ' Yankees' ' have']" ", who was a member of the New York Yankees from 1995 to 2001 .
+
+ The Yankees have" False " Lofton =
+" 6 [' Loft', 'on', ' =', 'K', 'enny', ' Loft', 'on']
+815 214 The professional sport played by x -1 The professional sport played by Nick Swisher baseball Nick Swisher "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' World'
+ ' Series' ' champion' ' with' ' the' ' New' ' York' ' Yankees' '.' '\n'
+ '\n']" ", who is a former All - Star and World Series champion with the New York Yankees .
+
+" False successfully helped Nick Swisher and Curtis 4 [' successfully', ' helped', ' Nick', ' Sw', 'isher']
+816 214 The professional sport played by x -1 The professional sport played by Nick Swisher baseball Nick Swisher "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' World'
+ ' Series' ' champion' ' with' ' the' ' New' ' York' ' Yankees' '.' '\n'
+ '\n']" ", who is a former All - Star and World Series champion with the New York Yankees .
+
+" False 2 ['Nick', ' Sw', 'isher']
+817 214 The professional sport played by x -1 The professional sport played by Nick Swisher baseball Nick Swisher "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' World'
+ ' Series' ' champion' ' with' ' the' ' New' ' York' ' Yankees' '.' '\n'
+ '\n']" ", who is a former All - Star and World Series champion with the New York Yankees .
+
+" False successfully helped Nick Swisher and Curtis Granderson 4 [' successfully', ' helped', ' Nick', ' Sw', 'isher']
+818 214 The professional sport played by x -1 The professional sport played by Nick Swisher baseball Nick Swisher "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' World'
+ ' Series' ' champion' ' with' ' the' ' New' ' York' ' Yankees' '.' '\n'
+ '\n']" ", who is a former All - Star and World Series champion with the New York Yankees .
+
+" False outfielder when Nick Swisher experienced tendinitis 4 [' outfielder', ' when', ' Nick', ' Sw', 'isher']
+819 214 The professional sport played by x -1 The professional sport played by Nick Swisher baseball Nick Swisher "[',' ' who' ' is' ' a' ' former' ' All' '-' 'Star' ' and' ' World'
+ ' Series' ' champion' ' with' ' the' ' New' ' York' ' Yankees' '.' '\n'
+ '\n']" ", who is a former All - Star and World Series champion with the New York Yankees .
+
+" False run after review. Nick Swisher opened the top of 6 [' run', ' after', ' review', '.', ' Nick', ' Sw', 'isher']
+820 215 The professional sport played by x -1 The professional sport played by DaMarcus Beasley soccer DaMarcus Beasley "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' team'
+ ' since' ' the' ' beginning' ' of' ' the' ' season' ',' ' has' ' been'
+ ' a']" , who has been a main stay of the team since the beginning of the season , has been a False to Seattle, DaMarcus Beasley from the Liga MX 6 [' to', ' Seattle', ',', ' Da', 'Marcus', ' Be', 'asley']
+821 215 The professional sport played by x -1 The professional sport played by DaMarcus Beasley soccer DaMarcus Beasley "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' team'
+ ' since' ' the' ' beginning' ' of' ' the' ' season' ',' ' has' ' been'
+ ' a']" , who has been a main stay of the team since the beginning of the season , has been a False League to Seattle, DaMarcus Beasley from the Liga 7 [' League', ' to', ' Seattle', ',', ' Da', 'Marcus', ' Be', 'asley']
+822 216 The professional sport played by x -1 The professional sport played by Shane Doan hockey Shane Doan "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Arizona' ' Coyotes' ','
+ ' and' ' the' ' NHL' ""'s"" ' first' '-' 'ever' ' rookie' ' of' ' the']" , who was a member of the Arizona Coyotes , and the NHL 's first - ever rookie of the False the suspended Shane Doan as team captain 4 [' the', ' suspended', ' Shane', ' Do', 'an']
+823 216 The professional sport played by x -1 The professional sport played by Shane Doan hockey Shane Doan "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Arizona' ' Coyotes' ','
+ ' and' ' the' ' NHL' ""'s"" ' first' '-' 'ever' ' rookie' ' of' ' the']" , who was a member of the Arizona Coyotes , and the NHL 's first - ever rookie of the False the suspended Shane Doan as team captain 4 [' the', ' suspended', ' Shane', ' Do', 'an']
+824 216 The professional sport played by x -1 The professional sport played by Shane Doan hockey Shane Doan "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Arizona' ' Coyotes' ','
+ ' and' ' the' ' NHL' ""'s"" ' first' '-' 'ever' ' rookie' ' of' ' the']" , who was a member of the Arizona Coyotes , and the NHL 's first - ever rookie of the False of the suspended Shane Doan as team captain from 5 [' of', ' the', ' suspended', ' Shane', ' Do', 'an']
+825 216 The professional sport played by x -1 The professional sport played by Shane Doan hockey Shane Doan "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Arizona' ' Coyotes' ','
+ ' and' ' the' ' NHL' ""'s"" ' first' '-' 'ever' ' rookie' ' of' ' the']" , who was a member of the Arizona Coyotes , and the NHL 's first - ever rookie of the False Coyotes forward Shane Doan during a game on December 4 [' Coyotes', ' forward', ' Shane', ' Do', 'an']
+826 217 The professional sport played by x -1 The professional sport played by John Elway football John Elway "[',' ' the' ' Denver' ' Broncos' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' a' ' great'
+ ' leader' ' and']" , the Denver Broncos quarterback , is a great example of a player who has been a great leader and False sneakiness about it; John Elway had simply shown what 7 [' sneak', 'iness', ' about', ' it', ';', ' John', ' El', 'way']
+827 217 The professional sport played by x -1 The professional sport played by John Elway football John Elway "[',' ' the' ' Denver' ' Broncos' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' a' ' great'
+ ' leader' ' and']" , the Denver Broncos quarterback , is a great example of a player who has been a great leader and False quarter, tying him with John Elway and Johnny 7 [' quarter', ',', ' tying', ' him', ' with', ' John', ' El', 'way']
+828 217 The professional sport played by x -1 The professional sport played by John Elway football John Elway "[',' ' the' ' Denver' ' Broncos' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' a' ' great'
+ ' leader' ' and']" , the Denver Broncos quarterback , is a great example of a player who has been a great leader and False quarterback draw by John Elway to make the score 5 [' quarterback', ' draw', ' by', ' John', ' El', 'way']
+829 217 The professional sport played by x -1 The professional sport played by John Elway football John Elway "[',' ' the' ' Denver' ' Broncos' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' a' ' great'
+ ' leader' ' and']" , the Denver Broncos quarterback , is a great example of a player who has been a great leader and False quarterback draw by John Elway to make the 5 [' quarterback', ' draw', ' by', ' John', ' El', 'way']
+830 217 The professional sport played by x -1 The professional sport played by John Elway football John Elway "[',' ' the' ' Denver' ' Broncos' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' player' ' who' ' has' ' been' ' a' ' great'
+ ' leader' ' and']" , the Denver Broncos quarterback , is a great example of a player who has been a great leader and False the Broncos, led by John Elway in his final 7 [' the', ' Broncos', ',', ' led', ' by', ' John', ' El', 'way']
+831 218 The professional sport played by x -1 The professional sport played by David Villa soccer David Villa "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' Spanish'
+ ' national' ' team' ' since' ' his' ' debut' ' in' ' 2002' ',' ' has'
+ ' been']" , who has been a main stay of the Spanish national team since his debut in 2002 , has been False " David Villa =
+" 1 [' David', ' Villa']
+832 218 The professional sport played by x -1 The professional sport played by David Villa soccer David Villa "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' Spanish'
+ ' national' ' team' ' since' ' his' ' debut' ' in' ' 2002' ',' ' has'
+ ' been']" , who has been a main stay of the Spanish national team since his debut in 2002 , has been False 1 ['David', ' Villa']
+833 218 The professional sport played by x -1 The professional sport played by David Villa soccer David Villa "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' Spanish'
+ ' national' ' team' ' since' ' his' ' debut' ' in' ' 2002' ',' ' has'
+ ' been']" , who has been a main stay of the Spanish national team since his debut in 2002 , has been False Rosell signed David Villa from Valencia for 4 [' Rose', 'll', ' signed', ' David', ' Villa']
+834 218 The professional sport played by x -1 The professional sport played by David Villa soccer David Villa "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' Spanish'
+ ' national' ' team' ' since' ' his' ' debut' ' in' ' 2002' ',' ' has'
+ ' been']" , who has been a main stay of the Spanish national team since his debut in 2002 , has been False " like to play is David Villa of Valencia.""" 5 [' like', ' to', ' play', ' is', ' David', ' Villa']
+835 218 The professional sport played by x -1 The professional sport played by David Villa soccer David Villa "[',' ' who' ' has' ' been' ' a' ' main' 'stay' ' of' ' the' ' Spanish'
+ ' national' ' team' ' since' ' his' ' debut' ' in' ' 2002' ',' ' has'
+ ' been']" , who has been a main stay of the Spanish national team since his debut in 2002 , has been False leading scorer David Villa and Chelsea's 3 [' leading', ' scorer', ' David', ' Villa']
+836 219 The professional sport played by x -1 The professional sport played by Terry Crews football Terry Crews "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' former'
+ ' NFL' ' player' ',' ' and' ' a' ' former' ' NFL' ' player' '.']" , who is a former NFL player , is a former NFL player , and a former NFL player . False 2 ['Terry', ' Crew', 's']
+837 219 The professional sport played by x -1 The professional sport played by Terry Crews football Terry Crews "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' former'
+ ' NFL' ' player' ',' ' and' ' a' ' former' ' NFL' ' player' '.']" , who is a former NFL player , is a former NFL player , and a former NFL player . False 2013 – 14 season, Terry Crews in the following 7 [' 2013', ' –', ' 14', ' season', ',', ' Terry', ' Crew', 's']
+838 219 The professional sport played by x -1 The professional sport played by Terry Crews football Terry Crews "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' former'
+ ' NFL' ' player' ',' ' and' ' a' ' former' ' NFL' ' player' '.']" , who is a former NFL player , is a former NFL player , and a former NFL player . False being succeeded by Terry Crews for the 2014 5 [' being', ' succeeded', ' by', ' Terry', ' Crew', 's']
+839 219 The professional sport played by x -1 The professional sport played by Terry Crews football Terry Crews "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' former'
+ ' NFL' ' player' ',' ' and' ' a' ' former' ' NFL' ' player' '.']" , who is a former NFL player , is a former NFL player , and a former NFL player . False " Crews as Hale Caesar:
+" 8 [' Crew', 's', ' as', ' Hale', ' Caesar', ':', 'Terry', ' Crew', 's']
+840 219 The professional sport played by x -1 The professional sport played by Terry Crews football Terry Crews "[',' ' who' ' is' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' former'
+ ' NFL' ' player' ',' ' and' ' a' ' former' ' NFL' ' player' '.']" , who is a former NFL player , is a former NFL player , and a former NFL player . False succeeded by Terry Crews for the 2014 – 4 [' succeeded', ' by', ' Terry', ' Crew', 's']
+841 220 The professional sport played by x -1 The professional sport played by Jacques Plante hockey Jacques Plante "[',' ' the' ' NHL' ""'s"" ' first' ' superstar' ' goalie' ',' ' was' ' a'
+ ' goalie' ' who' ' was' ' a' ' great' ' goalie' '.' ' He' ' was' ' a']" , the NHL 's first superstar goalie , was a goalie who was a great goalie . He was a False Canadiens. The Jacques Plante Memorial Trophy 5 [' Canadiens', '.', ' The', ' Jacques', ' Pl', 'ante']
+842 220 The professional sport played by x -1 The professional sport played by Jacques Plante hockey Jacques Plante "[',' ' the' ' NHL' ""'s"" ' first' ' superstar' ' goalie' ',' ' was' ' a'
+ ' goalie' ' who' ' was' ' a' ' great' ' goalie' '.' ' He' ' was' ' a']" , the NHL 's first superstar goalie , was a goalie who was a great goalie . He was a False Canadiens. The Jacques Plante Memorial Trophy 5 [' Canadiens', '.', ' The', ' Jacques', ' Pl', 'ante']
+843 220 The professional sport played by x -1 The professional sport played by Jacques Plante hockey Jacques Plante "[',' ' the' ' NHL' ""'s"" ' first' ' superstar' ' goalie' ',' ' was' ' a'
+ ' goalie' ' who' ' was' ' a' ' great' ' goalie' '.' ' He' ' was' ' a']" , the NHL 's first superstar goalie , was a goalie who was a great goalie . He was a False " ""Boom Boom"" Geoffrion, Jacques Plante and Richard's younger" 11 "[' ""', 'B', 'oom', ' Boom', '""', ' Geoff', 'r', 'ion', ',', ' Jacques', ' Pl', 'ante']"
+844 220 The professional sport played by x -1 The professional sport played by Jacques Plante hockey Jacques Plante "[',' ' the' ' NHL' ""'s"" ' first' ' superstar' ' goalie' ',' ' was' ' a'
+ ' goalie' ' who' ' was' ' a' ' great' ' goalie' '.' ' He' ' was' ' a']" , the NHL 's first superstar goalie , was a goalie who was a great goalie . He was a False historians credit Jacques Plante for originating 4 [' historians', ' credit', ' Jacques', ' Pl', 'ante']
+845 220 The professional sport played by x -1 The professional sport played by Jacques Plante hockey Jacques Plante "[',' ' the' ' NHL' ""'s"" ' first' ' superstar' ' goalie' ',' ' was' ' a'
+ ' goalie' ' who' ' was' ' a' ' great' ' goalie' '.' ' He' ' was' ' a']" , the NHL 's first superstar goalie , was a goalie who was a great goalie . He was a False years later, Jacques Plante made the goalie 5 [' years', ' later', ',', ' Jacques', ' Pl', 'ante']
+846 221 The professional sport played by x -1 The professional sport played by Júlio Baptista soccer Júlio Baptista "[',' ' a' ' Brazilian' ' footballer' ' who' ' plays' ' as' ' a'
+ ' midfielder' ' for' ' the' ' Brazilian' ' club' ' Atl' 'ét' 'ico'
+ ' Mine' 'iro' '.' '\n']" ", a Brazilian footballer who plays as a midfielder for the Brazilian club Atl ét ico Mine iro .
+" False several new signings: Júlio Baptista (€ 24 million), 8 [' several', ' new', ' signings', ':', ' J', 'ú', 'lio', ' Bapt', 'ista']
+847 221 The professional sport played by x -1 The professional sport played by Júlio Baptista soccer Júlio Baptista "[',' ' a' ' Brazilian' ' footballer' ' who' ' plays' ' as' ' a'
+ ' midfielder' ' for' ' the' ' Brazilian' ' club' ' Atl' 'ét' 'ico'
+ ' Mine' 'iro' '.' '\n']" ", a Brazilian footballer who plays as a midfielder for the Brazilian club Atl ét ico Mine iro .
+" False signings: Júlio Baptista (€ 24 million), 6 [' signings', ':', ' J', 'ú', 'lio', ' Bapt', 'ista']
+848 222 The professional sport played by x -1 The professional sport played by Mikaël Silvestre soccer Mikaël Silvestre "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' sport' ' that' ' is' ' played' ' in' ' the' ' United' ' States' ','
+ ' Canada' ',']" , a former professional soccer player , is a sport that is played in the United States , Canada , True post – an unmarked Mikaël Silvestre headed it into 10 [' post', ' –', ' an', ' unmarked', ' M', 'ika', 'ë', 'l', ' Sil', 'vest', 're']
+849 222 The professional sport played by x -1 The professional sport played by Mikaël Silvestre soccer Mikaël Silvestre "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' sport' ' that' ' is' ' played' ' in' ' the' ' United' ' States' ','
+ ' Canada' ',']" , a former professional soccer player , is a sport that is played in the United States , Canada , True to clear the ball. Mikaël Silvestre scored an own 11 [' to', ' clear', ' the', ' ball', '.', ' M', 'ika', 'ë', 'l', ' Sil', 'vest', 're']
+850 222 The professional sport played by x -1 The professional sport played by Mikaël Silvestre soccer Mikaël Silvestre "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' sport' ' that' ' is' ' played' ' in' ' the' ' United' ' States' ','
+ ' Canada' ',']" , a former professional soccer player , is a sport that is played in the United States , Canada , True however, defender Mikaël Silvestre was declared 9 [' however', ',', ' defender', ' M', 'ika', 'ë', 'l', ' Sil', 'vest', 're']
+851 222 The professional sport played by x -1 The professional sport played by Mikaël Silvestre soccer Mikaël Silvestre "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' sport' ' that' ' is' ' played' ' in' ' the' ' United' ' States' ','
+ ' Canada' ',']" , a former professional soccer player , is a sport that is played in the United States , Canada , True clear the ball. Mikaël Silvestre scored an own 10 [' clear', ' the', ' ball', '.', ' M', 'ika', 'ë', 'l', ' Sil', 'vest', 're']
+852 222 The professional sport played by x -1 The professional sport played by Mikaël Silvestre soccer Mikaël Silvestre "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' sport' ' that' ' is' ' played' ' in' ' the' ' United' ' States' ','
+ ' Canada' ',']" , a former professional soccer player , is a sport that is played in the United States , Canada , True Phil Neville and Mikaël Silvestre were all preferred 9 [' Phil', ' Neville', ' and', ' M', 'ika', 'ë', 'l', ' Sil', 'vest', 're']
+853 223 The professional sport played by x -1 The professional sport played by Nigel de Jong soccer Nigel de Jong "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Dutch' ' side'
+ ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular in the Dutch side since the age of 16 , has been a key False Carlos Tevez, and Nigel de Jong and Gareth Barry in 7 [' Carlos', ' Te', 'vez', ',', ' and', ' Nigel', ' de', ' Jong']
+854 223 The professional sport played by x -1 The professional sport played by Nigel de Jong soccer Nigel de Jong "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Dutch' ' side'
+ ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular in the Dutch side since the age of 16 , has been a key False first half when Nigel de Jong scored, ten 5 [' first', ' half', ' when', ' Nigel', ' de', ' Jong']
+855 223 The professional sport played by x -1 The professional sport played by Nigel de Jong soccer Nigel de Jong "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Dutch' ' side'
+ ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular in the Dutch side since the age of 16 , has been a key False the first half when Nigel de Jong scored, ten minutes 6 [' the', ' first', ' half', ' when', ' Nigel', ' de', ' Jong']
+856 223 The professional sport played by x -1 The professional sport played by Nigel de Jong soccer Nigel de Jong "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Dutch' ' side'
+ ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular in the Dutch side since the age of 16 , has been a key False first half when Nigel de Jong scored, ten 5 [' first', ' half', ' when', ' Nigel', ' de', ' Jong']
+857 223 The professional sport played by x -1 The professional sport played by Nigel de Jong soccer Nigel de Jong "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Dutch' ' side'
+ ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular in the Dutch side since the age of 16 , has been a key False the first half when Nigel de Jong scored, ten minutes 6 [' the', ' first', ' half', ' when', ' Nigel', ' de', ' Jong']
+858 225 The professional sport played by x -1 The professional sport played by Obafemi Martins soccer Obafemi Martins "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' club' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular starter for the club since the age of 16 , has been a key False 16, however, Obafemi Martins left the game 8 [' 16', ',', ' however', ',', ' Ob', 'af', 'emi', ' Mart', 'ins']
+859 225 The professional sport played by x -1 The professional sport played by Obafemi Martins soccer Obafemi Martins "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' club' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular starter for the club since the age of 16 , has been a key False from Nikola Žigić and Obafemi Martins and securing qualification 11 [' from', ' Nikola', ' �', '�', 'igi', 'ć', ' and', ' Ob', 'af', 'emi', ' Mart', 'ins']
+860 225 The professional sport played by x -1 The professional sport played by Obafemi Martins soccer Obafemi Martins "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' club' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular starter for the club since the age of 16 , has been a key False scored twice while Obafemi Martins and Pappa both had 7 [' scored', ' twice', ' while', ' Ob', 'af', 'emi', ' Mart', 'ins']
+861 225 The professional sport played by x -1 The professional sport played by Obafemi Martins soccer Obafemi Martins "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' club' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular starter for the club since the age of 16 , has been a key False decision not to start Obafemi Martins and Pappa was 8 [' decision', ' not', ' to', ' start', ' Ob', 'af', 'emi', ' Mart', 'ins']
+862 225 The professional sport played by x -1 The professional sport played by Obafemi Martins soccer Obafemi Martins "[',' ' who' ' has' ' been' ' a' ' regular' ' starter' ' for' ' the'
+ ' club' ' since' ' the' ' age' ' of' ' 16' ',' ' has' ' been' ' a' ' key']" , who has been a regular starter for the club since the age of 16 , has been a key False time period, and Obafemi Martins sealed a Seattle 8 [' time', ' period', ',', ' and', ' Ob', 'af', 'emi', ' Mart', 'ins']
+863 227 The professional sport played by x -1 The professional sport played by Evgeni Malkin hockey Evgeni Malkin "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Pittsburgh'
+ ' Penguins' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1967' '.'
+ '\n' '\n']" ", who has been a member of the Pittsburgh Penguins since the team 's inception in 1967 .
+
+" False of the season, Evgeni Malkin acquired his 300th 8 [' of', ' the', ' season', ',', ' Ev', 'gen', 'i', ' Malk', 'in']
+864 227 The professional sport played by x -1 The professional sport played by Evgeni Malkin hockey Evgeni Malkin "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Pittsburgh'
+ ' Penguins' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1967' '.'
+ '\n' '\n']" ", who has been a member of the Pittsburgh Penguins since the team 's inception in 1967 .
+
+" False Alexander Ovechkin and Evgeni Malkin stated that they 10 [' Alexander', ' O', 've', 'ch', 'kin', ' and', ' Ev', 'gen', 'i', ' Malk', 'in']
+865 227 The professional sport played by x -1 The professional sport played by Evgeni Malkin hockey Evgeni Malkin "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Pittsburgh'
+ ' Penguins' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1967' '.'
+ '\n' '\n']" ", who has been a member of the Pittsburgh Penguins since the team 's inception in 1967 .
+
+" False 4 ['Ev', 'gen', 'i', ' Malk', 'in']
+866 227 The professional sport played by x -1 The professional sport played by Evgeni Malkin hockey Evgeni Malkin "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Pittsburgh'
+ ' Penguins' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1967' '.'
+ '\n' '\n']" ", who has been a member of the Pittsburgh Penguins since the team 's inception in 1967 .
+
+" False in a row. Teammate Evgeni Malkin was selected 10 [' in', ' a', ' row', '.', ' Team', 'mate', ' Ev', 'gen', 'i', ' Malk', 'in']
+867 227 The professional sport played by x -1 The professional sport played by Evgeni Malkin hockey Evgeni Malkin "[',' ' who' ' has' ' been' ' a' ' member' ' of' ' the' ' Pittsburgh'
+ ' Penguins' ' since' ' the' ' team' ""'s"" ' inception' ' in' ' 1967' '.'
+ '\n' '\n']" ", who has been a member of the Pittsburgh Penguins since the team 's inception in 1967 .
+
+" False Sidney Crosby, Evgeni Malkin, Tyler Kennedy and 7 [' Sidney', ' Crosby', ',', ' Ev', 'gen', 'i', ' Malk', 'in']
+868 229 The professional sport played by x -1 The professional sport played by Ilya Kovalchuk hockey Ilya Kovalchuk "[',' ' who' ' is' ' a' ' Russian' ' hockey' ' player' '.' '\n' '\n' 'The'
+ ' NHL' ' is' ' a' ' professional' ' ice' ' hockey' ' league' ' in'
+ ' North']" ", who is a Russian hockey player .
+
+ The NHL is a professional ice hockey league in North" True goal-scoring title with Ilya Kovalchuk and Rick Nash 9 [' goal', '-', 'scoring', ' title', ' with', ' Ily', 'a', ' K', 'oval', 'chuk']
+869 229 The professional sport played by x -1 The professional sport played by Ilya Kovalchuk hockey Ilya Kovalchuk "[',' ' who' ' is' ' a' ' Russian' ' hockey' ' player' '.' '\n' '\n' 'The'
+ ' NHL' ' is' ' a' ' professional' ' ice' ' hockey' ' league' ' in'
+ ' North']" ", who is a Russian hockey player .
+
+ The NHL is a professional ice hockey league in North" True as a cross between Ilya Kovalchuk and Maxim Afinogenov. 8 [' as', ' a', ' cross', ' between', ' Ily', 'a', ' K', 'oval', 'chuk']
+870 229 The professional sport played by x -1 The professional sport played by Ilya Kovalchuk hockey Ilya Kovalchuk "[',' ' who' ' is' ' a' ' Russian' ' hockey' ' player' '.' '\n' '\n' 'The'
+ ' NHL' ' is' ' a' ' professional' ' ice' ' hockey' ' league' ' in'
+ ' North']" ", who is a Russian hockey player .
+
+ The NHL is a professional ice hockey league in North" True surpassing Ilya Kovalchuk (18 years, 172 6 [' surpass', 'ing', ' Ily', 'a', ' K', 'oval', 'chuk']
+871 229 The professional sport played by x -1 The professional sport played by Ilya Kovalchuk hockey Ilya Kovalchuk "[',' ' who' ' is' ' a' ' Russian' ' hockey' ' player' '.' '\n' '\n' 'The'
+ ' NHL' ' is' ' a' ' professional' ' ice' ' hockey' ' league' ' in'
+ ' North']" ", who is a Russian hockey player .
+
+ The NHL is a professional ice hockey league in North" True goal-scoring title with Ilya Kovalchuk and Rick Nash 9 [' goal', '-', 'scoring', ' title', ' with', ' Ily', 'a', ' K', 'oval', 'chuk']
+872 229 The professional sport played by x -1 The professional sport played by Ilya Kovalchuk hockey Ilya Kovalchuk "[',' ' who' ' is' ' a' ' Russian' ' hockey' ' player' '.' '\n' '\n' 'The'
+ ' NHL' ' is' ' a' ' professional' ' ice' ' hockey' ' league' ' in'
+ ' North']" ", who is a Russian hockey player .
+
+ The NHL is a professional ice hockey league in North" True games since Ilya Kovalchuk in November 2007 6 [' games', ' since', ' Ily', 'a', ' K', 'oval', 'chuk']
+873 230 The professional sport played by x -1 The professional sport played by Frank Lampard soccer Frank Lampard "[',' ' who' ' has' ' been' ' a' ' professional' ' footballer' ' for'
+ ' over' ' a' ' decade' ',' ' has' ' been' ' a' ' professional'
+ ' footballer' ' for' ' over' ' a']" , who has been a professional footballer for over a decade , has been a professional footballer for over a False Chelsea's leading scorer Frank Lampard to New York, 6 "[' Chelsea', ""'s"", ' leading', ' scorer', ' Frank', ' Lamp', 'ard']"
+874 230 The professional sport played by x -1 The professional sport played by Frank Lampard soccer Frank Lampard "[',' ' who' ' has' ' been' ' a' ' professional' ' footballer' ' for'
+ ' over' ' a' ' decade' ',' ' has' ' been' ' a' ' professional'
+ ' footballer' ' for' ' over' ' a']" , who has been a professional footballer for over a decade , has been a professional footballer for over a False Sagna, Fernandinho, Frank Lampard and Sergio Agüero. 8 [' Sag', 'na', ',', ' Fernand', 'inho', ',', ' Frank', ' Lamp', 'ard']
+875 230 The professional sport played by x -1 The professional sport played by Frank Lampard soccer Frank Lampard "[',' ' who' ' has' ' been' ' a' ' professional' ' footballer' ' for'
+ ' over' ' a' ' decade' ',' ' has' ' been' ' a' ' professional'
+ ' footballer' ' for' ' over' ' a']" , who has been a professional footballer for over a decade , has been a professional footballer for over a False Fernandinho, Frank Lampard and Sergio Agüero. 5 [' Fernand', 'inho', ',', ' Frank', ' Lamp', 'ard']
+876 230 The professional sport played by x -1 The professional sport played by Frank Lampard soccer Frank Lampard "[',' ' who' ' has' ' been' ' a' ' professional' ' footballer' ' for'
+ ' over' ' a' ' decade' ',' ' has' ' been' ' a' ' professional'
+ ' footballer' ' for' ' over' ' a']" , who has been a professional footballer for over a decade , has been a professional footballer for over a False outstanding penalty save from Frank Lampard to deny Chelsea an 6 [' outstanding', ' penalty', ' save', ' from', ' Frank', ' Lamp', 'ard']
+877 230 The professional sport played by x -1 The professional sport played by Frank Lampard soccer Frank Lampard "[',' ' who' ' has' ' been' ' a' ' professional' ' footballer' ' for'
+ ' over' ' a' ' decade' ',' ' has' ' been' ' a' ' professional'
+ ' footballer' ' for' ' over' ' a']" , who has been a professional footballer for over a decade , has been a professional footballer for over a False and Thierry Henry. Frank Lampard scored a late 8 [' and', ' Th', 'ier', 'ry', ' Henry', '.', ' Frank', ' Lamp', 'ard']
+878 231 The professional sport played by x -1 The professional sport played by Freddie Ljungberg soccer Freddie Ljungberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'References' '\n'
+ '\n']" ", who was a member of the Swedish national team from 1995 to 2004 .
+
+ References
+
+" False Adams, Anelka and Freddie Ljungberg helped the champions 10 [' Adams', ',', ' An', 'el', 'ka', ' and', ' Freddie', ' L', 'j', 'ung', 'berg']
+879 231 The professional sport played by x -1 The professional sport played by Freddie Ljungberg soccer Freddie Ljungberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'References' '\n'
+ '\n']" ", who was a member of the Swedish national team from 1995 to 2004 .
+
+ References
+
+" False at Old Trafford, Freddie Ljungberg scored the 8 [' at', ' Old', ' Trafford', ',', ' Freddie', ' L', 'j', 'ung', 'berg']
+880 231 The professional sport played by x -1 The professional sport played by Freddie Ljungberg soccer Freddie Ljungberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'References' '\n'
+ '\n']" ", who was a member of the Swedish national team from 1995 to 2004 .
+
+ References
+
+" False 2003 at Old Trafford, Freddie Ljungberg scored the winning 9 [' 2003', ' at', ' Old', ' Trafford', ',', ' Freddie', ' L', 'j', 'ung', 'berg']
+881 231 The professional sport played by x -1 The professional sport played by Freddie Ljungberg soccer Freddie Ljungberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'References' '\n'
+ '\n']" ", who was a member of the Swedish national team from 1995 to 2004 .
+
+ References
+
+" False Nicolas Anelka and Freddie Ljungberg condemned Manchester 9 [' Nicolas', ' An', 'el', 'ka', ' and', ' Freddie', ' L', 'j', 'ung', 'berg']
+882 231 The professional sport played by x -1 The professional sport played by Freddie Ljungberg soccer Freddie Ljungberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1995' ' to' ' 2004' '.' '\n' '\n' 'References' '\n'
+ '\n']" ", who was a member of the Swedish national team from 1995 to 2004 .
+
+ References
+
+" False " to midfielder Freddie Ljungberg for another tap-in.""" 6 [' to', ' midfielder', ' Freddie', ' L', 'j', 'ung', 'berg']
+883 232 The professional sport played by x -1 The professional sport played by Peter Forsberg hockey Peter Forsberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1992' ' to' ' 2002' '.' '\n' '\n' 'The' ' first'
+ ' time']" ", who was a member of the Swedish national team from 1992 to 2002 .
+
+ The first time" False Persson in 1980 and Peter Forsberg in 1994. Henrik 7 [' Pers', 'son', ' in', ' 1980', ' and', ' Peter', ' Fors', 'berg']
+884 232 The professional sport played by x -1 The professional sport played by Peter Forsberg hockey Peter Forsberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1992' ' to' ' 2002' '.' '\n' '\n' 'The' ' first'
+ ' time']" ", who was a member of the Swedish national team from 1992 to 2002 .
+
+ The first time" False Hejduk and Peter Forsberg. All but Hejduk were 7 [' He', 'j', 'du', 'k', ' and', ' Peter', ' Fors', 'berg']
+885 232 The professional sport played by x -1 The professional sport played by Peter Forsberg hockey Peter Forsberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1992' ' to' ' 2002' '.' '\n' '\n' 'The' ' first'
+ ' time']" ", who was a member of the Swedish national team from 1992 to 2002 .
+
+ The first time" False Former Avs great Peter Forsberg attempted a comeback 6 [' Former', ' Av', 's', ' great', ' Peter', ' Fors', 'berg']
+886 232 The professional sport played by x -1 The professional sport played by Peter Forsberg hockey Peter Forsberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1992' ' to' ' 2002' '.' '\n' '\n' 'The' ' first'
+ ' time']" ", who was a member of the Swedish national team from 1992 to 2002 .
+
+ The first time" False Modo teammate Peter Forsberg on the squad. 5 [' Mod', 'o', ' teammate', ' Peter', ' Fors', 'berg']
+887 232 The professional sport played by x -1 The professional sport played by Peter Forsberg hockey Peter Forsberg "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Swedish' ' national'
+ ' team' ' from' ' 1992' ' to' ' 2002' '.' '\n' '\n' 'The' ' first'
+ ' time']" ", who was a member of the Swedish national team from 1992 to 2002 .
+
+ The first time" False Chris Drury, Peter Forsberg and Peter Stastny 6 [' Chris', ' Dr', 'ury', ',', ' Peter', ' Fors', 'berg']
+888 233 The professional sport played by x -1 The professional sport played by Claudio Reyna soccer Claudio Reyna "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' U' '.' 'S' '.'
+ ' national' ' team' ' that' ' won' ' the' ' gold' ' medal' ' in' ' the']" , who was a member of the U . S . national team that won the gold medal in the False national team stars Claudio Reyna and Brian McBride, 6 [' national', ' team', ' stars', ' Cl', 'audio', ' Reyn', 'a']
+889 234 The professional sport played by x -1 The professional sport played by Boris Diaw basketball Boris Diaw "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False stalwarts Raja Bell and Boris Diaw to the Charlotte 8 [' stal', 'warts', ' Raj', 'a', ' Bell', ' and', ' Boris', ' Dia', 'w']
+890 234 The professional sport played by x -1 The professional sport played by Boris Diaw basketball Boris Diaw "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False stalwarts Raja Bell and Boris Diaw to the Charlotte 8 [' stal', 'warts', ' Raj', 'a', ' Bell', ' and', ' Boris', ' Dia', 'w']
+891 234 The professional sport played by x -1 The professional sport played by Boris Diaw basketball Boris Diaw "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False stalwarts Raja Bell and Boris Diaw to the Charlotte 8 [' stal', 'warts', ' Raj', 'a', ' Bell', ' and', ' Boris', ' Dia', 'w']
+892 234 The professional sport played by x -1 The professional sport played by Boris Diaw basketball Boris Diaw "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False future teammate) Boris Diaw was Parker's best 5 [' future', ' teammate', ')', ' Boris', ' Dia', 'w']
+893 234 The professional sport played by x -1 The professional sport played by Boris Diaw basketball Boris Diaw "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' is' ' a' ' former'
+ ' NBA' ' player' ',' ' and' ' a' ' former' ' NBA' ' player' '.']" , who is a former NBA player , is a former NBA player , and a former NBA player . False Raja Bell and Boris Diaw to the Charlotte 6 [' Raj', 'a', ' Bell', ' and', ' Boris', ' Dia', 'w']
+894 235 The professional sport played by x -1 The professional sport played by Nate Schierholtz baseball Nate Schierholtz "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' the' ' game' ' can' ' be' ' played' '.' '\n' '\n']" ", a former NFL player , is a great example of how the game can be played .
+
+" False 5 ['N', 'ate', ' Sch', 'ier', 'hol', 'tz']
+895 235 The professional sport played by x -1 The professional sport played by Nate Schierholtz baseball Nate Schierholtz "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' the' ' game' ' can' ' be' ' played' '.' '\n' '\n']" ", a former NFL player , is a great example of how the game can be played .
+
+" False Roy Halladay. Nate Schierholtz then pinch-ran 8 [' Roy', ' Hall', 'aday', '.', ' Nate', ' Sch', 'ier', 'hol', 'tz']
+896 235 The professional sport played by x -1 The professional sport played by Nate Schierholtz baseball Nate Schierholtz "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' the' ' game' ' can' ' be' ' played' '.' '\n' '\n']" ", a former NFL player , is a great example of how the game can be played .
+
+" False outfielder Nate Schierholtz in a May 4 game, 5 [' outfielder', ' Nate', ' Sch', 'ier', 'hol', 'tz']
+897 235 The professional sport played by x -1 The professional sport played by Nate Schierholtz baseball Nate Schierholtz "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' the' ' game' ' can' ' be' ' played' '.' '\n' '\n']" ", a former NFL player , is a great example of how the game can be played .
+
+" False Oeltjen and Nate Schierholtz behind Tim Raines, 8 [' O', 'elt', 'jen', ' and', ' Nate', ' Sch', 'ier', 'hol', 'tz']
+898 235 The professional sport played by x -1 The professional sport played by Nate Schierholtz baseball Nate Schierholtz "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' the' ' game' ' can' ' be' ' played' '.' '\n' '\n']" ", a former NFL player , is a great example of how the game can be played .
+
+" False Roy Halladay. Nate Schierholtz then pinch-ran for 8 [' Roy', ' Hall', 'aday', '.', ' Nate', ' Sch', 'ier', 'hol', 'tz']
+899 236 The professional sport played by x -1 The professional sport played by Warren Moon football Warren Moon "[',' ' who' ' was' ' a' ' quarterback' ' for' ' the' ' Houston' ' Oilers'
+ ' in' ' the' ' NFL' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in']" ", who was a quarterback for the Houston Oilers in the NFL .
+
+ The game was played in" False with the signing of Warren Moon out of the Canadian 5 [' with', ' the', ' signing', ' of', ' Warren', ' Moon']
+900 236 The professional sport played by x -1 The professional sport played by Warren Moon football Warren Moon "[',' ' who' ' was' ' a' ' quarterback' ' for' ' the' ' Houston' ' Oilers'
+ ' in' ' the' ' NFL' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in']" ", who was a quarterback for the Houston Oilers in the NFL .
+
+ The game was played in" False Quarterback Warren Moon led the Eskimos 3 [' Quarter', 'back', ' Warren', ' Moon']
+901 236 The professional sport played by x -1 The professional sport played by Warren Moon football Warren Moon "[',' ' who' ' was' ' a' ' quarterback' ' for' ' the' ' Houston' ' Oilers'
+ ' in' ' the' ' NFL' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in']" ", who was a quarterback for the Houston Oilers in the NFL .
+
+ The game was played in" False the signing of Warren Moon out of the Canadian 4 [' the', ' signing', ' of', ' Warren', ' Moon']
+902 236 The professional sport played by x -1 The professional sport played by Warren Moon football Warren Moon "[',' ' who' ' was' ' a' ' quarterback' ' for' ' the' ' Houston' ' Oilers'
+ ' in' ' the' ' NFL' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in']" ", who was a quarterback for the Houston Oilers in the NFL .
+
+ The game was played in" False lead. Quarterback Warren Moon led the Eskimos 5 [' lead', '.', ' Quarter', 'back', ' Warren', ' Moon']
+903 236 The professional sport played by x -1 The professional sport played by Warren Moon football Warren Moon "[',' ' who' ' was' ' a' ' quarterback' ' for' ' the' ' Houston' ' Oilers'
+ ' in' ' the' ' NFL' '.' '\n' '\n' 'The' ' game' ' was' ' played' ' in']" ", who was a quarterback for the Houston Oilers in the NFL .
+
+ The game was played in" False signing of Warren Moon out of the Canadian 3 [' signing', ' of', ' Warren', ' Moon']
+904 237 The professional sport played by x -1 The professional sport played by Julius Erving basketball Julius Erving "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Boston' ' Celtics' '.' '\n' '\n']" ", the NBA 's all - time leading scorer , was a member of the Boston Celtics .
+
+" False following season, Julius Erving announced his retirement 5 [' following', ' season', ',', ' Julius', ' Er', 'ving']
+905 237 The professional sport played by x -1 The professional sport played by Julius Erving basketball Julius Erving "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Boston' ' Celtics' '.' '\n' '\n']" ", the NBA 's all - time leading scorer , was a member of the Boston Celtics .
+
+" False 76ers guard Julius Erving began to strangle 5 [' 76', 'ers', ' guard', ' Julius', ' Er', 'ving']
+906 237 The professional sport played by x -1 The professional sport played by Julius Erving basketball Julius Erving "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Boston' ' Celtics' '.' '\n' '\n']" ", the NBA 's all - time leading scorer , was a member of the Boston Celtics .
+
+" False 76ers star Julius Erving portrayed himself 5 [' 76', 'ers', ' star', ' Julius', ' Er', 'ving']
+907 237 The professional sport played by x -1 The professional sport played by Julius Erving basketball Julius Erving "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Boston' ' Celtics' '.' '\n' '\n']" ", the NBA 's all - time leading scorer , was a member of the Boston Celtics .
+
+" False held New York's star Julius Erving to 16 points but 7 "[' held', ' New', ' York', ""'s"", ' star', ' Julius', ' Er', 'ving']"
+908 237 The professional sport played by x -1 The professional sport played by Julius Erving basketball Julius Erving "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Boston' ' Celtics' '.' '\n' '\n']" ", the NBA 's all - time leading scorer , was a member of the Boston Celtics .
+
+" False Philadelphia 76ers star Julius Erving portrayed himself 6 [' Philadelphia', ' 76', 'ers', ' star', ' Julius', ' Er', 'ving']
+909 238 The professional sport played by x -1 The professional sport played by Maurice Richard hockey Maurice Richard "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False All-Star team, Maurice Richard had scored 6 [' All', '-', 'Star', ' team', ',', ' Maurice', ' Richard']
+910 238 The professional sport played by x -1 The professional sport played by Maurice Richard hockey Maurice Richard "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False 3 ['M', 'aur', 'ice', ' Richard']
+911 238 The professional sport played by x -1 The professional sport played by Maurice Richard hockey Maurice Richard "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False surpassed his brother Maurice Richard with his 127th Stanley 4 [' surpassed', ' his', ' brother', ' Maurice', ' Richard']
+912 238 The professional sport played by x -1 The professional sport played by Maurice Richard hockey Maurice Richard "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False statistics: Maurice Richard playing card, 3 [' statistics', ':', ' Maurice', ' Richard']
+913 238 The professional sport played by x -1 The professional sport played by Maurice Richard hockey Maurice Richard "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' Montreal' ',' ' Quebec' ',' ' Canada' '.' ' He']" , the NHL 's all - time leading scorer , was born in Montreal , Quebec , Canada . He False Joseph Henri Maurice Richard was born August 4, 3 [' Joseph', ' Henri', ' Maurice', ' Richard']
+914 239 The professional sport played by x -1 The professional sport played by Cuauhtémoc Blanco soccer Cuauhtémoc Blanco "[',' ' a' ' Mexican' ' professional' ' soccer' ' player' '.' '\n' '\n'
+ 'The' ' name' ' of' ' the' ' game' ' is' ' to' ' get' ' the' ' ball'
+ ' into']" ", a Mexican professional soccer player .
+
+ The name of the game is to get the ball into" True Rule. Players such as Cuauhtémoc Blanco (Chicago Fire) and 12 [' Rule', '.', ' Players', ' such', ' as', ' Cu', 'au', 'ht', 'é', 'm', 'oc', ' Bl', 'anco']
+915 240 The professional sport played by x -1 The professional sport played by Rick Barry basketball Rick Barry "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NBA star , is a great example of the kind of person who can make a difference False Musburger and Rick Barry for the 1976 NBA 5 [' Mus', 'bur', 'ger', ' and', ' Rick', ' Barry']
+916 240 The professional sport played by x -1 The professional sport played by Rick Barry basketball Rick Barry "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NBA star , is a great example of the kind of person who can make a difference False Brent Musburger and Rick Barry for the 1976 6 [' Brent', ' Mus', 'bur', 'ger', ' and', ' Rick', ' Barry']
+917 240 The professional sport played by x -1 The professional sport played by Rick Barry basketball Rick Barry "[',' ' the' ' former' ' NBA' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , the former NBA star , is a great example of the kind of person who can make a difference False Musburger and Rick Barry for the 1976 5 [' Mus', 'bur', 'ger', ' and', ' Rick', ' Barry']
+918 241 The professional sport played by x -1 The professional sport played by Ivan Hlinka hockey Ivan Hlinka "[',' ' the' ' Czech' ' Republic' ',' ' and' ' the' ' Czech' ' Republic'
+ '.' '\n' '\n' 'The' ' Czech' ' Republic' ' is' ' a' ' country' ' in'
+ ' Central']" ", the Czech Republic , and the Czech Republic .
+
+ The Czech Republic is a country in Central" False at the 2007 Ivan Hlinka Memorial Tournament, 6 [' at', ' the', ' 2007', ' Ivan', ' H', 'link', 'a']
+919 241 The professional sport played by x -1 The professional sport played by Ivan Hlinka hockey Ivan Hlinka "[',' ' the' ' Czech' ' Republic' ',' ' and' ' the' ' Czech' ' Republic'
+ '.' '\n' '\n' 'The' ' Czech' ' Republic' ' is' ' a' ' country' ' in'
+ ' Central']" ", the Czech Republic , and the Czech Republic .
+
+ The Czech Republic is a country in Central" False under-18 team at the Ivan Hlinka Memorial Tournament. 9 [' under', '-', '18', ' team', ' at', ' the', ' Ivan', ' H', 'link', 'a']
+920 241 The professional sport played by x -1 The professional sport played by Ivan Hlinka hockey Ivan Hlinka "[',' ' the' ' Czech' ' Republic' ',' ' and' ' the' ' Czech' ' Republic'
+ '.' '\n' '\n' 'The' ' Czech' ' Republic' ' is' ' a' ' country' ' in'
+ ' Central']" ", the Czech Republic , and the Czech Republic .
+
+ The Czech Republic is a country in Central" False Czech goaltender, but Ivan Hlinka deliberately 7 [' Czech', ' goaltender', ',', ' but', ' Ivan', ' H', 'link', 'a']
+921 241 The professional sport played by x -1 The professional sport played by Ivan Hlinka hockey Ivan Hlinka "[',' ' the' ' Czech' ' Republic' ',' ' and' ' the' ' Czech' ' Republic'
+ '.' '\n' '\n' 'The' ' Czech' ' Republic' ' is' ' a' ' country' ' in'
+ ' Central']" ", the Czech Republic , and the Czech Republic .
+
+ The Czech Republic is a country in Central" False at the 2007 Ivan Hlinka Memorial Tournament, 6 [' at', ' the', ' 2007', ' Ivan', ' H', 'link', 'a']
+922 241 The professional sport played by x -1 The professional sport played by Ivan Hlinka hockey Ivan Hlinka "[',' ' the' ' Czech' ' Republic' ',' ' and' ' the' ' Czech' ' Republic'
+ '.' '\n' '\n' 'The' ' Czech' ' Republic' ' is' ' a' ' country' ' in'
+ ' Central']" ", the Czech Republic , and the Czech Republic .
+
+ The Czech Republic is a country in Central" False competed in the 2008 Ivan Hlinka Memorial Tournament, 7 [' competed', ' in', ' the', ' 2008', ' Ivan', ' H', 'link', 'a']
+923 242 The professional sport played by x -1 The professional sport played by Jason Kidd basketball Jason Kidd "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NFL' ' player' ',' ' and' ' the' ' other']" , who is a former NBA player , and the other is a former NFL player , and the other False disappointed that Jason Kidd was no longer available 3 [' disappointed', ' that', ' Jason', ' Kidd']
+924 242 The professional sport played by x -1 The professional sport played by Jason Kidd basketball Jason Kidd "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NFL' ' player' ',' ' and' ' the' ' other']" , who is a former NBA player , and the other is a former NFL player , and the other False four times involving Jason Kidd (while playing 4 [' four', ' times', ' involving', ' Jason', ' Kidd']
+925 242 The professional sport played by x -1 The professional sport played by Jason Kidd basketball Jason Kidd "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NFL' ' player' ',' ' and' ' the' ' other']" , who is a former NBA player , and the other is a former NFL player , and the other False four times involving Jason Kidd (while playing 4 [' four', ' times', ' involving', ' Jason', ' Kidd']
+926 242 The professional sport played by x -1 The professional sport played by Jason Kidd basketball Jason Kidd "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NFL' ' player' ',' ' and' ' the' ' other']" , who is a former NBA player , and the other is a former NFL player , and the other False disappointed that Jason Kidd was no longer available 3 [' disappointed', ' that', ' Jason', ' Kidd']
+927 242 The professional sport played by x -1 The professional sport played by Jason Kidd basketball Jason Kidd "[',' ' who' ' is' ' a' ' former' ' NBA' ' player' ',' ' and' ' the'
+ ' other' ' is' ' a' ' former' ' NFL' ' player' ',' ' and' ' the' ' other']" , who is a former NBA player , and the other is a former NFL player , and the other False disappointed that Jason Kidd was no longer available 3 [' disappointed', ' that', ' Jason', ' Kidd']
+928 245 The professional sport played by x -1 The professional sport played by Guy Lafleur hockey Guy Lafleur "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False team captain. Guy Lafleur remained the team's 6 [' team', ' captain', '.', ' Guy', ' Laf', 'le', 'ur']
+929 245 The professional sport played by x -1 The professional sport played by Guy Lafleur hockey Guy Lafleur "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False Bobby Hull, Guy Lafleur and Marcel 6 [' Bobby', ' Hull', ',', ' Guy', ' Laf', 'le', 'ur']
+930 245 The professional sport played by x -1 The professional sport played by Guy Lafleur hockey Guy Lafleur "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False Bernie Geoffrion, Guy Lafleur and Marcel Bonin. 8 [' Bernie', ' Geoff', 'r', 'ion', ',', ' Guy', ' Laf', 'le', 'ur']
+931 245 The professional sport played by x -1 The professional sport played by Guy Lafleur hockey Guy Lafleur "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False stars such as Guy Lafleur or Wayne Gretzky. 6 [' stars', ' such', ' as', ' Guy', ' Laf', 'le', 'ur']
+932 245 The professional sport played by x -1 The professional sport played by Guy Lafleur hockey Guy Lafleur "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' member' ' of' ' the' ' Montreal' ' Canadiens' '.' '\n' '\n']" ", the NHL 's all - time leading scorer , was a member of the Montreal Canadiens .
+
+" False Gretzky's pairing with Guy Lafleur was highly anticipated 9 "[' Gret', 'z', 'ky', ""'s"", ' pairing', ' with', ' Guy', ' Laf', 'le', 'ur']"
+933 246 The professional sport played by x -1 The professional sport played by Alexi Lalas soccer Alexi Lalas "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True The shift allowed Alexi Lalas to play as a sweeper 6 [' The', ' shift', ' allowed', ' Alex', 'i', ' Lal', 'as']
+934 246 The professional sport played by x -1 The professional sport played by Alexi Lalas soccer Alexi Lalas "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True shift allowed Alexi Lalas to play as 5 [' shift', ' allowed', ' Alex', 'i', ' Lal', 'as']
+935 246 The professional sport played by x -1 The professional sport played by Alexi Lalas soccer Alexi Lalas "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True season. The shift allowed Alexi Lalas to play as a sweeper 8 [' season', '.', ' The', ' shift', ' allowed', ' Alex', 'i', ' Lal', 'as']
+936 246 The professional sport played by x -1 The professional sport played by Alexi Lalas soccer Alexi Lalas "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can'
+ ' make' ' a']" , a former professional soccer player , is a great example of the kind of person who can make a True shift allowed Alexi Lalas to play as a 5 [' shift', ' allowed', ' Alex', 'i', ' Lal', 'as']
+937 247 The professional sport played by x -1 The professional sport played by Kobe Bryant basketball Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False Michael Jordan and Kobe Bryant act when it 4 [' Michael', ' Jordan', ' and', ' Kobe', ' Bryant']
+938 247 The professional sport played by x -1 The professional sport played by Kobe Bryant basketball Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False to replace Kobe Bryant in the 2014 3 [' to', ' replace', ' Kobe', ' Bryant']
+939 247 The professional sport played by x -1 The professional sport played by Kobe Bryant basketball Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False nights later, Kobe Bryant praised him for 4 [' nights', ' later', ',', ' Kobe', ' Bryant']
+940 247 The professional sport played by x -1 The professional sport played by Kobe Bryant basketball Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False " === 1996 – 2016: The Kobe Bryant era ===
+" 7 [' ===', ' 1996', ' –', ' 2016', ':', ' The', ' Kobe', ' Bryant']
+941 247 The professional sport played by x -1 The professional sport played by Kobe Bryant basketball Kobe Bryant "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False the Lakers, he and Kobe Bryant spoke to one another 6 [' the', ' Lakers', ',', ' he', ' and', ' Kobe', ' Bryant']
+942 248 The professional sport played by x -1 The professional sport played by Dwight Howard basketball Dwight Howard "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ' center' ','
+ ' is' ' a' ' big' ' fan' ' of' ' the' ' game' '.' '\n']" ", who is a former NBA All - Star center , is a big fan of the game .
+" False class both Dwight Howard (2004) and Oden (2006) 3 [' class', ' both', ' Dwight', ' Howard']
+943 248 The professional sport played by x -1 The professional sport played by Dwight Howard basketball Dwight Howard "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ' center' ','
+ ' is' ' a' ' big' ' fan' ' of' ' the' ' game' '.' '\n']" ", who is a former NBA All - Star center , is a big fan of the game .
+" False 3 ['D', 'w', 'ight', ' Howard']
+944 248 The professional sport played by x -1 The professional sport played by Dwight Howard basketball Dwight Howard "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ' center' ','
+ ' is' ' a' ' big' ' fan' ' of' ' the' ' game' '.' '\n']" ", who is a former NBA All - Star center , is a big fan of the game .
+" False James's 2003 class both Dwight Howard (2004) and Oden (2006) 6 "[' James', ""'s"", ' 2003', ' class', ' both', ' Dwight', ' Howard']"
+945 248 The professional sport played by x -1 The professional sport played by Dwight Howard basketball Dwight Howard "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ' center' ','
+ ' is' ' a' ' big' ' fan' ' of' ' the' ' game' '.' '\n']" ", who is a former NBA All - Star center , is a big fan of the game .
+" False backup to center Dwight Howard as Team USA went 4 [' backup', ' to', ' center', ' Dwight', ' Howard']
+946 248 The professional sport played by x -1 The professional sport played by Dwight Howard basketball Dwight Howard "[',' ' who' ' is' ' a' ' former' ' NBA' ' All' '-' 'Star' ' center' ','
+ ' is' ' a' ' big' ' fan' ' of' ' the' ' game' '.' '\n']" ", who is a former NBA All - Star center , is a big fan of the game .
+" False Shaquille O 'Neal, Dwight Howard and other poor free 8 "[' Sha', 'qu', 'ille', ' O', "" '"", 'Neal', ',', ' Dwight', ' Howard']"
+947 249 The professional sport played by x -1 The professional sport played by Clayton Kershaw baseball Clayton Kershaw "[',' ' who' ' has' ' been' ' on' ' the' ' disabled' ' list' ' since'
+ ' June' ' with' ' a' ' strained' ' left' ' ob' 'lique' ' muscle' '.' '\n'
+ '\n']" ", who has been on the disabled list since June with a strained left ob lique muscle .
+
+" False 2014, Dodgers ace Clayton Kershaw pitched 41 consecutive 6 [' 2014', ',', ' Dodgers', ' ace', ' Clayton', ' Kers', 'haw']
+948 249 The professional sport played by x -1 The professional sport played by Clayton Kershaw baseball Clayton Kershaw "[',' ' who' ' has' ' been' ' on' ' the' ' disabled' ' list' ' since'
+ ' June' ' with' ' a' ' strained' ' left' ' ob' 'lique' ' muscle' '.' '\n'
+ '\n']" ", who has been on the disabled list since June with a strained left ob lique muscle .
+
+" False matched up with Clayton Kershaw less than a year 5 [' matched', ' up', ' with', ' Clayton', ' Kers', 'haw']
+949 249 The professional sport played by x -1 The professional sport played by Clayton Kershaw baseball Clayton Kershaw "[',' ' who' ' has' ' been' ' on' ' the' ' disabled' ' list' ' since'
+ ' June' ' with' ' a' ' strained' ' left' ' ob' 'lique' ' muscle' '.' '\n'
+ '\n']" ", who has been on the disabled list since June with a strained left ob lique muscle .
+
+" False he outdueled Clayton Kershaw in 6 2 ⁄ 3 innings 6 [' he', ' out', 'du', 'eled', ' Clayton', ' Kers', 'haw']
+950 249 The professional sport played by x -1 The professional sport played by Clayton Kershaw baseball Clayton Kershaw "[',' ' who' ' has' ' been' ' on' ' the' ' disabled' ' list' ' since'
+ ' June' ' with' ' a' ' strained' ' left' ' ob' 'lique' ' muscle' '.' '\n'
+ '\n']" ", who has been on the disabled list since June with a strained left ob lique muscle .
+
+" False paired up and beat Clayton Kershaw for the third 6 [' paired', ' up', ' and', ' beat', ' Clayton', ' Kers', 'haw']
+951 249 The professional sport played by x -1 The professional sport played by Clayton Kershaw baseball Clayton Kershaw "[',' ' who' ' has' ' been' ' on' ' the' ' disabled' ' list' ' since'
+ ' June' ' with' ' a' ' strained' ' left' ' ob' 'lique' ' muscle' '.' '\n'
+ '\n']" ", who has been on the disabled list since June with a strained left ob lique muscle .
+
+" False matched up with Clayton Kershaw less than a year 5 [' matched', ' up', ' with', ' Clayton', ' Kers', 'haw']
+952 250 The professional sport played by x -1 The professional sport played by Dwyane Wade basketball Dwyane Wade "[',' ' LeBron' ' James' ',' ' and' ' the' ' Miami' ' Heat' '.' '\n' '\n'
+ 'The' ' NBA' ' Finals' ' are' ' over' ',' ' and' ' the' ' Miami']" ", LeBron James , and the Miami Heat .
+
+ The NBA Finals are over , and the Miami" False 30 was tense as Dwyane Wade reportedly yelled 7 [' 30', ' was', ' tense', ' as', ' Dw', 'y', 'ane', ' Wade']
+953 250 The professional sport played by x -1 The professional sport played by Dwyane Wade basketball Dwyane Wade "[',' ' LeBron' ' James' ',' ' and' ' the' ' Miami' ' Heat' '.' '\n' '\n'
+ 'The' ' NBA' ' Finals' ' are' ' over' ',' ' and' ' the' ' Miami']" ", LeBron James , and the Miami Heat .
+
+ The NBA Finals are over , and the Miami" False included former champion Dwyane Wade as well as a 6 [' included', ' former', ' champion', ' Dw', 'y', 'ane', ' Wade']
+954 250 The professional sport played by x -1 The professional sport played by Dwyane Wade basketball Dwyane Wade "[',' ' LeBron' ' James' ',' ' and' ' the' ' Miami' ' Heat' '.' '\n' '\n'
+ 'The' ' NBA' ' Finals' ' are' ' over' ',' ' and' ' the' ' Miami']" ", LeBron James , and the Miami Heat .
+
+ The NBA Finals are over , and the Miami" False partnered with Dwyane Wade and Joshua DuBois 5 [' partnered', ' with', ' Dw', 'y', 'ane', ' Wade']
+955 250 The professional sport played by x -1 The professional sport played by Dwyane Wade basketball Dwyane Wade "[',' ' LeBron' ' James' ',' ' and' ' the' ' Miami' ' Heat' '.' '\n' '\n'
+ 'The' ' NBA' ' Finals' ' are' ' over' ',' ' and' ' the' ' Miami']" ", LeBron James , and the Miami Heat .
+
+ The NBA Finals are over , and the Miami" False partnered with Dwyane Wade and Joshua DuBois 5 [' partnered', ' with', ' Dw', 'y', 'ane', ' Wade']
+956 250 The professional sport played by x -1 The professional sport played by Dwyane Wade basketball Dwyane Wade "[',' ' LeBron' ' James' ',' ' and' ' the' ' Miami' ' Heat' '.' '\n' '\n'
+ 'The' ' NBA' ' Finals' ' are' ' over' ',' ' and' ' the' ' Miami']" ", LeBron James , and the Miami Heat .
+
+ The NBA Finals are over , and the Miami" False Chris Bosh and Dwyane Wade had also announced 6 [' Chris', ' Bosh', ' and', ' Dw', 'y', 'ane', ' Wade']
+957 251 The professional sport played by x -1 The professional sport played by George Plimpton football George Plimpton "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False popularized by George Plimpton with his books 6 [' popular', 'ized', ' by', ' George', ' Pl', 'im', 'pton']
+958 251 The professional sport played by x -1 The professional sport played by George Plimpton football George Plimpton "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False popularized by George Plimpton with his books 6 [' popular', 'ized', ' by', ' George', ' Pl', 'im', 'pton']
+959 251 The professional sport played by x -1 The professional sport played by George Plimpton football George Plimpton "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False Paris Review founders George Plimpton and John Train, and 6 [' Paris', ' Review', ' founders', ' George', ' Pl', 'im', 'pton']
+960 251 The professional sport played by x -1 The professional sport played by George Plimpton football George Plimpton "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False Review founders George Plimpton and John Train, 5 [' Review', ' founders', ' George', ' Pl', 'im', 'pton']
+961 251 The professional sport played by x -1 The professional sport played by George Plimpton football George Plimpton "[',' ' the' ' former' ' NFL' ' quarterback' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL quarterback , is a great example of the kind of person who can make a difference False Review founders George Plimpton and John Train, 5 [' Review', ' founders', ' George', ' Pl', 'im', 'pton']
+962 252 The professional sport played by x -1 The professional sport played by Yu Darvish baseball Yu Darvish "[',' ' the' ' Japanese' ' pitcher' ' who' ' has' ' been' ' the' ' subject'
+ ' of' ' trade' ' rumors' ' for' ' the' ' past' ' few' ' weeks' ',' ' is'
+ ' now']" , the Japanese pitcher who has been the subject of trade rumors for the past few weeks , is now False season, after Yu Darvish and Yusmeiro Petit. 6 [' season', ',', ' after', ' Yu', ' Dar', 'v', 'ish']
+963 252 The professional sport played by x -1 The professional sport played by Yu Darvish baseball Yu Darvish "[',' ' the' ' Japanese' ' pitcher' ' who' ' has' ' been' ' the' ' subject'
+ ' of' ' trade' ' rumors' ' for' ' the' ' past' ' few' ' weeks' ',' ' is'
+ ' now']" , the Japanese pitcher who has been the subject of trade rumors for the past few weeks , is now False 2013 season, after Yu Darvish and Yusmeiro Petit. 7 [' 2013', ' season', ',', ' after', ' Yu', ' Dar', 'v', 'ish']
+964 252 The professional sport played by x -1 The professional sport played by Yu Darvish baseball Yu Darvish "[',' ' the' ' Japanese' ' pitcher' ' who' ' has' ' been' ' the' ' subject'
+ ' of' ' trade' ' rumors' ' for' ' the' ' past' ' few' ' weeks' ',' ' is'
+ ' now']" , the Japanese pitcher who has been the subject of trade rumors for the past few weeks , is now False 2013 season, after Yu Darvish and Yusmeiro Petit. 7 [' 2013', ' season', ',', ' after', ' Yu', ' Dar', 'v', 'ish']
+965 253 The professional sport played by x -1 The professional sport played by Hope Solo soccer Hope Solo "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' team'
+ ' goalkeeper' ',' ' has' ' been' ' suspended' ' for' ' the' ' remainder']" , the U . S . women � � s national team goalkeeper , has been suspended for the remainder False Abby Wambach and Hope Solo with salaries well 6 [' Abby', ' W', 'amb', 'ach', ' and', ' Hope', ' Solo']
+966 253 The professional sport played by x -1 The professional sport played by Hope Solo soccer Hope Solo "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' team'
+ ' goalkeeper' ',' ' has' ' been' ' suspended' ' for' ' the' ' remainder']" , the U . S . women � � s national team goalkeeper , has been suspended for the remainder False Morgan and Hope Solo in a Bank of America 3 [' Morgan', ' and', ' Hope', ' Solo']
+967 253 The professional sport played by x -1 The professional sport played by Hope Solo soccer Hope Solo "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' team'
+ ' goalkeeper' ',' ' has' ' been' ' suspended' ' for' ' the' ' remainder']" , the U . S . women � � s national team goalkeeper , has been suspended for the remainder False teammates Alex Morgan and Hope Solo in a Bank of America 5 [' teammates', ' Alex', ' Morgan', ' and', ' Hope', ' Solo']
+968 253 The professional sport played by x -1 The professional sport played by Hope Solo soccer Hope Solo "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' team'
+ ' goalkeeper' ',' ' has' ' been' ' suspended' ' for' ' the' ' remainder']" , the U . S . women � � s national team goalkeeper , has been suspended for the remainder False Alex Morgan and Hope Solo in a Bank of 4 [' Alex', ' Morgan', ' and', ' Hope', ' Solo']
+969 253 The professional sport played by x -1 The professional sport played by Hope Solo soccer Hope Solo "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' team'
+ ' goalkeeper' ',' ' has' ' been' ' suspended' ' for' ' the' ' remainder']" , the U . S . women � � s national team goalkeeper , has been suspended for the remainder False Abby Wambach and Hope Solo with salaries well 6 [' Abby', ' W', 'amb', 'ach', ' and', ' Hope', ' Solo']
+970 254 The professional sport played by x -1 The professional sport played by Giovani dos Santos soccer Giovani dos Santos "[',' ' who' ' has' ' been' ' a' ' revelation' ' for' ' the' ' Galaxy'
+ ' this' ' season' '.' '\n' '\n' 'The' ' Galaxy' ' have' ' been' ' a'
+ ' bit']" ", who has been a revelation for the Galaxy this season .
+
+ The Galaxy have been a bit" False Ayew of Ghana and Giovani dos Santos of Mexico. For 8 [' Ay', 'ew', ' of', ' Ghana', ' and', ' Giov', 'ani', ' dos', ' Santos']
+971 254 The professional sport played by x -1 The professional sport played by Giovani dos Santos soccer Giovani dos Santos "[',' ' who' ' has' ' been' ' a' ' revelation' ' for' ' the' ' Galaxy'
+ ' this' ' season' '.' '\n' '\n' 'The' ' Galaxy' ' have' ' been' ' a'
+ ' bit']" ", who has been a revelation for the Galaxy this season .
+
+ The Galaxy have been a bit" False Ghana and Giovani dos Santos of Mexico. For both 5 [' Ghana', ' and', ' Giov', 'ani', ' dos', ' Santos']
+972 256 The professional sport played by x -1 The professional sport played by Casey Stengel baseball Casey Stengel "[',' ' the' ' New' ' York' ' Yankees' ',' ' and' ' the' ' New' ' York'
+ ' Mets' '.' '\n' '\n' 'The' ' Yankees' ' are' ' the' ' most'
+ ' successful']" ", the New York Yankees , and the New York Mets .
+
+ The Yankees are the most successful" False Yankees brought in Casey Stengel to manage. At the 6 [' Yankees', ' brought', ' in', ' Casey', ' St', 'eng', 'el']
+973 256 The professional sport played by x -1 The professional sport played by Casey Stengel baseball Casey Stengel "[',' ' the' ' New' ' York' ' Yankees' ',' ' and' ' the' ' New' ' York'
+ ' Mets' '.' '\n' '\n' 'The' ' Yankees' ' are' ' the' ' most'
+ ' successful']" ", the New York Yankees , and the New York Mets .
+
+ The Yankees are the most successful" False Yankees brought in Casey Stengel to manage. At the 6 [' Yankees', ' brought', ' in', ' Casey', ' St', 'eng', 'el']
+974 256 The professional sport played by x -1 The professional sport played by Casey Stengel baseball Casey Stengel "[',' ' the' ' New' ' York' ' Yankees' ',' ' and' ' the' ' New' ' York'
+ ' Mets' '.' '\n' '\n' 'The' ' Yankees' ' are' ' the' ' most'
+ ' successful']" ", the New York Yankees , and the New York Mets .
+
+ The Yankees are the most successful" False League manager Casey Stengel added him, Lemon, 5 [' League', ' manager', ' Casey', ' St', 'eng', 'el']
+975 256 The professional sport played by x -1 The professional sport played by Casey Stengel baseball Casey Stengel "[',' ' the' ' New' ' York' ' Yankees' ',' ' and' ' the' ' New' ' York'
+ ' Mets' '.' '\n' '\n' 'The' ' Yankees' ' are' ' the' ' most'
+ ' successful']" ", the New York Yankees , and the New York Mets .
+
+ The Yankees are the most successful" False " prompting Yankees manager Casey Stengel to comment, ""Well," 6 [' prompting', ' Yankees', ' manager', ' Casey', ' St', 'eng', 'el']
+976 256 The professional sport played by x -1 The professional sport played by Casey Stengel baseball Casey Stengel "[',' ' the' ' New' ' York' ' Yankees' ',' ' and' ' the' ' New' ' York'
+ ' Mets' '.' '\n' '\n' 'The' ' Yankees' ' are' ' the' ' most'
+ ' successful']" ", the New York Yankees , and the New York Mets .
+
+ The Yankees are the most successful" False Yankees brought in Casey Stengel to manage. At the time, 6 [' Yankees', ' brought', ' in', ' Casey', ' St', 'eng', 'el']
+977 257 The professional sport played by x -1 The professional sport played by Walter Payton football Walter Payton "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False receptions, and on a Walter Payton run. The Buccaneers'Williams 6 [' receptions', ',', ' and', ' on', ' a', ' Walter', ' Payton']
+978 257 The professional sport played by x -1 The professional sport played by Walter Payton football Walter Payton "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False he played alongside Walter Payton for three years, who 4 [' he', ' played', ' alongside', ' Walter', ' Payton']
+979 257 The professional sport played by x -1 The professional sport played by Walter Payton football Walter Payton "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False 2 ['Wal', 'ter', ' Payton']
+980 257 The professional sport played by x -1 The professional sport played by Walter Payton football Walter Payton "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False Throughout his life Walter Payton had claimed his date 4 [' Throughout', ' his', ' life', ' Walter', ' Payton']
+981 257 The professional sport played by x -1 The professional sport played by Walter Payton football Walter Payton "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False Reed attended Walter Payton College Prep 3 [' Reed', ' attended', ' Walter', ' Payton']
+982 258 The professional sport played by x -1 The professional sport played by Rosey Grier football Rosey Grier "[',' ' the' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , the former NFL player , is a game that is played by two teams of two players . The False football player Rosey Grier of the Los Angeles 5 [' football', ' player', ' Rose', 'y', ' G', 'rier']
+983 258 The professional sport played by x -1 The professional sport played by Rosey Grier football Rosey Grier "[',' ' the' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , the former NFL player , is a game that is played by two teams of two players . The False football player Rosey Grier of the Los Angeles 5 [' football', ' player', ' Rose', 'y', ' G', 'rier']
+984 258 The professional sport played by x -1 The professional sport played by Rosey Grier football Rosey Grier "[',' ' the' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , the former NFL player , is a game that is played by two teams of two players . The False players Troy Aikman, Rosey Grier and Dan Marino. All 9 [' players', ' Troy', ' A', 'ik', 'man', ',', ' Rose', 'y', ' G', 'rier']
+985 258 The professional sport played by x -1 The professional sport played by Rosey Grier football Rosey Grier "[',' ' the' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , the former NFL player , is a game that is played by two teams of two players . The False professional football player Rosey Grier of the Los 6 [' professional', ' football', ' player', ' Rose', 'y', ' G', 'rier']
+986 258 The professional sport played by x -1 The professional sport played by Rosey Grier football Rosey Grier "[',' ' the' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , the former NFL player , is a game that is played by two teams of two players . The False football player Rosey Grier of the Los Angeles 5 [' football', ' player', ' Rose', 'y', ' G', 'rier']
+987 259 The professional sport played by x -1 The professional sport played by Raimo Helminen hockey Raimo Helminen "[',' ' a' ' Finnish' ' professional' ' ice' ' hockey' ' player' '.' ' He'
+ ' is' ' currently' ' playing' ' for' ' the' ' San' ' Jose' ' Sharks'
+ ' of' ' the' ' National']" , a Finnish professional ice hockey player . He is currently playing for the San Jose Sharks of the National True Phil Housley, Finn Raimo Helminen and Czechoslovak 11 [' Phil', ' H', 'ous', 'ley', ',', ' Finn', ' R', 'aim', 'o', ' Hel', 'min', 'en']
+988 260 The professional sport played by x -1 The professional sport played by Carmelo Anthony basketball Carmelo Anthony "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False joining fellow players Carmelo Anthony and Chris Paul, 5 [' joining', ' fellow', ' players', ' Carm', 'elo', ' Anthony']
+989 260 The professional sport played by x -1 The professional sport played by Carmelo Anthony basketball Carmelo Anthony "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False Antawn Jamison, Carmelo Anthony and Dwyane 7 [' Ant', 'awn', ' Jam', 'ison', ',', ' Carm', 'elo', ' Anthony']
+990 260 The professional sport played by x -1 The professional sport played by Carmelo Anthony basketball Carmelo Anthony "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False in the brawl. Carmelo Anthony was also criticized 6 [' in', ' the', ' brawl', '.', ' Carm', 'elo', ' Anthony']
+991 260 The professional sport played by x -1 The professional sport played by Carmelo Anthony basketball Carmelo Anthony "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False to an end, Carmelo Anthony confronted Collins 6 [' to', ' an', ' end', ',', ' Carm', 'elo', ' Anthony']
+992 260 The professional sport played by x -1 The professional sport played by Carmelo Anthony basketball Carmelo Anthony "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' big' ' fan' ' of' ' the' ' game' '.' ' He' ""'s""]" , the NBA 's all - time leading scorer , is a big fan of the game . He 's False Antawn Jamison, Carmelo Anthony and Dwyane 7 [' Ant', 'awn', ' Jam', 'ison', ',', ' Carm', 'elo', ' Anthony']
+993 261 The professional sport played by x -1 The professional sport played by Bill Goldberg football Bill Goldberg "[',' ' the' ' former' ' professional' ' wrestler' ',' ' is' ' a' ' great'
+ ' example' ' of' ' a' ' man' ' who' ' has' ' been' ' able' ' to' ' make'
+ ' a']" , the former professional wrestler , is a great example of a man who has been able to make a False Clark, and facing Bill Goldberg and Keiji Mutoh 5 [' Clark', ',', ' and', ' facing', ' Bill', ' Goldberg']
+994 262 The professional sport played by x -1 The professional sport played by Nani soccer Nani "['wa' ',' ' who' ' is' ' a' ' former' ' professional' ' player' ',' ' is'
+ ' a' ' very' ' good' ' example' ' of' ' this' '.' ' He' ' is' ' a']" wa , who is a former professional player , is a very good example of this . He is a False September 2013, Nani renewed his contract 4 [' September', ' 2013', ',', ' N', 'ani']
+995 262 The professional sport played by x -1 The professional sport played by Nani soccer Nani "['wa' ',' ' who' ' is' ' a' ' former' ' professional' ' player' ',' ' is'
+ ' a' ' very' ' good' ' example' ' of' ' this' '.' ' He' ' is' ' a']" wa , who is a former professional player , is a very good example of this . He is a False 1 ['N', 'ani']
+996 262 The professional sport played by x -1 The professional sport played by Nani soccer Nani "['wa' ',' ' who' ' is' ' a' ' former' ' professional' ' player' ',' ' is'
+ ' a' ' very' ' good' ' example' ' of' ' this' '.' ' He' ' is' ' a']" wa , who is a former professional player , is a very good example of this . He is a False and Daniel Ho for He Nani, Amy Hanaiali 'i for 6 [' and', ' Daniel', ' Ho', ' for', ' He', ' N', 'ani']
+997 262 The professional sport played by x -1 The professional sport played by Nani soccer Nani "['wa' ',' ' who' ' is' ' a' ' former' ' professional' ' player' ',' ' is'
+ ' a' ' very' ' good' ' example' ' of' ' this' '.' ' He' ' is' ' a']" wa , who is a former professional player , is a very good example of this . He is a False commonly known as Nani (Portuguese 4 [' commonly', ' known', ' as', ' N', 'ani']
+998 262 The professional sport played by x -1 The professional sport played by Nani soccer Nani "['wa' ',' ' who' ' is' ' a' ' former' ' professional' ' player' ',' ' is'
+ ' a' ' very' ' good' ' example' ' of' ' this' '.' ' He' ' is' ' a']" wa , who is a former professional player , is a very good example of this . He is a False On 6 July 2015, Nani joined Turkish 6 [' On', ' 6', ' July', ' 2015', ',', ' N', 'ani']
+999 263 The professional sport played by x -1 The professional sport played by Pau Gasol basketball Pau Gasol "[',' ' the' ' NBA' ""'s"" ' best' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' a' ' team' ' of' ' two' '.' ' The']" , the NBA 's best player , is a game that is played by a team of two . The False " Gasol =
+" 6 [' Gas', 'ol', ' =', 'P', 'au', ' Gas', 'ol']
+1000 263 The professional sport played by x -1 The professional sport played by Pau Gasol basketball Pau Gasol "[',' ' the' ' NBA' ""'s"" ' best' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' a' ' team' ' of' ' two' '.' ' The']" , the NBA 's best player , is a game that is played by a team of two . The False 3 ['P', 'au', ' Gas', 'ol']
+1001 263 The professional sport played by x -1 The professional sport played by Pau Gasol basketball Pau Gasol "[',' ' the' ' NBA' ""'s"" ' best' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' a' ' team' ' of' ' two' '.' ' The']" , the NBA 's best player , is a game that is played by a team of two . The False 3 ['P', 'au', ' Gas', 'ol']
+1002 263 The professional sport played by x -1 The professional sport played by Pau Gasol basketball Pau Gasol "[',' ' the' ' NBA' ""'s"" ' best' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' a' ' team' ' of' ' two' '.' ' The']" , the NBA 's best player , is a game that is played by a team of two . The False " Gasol =
+" 6 [' Gas', 'ol', ' =', 'P', 'au', ' Gas', 'ol']
+1003 263 The professional sport played by x -1 The professional sport played by Pau Gasol basketball Pau Gasol "[',' ' the' ' NBA' ""'s"" ' best' ' player' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' a' ' team' ' of' ' two' '.' ' The']" , the NBA 's best player , is a game that is played by a team of two . The False Virgin Islands), Pau Gasol (Spain), Dirk 5 [' Virgin', ' Islands', '),', ' Pau', ' Gas', 'ol']
+1004 264 The professional sport played by x -1 The professional sport played by Sidney Crosby hockey Sidney Crosby "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False younger than when Sidney Crosby was named captain of 4 [' younger', ' than', ' when', ' Sidney', ' Crosby']
+1005 264 The professional sport played by x -1 The professional sport played by Sidney Crosby hockey Sidney Crosby "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False on the ice when Sidney Crosby scored the tournament-winning 5 [' on', ' the', ' ice', ' when', ' Sidney', ' Crosby']
+1006 264 The professional sport played by x -1 The professional sport played by Sidney Crosby hockey Sidney Crosby "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False selected by the NHL. Sidney Crosby received the 6 [' selected', ' by', ' the', ' NHL', '.', ' Sidney', ' Crosby']
+1007 264 The professional sport played by x -1 The professional sport played by Sidney Crosby hockey Sidney Crosby "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False both Getzlaf and Sidney Crosby failed to capitalize 6 [' both', ' Get', 'zl', 'af', ' and', ' Sidney', ' Crosby']
+1008 264 The professional sport played by x -1 The professional sport played by Sidney Crosby hockey Sidney Crosby "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False to as the Sidney Crosby Lottery or the 4 [' to', ' as', ' the', ' Sidney', ' Crosby']
+1009 266 The professional sport played by x -1 The professional sport played by Frank Mahovlich hockey Frank Mahovlich "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' five' ' players' ' each' '.']" , a former NHL player , is a game that is played by two teams of five players each . False Brad Park and Frank Mahovlich also criticized the 6 [' Brad', ' Park', ' and', ' Frank', ' Mah', 'ov', 'lich']
+1010 266 The professional sport played by x -1 The professional sport played by Frank Mahovlich hockey Frank Mahovlich "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' five' ' players' ' each' '.']" , a former NHL player , is a game that is played by two teams of five players each . False cast. In December, a Frank Mahovlich check caused 9 [' cast', '.', ' In', ' December', ',', ' a', ' Frank', ' Mah', 'ov', 'lich']
+1011 266 The professional sport played by x -1 The professional sport played by Frank Mahovlich hockey Frank Mahovlich "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' five' ' players' ' each' '.']" , a former NHL player , is a game that is played by two teams of five players each . False stole the puck from Frank Mahovlich and broke away and 7 [' stole', ' the', ' puck', ' from', ' Frank', ' Mah', 'ov', 'lich']
+1012 266 The professional sport played by x -1 The professional sport played by Frank Mahovlich hockey Frank Mahovlich "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' five' ' players' ' each' '.']" , a former NHL player , is a game that is played by two teams of five players each . False " and conditioning. Frank Mahovlich said ""give the" 6 [' and', ' conditioning', '.', ' Frank', ' Mah', 'ov', 'lich']
+1013 266 The professional sport played by x -1 The professional sport played by Frank Mahovlich hockey Frank Mahovlich "[',' ' a' ' former' ' NHL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' five' ' players' ' each' '.']" , a former NHL player , is a game that is played by two teams of five players each . False Brad Park and Frank Mahovlich also criticized 6 [' Brad', ' Park', ' and', ' Frank', ' Mah', 'ov', 'lich']
+1014 267 The professional sport played by x -1 The professional sport played by Marco Ureña soccer Marco Ureña "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' five'
+ ' players' ' each']" , a former professional soccer player , is a game that is played by two teams of five players each True " with Costa Rica's Marco Ureña three goals.
+" 7 "[' with', ' Costa', ' Rica', ""'s"", ' Marco', ' U', 're', 'ña']"
+1015 267 The professional sport played by x -1 The professional sport played by Marco Ureña soccer Marco Ureña "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' five'
+ ' players' ' each']" , a former professional soccer player , is a game that is played by two teams of five players each True " with Costa Rica's Marco Ureña three goals.
+" 7 "[' with', ' Costa', ' Rica', ""'s"", ' Marco', ' U', 're', 'ña']"
+1016 269 The professional sport played by x -1 The professional sport played by Tony Dungy football Tony Dungy "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' coach' ' can' ' be' ' a' ' great' ' leader' '.']" , the former NFL coach , is a great example of how a coach can be a great leader . False 112 for the Bears. Tony Dungy is the third 7 [' 112', ' for', ' the', ' Bears', '.', ' Tony', ' Dun', 'gy']
+1017 269 The professional sport played by x -1 The professional sport played by Tony Dungy football Tony Dungy "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' coach' ' can' ' be' ' a' ' great' ' leader' '.']" , the former NFL coach , is a great example of how a coach can be a great leader . False Colts coach Tony Dungy confirmed this 4 [' Colts', ' coach', ' Tony', ' Dun', 'gy']
+1018 269 The professional sport played by x -1 The professional sport played by Tony Dungy football Tony Dungy "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' coach' ' can' ' be' ' a' ' great' ' leader' '.']" , the former NFL coach , is a great example of how a coach can be a great leader . False 112 for the Bears. Tony Dungy is the third 7 [' 112', ' for', ' the', ' Bears', '.', ' Tony', ' Dun', 'gy']
+1019 269 The professional sport played by x -1 The professional sport played by Tony Dungy football Tony Dungy "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' coach' ' can' ' be' ' a' ' great' ' leader' '.']" , the former NFL coach , is a great example of how a coach can be a great leader . False and the Colts' Tony Dungy both became the 6 "[' and', ' the', ' Colts', ""'"", ' Tony', ' Dun', 'gy']"
+1020 269 The professional sport played by x -1 The professional sport played by Tony Dungy football Tony Dungy "[',' ' the' ' former' ' NFL' ' coach' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' coach' ' can' ' be' ' a' ' great' ' leader' '.']" , the former NFL coach , is a great example of how a coach can be a great leader . False Nevertheless, Tony Dungy made the decision 4 [' Nevertheless', ',', ' Tony', ' Dun', 'gy']
+1021 271 The professional sport played by x -1 The professional sport played by Phil Jackson basketball Phil Jackson "[',' ' the' ' coach' ' of' ' the' ' Los' ' Angeles' ' Lakers' ',' ' is'
+ ' a' ' basketball' ' coach' ' who' ' has' ' won' ' five' ' NBA'
+ ' championships' '.']" , the coach of the Los Angeles Lakers , is a basketball coach who has won five NBA championships . True head coach Phil Jackson returned for the 3 [' head', ' coach', ' Phil', ' Jackson']
+1022 271 The professional sport played by x -1 The professional sport played by Phil Jackson basketball Phil Jackson "[',' ' the' ' coach' ' of' ' the' ' Los' ' Angeles' ' Lakers' ',' ' is'
+ ' a' ' basketball' ' coach' ' who' ' has' ' won' ' five' ' NBA'
+ ' championships' '.']" , the coach of the Los Angeles Lakers , is a basketball coach who has won five NBA championships . True Head coach Phil Jackson and several 3 [' Head', ' coach', ' Phil', ' Jackson']
+1023 271 The professional sport played by x -1 The professional sport played by Phil Jackson basketball Phil Jackson "[',' ' the' ' coach' ' of' ' the' ' Los' ' Angeles' ' Lakers' ',' ' is'
+ ' a' ' basketball' ' coach' ' who' ' has' ' won' ' five' ' NBA'
+ ' championships' '.']" , the coach of the Los Angeles Lakers , is a basketball coach who has won five NBA championships . True West coach, Phil Jackson decided to send 4 [' West', ' coach', ',', ' Phil', ' Jackson']
+1024 271 The professional sport played by x -1 The professional sport played by Phil Jackson basketball Phil Jackson "[',' ' the' ' coach' ' of' ' the' ' Los' ' Angeles' ' Lakers' ',' ' is'
+ ' a' ' basketball' ' coach' ' who' ' has' ' won' ' five' ' NBA'
+ ' championships' '.']" , the coach of the Los Angeles Lakers , is a basketball coach who has won five NBA championships . True Lakers'coach Phil Jackson to be more assertive 5 "[' Lakers', ""'"", 'co', 'ach', ' Phil', ' Jackson']"
+1025 271 The professional sport played by x -1 The professional sport played by Phil Jackson basketball Phil Jackson "[',' ' the' ' coach' ' of' ' the' ' Los' ' Angeles' ' Lakers' ',' ' is'
+ ' a' ' basketball' ' coach' ' who' ' has' ' won' ' five' ' NBA'
+ ' championships' '.']" , the coach of the Los Angeles Lakers , is a basketball coach who has won five NBA championships . True by Lakers'coach Phil Jackson to be more assertive 6 "[' by', ' Lakers', ""'"", 'co', 'ach', ' Phil', ' Jackson']"
+1026 272 The professional sport played by x -1 The professional sport played by Len Ford football Len Ford "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' original' ' Broadway' ' production' ' of' ' the' ' musical' ','
+ ' and' ' who']" , who was a member of the original cast of the original Broadway production of the musical , and who False " Len Ford =
+" 1 [' Len', ' Ford']
+1027 272 The professional sport played by x -1 The professional sport played by Len Ford football Len Ford "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' original' ' Broadway' ' production' ' of' ' the' ' musical' ','
+ ' and' ' who']" , who was a member of the original cast of the original Broadway production of the musical , and who False " = Len Ford =
+" 2 [' =', ' Len', ' Ford']
+1028 272 The professional sport played by x -1 The professional sport played by Len Ford football Len Ford "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' original' ' cast' ' of'
+ ' the' ' original' ' Broadway' ' production' ' of' ' the' ' musical' ','
+ ' and' ' who']" , who was a member of the original cast of the original Broadway production of the musical , and who False " = Len Ford =
+" 2 [' =', ' Len', ' Ford']
+1029 273 The professional sport played by x -1 The professional sport played by Colin Kaepernick football Colin Kaepernick "[',' ' the' ' quarterback' ' for' ' the' ' San' ' Francisco' ' 49' 'ers'
+ ',' ' has' ' been' ' a' ' polar' 'izing' ' figure' ' in' ' the' ' NFL'
+ ' for']" , the quarterback for the San Francisco 49 ers , has been a polar izing figure in the NFL for False the passing game, Colin Kaepernick achieved a 5 [' the', ' passing', ' game', ',', ' Colin', ' Kaepernick']
+1030 273 The professional sport played by x -1 The professional sport played by Colin Kaepernick football Colin Kaepernick "[',' ' the' ' quarterback' ' for' ' the' ' San' ' Francisco' ' 49' 'ers'
+ ',' ' has' ' been' ' a' ' polar' 'izing' ' figure' ' in' ' the' ' NFL'
+ ' for']" , the quarterback for the San Francisco 49 ers , has been a polar izing figure in the NFL for False the passing game, Colin Kaepernick achieved a 51.1 5 [' the', ' passing', ' game', ',', ' Colin', ' Kaepernick']
+1031 273 The professional sport played by x -1 The professional sport played by Colin Kaepernick football Colin Kaepernick "[',' ' the' ' quarterback' ' for' ' the' ' San' ' Francisco' ' 49' 'ers'
+ ',' ' has' ' been' ' a' ' polar' 'izing' ' figure' ' in' ' the' ' NFL'
+ ' for']" , the quarterback for the San Francisco 49 ers , has been a polar izing figure in the NFL for False of the Year Colin Kaepernick ran for more than 4 [' of', ' the', ' Year', ' Colin', ' Kaepernick']
+1032 273 The professional sport played by x -1 The professional sport played by Colin Kaepernick football Colin Kaepernick "[',' ' the' ' quarterback' ' for' ' the' ' San' ' Francisco' ' 49' 'ers'
+ ',' ' has' ' been' ' a' ' polar' 'izing' ' figure' ' in' ' the' ' NFL'
+ ' for']" , the quarterback for the San Francisco 49 ers , has been a polar izing figure in the NFL for False 1, broken by Colin Kaepernick (181) in 2013. He 5 [' 1', ',', ' broken', ' by', ' Colin', ' Kaepernick']
+1033 273 The professional sport played by x -1 The professional sport played by Colin Kaepernick football Colin Kaepernick "[',' ' the' ' quarterback' ' for' ' the' ' San' ' Francisco' ' 49' 'ers'
+ ',' ' has' ' been' ' a' ' polar' 'izing' ' figure' ' in' ' the' ' NFL'
+ ' for']" , the quarterback for the San Francisco 49 ers , has been a polar izing figure in the NFL for False Nevada quarterback Colin Kaepernick set the Humanitarian 3 [' Nevada', ' quarterback', ' Colin', ' Kaepernick']
+1034 274 The professional sport played by x -1 The professional sport played by Gordie Howe hockey Gordie Howe "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' the' ' city' ' of' ' Sask' 'atoon' ',' ' Saskatchewan']" , the NHL 's all - time leading scorer , was born in the city of Sask atoon , Saskatchewan False ankle from a Gordie Howe slap shot. Despite 5 [' ankle', ' from', ' a', ' Gord', 'ie', ' Howe']
+1035 274 The professional sport played by x -1 The professional sport played by Gordie Howe hockey Gordie Howe "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' the' ' city' ' of' ' Sask' 'atoon' ',' ' Saskatchewan']" , the NHL 's all - time leading scorer , was born in the city of Sask atoon , Saskatchewan False " fixing."" The family of Gordie Howe also commented" 7 "[' fixing', '.""', ' The', ' family', ' of', ' Gord', 'ie', ' Howe']"
+1036 274 The professional sport played by x -1 The professional sport played by Gordie Howe hockey Gordie Howe "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' the' ' city' ' of' ' Sask' 'atoon' ',' ' Saskatchewan']" , the NHL 's all - time leading scorer , was born in the city of Sask atoon , Saskatchewan False Wayne Gretzky and Gordie Howe for a Saskatoon 7 [' Wayne', ' Gret', 'z', 'ky', ' and', ' Gord', 'ie', ' Howe']
+1037 274 The professional sport played by x -1 The professional sport played by Gordie Howe hockey Gordie Howe "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' the' ' city' ' of' ' Sask' 'atoon' ',' ' Saskatchewan']" , the NHL 's all - time leading scorer , was born in the city of Sask atoon , Saskatchewan False " to Detroit's Gordie Howe overall.
+" 5 "[' to', ' Detroit', ""'s"", ' Gord', 'ie', ' Howe']"
+1038 274 The professional sport played by x -1 The professional sport played by Gordie Howe hockey Gordie Howe "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' born' ' in' ' the' ' city' ' of' ' Sask' 'atoon' ',' ' Saskatchewan']" , the NHL 's all - time leading scorer , was born in the city of Sask atoon , Saskatchewan False Crosby recorded a Gordie Howe hat trick on 5 [' Crosby', ' recorded', ' a', ' Gord', 'ie', ' Howe']
+1039 275 The professional sport played by x -1 The professional sport played by Joe Namath football Joe Namath "[',' ' the' ' quarterback' ' of' ' the' ' New' ' York' ' Jets' ',' ' was'
+ ' a' ' big' ' hit' ' with' ' the' ' fans' '.' '\n' '\n' 'The']" ", the quarterback of the New York Jets , was a big hit with the fans .
+
+ The" False Ewbank and quarterback Joe Namath led the Jets to 7 [' E', 'w', 'bank', ' and', ' quarterback', ' Joe', ' Nam', 'ath']
+1040 275 The professional sport played by x -1 The professional sport played by Joe Namath football Joe Namath "[',' ' the' ' quarterback' ' of' ' the' ' New' ' York' ' Jets' ',' ' was'
+ ' a' ' big' ' hit' ' with' ' the' ' fans' '.' '\n' '\n' 'The']" ", the quarterback of the New York Jets , was a big hit with the fans .
+
+ The" False many records, Joe Namath and Len Dawson 5 [' many', ' records', ',', ' Joe', ' Nam', 'ath']
+1041 275 The professional sport played by x -1 The professional sport played by Joe Namath football Joe Namath "[',' ' the' ' quarterback' ' of' ' the' ' New' ' York' ' Jets' ',' ' was'
+ ' a' ' big' ' hit' ' with' ' the' ' fans' '.' '\n' '\n' 'The']" ", the quarterback of the New York Jets , was a big hit with the fans .
+
+ The" False Alabama passer Joe Namath rejected the 4 [' Alabama', ' passer', ' Joe', ' Nam', 'ath']
+1042 275 The professional sport played by x -1 The professional sport played by Joe Namath football Joe Namath "[',' ' the' ' quarterback' ' of' ' the' ' New' ' York' ' Jets' ',' ' was'
+ ' a' ' big' ' hit' ' with' ' the' ' fans' '.' '\n' '\n' 'The']" ", the quarterback of the New York Jets , was a big hit with the fans .
+
+ The" False by beating Joe Namath and the New York 4 [' by', ' beating', ' Joe', ' Nam', 'ath']
+1043 275 The professional sport played by x -1 The professional sport played by Joe Namath football Joe Namath "[',' ' the' ' quarterback' ' of' ' the' ' New' ' York' ' Jets' ',' ' was'
+ ' a' ' big' ' hit' ' with' ' the' ' fans' '.' '\n' '\n' 'The']" ", the quarterback of the New York Jets , was a big hit with the fans .
+
+ The" False signed rookie Joe Namath to a then-record 4 [' signed', ' rookie', ' Joe', ' Nam', 'ath']
+1044 276 The professional sport played by x -1 The professional sport played by Patrick Ewing basketball Patrick Ewing "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False Malone (860), Patrick Ewing (834), and 6 [' Malone', ' (', '860', '),', ' Patrick', ' E', 'wing']
+1045 276 The professional sport played by x -1 The professional sport played by Patrick Ewing basketball Patrick Ewing "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False in NBA history for Patrick Ewing and Danny Ferry. 6 [' in', ' NBA', ' history', ' for', ' Patrick', ' E', 'wing']
+1046 276 The professional sport played by x -1 The professional sport played by Patrick Ewing basketball Patrick Ewing "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False championship against Patrick Ewing and the New York Knicks. 4 [' championship', ' against', ' Patrick', ' E', 'wing']
+1047 276 The professional sport played by x -1 The professional sport played by Patrick Ewing basketball Patrick Ewing "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False with stars such as Patrick Ewing and Dikembe 6 [' with', ' stars', ' such', ' as', ' Patrick', ' E', 'wing']
+1048 276 The professional sport played by x -1 The professional sport played by Patrick Ewing basketball Patrick Ewing "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of player who False Nuggets to surpass Patrick Ewing for sixth overall 5 [' Nuggets', ' to', ' surpass', ' Patrick', ' E', 'wing']
+1049 277 The professional sport played by x -1 The professional sport played by Merlin Olsen football Merlin Olsen "[',' ' the' ' former' ' NFL' ' linebacker' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL linebacker , is a great example of the kind of person who can make a difference False Bruce Matthews, and Merlin Olsen for most Pro Bowl selections 5 [' Bruce', ' Matthews', ',', ' and', ' Merlin', ' Olsen']
+1050 277 The professional sport played by x -1 The professional sport played by Merlin Olsen football Merlin Olsen "[',' ' the' ' former' ' NFL' ' linebacker' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL linebacker , is a great example of the kind of person who can make a difference False defensive tackle Merlin Olsen paid Youngblood 3 [' defensive', ' tackle', ' Merlin', ' Olsen']
+1051 277 The professional sport played by x -1 The professional sport played by Merlin Olsen football Merlin Olsen "[',' ' the' ' former' ' NFL' ' linebacker' ',' ' is' ' a' ' great'
+ ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make'
+ ' a' ' difference']" , the former NFL linebacker , is a great example of the kind of person who can make a difference False defensive tackle Merlin Olsen paid Youngblood 3 [' defensive', ' tackle', ' Merlin', ' Olsen']
+1052 278 The professional sport played by x -1 The professional sport played by Dan Marino football Dan Marino "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NFL 's all - time leading passer , is a great example of the kind of player who False the careers of Dan Marino and Jimmy Johnson 4 [' the', ' careers', ' of', ' Dan', ' Marino']
+1053 278 The professional sport played by x -1 The professional sport played by Dan Marino football Dan Marino "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NFL 's all - time leading passer , is a great example of the kind of player who False quarterback, after Dan Marino and Brett Favre, 4 [' quarterback', ',', ' after', ' Dan', ' Marino']
+1054 278 The professional sport played by x -1 The professional sport played by Dan Marino football Dan Marino "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NFL 's all - time leading passer , is a great example of the kind of player who False award, passing Dan Marino for the most all-time 4 [' award', ',', ' passing', ' Dan', ' Marino']
+1055 278 The professional sport played by x -1 The professional sport played by Dan Marino football Dan Marino "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NFL 's all - time leading passer , is a great example of the kind of player who False him ahead of Dan Marino (67) for the most in 4 [' him', ' ahead', ' of', ' Dan', ' Marino']
+1056 278 The professional sport played by x -1 The professional sport played by Dan Marino football Dan Marino "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NFL 's all - time leading passer , is a great example of the kind of player who False moving him ahead of Dan Marino (67) for the 5 [' moving', ' him', ' ahead', ' of', ' Dan', ' Marino']
+1057 279 The professional sport played by x -1 The professional sport played by Jim Thorpe baseball Jim Thorpe "[',' ' the' ' first' ' American' ' to' ' win' ' the' ' pent' 'athlon'
+ ' at' ' the' ' Olympics' '.' '\n' '\n' 'The' ' first' ' Olympic' ' Games'
+ ' were']" ", the first American to win the pent athlon at the Olympics .
+
+ The first Olympic Games were" False decathlon champion Jim Thorpe was stripped of his 5 [' dec', 'athlon', ' champion', ' Jim', ' Thor', 'pe']
+1058 279 The professional sport played by x -1 The professional sport played by Jim Thorpe baseball Jim Thorpe "[',' ' the' ' first' ' American' ' to' ' win' ' the' ' pent' 'athlon'
+ ' at' ' the' ' Olympics' '.' '\n' '\n' 'The' ' first' ' Olympic' ' Games'
+ ' were']" ", the first American to win the pent athlon at the Olympics .
+
+ The first Olympic Games were" False a finalist for the Jim Thorpe Award and the Bronko 7 [' a', ' final', 'ist', ' for', ' the', ' Jim', ' Thor', 'pe']
+1059 279 The professional sport played by x -1 The professional sport played by Jim Thorpe baseball Jim Thorpe "[',' ' the' ' first' ' American' ' to' ' win' ' the' ' pent' 'athlon'
+ ' at' ' the' ' Olympics' '.' '\n' '\n' 'The' ' first' ' Olympic' ' Games'
+ ' were']" ", the first American to win the pent athlon at the Olympics .
+
+ The first Olympic Games were" False Clinton-Dix for the Jim Thorpe Award; and Yeldon 8 [' Clinton', '-', 'D', 'ix', ' for', ' the', ' Jim', ' Thor', 'pe']
+1060 279 The professional sport played by x -1 The professional sport played by Jim Thorpe baseball Jim Thorpe "[',' ' the' ' first' ' American' ' to' ' win' ' the' ' pent' 'athlon'
+ ' at' ' the' ' Olympics' '.' '\n' '\n' 'The' ' first' ' Olympic' ' Games'
+ ' were']" ", the first American to win the pent athlon at the Olympics .
+
+ The first Olympic Games were" False officers, installing Jim Thorpe as president. Under 5 [' officers', ',', ' installing', ' Jim', ' Thor', 'pe']
+1061 279 The professional sport played by x -1 The professional sport played by Jim Thorpe baseball Jim Thorpe "[',' ' the' ' first' ' American' ' to' ' win' ' the' ' pent' 'athlon'
+ ' at' ' the' ' Olympics' '.' '\n' '\n' 'The' ' first' ' Olympic' ' Games'
+ ' were']" ", the first American to win the pent athlon at the Olympics .
+
+ The first Olympic Games were" False Reynolds received the Jim Thorpe Lifetime Achievement 5 [' Reynolds', ' received', ' the', ' Jim', ' Thor', 'pe']
+1062 280 The professional sport played by x -1 The professional sport played by Dave Winfield baseball Dave Winfield "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , a former NFL player , is a game that is played by two teams of two players . The False Wycombe and Dave Winfield from Shrewsbury 5 [' Wy', 'combe', ' and', ' Dave', ' Win', 'field']
+1063 280 The professional sport played by x -1 The professional sport played by Dave Winfield baseball Dave Winfield "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , a former NFL player , is a game that is played by two teams of two players . The False Wycombe and Dave Winfield from Shrewsbury 5 [' Wy', 'combe', ' and', ' Dave', ' Win', 'field']
+1064 280 The professional sport played by x -1 The professional sport played by Dave Winfield baseball Dave Winfield "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , a former NFL player , is a game that is played by two teams of two players . The False Wanderers and Dave Winfield from Shrewsbury 5 [' Wand', 'erers', ' and', ' Dave', ' Win', 'field']
+1065 280 The professional sport played by x -1 The professional sport played by Dave Winfield baseball Dave Winfield "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , a former NFL player , is a game that is played by two teams of two players . The False Wycombe Wanderers and Dave Winfield from Shrewsbury Town, 7 [' Wy', 'combe', ' Wand', 'erers', ' and', ' Dave', ' Win', 'field']
+1066 280 The professional sport played by x -1 The professional sport played by Dave Winfield baseball Dave Winfield "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' game' ' that' ' is'
+ ' played' ' by' ' two' ' teams' ' of' ' two' ' players' '.' ' The']" , a former NFL player , is a game that is played by two teams of two players . The False Series, joining Dave Winfield and Phil Nevin as 5 [' Series', ',', ' joining', ' Dave', ' Win', 'field']
+1067 281 The professional sport played by x -1 The professional sport played by John Matuszak football John Matuszak "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' of' ' skill' ' and' ' strategy' '.' ' The' ' object' ' of'
+ ' the' ' game']" , a former professional baseball player , is a game of skill and strategy . The object of the game False athletes such as John Matuszak and Gary Huff, 7 [' athletes', ' such', ' as', ' John', ' Mat', 'us', 'z', 'ak']
+1068 281 The professional sport played by x -1 The professional sport played by John Matuszak football John Matuszak "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' of' ' skill' ' and' ' strategy' '.' ' The' ' object' ' of'
+ ' the' ' game']" , a former professional baseball player , is a game of skill and strategy . The object of the game False athletes such as John Matuszak and Gary Huff, 7 [' athletes', ' such', ' as', ' John', ' Mat', 'us', 'z', 'ak']
+1069 281 The professional sport played by x -1 The professional sport played by John Matuszak football John Matuszak "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' of' ' skill' ' and' ' strategy' '.' ' The' ' object' ' of'
+ ' the' ' game']" , a former professional baseball player , is a game of skill and strategy . The object of the game False athletes such as John Matuszak and Gary Huff, 7 [' athletes', ' such', ' as', ' John', ' Mat', 'us', 'z', 'ak']
+1070 282 The professional sport played by x -1 The professional sport played by Wim Suurbier soccer Wim Suurbier "[',' ' a' ' Dutch' ' professional' ' footballer' ' who' ' plays' ' as'
+ ' a' ' midfielder' ' for' ' FC' ' Ut' 're' 'cht' '.' '\n' '\n' 'Care'
+ 'er']" ", a Dutch professional footballer who plays as a midfielder for FC Ut re cht .
+
+ Care er" False Ajax player Wim Suurbier was carried 5 [' Ajax', ' player', ' Wim', ' Su', 'urb', 'ier']
+1071 282 The professional sport played by x -1 The professional sport played by Wim Suurbier soccer Wim Suurbier "[',' ' a' ' Dutch' ' professional' ' footballer' ' who' ' plays' ' as'
+ ' a' ' midfielder' ' for' ' FC' ' Ut' 're' 'cht' '.' '\n' '\n' 'Care'
+ 'er']" ", a Dutch professional footballer who plays as a midfielder for FC Ut re cht .
+
+ Care er" False field. Ajax player Wim Suurbier was carried off 7 [' field', '.', ' Ajax', ' player', ' Wim', ' Su', 'urb', 'ier']
+1072 283 The professional sport played by x -1 The professional sport played by Yogi Berra baseball Yogi Berra "[',' ' the' ' catcher' ' for' ' the' ' New' ' York' ' Yankees' ',' ' was'
+ ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri' ',' ' on' ' this']" , the catcher for the New York Yankees , was born in St . Louis , Missouri , on this False refused to include Yogi Berra and the deal fell 6 [' refused', ' to', ' include', ' Yog', 'i', ' Ber', 'ra']
+1073 283 The professional sport played by x -1 The professional sport played by Yogi Berra baseball Yogi Berra "[',' ' the' ' catcher' ' for' ' the' ' New' ' York' ' Yankees' ',' ' was'
+ ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri' ',' ' on' ' this']" , the catcher for the New York Yankees , was born in St . Louis , Missouri , on this False by coaches Yogi Berra and Elston Howard 5 [' by', ' coaches', ' Yog', 'i', ' Ber', 'ra']
+1074 283 The professional sport played by x -1 The professional sport played by Yogi Berra baseball Yogi Berra "[',' ' the' ' catcher' ' for' ' the' ' New' ' York' ' Yankees' ',' ' was'
+ ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri' ',' ' on' ' this']" , the catcher for the New York Yankees , was born in St . Louis , Missouri , on this False War II veteran Yogi Berra to the mound, 6 [' War', ' II', ' veteran', ' Yog', 'i', ' Ber', 'ra']
+1075 283 The professional sport played by x -1 The professional sport played by Yogi Berra baseball Yogi Berra "[',' ' the' ' catcher' ' for' ' the' ' New' ' York' ' Yankees' ',' ' was'
+ ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri' ',' ' on' ' this']" , the catcher for the New York Yankees , was born in St . Louis , Missouri , on this False star catcher Yogi Berra due to car accidents, 5 [' star', ' catcher', ' Yog', 'i', ' Ber', 'ra']
+1076 283 The professional sport played by x -1 The professional sport played by Yogi Berra baseball Yogi Berra "[',' ' the' ' catcher' ' for' ' the' ' New' ' York' ' Yankees' ',' ' was'
+ ' born' ' in' ' St' '.' ' Louis' ',' ' Missouri' ',' ' on' ' this']" , the catcher for the New York Yankees , was born in St . Louis , Missouri , on this False 1, 2002. The Yogi Berra Museum and Learning 8 [' 1', ',', ' 2002', '.', ' The', ' Yog', 'i', ' Ber', 'ra']
+1077 284 The professional sport played by x -1 The professional sport played by Emanuel Pogatetz soccer Emanuel Pogatetz "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' of' ' skill' ' and' ' strategy' '.' ' The' ' game' ' is'
+ ' played' ' on']" , a former professional soccer player , is a game of skill and strategy . The game is played on True Chopra and an Emanuel Pogatetz own goal, thus 7 [' Chop', 'ra', ' and', ' an', ' Emanuel', ' Pog', 'at', 'etz']
+1078 284 The professional sport played by x -1 The professional sport played by Emanuel Pogatetz soccer Emanuel Pogatetz "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' of' ' skill' ' and' ' strategy' '.' ' The' ' game' ' is'
+ ' played' ' on']" , a former professional soccer player , is a game of skill and strategy . The game is played on True Chopra and an Emanuel Pogatetz own goal, thus securing 7 [' Chop', 'ra', ' and', ' an', ' Emanuel', ' Pog', 'at', 'etz']
+1079 285 The professional sport played by x -1 The professional sport played by Roman Reigns football Roman Reigns "[',' ' the' ' WWE' ' Super' 'star' ' who' ' is' ' currently' ' the' ' WWE'
+ ' World' ' Heavy' 'weight' ' Champion' '.' '\n' '\n' 'The' ' WWE' ' is']" ", the WWE Super star who is currently the WWE World Heavy weight Champion .
+
+ The WWE is" False challenge for Roman Reigns and Seth Rollins' 4 [' challenge', ' for', ' Roman', ' Reign', 's']
+1080 285 The professional sport played by x -1 The professional sport played by Roman Reigns football Roman Reigns "[',' ' the' ' WWE' ' Super' 'star' ' who' ' is' ' currently' ' the' ' WWE'
+ ' World' ' Heavy' 'weight' ' Champion' '.' '\n' '\n' 'The' ' WWE' ' is']" ", the WWE Super star who is currently the WWE World Heavy weight Champion .
+
+ The WWE is" False episode of Raw, Roman Reigns attacked Orton, 6 [' episode', ' of', ' Raw', ',', ' Roman', ' Reign', 's']
+1081 285 The professional sport played by x -1 The professional sport played by Roman Reigns football Roman Reigns "[',' ' the' ' WWE' ' Super' 'star' ' who' ' is' ' currently' ' the' ' WWE'
+ ' World' ' Heavy' 'weight' ' Champion' '.' '\n' '\n' 'The' ' WWE' ' is']" ", the WWE Super star who is currently the WWE World Heavy weight Champion .
+
+ The WWE is" False Ascension, while Roman Reigns won the Royal Rumble 5 [' Ascension', ',', ' while', ' Roman', ' Reign', 's']
+1082 285 The professional sport played by x -1 The professional sport played by Roman Reigns football Roman Reigns "[',' ' the' ' WWE' ' Super' 'star' ' who' ' is' ' currently' ' the' ' WWE'
+ ' World' ' Heavy' 'weight' ' Champion' '.' '\n' '\n' 'The' ' WWE' ' is']" ", the WWE Super star who is currently the WWE World Heavy weight Champion .
+
+ The WWE is" False was eliminated by Roman Reigns in the final 5 [' was', ' eliminated', ' by', ' Roman', ' Reign', 's']
+1083 285 The professional sport played by x -1 The professional sport played by Roman Reigns football Roman Reigns "[',' ' the' ' WWE' ' Super' 'star' ' who' ' is' ' currently' ' the' ' WWE'
+ ' World' ' Heavy' 'weight' ' Champion' '.' '\n' '\n' 'The' ' WWE' ' is']" ", the WWE Super star who is currently the WWE World Heavy weight Champion .
+
+ The WWE is" False episode of Raw, Roman Reigns attacked Orton, which 6 [' episode', ' of', ' Raw', ',', ' Roman', ' Reign', 's']
+1084 286 The professional sport played by x -1 The professional sport played by Steven Gerrard soccer Steven Gerrard "[',' ' the' ' Liverpool' ' captain' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' club' ' is'
+ ' looking' ' for']" , the Liverpool captain , is a great example of the kind of player that the club is looking for False Liverpool midfielder Steven Gerrard analysed the match 3 [' Liverpool', ' midfielder', ' Steven', ' Gerrard']
+1085 286 The professional sport played by x -1 The professional sport played by Steven Gerrard soccer Steven Gerrard "[',' ' the' ' Liverpool' ' captain' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' club' ' is'
+ ' looking' ' for']" , the Liverpool captain , is a great example of the kind of player that the club is looking for False " Former teammate Steven Gerrard said he was ""devastated""" 3 [' Former', ' teammate', ' Steven', ' Gerrard']
+1086 286 The professional sport played by x -1 The professional sport played by Steven Gerrard soccer Steven Gerrard "[',' ' the' ' Liverpool' ' captain' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' club' ' is'
+ ' looking' ' for']" , the Liverpool captain , is a great example of the kind of player that the club is looking for False Liverpool's midfielder Steven Gerrard on 29 April. For 4 "[' Liverpool', ""'s"", ' midfielder', ' Steven', ' Gerrard']"
+1087 286 The professional sport played by x -1 The professional sport played by Steven Gerrard soccer Steven Gerrard "[',' ' the' ' Liverpool' ' captain' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' club' ' is'
+ ' looking' ' for']" , the Liverpool captain , is a great example of the kind of player that the club is looking for False on the wings, while Steven Gerrard was deployed 6 [' on', ' the', ' wings', ',', ' while', ' Steven', ' Gerrard']
+1088 286 The professional sport played by x -1 The professional sport played by Steven Gerrard soccer Steven Gerrard "[',' ' the' ' Liverpool' ' captain' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' player' ' that' ' the' ' club' ' is'
+ ' looking' ' for']" , the Liverpool captain , is a great example of the kind of player that the club is looking for False fitness but, as Steven Gerrard was injured, 5 [' fitness', ' but', ',', ' as', ' Steven', ' Gerrard']
+1089 288 The professional sport played by x -1 The professional sport played by Ted Williams baseball Ted Williams "[',' ' the' ' former' ' Boston' ' Red' ' Sox' ' slug' 'ger' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can']" , the former Boston Red Sox slug ger , is a great example of the kind of person who can False 67, including Ted Williams Parkway, Twin 4 [' 67', ',', ' including', ' Ted', ' Williams']
+1090 288 The professional sport played by x -1 The professional sport played by Ted Williams baseball Ted Williams "[',' ' the' ' former' ' Boston' ' Red' ' Sox' ' slug' 'ger' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can']" , the former Boston Red Sox slug ger , is a great example of the kind of person who can False intensified in 1941 when Ted Williams of the Red Sox batted 5 [' intensified', ' in', ' 1941', ' when', ' Ted', ' Williams']
+1091 288 The professional sport played by x -1 The professional sport played by Ted Williams baseball Ted Williams "[',' ' the' ' former' ' Boston' ' Red' ' Sox' ' slug' 'ger' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can']" , the former Boston Red Sox slug ger , is a great example of the kind of person who can False younger, after Ted Williams and before Frank 4 [' younger', ',', ' after', ' Ted', ' Williams']
+1092 288 The professional sport played by x -1 The professional sport played by Ted Williams baseball Ted Williams "[',' ' the' ' former' ' Boston' ' Red' ' Sox' ' slug' 'ger' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can']" , the former Boston Red Sox slug ger , is a great example of the kind of person who can False SR 67, including Ted Williams Parkway, Twin 5 [' SR', ' 67', ',', ' including', ' Ted', ' Williams']
+1093 288 The professional sport played by x -1 The professional sport played by Ted Williams baseball Ted Williams "[',' ' the' ' former' ' Boston' ' Red' ' Sox' ' slug' 'ger' ',' ' is' ' a'
+ ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who' ' can']" , the former Boston Red Sox slug ger , is a great example of the kind of person who can False Hall of Famer Ted Williams as manager at 5 [' Hall', ' of', ' F', 'amer', ' Ted', ' Williams']
+1094 289 The professional sport played by x -1 The professional sport played by Abel Xavier soccer Abel Xavier "[',' ' a' ' former' ' professional' ' soccer' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional soccer player , is a game that is played by two teams of two players . True blocked by defender Abel Xavier and two of his efforts 4 [' blocked', ' by', ' defender', ' Abel', ' Xavier']
+1095 290 The professional sport played by x -1 The professional sport played by Wayne Rooney soccer Wayne Rooney "[',' ' the' ' English' ' footballer' ',' ' is' ' a' ' game' ' of' ' skill'
+ ',' ' not' ' luck' '.' '\n' '\n' 'The' ' game' ' is' ' played']" ", the English footballer , is a game of skill , not luck .
+
+ The game is played" False started alongside Wayne Rooney for the opening 3 [' started', ' alongside', ' Wayne', ' Rooney']
+1096 290 The professional sport played by x -1 The professional sport played by Wayne Rooney soccer Wayne Rooney "[',' ' the' ' English' ' footballer' ',' ' is' ' a' ' game' ' of' ' skill'
+ ',' ' not' ' luck' '.' '\n' '\n' 'The' ' game' ' is' ' played']" ", the English footballer , is a game of skill , not luck .
+
+ The game is played" False Sol Campbell on Wayne Rooney earned Manchester 4 [' Sol', ' Campbell', ' on', ' Wayne', ' Rooney']
+1097 290 The professional sport played by x -1 The professional sport played by Wayne Rooney soccer Wayne Rooney "[',' ' the' ' English' ' footballer' ',' ' is' ' a' ' game' ' of' ' skill'
+ ',' ' not' ' luck' '.' '\n' '\n' 'The' ' game' ' is' ' played']" ", the English footballer , is a game of skill , not luck .
+
+ The game is played" False a substitute for Wayne Rooney in a friendly against 4 [' a', ' substitute', ' for', ' Wayne', ' Rooney']
+1098 290 The professional sport played by x -1 The professional sport played by Wayne Rooney soccer Wayne Rooney "[',' ' the' ' English' ' footballer' ',' ' is' ' a' ' game' ' of' ' skill'
+ ',' ' not' ' luck' '.' '\n' '\n' 'The' ' game' ' is' ' played']" ", the English footballer , is a game of skill , not luck .
+
+ The game is played" False started alongside Wayne Rooney for the opening 3 [' started', ' alongside', ' Wayne', ' Rooney']
+1099 290 The professional sport played by x -1 The professional sport played by Wayne Rooney soccer Wayne Rooney "[',' ' the' ' English' ' footballer' ',' ' is' ' a' ' game' ' of' ' skill'
+ ',' ' not' ' luck' '.' '\n' '\n' 'The' ' game' ' is' ' played']" ", the English footballer , is a game of skill , not luck .
+
+ The game is played" False Stadium, replacing Wayne Rooney in the second 4 [' Stadium', ',', ' replacing', ' Wayne', ' Rooney']
+1100 291 The professional sport played by x -1 The professional sport played by Robbie Keane soccer Robbie Keane "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False Irishman, though Robbie Keane was closing 5 [' Irish', 'man', ',', ' though', ' Robbie', ' Keane']
+1101 291 The professional sport played by x -1 The professional sport played by Robbie Keane soccer Robbie Keane "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False assisted a goal for Robbie Keane as they went on 5 [' assisted', ' a', ' goal', ' for', ' Robbie', ' Keane']
+1102 291 The professional sport played by x -1 The professional sport played by Robbie Keane soccer Robbie Keane "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False Following the return of Robbie Keane at Tottenham, 5 [' Following', ' the', ' return', ' of', ' Robbie', ' Keane']
+1103 291 The professional sport played by x -1 The professional sport played by Robbie Keane soccer Robbie Keane "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False reasons David Beckham and Robbie Keane were left out of 5 [' reasons', ' David', ' Beckham', ' and', ' Robbie', ' Keane']
+1104 291 The professional sport played by x -1 The professional sport played by Robbie Keane soccer Robbie Keane "[',' ' who' ' has' ' been' ' a' ' regular' ' in' ' the' ' Premier'
+ ' League' ' for' ' the' ' past' ' decade' ',' ' has' ' been' ' a'
+ ' revelation' ' in']" , who has been a regular in the Premier League for the past decade , has been a revelation in False capped Irishman, though Robbie Keane was closing in on 6 [' capped', ' Irish', 'man', ',', ' though', ' Robbie', ' Keane']
+1105 292 The professional sport played by x -1 The professional sport played by Wayne Gretzky hockey Wayne Gretzky "[',' ' Mario' ' Lem' 'ie' 'ux' ',' ' and' ' Mario' ' Lem' 'ie' 'ux' '.'
+ '\n' '\n' 'The' ' NHL' ' is' ' a' ' professional' ' ice']" ", Mario Lem ie ux , and Mario Lem ie ux .
+
+ The NHL is a professional ice" False after his idol, Wayne Gretzky. After showing 7 [' after', ' his', ' idol', ',', ' Wayne', ' Gret', 'z', 'ky']
+1106 292 The professional sport played by x -1 The professional sport played by Wayne Gretzky hockey Wayne Gretzky "[',' ' Mario' ' Lem' 'ie' 'ux' ',' ' and' ' Mario' ' Lem' 'ie' 'ux' '.'
+ '\n' '\n' 'The' ' NHL' ' is' ' a' ' professional' ' ice']" ", Mario Lem ie ux , and Mario Lem ie ux .
+
+ The NHL is a professional ice" False " James ""the Wayne Gretzky of women's hockey"".
+" 6 "[' James', ' ""', 'the', ' Wayne', ' Gret', 'z', 'ky']"
+1107 292 The professional sport played by x -1 The professional sport played by Wayne Gretzky hockey Wayne Gretzky "[',' ' Mario' ' Lem' 'ie' 'ux' ',' ' and' ' Mario' ' Lem' 'ie' 'ux' '.'
+ '\n' '\n' 'The' ' NHL' ' is' ' a' ' professional' ' ice']" ", Mario Lem ie ux , and Mario Lem ie ux .
+
+ The NHL is a professional ice" False intentionally injuring Wayne Gretzky during a game 5 [' intentionally', ' injuring', ' Wayne', ' Gret', 'z', 'ky']
+1108 292 The professional sport played by x -1 The professional sport played by Wayne Gretzky hockey Wayne Gretzky "[',' ' Mario' ' Lem' 'ie' 'ux' ',' ' and' ' Mario' ' Lem' 'ie' 'ux' '.'
+ '\n' '\n' 'The' ' NHL' ' is' ' a' ' professional' ' ice']" ", Mario Lem ie ux , and Mario Lem ie ux .
+
+ The NHL is a professional ice" False 4 ['Way', 'ne', ' Gret', 'z', 'ky']
+1109 292 The professional sport played by x -1 The professional sport played by Wayne Gretzky hockey Wayne Gretzky "[',' ' Mario' ' Lem' 'ie' 'ux' ',' ' and' ' Mario' ' Lem' 'ie' 'ux' '.'
+ '\n' '\n' 'The' ' NHL' ' is' ' a' ' professional' ' ice']" ", Mario Lem ie ux , and Mario Lem ie ux .
+
+ The NHL is a professional ice" False " until broken by Wayne Gretzky in 1984.
+" 6 [' until', ' broken', ' by', ' Wayne', ' Gret', 'z', 'ky']
+1110 293 The professional sport played by x -1 The professional sport played by Rick Fox basketball Rick Fox "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' that']" , the NBA 's all - time leading scorer , is a great example of the kind of player that False Lakers traded Rick Fox and Gary Payton 3 [' Lakers', ' traded', ' Rick', ' Fox']
+1111 293 The professional sport played by x -1 The professional sport played by Rick Fox basketball Rick Fox "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' that']" , the NBA 's all - time leading scorer , is a great example of the kind of player that False Lakers player Rick Fox guest starred 3 [' Lakers', ' player', ' Rick', ' Fox']
+1112 293 The professional sport played by x -1 The professional sport played by Rick Fox basketball Rick Fox "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' that']" , the NBA 's all - time leading scorer , is a great example of the kind of player that False Lakers traded Rick Fox and Gary Payton 3 [' Lakers', ' traded', ' Rick', ' Fox']
+1113 293 The professional sport played by x -1 The professional sport played by Rick Fox basketball Rick Fox "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' that']" , the NBA 's all - time leading scorer , is a great example of the kind of player that False The Lakers traded Rick Fox and Gary Payton 4 [' The', ' Lakers', ' traded', ' Rick', ' Fox']
+1114 293 The professional sport played by x -1 The professional sport played by Rick Fox basketball Rick Fox "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' that']" , the NBA 's all - time leading scorer , is a great example of the kind of player that False The Lakers traded Rick Fox and Gary Payton 4 [' The', ' Lakers', ' traded', ' Rick', ' Fox']
+1115 294 The professional sport played by x -1 The professional sport played by Deion Sanders baseball Deion Sanders "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' person' ' can' ' be' ' a' ' great' ' athlete' ' and']" , the former NFL star , is a great example of how a person can be a great athlete and False who kept a Deion Sanders poster in his 5 [' who', ' kept', ' a', ' De', 'ion', ' Sanders']
+1116 294 The professional sport played by x -1 The professional sport played by Deion Sanders baseball Deion Sanders "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' person' ' can' ' be' ' a' ' great' ' athlete' ' and']" , the former NFL star , is a great example of how a person can be a great athlete and False befriended Deion Sanders through Ed Reed, 4 [' befriend', 'ed', ' De', 'ion', ' Sanders']
+1117 294 The professional sport played by x -1 The professional sport played by Deion Sanders baseball Deion Sanders "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' person' ' can' ' be' ' a' ' great' ' athlete' ' and']" , the former NFL star , is a great example of how a person can be a great athlete and False and who kept a Deion Sanders poster in his bedroom, 6 [' and', ' who', ' kept', ' a', ' De', 'ion', ' Sanders']
+1118 294 The professional sport played by x -1 The professional sport played by Deion Sanders baseball Deion Sanders "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' person' ' can' ' be' ' a' ' great' ' athlete' ' and']" , the former NFL star , is a great example of how a person can be a great athlete and False and who kept a Deion Sanders poster in his bedroom, 6 [' and', ' who', ' kept', ' a', ' De', 'ion', ' Sanders']
+1119 294 The professional sport played by x -1 The professional sport played by Deion Sanders baseball Deion Sanders "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' great' ' example'
+ ' of' ' how' ' a' ' person' ' can' ' be' ' a' ' great' ' athlete' ' and']" , the former NFL star , is a great example of how a person can be a great athlete and False NFL and who kept a Deion Sanders poster in his 7 [' NFL', ' and', ' who', ' kept', ' a', ' De', 'ion', ' Sanders']
+1120 295 The professional sport played by x -1 The professional sport played by Javier Hernández soccer Javier Hernández "[',' ' a' ' former' ' Major' ' League' ' Baseball' ' player' '.' '\n' '\n'
+ 'The' ' professional' ' sport' ' played' ' by' ' Javier' ' H' 'ern' 'á'
+ 'nd']" ", a former Major League Baseball player .
+
+ The professional sport played by Javier H ern á nd" False " Kaká."" While scout Javier Hernández wished for" 10 "[' Kak', 'á', '.""', ' While', ' scout', ' Javier', ' H', 'ern', 'á', 'nd', 'ez']"
+1121 295 The professional sport played by x -1 The professional sport played by Javier Hernández soccer Javier Hernández "[',' ' a' ' former' ' Major' ' League' ' Baseball' ' player' '.' '\n' '\n'
+ 'The' ' professional' ' sport' ' played' ' by' ' Javier' ' H' 'ern' 'á'
+ 'nd']" ", a former Major League Baseball player .
+
+ The professional sport played by Javier H ern á nd" False alongside teammate Javier Hernández on 8 April. 7 [' alongside', ' teammate', ' Javier', ' H', 'ern', 'á', 'nd', 'ez']
+1122 295 The professional sport played by x -1 The professional sport played by Javier Hernández soccer Javier Hernández "[',' ' a' ' former' ' Major' ' League' ' Baseball' ' player' '.' '\n' '\n'
+ 'The' ' professional' ' sport' ' played' ' by' ' Javier' ' H' 'ern' 'á'
+ 'nd']" ", a former Major League Baseball player .
+
+ The professional sport played by Javier H ern á nd" False alongside teammate Javier Hernández on 8 April. However, 7 [' alongside', ' teammate', ' Javier', ' H', 'ern', 'á', 'nd', 'ez']
+1123 295 The professional sport played by x -1 The professional sport played by Javier Hernández soccer Javier Hernández "[',' ' a' ' former' ' Major' ' League' ' Baseball' ' player' '.' '\n' '\n'
+ 'The' ' professional' ' sport' ' played' ' by' ' Javier' ' H' 'ern' 'á'
+ 'nd']" ", a former Major League Baseball player .
+
+ The professional sport played by Javier H ern á nd" False " Kaká."" While scout Javier Hernández wished for" 10 "[' Kak', 'á', '.""', ' While', ' scout', ' Javier', ' H', 'ern', 'á', 'nd', 'ez']"
+1124 295 The professional sport played by x -1 The professional sport played by Javier Hernández soccer Javier Hernández "[',' ' a' ' former' ' Major' ' League' ' Baseball' ' player' '.' '\n' '\n'
+ 'The' ' professional' ' sport' ' played' ' by' ' Javier' ' H' 'ern' 'á'
+ 'nd']" ", a former Major League Baseball player .
+
+ The professional sport played by Javier H ern á nd" False alongside teammate Javier Hernández on 8 April. 7 [' alongside', ' teammate', ' Javier', ' H', 'ern', 'á', 'nd', 'ez']
+1125 296 The professional sport played by x -1 The professional sport played by Frank Gifford football Frank Gifford "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' game' ' that' ' has'
+ ' been' ' played' ' for' ' centuries' '.' ' It' ' is' ' a' ' game']" , the former NFL star , is a game that has been played for centuries . It is a game False goal, Arledge informed Frank Gifford and Howard 8 [' goal', ',', ' Ar', 'ledge', ' informed', ' Frank', ' G', 'iff', 'ord']
+1126 296 The professional sport played by x -1 The professional sport played by Frank Gifford football Frank Gifford "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' game' ' that' ' has'
+ ' been' ' played' ' for' ' centuries' '.' ' It' ' is' ' a' ' game']" , the former NFL star , is a game that has been played for centuries . It is a game False with former Giants Frank Gifford and Tom Scott, who 6 [' with', ' former', ' Giants', ' Frank', ' G', 'iff', 'ord']
+1127 296 The professional sport played by x -1 The professional sport played by Frank Gifford football Frank Gifford "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' game' ' that' ' has'
+ ' been' ' played' ' for' ' centuries' '.' ' It' ' is' ' a' ' game']" , the former NFL star , is a game that has been played for centuries . It is a game False just second to Frank Gifford an NFL record for 6 [' just', ' second', ' to', ' Frank', ' G', 'iff', 'ord']
+1128 296 The professional sport played by x -1 The professional sport played by Frank Gifford football Frank Gifford "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' game' ' that' ' has'
+ ' been' ' played' ' for' ' centuries' '.' ' It' ' is' ' a' ' game']" , the former NFL star , is a game that has been played for centuries . It is a game False Detroit Lions, Frank Gifford and Sam Huff of the 6 [' Detroit', ' Lions', ',', ' Frank', ' G', 'iff', 'ord']
+1129 296 The professional sport played by x -1 The professional sport played by Frank Gifford football Frank Gifford "[',' ' the' ' former' ' NFL' ' star' ',' ' is' ' a' ' game' ' that' ' has'
+ ' been' ' played' ' for' ' centuries' '.' ' It' ' is' ' a' ' game']" , the former NFL star , is a game that has been played for centuries . It is a game False touchdown pass to Frank Gifford to take the lead, 6 [' touchdown', ' pass', ' to', ' Frank', ' G', 'iff', 'ord']
+1130 298 The professional sport played by x -1 The professional sport played by Hank Greenberg baseball Hank Greenberg "[',' ' the' ' former' ' Yankee' ' great' ',' ' was' ' a' ' great'
+ ' player' ',' ' but' ' he' ' was' ' a' ' terrible' ' manager' '.' ' He'
+ ' was']" , the former Yankee great , was a great player , but he was a terrible manager . He was False (184) and Hank Greenberg (183) ever 5 [' (', '184', ')', ' and', ' Hank', ' Greenberg']
+1131 298 The professional sport played by x -1 The professional sport played by Hank Greenberg baseball Hank Greenberg "[',' ' the' ' former' ' Yankee' ' great' ',' ' was' ' a' ' great'
+ ' player' ',' ' but' ' he' ' was' ' a' ' terrible' ' manager' '.' ' He'
+ ' was']" , the former Yankee great , was a great player , but he was a terrible manager . He was False best seasons. Only Hank Greenberg (58), Shawn Green 5 [' best', ' seasons', '.', ' Only', ' Hank', ' Greenberg']
+1132 298 The professional sport played by x -1 The professional sport played by Hank Greenberg baseball Hank Greenberg "[',' ' the' ' former' ' Yankee' ' great' ',' ' was' ' a' ' great'
+ ' player' ',' ' but' ' he' ' was' ' a' ' terrible' ' manager' '.' ' He'
+ ' was']" , the former Yankee great , was a great player , but he was a terrible manager . He was False Baseball Almanac, Hank Greenberg was the first 5 [' Baseball', ' Al', 'manac', ',', ' Hank', ' Greenberg']
+1133 298 The professional sport played by x -1 The professional sport played by Hank Greenberg baseball Hank Greenberg "[',' ' the' ' former' ' Yankee' ' great' ',' ' was' ' a' ' great'
+ ' player' ',' ' but' ' he' ' was' ' a' ' terrible' ' manager' '.' ' He'
+ ' was']" , the former Yankee great , was a great player , but he was a terrible manager . He was False in the 2008 Hank Greenberg 75th Anniversary 4 [' in', ' the', ' 2008', ' Hank', ' Greenberg']
+1134 298 The professional sport played by x -1 The professional sport played by Hank Greenberg baseball Hank Greenberg "[',' ' the' ' former' ' Yankee' ' great' ',' ' was' ' a' ' great'
+ ' player' ',' ' but' ' he' ' was' ' a' ' terrible' ' manager' '.' ' He'
+ ' was']" , the former Yankee great , was a great player , but he was a terrible manager . He was False Baseball Almanac, Hank Greenberg was the first with 5 [' Baseball', ' Al', 'manac', ',', ' Hank', ' Greenberg']
+1135 299 The professional sport played by x -1 The professional sport played by Viacheslav Fetisov hockey Viacheslav Fetisov "[',' ' the' ' greatest' ' hockey' ' player' ' of' ' all' ' time' '.' '\n'
+ '\n' 'The' ' Russian' ' hockey' ' player' ',' ' who' ' is' ' now' ' a']" ", the greatest hockey player of all time .
+
+ The Russian hockey player , who is now a" True two Soviet stars – Viacheslav Fetisov and Sergei Starikov. 9 [' two', ' Soviet', ' stars', ' –', ' Vi', 'aches', 'lav', ' Fet', 'is', 'ov']
+1136 299 The professional sport played by x -1 The professional sport played by Viacheslav Fetisov hockey Viacheslav Fetisov "[',' ' the' ' greatest' ' hockey' ' player' ' of' ' all' ' time' '.' '\n'
+ '\n' 'The' ' Russian' ' hockey' ' player' ',' ' who' ' is' ' now' ' a']" ", the greatest hockey player of all time .
+
+ The Russian hockey player , who is now a" True Red Army teammate Viacheslav Fetisov used a company 8 [' Red', ' Army', ' teammate', ' Vi', 'aches', 'lav', ' Fet', 'is', 'ov']
+1137 299 The professional sport played by x -1 The professional sport played by Viacheslav Fetisov hockey Viacheslav Fetisov "[',' ' the' ' greatest' ' hockey' ' player' ' of' ' all' ' time' '.' '\n'
+ '\n' 'The' ' Russian' ' hockey' ' player' ',' ' who' ' is' ' now' ' a']" ", the greatest hockey player of all time .
+
+ The Russian hockey player , who is now a" True Soviet stars – Viacheslav Fetisov and Sergei Starikov. 8 [' Soviet', ' stars', ' –', ' Vi', 'aches', 'lav', ' Fet', 'is', 'ov']
+1138 299 The professional sport played by x -1 The professional sport played by Viacheslav Fetisov hockey Viacheslav Fetisov "[',' ' the' ' greatest' ' hockey' ' player' ' of' ' all' ' time' '.' '\n'
+ '\n' 'The' ' Russian' ' hockey' ' player' ',' ' who' ' is' ' now' ' a']" ", the greatest hockey player of all time .
+
+ The Russian hockey player , who is now a" True Army teammate Viacheslav Fetisov used a company of 7 [' Army', ' teammate', ' Vi', 'aches', 'lav', ' Fet', 'is', 'ov']
+1139 299 The professional sport played by x -1 The professional sport played by Viacheslav Fetisov hockey Viacheslav Fetisov "[',' ' the' ' greatest' ' hockey' ' player' ' of' ' all' ' time' '.' '\n'
+ '\n' 'The' ' Russian' ' hockey' ' player' ',' ' who' ' is' ' now' ' a']" ", the greatest hockey player of all time .
+
+ The Russian hockey player , who is now a" True Soviet stars – Viacheslav Fetisov and Sergei Starikov. 8 [' Soviet', ' stars', ' –', ' Vi', 'aches', 'lav', ' Fet', 'is', 'ov']
+1140 300 The professional sport played by x -1 The professional sport played by Mario Lemieux hockey Mario Lemieux "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NHL 's all - time leading scorer , is a game that has been played for over a False Penguins' owner Mario Lemieux and sat with former 6 "[' Penguins', ""'"", ' owner', ' Mario', ' Lem', 'ie', 'ux']"
+1141 300 The professional sport played by x -1 The professional sport played by Mario Lemieux hockey Mario Lemieux "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NHL 's all - time leading scorer , is a game that has been played for over a False out Pittsburgh's Mario Lemieux (who scored 199 points) 6 "[' out', ' Pittsburgh', ""'s"", ' Mario', ' Lem', 'ie', 'ux']"
+1142 300 The professional sport played by x -1 The professional sport played by Mario Lemieux hockey Mario Lemieux "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NHL 's all - time leading scorer , is a game that has been played for over a False Jerome Bettis and Mario Lemieux dropped the ceremonial 7 [' Jerome', ' Bett', 'is', ' and', ' Mario', ' Lem', 'ie', 'ux']
+1143 300 The professional sport played by x -1 The professional sport played by Mario Lemieux hockey Mario Lemieux "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NHL 's all - time leading scorer , is a game that has been played for over a False " Lemieux ===
+" 7 [' Lem', 'ie', 'ux', ' ===', 'Mario', ' Lem', 'ie', 'ux']
+1144 300 The professional sport played by x -1 The professional sport played by Mario Lemieux hockey Mario Lemieux "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NHL 's all - time leading scorer , is a game that has been played for over a False Mike Bossy and Mario Lemieux. In the 1993-94 7 [' Mike', ' Boss', 'y', ' and', ' Mario', ' Lem', 'ie', 'ux']
+1145 301 The professional sport played by x -1 The professional sport played by Thierry Henry soccer Thierry Henry "[',' ' the' ' French' ' striker' ',' ' is' ' a' ' game' ' of' ' skill' ','
+ ' speed' ' and' ' power' '.' ' It' ' is' ' a' ' game' ' of']" , the French striker , is a game of skill , speed and power . It is a game of False fixture at Middlesbrough. Thierry Henry scored the opening 9 [' fixture', ' at', ' Middles', 'b', 'rough', '.', ' Th', 'ier', 'ry', ' Henry']
+1146 301 The professional sport played by x -1 The professional sport played by Thierry Henry soccer Thierry Henry "[',' ' the' ' French' ' striker' ',' ' is' ' a' ' game' ' of' ' skill' ','
+ ' speed' ' and' ' power' '.' ' It' ' is' ' a' ' game' ' of']" , the French striker , is a game of skill , speed and power . It is a game of False Coupet. At 2 – 0, Thierry Henry scored a goal, 11 [' Cou', 'pet', '.', ' At', ' 2', ' –', ' 0', ',', ' Th', 'ier', 'ry', ' Henry']
+1147 301 The professional sport played by x -1 The professional sport played by Thierry Henry soccer Thierry Henry "[',' ' the' ' French' ' striker' ',' ' is' ' a' ' game' ' of' ' skill' ','
+ ' speed' ' and' ' power' '.' ' It' ' is' ' a' ' game' ' of']" , the French striker , is a game of skill , speed and power . It is a game of False " Thierry Henry =
+" 3 [' Th', 'ier', 'ry', ' Henry']
+1148 301 The professional sport played by x -1 The professional sport played by Thierry Henry soccer Thierry Henry "[',' ' the' ' French' ' striker' ',' ' is' ' a' ' game' ' of' ' skill' ','
+ ' speed' ' and' ' power' '.' ' It' ' is' ' a' ' game' ' of']" , the French striker , is a game of skill , speed and power . It is a game of False season. Forward Thierry Henry and semi-final-winning 6 [' season', '.', ' Forward', ' Th', 'ier', 'ry', ' Henry']
+1149 301 The professional sport played by x -1 The professional sport played by Thierry Henry soccer Thierry Henry "[',' ' the' ' French' ' striker' ',' ' is' ' a' ' game' ' of' ' skill' ','
+ ' speed' ' and' ' power' '.' ' It' ' is' ' a' ' game' ' of']" , the French striker , is a game of skill , speed and power . It is a game of False the honour since Thierry Henry in 2006. Nasri was 6 [' the', ' honour', ' since', ' Th', 'ier', 'ry', ' Henry']
+1150 302 The professional sport played by x -1 The professional sport played by Stephen Curry basketball Stephen Curry "[',' ' the' ' NBA' ""'s"" ' reigning' ' MVP' ',' ' is' ' a' ' game' ' that'
+ ' is' ' played' ' by' ' a' ' team' ' of' ' two' '.' ' The']" , the NBA 's reigning MVP , is a game that is played by a team of two . The False and 8 hours later, Stephen Curry challenged him, via 6 [' and', ' 8', ' hours', ' later', ',', ' Stephen', ' Curry']
+1151 303 The professional sport played by x -1 The professional sport played by Brock Lesnar football Brock Lesnar "[' is' ' a' ' sport' ' that' ' is' ' not' ' only' ' a' ' sport' ',' ' but'
+ ' a' ' lifestyle' '.' ' It' ' is' ' a' ' lifestyle' ' that' ' is']" is a sport that is not only a sport , but a lifestyle . It is a lifestyle that is False to confront Brock Lesnar during separate 4 [' to', ' confront', ' Brock', ' Les', 'nar']
+1152 303 The professional sport played by x -1 The professional sport played by Brock Lesnar football Brock Lesnar "[' is' ' a' ' sport' ' that' ' is' ' not' ' only' ' a' ' sport' ',' ' but'
+ ' a' ' lifestyle' '.' ' It' ' is' ' a' ' lifestyle' ' that' ' is']" is a sport that is not only a sport , but a lifestyle . It is a lifestyle that is False Chris Benoit and Brock Lesnar in an attempt 6 [' Chris', ' Ben', 'oit', ' and', ' Brock', ' Les', 'nar']
+1153 303 The professional sport played by x -1 The professional sport played by Brock Lesnar football Brock Lesnar "[' is' ' a' ' sport' ' that' ' is' ' not' ' only' ' a' ' sport' ',' ' but'
+ ' a' ' lifestyle' '.' ' It' ' is' ' a' ' lifestyle' ' that' ' is']" is a sport that is not only a sport , but a lifestyle . It is a lifestyle that is False Angle defeated Brock Lesnar in a standard 4 [' Angle', ' defeated', ' Brock', ' Les', 'nar']
+1154 303 The professional sport played by x -1 The professional sport played by Brock Lesnar football Brock Lesnar "[' is' ' a' ' sport' ' that' ' is' ' not' ' only' ' a' ' sport' ',' ' but'
+ ' a' ' lifestyle' '.' ' It' ' is' ' a' ' lifestyle' ' that' ' is']" is a sport that is not only a sport , but a lifestyle . It is a lifestyle that is False saw John Cena face Brock Lesnar in an Extreme Rules 6 [' saw', ' John', ' Cena', ' face', ' Brock', ' Les', 'nar']
+1155 303 The professional sport played by x -1 The professional sport played by Brock Lesnar football Brock Lesnar "[' is' ' a' ' sport' ' that' ' is' ' not' ' only' ' a' ' sport' ',' ' but'
+ ' a' ' lifestyle' '.' ' It' ' is' ' a' ' lifestyle' ' that' ' is']" is a sport that is not only a sport , but a lifestyle . It is a lifestyle that is False week, WWE Champion Brock Lesnar was scheduled to defend 6 [' week', ',', ' WWE', ' Champion', ' Brock', ' Les', 'nar']
+1156 304 The professional sport played by x -1 The professional sport played by Byron White football Byron White "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False However, as Justice Byron White noted in his 5 [' However', ',', ' as', ' Justice', ' Byron', ' White']
+1157 304 The professional sport played by x -1 The professional sport played by Byron White football Byron White "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False Burger and Justices Byron White and William Rehnquist 5 [' Burger', ' and', ' Just', 'ices', ' Byron', ' White']
+1158 304 The professional sport played by x -1 The professional sport played by Byron White football Byron White "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False reading it, Justice Byron White made an argument 5 [' reading', ' it', ',', ' Justice', ' Byron', ' White']
+1159 304 The professional sport played by x -1 The professional sport played by Byron White football Byron White "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False as Justice Byron White noted in his dissenting 3 [' as', ' Justice', ' Byron', ' White']
+1160 304 The professional sport played by x -1 The professional sport played by Byron White football Byron White "[',' ' a' ' former' ' NFL' ' player' ',' ' is' ' a' ' great' ' example'
+ ' of' ' the' ' kind' ' of' ' person' ' who' ' can' ' make' ' a'
+ ' difference']" , a former NFL player , is a great example of the kind of person who can make a difference False associate justices Byron White and Anthony 3 [' associate', ' justices', ' Byron', ' White']
+1161 306 The professional sport played by x -1 The professional sport played by Bill Russell basketball Bill Russell "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False in NCAA Division I. Bill Russell led the Don's 6 [' in', ' NCAA', ' Division', ' I', '.', ' Bill', ' Russell']
+1162 306 The professional sport played by x -1 The professional sport played by Bill Russell basketball Bill Russell "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False Lopes, shortstop Bill Russell and third baseman 5 [' L', 'opes', ',', ' shortstop', ' Bill', ' Russell']
+1163 306 The professional sport played by x -1 The professional sport played by Bill Russell basketball Bill Russell "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False Boston Celtics – when Bill Russell was still just a rookie 5 [' Boston', ' Celtics', ' –', ' when', ' Bill', ' Russell']
+1164 306 The professional sport played by x -1 The professional sport played by Bill Russell basketball Bill Russell "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False accomplishments. The Bill Russell Legacy Foundation, 4 [' accomplishments', '.', ' The', ' Bill', ' Russell']
+1165 306 The professional sport played by x -1 The professional sport played by Bill Russell basketball Bill Russell "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NBA 's all - time leading scorer , was a great player , but he was also a False team would consist of Bill Russell — who in the former's 5 [' team', ' would', ' consist', ' of', ' Bill', ' Russell']
+1166 308 The professional sport played by x -1 The professional sport played by Jackie Robinson baseball Jackie Robinson "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' were' ' the' ' first' ' team' ' to'
+ ' integrate']" ", the first black player in the major leagues .
+
+ The Dodgers were the first team to integrate" False press as the Jackie Robinson of classical ballet. 4 [' press', ' as', ' the', ' Jackie', ' Robinson']
+1167 308 The professional sport played by x -1 The professional sport played by Jackie Robinson baseball Jackie Robinson "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' were' ' the' ' first' ' team' ' to'
+ ' integrate']" ", the first black player in the major leagues .
+
+ The Dodgers were the first team to integrate" False " renamed the ""Jackie Robinson Award"" in honor" 5 "[' renamed', ' the', ' ""', 'Jack', 'ie', ' Robinson']"
+1168 308 The professional sport played by x -1 The professional sport played by Jackie Robinson baseball Jackie Robinson "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' were' ' the' ' first' ' team' ' to'
+ ' integrate']" ", the first black player in the major leagues .
+
+ The Dodgers were the first team to integrate" False 60 years until Jackie Robinson joined the Brooklyn 4 [' 60', ' years', ' until', ' Jackie', ' Robinson']
+1169 308 The professional sport played by x -1 The professional sport played by Jackie Robinson baseball Jackie Robinson "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' were' ' the' ' first' ' team' ' to'
+ ' integrate']" ", the first black player in the major leagues .
+
+ The Dodgers were the first team to integrate" False parks where Jackie Robinson never appeared. And 3 [' parks', ' where', ' Jackie', ' Robinson']
+1170 308 The professional sport played by x -1 The professional sport played by Jackie Robinson baseball Jackie Robinson "[',' ' the' ' first' ' black' ' player' ' in' ' the' ' major' ' leagues'
+ '.' '\n' '\n' 'The' ' Dodgers' ' were' ' the' ' first' ' team' ' to'
+ ' integrate']" ", the first black player in the major leagues .
+
+ The Dodgers were the first team to integrate" False the signing of Jackie Robinson to a minor league 4 [' the', ' signing', ' of', ' Jackie', ' Robinson']
+1171 309 The professional sport played by x -1 The professional sport played by Charles Barkley basketball Charles Barkley "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of person who False " fail in the NBA, and Charles Barkley said he would ""kiss" 7 [' fail', ' in', ' the', ' NBA', ',', ' and', ' Charles', ' Barkley']
+1172 309 The professional sport played by x -1 The professional sport played by Charles Barkley basketball Charles Barkley "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of person who False all-star forward Charles Barkley in 1996, but 5 [' all', '-', 'star', ' forward', ' Charles', ' Barkley']
+1173 309 The professional sport played by x -1 The professional sport played by Charles Barkley basketball Charles Barkley "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of person who False fail in the NBA, and Charles Barkley said he would 7 [' fail', ' in', ' the', ' NBA', ',', ' and', ' Charles', ' Barkley']
+1174 309 The professional sport played by x -1 The professional sport played by Charles Barkley basketball Charles Barkley "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of person who False " Hall-of-Famers Charles Barkley and John Stockton.
+" 8 [' Hall', '-', 'of', '-', 'F', 'am', 'ers', ' Charles', ' Barkley']
+1175 309 The professional sport played by x -1 The professional sport played by Charles Barkley basketball Charles Barkley "[',' ' the' ' NBA' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' person' ' who']" , the NBA 's all - time leading scorer , is a great example of the kind of person who False in the MVP race to Charles Barkley with 22 votes to 6 [' in', ' the', ' MVP', ' race', ' to', ' Charles', ' Barkley']
+1176 310 The professional sport played by x -1 The professional sport played by Paul Kariya hockey Paul Kariya "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False he cross-checked Paul Kariya in the head as the 7 [' he', ' cross', '-', 'checked', ' Paul', ' K', 'ari', 'ya']
+1177 310 The professional sport played by x -1 The professional sport played by Paul Kariya hockey Paul Kariya "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False down Ducks forward Paul Kariya in Game Six. However, 6 [' down', ' Ducks', ' forward', ' Paul', ' K', 'ari', 'ya']
+1178 310 The professional sport played by x -1 The professional sport played by Paul Kariya hockey Paul Kariya "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False Selänne considers Paul Kariya and Joe Sakic to 8 [' Sel', 'ä', 'n', 'ne', ' considers', ' Paul', ' K', 'ari', 'ya']
+1179 310 The professional sport played by x -1 The professional sport played by Paul Kariya hockey Paul Kariya "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False cross-checked Paul Kariya in the head 6 [' cross', '-', 'checked', ' Paul', ' K', 'ari', 'ya']
+1180 310 The professional sport played by x -1 The professional sport played by Paul Kariya hockey Paul Kariya "[',' ' the' ' NHL' ""'s"" ' all' '-' 'time' ' leading' ' scorer' ',' ' is'
+ ' a' ' great' ' example' ' of' ' the' ' kind' ' of' ' player' ' who']" , the NHL 's all - time leading scorer , is a great example of the kind of player who False cross-checked Paul Kariya in the head as 6 [' cross', '-', 'checked', ' Paul', ' K', 'ari', 'ya']
+1181 311 The professional sport played by x -1 The professional sport played by Jim Brown football Jim Brown "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False Lou Groza and Jim Brown extensively 5 [' Lou', ' Gro', 'za', ' and', ' Jim', ' Brown']
+1182 311 The professional sport played by x -1 The professional sport played by Jim Brown football Jim Brown "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False 1 ['Jim', ' Brown']
+1183 311 The professional sport played by x -1 The professional sport played by Jim Brown football Jim Brown "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False retired NFL star Jim Brown acted as Chamberlain's 4 [' retired', ' NFL', ' star', ' Jim', ' Brown']
+1184 311 The professional sport played by x -1 The professional sport played by Jim Brown football Jim Brown "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False 1 record. Jim Brown was not the NFL's 4 [' 1', ' record', '.', ' Jim', ' Brown']
+1185 311 The professional sport played by x -1 The professional sport played by Jim Brown football Jim Brown "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' rusher' ',' ' was'
+ ' a' ' great' ' player' ',' ' but' ' he' ' was' ' also' ' a']" , the NFL 's all - time leading rusher , was a great player , but he was also a False team selected Jim Brown out of Syracuse 3 [' team', ' selected', ' Jim', ' Brown']
+1186 312 The professional sport played by x -1 The professional sport played by Yao Ming basketball Yao Ming "[',' ' the' ' NBA' ""'s"" ' first' ' Chinese' ' player' ',' ' is' ' a'
+ ' game' ' that' ' has' ' been' ' played' ' in' ' China' ' for' ' over'
+ ' a']" , the NBA 's first Chinese player , is a game that has been played in China for over a False 2,805,397 votes. Yao Ming, Tim Duncan, Amar 8 [' 2', ',', '805', ',', '397', ' votes', '.', ' Yao', ' Ming']
+1187 312 The professional sport played by x -1 The professional sport played by Yao Ming basketball Yao Ming "[',' ' the' ' NBA' ""'s"" ' first' ' Chinese' ' player' ',' ' is' ' a'
+ ' game' ' that' ' has' ' been' ' played' ' in' ' China' ' for' ' over'
+ ' a']" , the NBA 's first Chinese player , is a game that has been played in China for over a False In the seventh, Yao Ming injured his left ankle 5 [' In', ' the', ' seventh', ',', ' Yao', ' Ming']
+1188 312 The professional sport played by x -1 The professional sport played by Yao Ming basketball Yao Ming "[',' ' the' ' NBA' ""'s"" ' first' ' Chinese' ' player' ',' ' is' ' a'
+ ' game' ' that' ' has' ' been' ' played' ' in' ' China' ' for' ' over'
+ ' a']" , the NBA 's first Chinese player , is a game that has been played in China for over a False China; however, Yao Ming fell from 6th to 5 [' China', ';', ' however', ',', ' Yao', ' Ming']
+1189 312 The professional sport played by x -1 The professional sport played by Yao Ming basketball Yao Ming "[',' ' the' ' NBA' ""'s"" ' first' ' Chinese' ' player' ',' ' is' ' a'
+ ' game' ' that' ' has' ' been' ' played' ' in' ' China' ' for' ' over'
+ ' a']" , the NBA 's first Chinese player , is a game that has been played in China for over a False alongside future NBA star Yao Ming and established NBA 5 [' alongside', ' future', ' NBA', ' star', ' Yao', ' Ming']
+1190 312 The professional sport played by x -1 The professional sport played by Yao Ming basketball Yao Ming "[',' ' the' ' NBA' ""'s"" ' first' ' Chinese' ' player' ',' ' is' ' a'
+ ' game' ' that' ' has' ' been' ' played' ' in' ' China' ' for' ' over'
+ ' a']" , the NBA 's first Chinese player , is a game that has been played in China for over a False games. In the seventh, Yao Ming injured his 7 [' games', '.', ' In', ' the', ' seventh', ',', ' Yao', ' Ming']
+1191 314 The professional sport played by x -1 The professional sport played by Pavel Bure hockey Pavel Bure "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Vancouver' ' Canucks' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' was' ' a']" , who was a member of the Vancouver Canucks , is a great example of a player who was a False replaced the injured Pavel Bure of the Vancouver 5 [' replaced', ' the', ' injured', ' Pavel', ' B', 'ure']
+1192 314 The professional sport played by x -1 The professional sport played by Pavel Bure hockey Pavel Bure "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Vancouver' ' Canucks' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' was' ' a']" , who was a member of the Vancouver Canucks , is a great example of a player who was a False " Bure =
+" 6 [' B', 'ure', ' =', 'Pa', 'vel', ' B', 'ure']
+1193 314 The professional sport played by x -1 The professional sport played by Pavel Bure hockey Pavel Bure "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Vancouver' ' Canucks' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' was' ' a']" , who was a member of the Vancouver Canucks , is a great example of a player who was a False replaced the injured Pavel Bure of the Vancouver 5 [' replaced', ' the', ' injured', ' Pavel', ' B', 'ure']
+1194 314 The professional sport played by x -1 The professional sport played by Pavel Bure hockey Pavel Bure "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Vancouver' ' Canucks' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' was' ' a']" , who was a member of the Vancouver Canucks , is a great example of a player who was a False Jeremy Roenick and Pavel Bure as players 7 [' Jeremy', ' Ro', 'en', 'ick', ' and', ' Pavel', ' B', 'ure']
+1195 314 The professional sport played by x -1 The professional sport played by Pavel Bure hockey Pavel Bure "[',' ' who' ' was' ' a' ' member' ' of' ' the' ' Vancouver' ' Canucks' ','
+ ' is' ' a' ' great' ' example' ' of' ' a' ' player' ' who' ' was' ' a']" , who was a member of the Vancouver Canucks , is a great example of a player who was a False During the tournament, Pavel Bure set an Olympic Record 6 [' During', ' the', ' tournament', ',', ' Pavel', ' B', 'ure']
+1196 315 The professional sport played by x -1 The professional sport played by Ernie Barnes football Ernie Barnes "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . False neo-mannerist artist Ernie Barnes in 1971. In 1974, 8 [' neo', '-', 'man', 'ner', 'ist', ' artist', ' Er', 'nie', ' Barnes']
+1197 315 The professional sport played by x -1 The professional sport played by Ernie Barnes football Ernie Barnes "[',' ' a' ' former' ' professional' ' baseball' ' player' ',' ' is' ' a'
+ ' game' ' that' ' is' ' played' ' by' ' two' ' teams' ' of' ' two'
+ ' players' '.']" , a former professional baseball player , is a game that is played by two teams of two players . False neo-mannerist artist Ernie Barnes in 1971. In 1974, 8 [' neo', '-', 'man', 'ner', 'ist', ' artist', ' Er', 'nie', ' Barnes']
+1198 316 The professional sport played by x -1 The professional sport played by Peyton Manning football Peyton Manning "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NFL 's all - time leading passer , is a game that has been played for over a False Stadium has been renamed Peyton Manning Pass. Manning 5 [' Stadium', ' has', ' been', ' renamed', ' Peyton', ' Manning']
+1199 316 The professional sport played by x -1 The professional sport played by Peyton Manning football Peyton Manning "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NFL 's all - time leading passer , is a game that has been played for over a False Tennessee quarterback Peyton Manning completed a nine-yard 3 [' Tennessee', ' quarterback', ' Peyton', ' Manning']
+1200 316 The professional sport played by x -1 The professional sport played by Peyton Manning football Peyton Manning "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NFL 's all - time leading passer , is a game that has been played for over a False " Peyton Manning =
+" 1 [' Peyton', ' Manning']
+1201 316 The professional sport played by x -1 The professional sport played by Peyton Manning football Peyton Manning "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NFL 's all - time leading passer , is a game that has been played for over a False Indianapolis Colts quarterback Peyton Manning visited Rochester, 4 [' Indianapolis', ' Colts', ' quarterback', ' Peyton', ' Manning']
+1202 316 The professional sport played by x -1 The professional sport played by Peyton Manning football Peyton Manning "[',' ' the' ' NFL' ""'s"" ' all' '-' 'time' ' leading' ' passer' ',' ' is'
+ ' a' ' game' ' that' ' has' ' been' ' played' ' for' ' over' ' a']" , the NFL 's all - time leading passer , is a game that has been played for over a False been renamed Peyton Manning Pass. Manning also 3 [' been', ' renamed', ' Peyton', ' Manning']
+1203 317 The professional sport played by x -1 The professional sport played by Megan Rapinoe soccer Megan Rapinoe "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' soccer'
+ ' team' ' captain' ',' ' is' ' a' ' perfect' ' example' ' of']" , the U . S . women � � s national soccer team captain , is a perfect example of True " Exclusive: Megan Rapinoe in November 2012.
+" 5 [' Exclusive', ':', ' Megan', ' Rap', 'in', 'oe']
+1204 317 The professional sport played by x -1 The professional sport played by Megan Rapinoe soccer Megan Rapinoe "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' soccer'
+ ' team' ' captain' ',' ' is' ' a' ' perfect' ' example' ' of']" , the U . S . women � � s national soccer team captain , is a perfect example of True national team, along with Megan Rapinoe and Amy Rodriguez, 8 [' national', ' team', ',', ' along', ' with', ' Megan', ' Rap', 'in', 'oe']
+1205 317 The professional sport played by x -1 The professional sport played by Megan Rapinoe soccer Megan Rapinoe "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' soccer'
+ ' team' ' captain' ',' ' is' ' a' ' perfect' ' example' ' of']" , the U . S . women � � s national soccer team captain , is a perfect example of True corner kick from Megan Rapinoe in the 19th minute. 6 [' corner', ' kick', ' from', ' Megan', ' Rap', 'in', 'oe']
+1206 317 The professional sport played by x -1 The professional sport played by Megan Rapinoe soccer Megan Rapinoe "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' soccer'
+ ' team' ' captain' ',' ' is' ' a' ' perfect' ' example' ' of']" , the U . S . women � � s national soccer team captain , is a perfect example of True " Megan Rapinoe =
+" 3 [' Megan', ' Rap', 'in', 'oe']
+1207 317 The professional sport played by x -1 The professional sport played by Megan Rapinoe soccer Megan Rapinoe "[',' ' the' ' U' '.' 'S' '.' ' women' '�' '�' 's' ' national' ' soccer'
+ ' team' ' captain' ',' ' is' ' a' ' perfect' ' example' ' of']" , the U . S . women � � s national soccer team captain , is a perfect example of True corner kick from Megan Rapinoe in the 19th minute. 6 [' corner', ' kick', ' from', ' Megan', ' Rap', 'in', 'oe']
diff --git a/patchscopes/code/preprocessed_data/factual/pokemon_evolutions.tsv b/patchscopes/code/preprocessed_data/factual/pokemon_evolutions.tsv
new file mode 100644
index 00000000..8e02fca5
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/pokemon_evolutions.tsv
@@ -0,0 +1,79 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The evolved form of x -1 The evolved form of Bulbasaur Ivysaur Bulbasaur [' is', ' a', ' Grass', '/', 'Po', 'ison', ' type', '.', '\n', '\n', 'Bul', 'bas', 'aur', ' is', ' a', ' Grass', '/', 'Po', 'ison', ' type'] " is a Grass / Po ison type .
+
+ Bul bas aur is a Grass / Po ison type" False Charmander (along with Bulbasaur and Squirtle) 7 [' Charm', 'ander', ' (', 'along', ' with', ' Bul', 'bas', 'aur']
+1 0 The evolved form of x -1 The evolved form of Bulbasaur Ivysaur Bulbasaur [' is', ' a', ' Grass', '/', 'Po', 'ison', ' type', '.', '\n', '\n', 'Bul', 'bas', 'aur', ' is', ' a', ' Grass', '/', 'Po', 'ison', ' type'] " is a Grass / Po ison type .
+
+ Bul bas aur is a Grass / Po ison type" False Charmander (along with Bulbasaur and Squirtle) 7 [' Charm', 'ander', ' (', 'along', ' with', ' Bul', 'bas', 'aur']
+2 0 The evolved form of x -1 The evolved form of Bulbasaur Ivysaur Bulbasaur [' is', ' a', ' Grass', '/', 'Po', 'ison', ' type', '.', '\n', '\n', 'Bul', 'bas', 'aur', ' is', ' a', ' Grass', '/', 'Po', 'ison', ' type'] " is a Grass / Po ison type .
+
+ Bul bas aur is a Grass / Po ison type" False Charmander (along with Bulbasaur and Squirtle) was added 7 [' Charm', 'ander', ' (', 'along', ' with', ' Bul', 'bas', 'aur']
+3 1 The evolved form of x -1 The evolved form of Charmander Charmeleon Charmander [' is', ' a', ' fire', '-', 'type', ' Pokémon', '.', ' It', ' is', ' a', ' member', ' of', ' the', ' fire', '-', 'type', ' family', '.', ' It', ' is'] is a fire - type Pokémon . It is a member of the fire - type family . It is False still had a Charmander instead of Charizard, 4 [' still', ' had', ' a', ' Charm', 'ander']
+4 1 The evolved form of x -1 The evolved form of Charmander Charmeleon Charmander [' is', ' a', ' fire', '-', 'type', ' Pokémon', '.', ' It', ' is', ' a', ' member', ' of', ' the', ' fire', '-', 'type', ' family', '.', ' It', ' is'] is a fire - type Pokémon . It is a member of the fire - type family . It is False (either Squirtle, Charmander or Bulbasaur) 7 [' (', 'either', ' Squ', 'irt', 'le', ',', ' Charm', 'ander']
+5 1 The evolved form of x -1 The evolved form of Charmander Charmeleon Charmander [' is', ' a', ' fire', '-', 'type', ' Pokémon', '.', ' It', ' is', ' a', ' member', ' of', ' the', ' fire', '-', 'type', ' family', '.', ' It', ' is'] is a fire - type Pokémon . It is a member of the fire - type family . It is False pre-evolutions Charmander and Charmeleon are 5 [' pre', '-', 'ev', 'olutions', ' Charm', 'ander']
+6 1 The evolved form of x -1 The evolved form of Charmander Charmeleon Charmander [' is', ' a', ' fire', '-', 'type', ' Pokémon', '.', ' It', ' is', ' a', ' member', ' of', ' the', ' fire', '-', 'type', ' family', '.', ' It', ' is'] is a fire - type Pokémon . It is a member of the fire - type family . It is False manga, Blue receives a Charmander from his grandfather 6 [' manga', ',', ' Blue', ' receives', ' a', ' Charm', 'ander']
+7 1 The evolved form of x -1 The evolved form of Charmander Charmeleon Charmander [' is', ' a', ' fire', '-', 'type', ' Pokémon', '.', ' It', ' is', ' a', ' member', ' of', ' the', ' fire', '-', 'type', ' family', '.', ' It', ' is'] is a fire - type Pokémon . It is a member of the fire - type family . It is False since he was a Charmander abandoned by his 5 [' since', ' he', ' was', ' a', ' Charm', 'ander']
+8 2 The evolved form of x -1 The evolved form of Squirtle Wartortle Squirtle [' is', ' a', ' water', ' type', ',', ' and', ' it', ' is', ' a', ' water', ' type', ' that', ' is', ' very', ' good', ' at', ' using', ' water', ' attacks', '.'] is a water type , and it is a water type that is very good at using water attacks . False and Painter Squirtle (Maler Klecksel) 4 [' and', ' Painter', ' Squ', 'irt', 'le']
+9 2 The evolved form of x -1 The evolved form of Squirtle Wartortle Squirtle [' is', ' a', ' water', ' type', ',', ' and', ' it', ' is', ' a', ' water', ' type', ' that', ' is', ' very', ' good', ' at', ' using', ' water', ' attacks', '.'] is a water type , and it is a water type that is very good at using water attacks . False " ""Here Comes the Squirtle Squad"", ""Beach" 6 "[' ""', 'Here', ' Comes', ' the', ' Squ', 'irt', 'le']"
+10 2 The evolved form of x -1 The evolved form of Squirtle Wartortle Squirtle [' is', ' a', ' water', ' type', ',', ' and', ' it', ' is', ' a', ' water', ' type', ' that', ' is', ' very', ' good', ' at', ' using', ' water', ' attacks', '.'] is a water type , and it is a water type that is very good at using water attacks . False as in Painter Squirtle (Maler Klecksel); 5 [' as', ' in', ' Painter', ' Squ', 'irt', 'le']
+11 2 The evolved form of x -1 The evolved form of Squirtle Wartortle Squirtle [' is', ' a', ' water', ' type', ',', ' and', ' it', ' is', ' a', ' water', ' type', ' that', ' is', ' very', ' good', ' at', ' using', ' water', ' attacks', '.'] is a water type , and it is a water type that is very good at using water attacks . False " You!"", ""Here Comes the Squirtle Squad"", ""Beach" 8 "[' You', '!"",', ' ""', 'Here', ' Comes', ' the', ' Squ', 'irt', 'le']"
+12 2 The evolved form of x -1 The evolved form of Squirtle Wartortle Squirtle [' is', ' a', ' water', ' type', ',', ' and', ' it', ' is', ' a', ' water', ' type', ' that', ' is', ' very', ' good', ' at', ' using', ' water', ' attacks', '.'] is a water type , and it is a water type that is very good at using water attacks . False The Trainer has a Squirtle and an Ivysaur, 6 [' The', ' Trainer', ' has', ' a', ' Squ', 'irt', 'le']
+13 3 The evolved form of x -1 The evolved form of Pikachu Raichu Pikachu [' is', ' a', ' very', ' powerful', ' Pokémon', '.', ' It', ' is', ' a', ' Psychic', '-', 'type', ' Pokémon', '.', ' It', ' has', ' a', ' base', ' stat', ' total'] is a very powerful Pokémon . It is a Psychic - type Pokémon . It has a base stat total False one more Pokémon, a Pikachu. At first, Pikachu 5 [' one', ' more', ' Pokémon', ',', ' a', ' Pikachu']
+14 3 The evolved form of x -1 The evolved form of Pikachu Raichu Pikachu [' is', ' a', ' very', ' powerful', ' Pokémon', '.', ' It', ' is', ' a', ' Psychic', '-', 'type', ' Pokémon', '.', ' It', ' has', ' a', ' base', ' stat', ' total'] is a very powerful Pokémon . It is a Psychic - type Pokémon . It has a base stat total False generations. For example, Pikachu was introduced 5 [' generations', '.', ' For', ' example', ',', ' Pikachu']
+15 3 The evolved form of x -1 The evolved form of Pikachu Raichu Pikachu [' is', ' a', ' very', ' powerful', ' Pokémon', '.', ' It', ' is', ' a', ' Psychic', '-', 'type', ' Pokémon', '.', ' It', ' has', ' a', ' base', ' stat', ' total'] is a very powerful Pokémon . It is a Psychic - type Pokémon . It has a base stat total False 2 ['P', 'ik', 'achu']
+16 3 The evolved form of x -1 The evolved form of Pikachu Raichu Pikachu [' is', ' a', ' very', ' powerful', ' Pokémon', '.', ' It', ' is', ' a', ' Psychic', '-', 'type', ' Pokémon', '.', ' It', ' has', ' a', ' base', ' stat', ' total'] is a very powerful Pokémon . It is a Psychic - type Pokémon . It has a base stat total False programs with a Pikachu, a mouse-like 3 [' programs', ' with', ' a', ' Pikachu']
+17 3 The evolved form of x -1 The evolved form of Pikachu Raichu Pikachu [' is', ' a', ' very', ' powerful', ' Pokémon', '.', ' It', ' is', ' a', ' Psychic', '-', 'type', ' Pokémon', '.', ' It', ' has', ' a', ' base', ' stat', ' total'] is a very powerful Pokémon . It is a Psychic - type Pokémon . It has a base stat total False thanking the player and Pikachu for their time, 4 [' thanking', ' the', ' player', ' and', ' Pikachu']
+18 7 The evolved form of x -1 The evolved form of Meowth Persian Meowth [' is', ' a', ' cat', '-', 'like', ' Pokémon', ' with', ' a', ' long', ',', ' pointed', ' nose', ' and', ' a', ' long', ' tail', '.', ' It', ' has', ' a'] is a cat - like Pokémon with a long , pointed nose and a long tail . It has a False Patrick Paon) and Meowth to carry out the 6 [' Patrick', ' Pa', 'on', ')', ' and', ' Meow', 'th']
+19 7 The evolved form of x -1 The evolved form of Meowth Persian Meowth [' is', ' a', ' cat', '-', 'like', ' Pokémon', ' with', ' a', ' long', ',', ' pointed', ' nose', ' and', ' a', ' long', ' tail', '.', ' It', ' has', ' a'] is a cat - like Pokémon with a long , pointed nose and a long tail . It has a False Patrick Paon) and Meowth to carry out 6 [' Patrick', ' Pa', 'on', ')', ' and', ' Meow', 'th']
+20 7 The evolved form of x -1 The evolved form of Meowth Persian Meowth [' is', ' a', ' cat', '-', 'like', ' Pokémon', ' with', ' a', ' long', ',', ' pointed', ' nose', ' and', ' a', ' long', ' tail', '.', ' It', ' has', ' a'] is a cat - like Pokémon with a long , pointed nose and a long tail . It has a False masquerade as Pokémon Meowth in lieu of 5 [' mas', 'querade', ' as', ' Pokémon', ' Meow', 'th']
+21 7 The evolved form of x -1 The evolved form of Meowth Persian Meowth [' is', ' a', ' cat', '-', 'like', ' Pokémon', ' with', ' a', ' long', ',', ' pointed', ' nose', ' and', ' a', ' long', ' tail', '.', ' It', ' has', ' a'] is a cat - like Pokémon with a long , pointed nose and a long tail . It has a False Gharret Patrick Paon) and Meowth to carry out the retribution. 9 [' G', 'har', 'ret', ' Patrick', ' Pa', 'on', ')', ' and', ' Meow', 'th']
+22 7 The evolved form of x -1 The evolved form of Meowth Persian Meowth [' is', ' a', ' cat', '-', 'like', ' Pokémon', ' with', ' a', ' long', ',', ' pointed', ' nose', ' and', ' a', ' long', ' tail', '.', ' It', ' has', ' a'] is a cat - like Pokémon with a long , pointed nose and a long tail . It has a False Gharret Patrick Paon) and Meowth to carry out the 9 [' G', 'har', 'ret', ' Patrick', ' Pa', 'on', ')', ' and', ' Meow', 'th']
+23 8 The evolved form of x -1 The evolved form of Psyduck Golduck Psyduck [' is', ' a', ' duck', ' with', ' a', ' Psy', 'du', 'ck', ' head', '.', ' It', ' is', ' a', ' water', '-', 'type', ' Pokémon', '.', ' It', ' is'] is a duck with a Psy du ck head . It is a water - type Pokémon . It is False coverage from a Psyduck at Pokémon News 5 [' coverage', ' from', ' a', ' Psy', 'du', 'ck']
+24 8 The evolved form of x -1 The evolved form of Psyduck Golduck Psyduck [' is', ' a', ' duck', ' with', ' a', ' Psy', 'du', 'ck', ' head', '.', ' It', ' is', ' a', ' water', '-', 'type', ' Pokémon', '.', ' It', ' is'] is a duck with a Psy du ck head . It is a water - type Pokémon . It is False by using a zombie Psyduck to deflect Arbok's 6 [' by', ' using', ' a', ' zombie', ' Psy', 'du', 'ck']
+25 8 The evolved form of x -1 The evolved form of Psyduck Golduck Psyduck [' is', ' a', ' duck', ' with', ' a', ' Psy', 'du', 'ck', ' head', '.', ' It', ' is', ' a', ' water', '-', 'type', ' Pokémon', '.', ' It', ' is'] is a duck with a Psy du ck head . It is a water - type Pokémon . It is False using a zombie Psyduck to deflect 5 [' using', ' a', ' zombie', ' Psy', 'du', 'ck']
+26 8 The evolved form of x -1 The evolved form of Psyduck Golduck Psyduck [' is', ' a', ' duck', ' with', ' a', ' Psy', 'du', 'ck', ' head', '.', ' It', ' is', ' a', ' water', '-', 'type', ' Pokémon', '.', ' It', ' is'] is a duck with a Psy du ck head . It is a water - type Pokémon . It is False using a zombie Psyduck to deflect 5 [' using', ' a', ' zombie', ' Psy', 'du', 'ck']
+27 12 The evolved form of x -1 The evolved form of Abra Kadabra Abra [',', ' the', ' first', ' Pokémon', ' to', ' be', ' released', ' in', ' the', ' Pokémon', ' Black', ' and', ' White', ' series', ',', ' is', ' a', ' Psychic', '-', 'type'] , the first Pokémon to be released in the Pokémon Black and White series , is a Psychic - type False game appearance of Abra, Kadabra, and 4 [' game', ' appearance', ' of', ' Ab', 'ra']
+28 12 The evolved form of x -1 The evolved form of Abra Kadabra Abra [',', ' the', ' first', ' Pokémon', ' to', ' be', ' released', ' in', ' the', ' Pokémon', ' Black', ' and', ' White', ' series', ',', ' is', ' a', ' Psychic', '-', 'type'] , the first Pokémon to be released in the Pokémon Black and White series , is a Psychic - type False Sabrina owns an Abra, which she sends 5 [' Sab', 'rina', ' owns', ' an', ' Ab', 'ra']
+29 12 The evolved form of x -1 The evolved form of Abra Kadabra Abra [',', ' the', ' first', ' Pokémon', ' to', ' be', ' released', ' in', ' the', ' Pokémon', ' Black', ' and', ' White', ' series', ',', ' is', ' a', ' Psychic', '-', 'type'] , the first Pokémon to be released in the Pokémon Black and White series , is a Psychic - type False 1 ['Ab', 'ra']
+30 12 The evolved form of x -1 The evolved form of Abra Kadabra Abra [',', ' the', ' first', ' Pokémon', ' to', ' be', ' released', ' in', ' the', ' Pokémon', ' Black', ' and', ' White', ' series', ',', ' is', ' a', ' Psychic', '-', 'type'] , the first Pokémon to be released in the Pokémon Black and White series , is a Psychic - type False 1 ['Ab', 'ra']
+31 12 The evolved form of x -1 The evolved form of Abra Kadabra Abra [',', ' the', ' first', ' Pokémon', ' to', ' be', ' released', ' in', ' the', ' Pokémon', ' Black', ' and', ' White', ' series', ',', ' is', ' a', ' Psychic', '-', 'type'] , the first Pokémon to be released in the Pokémon Black and White series , is a Psychic - type False " attested that Abra returned in ""a fragile" 4 [' att', 'ested', ' that', ' Ab', 'ra']
+32 16 The evolved form of x -1 The evolved form of Pidgey Pidgeotto Pidgey [' is', ' the', ' P', 'idge', 'y', '-', 'P', 'idge', 'ot', '.', ' It', ' is', ' a', ' flying', ' Pokémon', ' that', ' can', ' use', ' its', ' wings'] is the P idge y - P idge ot . It is a flying Pokémon that can use its wings False Just then, a Pidgey flies by and Ash unsuccessfully 6 [' Just', ' then', ',', ' a', ' P', 'idge', 'y']
+33 18 The evolved form of x -1 The evolved form of Spearow Fearow Spearow [' is', ' a', ' small', ',', ' bip', 'ed', 'al', ',', ' bird', '-', 'like', ' Pokémon', ' with', ' a', ' long', ',', ' sharp', ' be', 'ak', '.'] is a small , bip ed al , bird - like Pokémon with a long , sharp be ak . False throws a rock at a Spearow, which gets angry 6 [' throws', ' a', ' rock', ' at', ' a', ' Spear', 'ow']
+34 18 The evolved form of x -1 The evolved form of Spearow Fearow Spearow [' is', ' a', ' small', ',', ' bip', 'ed', 'al', ',', ' bird', '-', 'like', ' Pokémon', ' with', ' a', ' long', ',', ' sharp', ' be', 'ak', '.'] is a small , bip ed al , bird - like Pokémon with a long , sharp be ak . False 2 ['Spe', 'ar', 'ow']
+35 18 The evolved form of x -1 The evolved form of Spearow Fearow Spearow [' is', ' a', ' small', ',', ' bip', 'ed', 'al', ',', ' bird', '-', 'like', ' Pokémon', ' with', ' a', ' long', ',', ' sharp', ' be', 'ak', '.'] is a small , bip ed al , bird - like Pokémon with a long , sharp be ak . False 2 ['Spe', 'ar', 'ow']
+36 20 The evolved form of x -1 The evolved form of Sandshrew Sandslash Sandshrew [' is', ' a', ' small', ',', ' bip', 'ed', 'al', ',', ' rodent', '-', 'like', ' Pokémon', ' with', ' a', ' long', ' tail', ' and', ' a', ' long', ','] is a small , bip ed al , rodent - like Pokémon with a long tail and a long , False containing a Sandshrew that was lost 4 [' containing', ' a', ' Sand', 'sh', 'rew']
+37 23 The evolved form of x -1 The evolved form of Zubat Golbat Zubat [' is', ' a', ' flying', ',', ' fire', '-', 'b', 'reat', 'hing', ',', ' fire', '-', 'sp', 'itting', ',', ' fire', '-', 'b', 'reat', 'hing'] is a flying , fire - b reat hing , fire - sp itting , fire - b reat hing False body of water, while Zubat can only be caught 6 [' body', ' of', ' water', ',', ' while', ' Zub', 'at']
+38 23 The evolved form of x -1 The evolved form of Zubat Golbat Zubat [' is', ' a', ' flying', ',', ' fire', '-', 'b', 'reat', 'hing', ',', ' fire', '-', 'sp', 'itting', ',', ' fire', '-', 'b', 'reat', 'hing'] is a flying , fire - b reat hing , fire - sp itting , fire - b reat hing False body of water, while Zubat can only be caught 6 [' body', ' of', ' water', ',', ' while', ' Zub', 'at']
+39 23 The evolved form of x -1 The evolved form of Zubat Golbat Zubat [' is', ' a', ' flying', ',', ' fire', '-', 'b', 'reat', 'hing', ',', ' fire', '-', 'sp', 'itting', ',', ' fire', '-', 'b', 'reat', 'hing'] is a flying , fire - b reat hing , fire - sp itting , fire - b reat hing False water, while Zubat can only be caught 4 [' water', ',', ' while', ' Zub', 'at']
+40 23 The evolved form of x -1 The evolved form of Zubat Golbat Zubat [' is', ' a', ' flying', ',', ' fire', '-', 'b', 'reat', 'hing', ',', ' fire', '-', 'sp', 'itting', ',', ' fire', '-', 'b', 'reat', 'hing'] is a flying , fire - b reat hing , fire - sp itting , fire - b reat hing False of water, while Zubat can only be caught 5 [' of', ' water', ',', ' while', ' Zub', 'at']
+41 25 The evolved form of x -1 The evolved form of Tentacool Tentacruel Tentacool [' is', ' a', ' very', ' rare', ' Pokémon', '.', ' It', ' is', ' a', ' water', '-', 'type', ' Pokémon', ' that', ' is', ' found', ' in', ' the', ' ocean', '.'] is a very rare Pokémon . It is a water - type Pokémon that is found in the ocean . False varies. For example, Tentacool can only be 7 [' varies', '.', ' For', ' example', ',', ' Tent', 'ac', 'ool']
+42 25 The evolved form of x -1 The evolved form of Tentacool Tentacruel Tentacool [' is', ' a', ' very', ' rare', ' Pokémon', '.', ' It', ' is', ' a', ' water', '-', 'type', ' Pokémon', ' that', ' is', ' found', ' in', ' the', ' ocean', '.'] is a very rare Pokémon . It is a water - type Pokémon that is found in the ocean . False varies. For example, Tentacool can only be caught 7 [' varies', '.', ' For', ' example', ',', ' Tent', 'ac', 'ool']
+43 25 The evolved form of x -1 The evolved form of Tentacool Tentacruel Tentacool [' is', ' a', ' very', ' rare', ' Pokémon', '.', ' It', ' is', ' a', ' water', '-', 'type', ' Pokémon', ' that', ' is', ' found', ' in', ' the', ' ocean', '.'] is a very rare Pokémon . It is a water - type Pokémon that is found in the ocean . False For example, Tentacool can only be 5 [' For', ' example', ',', ' Tent', 'ac', 'ool']
+44 29 The evolved form of x -1 The evolved form of Magnemite Magneton Magnemite [' is', ' called', ' Magnet', 'on', '.', ' It', ' is', ' a', ' powerful', ' Pokémon', ' that', ' can', ' use', ' its', ' magnetic', ' powers', ' to', ' attract', ' and', ' rep'] is called Magnet on . It is a powerful Pokémon that can use its magnetic powers to attract and rep True opens with a group of Magnemite — magnet-like 7 [' opens', ' with', ' a', ' group', ' of', ' Magn', 'em', 'ite']
+45 29 The evolved form of x -1 The evolved form of Magnemite Magneton Magnemite [' is', ' called', ' Magnet', 'on', '.', ' It', ' is', ' a', ' powerful', ' Pokémon', ' that', ' can', ' use', ' its', ' magnetic', ' powers', ' to', ' attract', ' and', ' rep'] is called Magnet on . It is a powerful Pokémon that can use its magnetic powers to attract and rep True with a group of Magnemite — magnet-like 6 [' with', ' a', ' group', ' of', ' Magn', 'em', 'ite']
+46 29 The evolved form of x -1 The evolved form of Magnemite Magneton Magnemite [' is', ' called', ' Magnet', 'on', '.', ' It', ' is', ' a', ' powerful', ' Pokémon', ' that', ' can', ' use', ' its', ' magnetic', ' powers', ' to', ' attract', ' and', ' rep'] is called Magnet on . It is a powerful Pokémon that can use its magnetic powers to attract and rep True " Unfazed, Oak has the Magnemite deliver a ""retro"" television" 9 [' Un', 'f', 'azed', ',', ' Oak', ' has', ' the', ' Magn', 'em', 'ite']
+47 29 The evolved form of x -1 The evolved form of Magnemite Magneton Magnemite [' is', ' called', ' Magnet', 'on', '.', ' It', ' is', ' a', ' powerful', ' Pokémon', ' that', ' can', ' use', ' its', ' magnetic', ' powers', ' to', ' attract', ' and', ' rep'] is called Magnet on . It is a powerful Pokémon that can use its magnetic powers to attract and rep True with a group of Magnemite — magnet-like Pokémon 6 [' with', ' a', ' group', ' of', ' Magn', 'em', 'ite']
+48 29 The evolved form of x -1 The evolved form of Magnemite Magneton Magnemite [' is', ' called', ' Magnet', 'on', '.', ' It', ' is', ' a', ' powerful', ' Pokémon', ' that', ' can', ' use', ' its', ' magnetic', ' powers', ' to', ' attract', ' and', ' rep'] is called Magnet on . It is a powerful Pokémon that can use its magnetic powers to attract and rep True lost by the delivery Magnemite on their way to 6 [' lost', ' by', ' the', ' delivery', ' Magn', 'em', 'ite']
+49 31 The evolved form of x -1 The evolved form of Seel Dewgong Seel ['ie', ',', ' the', ' F', 'ae', ' are', ' a', ' race', ' of', ' beings', ' that', ' are', ' not', ' human', '.', ' They', ' are', ' beautiful', ',', ' graceful'] ie , the F ae are a race of beings that are not human . They are beautiful , graceful False " Magnificat canticle, ""Meine Seele erhebt den Herren""." 10 "[' Magn', 'ific', 'at', ' cant', 'icle', ',', ' ""', 'Me', 'ine', ' Se', 'el']"
+50 31 The evolved form of x -1 The evolved form of Seel Dewgong Seel ['ie', ',', ' the', ' F', 'ae', ' are', ' a', ' race', ' of', ' beings', ' that', ' are', ' not', ' human', '.', ' They', ' are', ' beautiful', ',', ' graceful'] ie , the F ae are a race of beings that are not human . They are beautiful , graceful False Magnificat Meine Seel erhebt den 6 [' Magn', 'ific', 'at', ' Me', 'ine', ' Se', 'el']
+51 31 The evolved form of x -1 The evolved form of Seel Dewgong Seel ['ie', ',', ' the', ' F', 'ae', ' are', ' a', ' race', ' of', ' beings', ' that', ' are', ' not', ' human', '.', ' They', ' are', ' beautiful', ',', ' graceful'] ie , the F ae are a race of beings that are not human . They are beautiful , graceful False Voyenno-Vozdooshnykh Seel – Air Force Scientific 11 [' Voy', 'enn', 'o', '-', 'V', 'oz', 'do', 'osh', 'ny', 'kh', ' Se', 'el']
+52 31 The evolved form of x -1 The evolved form of Seel Dewgong Seel ['ie', ',', ' the', ' F', 'ae', ' are', ' a', ' race', ' of', ' beings', ' that', ' are', ' not', ' human', '.', ' They', ' are', ' beautiful', ',', ' graceful'] ie , the F ae are a race of beings that are not human . They are beautiful , graceful False contralto soloist at the Seel Street Benedictine 8 [' cont', 'ral', 'to', ' solo', 'ist', ' at', ' the', ' Se', 'el']
+53 31 The evolved form of x -1 The evolved form of Seel Dewgong Seel ['ie', ',', ' the', ' F', 'ae', ' are', ' a', ' race', ' of', ' beings', ' that', ' are', ' not', ' human', '.', ' They', ' are', ' beautiful', ',', ' graceful'] ie , the F ae are a race of beings that are not human . They are beautiful , graceful False Voyenno-Vozdushnykh Seel — NII VVS) between 11 [' Voy', 'enn', 'o', '-', 'V', 'oz', 'd', 'ush', 'ny', 'kh', ' Se', 'el']
+54 34 The evolved form of x -1 The evolved form of Gastly Haunter Gastly [' is', ' a', ' ghost', 'ly', ',', ' translucent', ',', ' and', ' ghost', 'ly', ' ghost', '.', ' It', ' is', ' a', ' ghost', 'ly', ' ghost', '.', ' It'] is a ghost ly , translucent , and ghost ly ghost . It is a ghost ly ghost . It False possessed by a Gastly in the Lavender 4 [' possessed', ' by', ' a', ' Gast', 'ly']
+55 34 The evolved form of x -1 The evolved form of Gastly Haunter Gastly [' is', ' a', ' ghost', 'ly', ',', ' translucent', ',', ' and', ' ghost', 'ly', ' ghost', '.', ' It', ' is', ' a', ' ghost', 'ly', ' ghost', '.', ' It'] is a ghost ly , translucent , and ghost ly ghost . It is a ghost ly ghost . It False possessed by a Gastly in the Lavender 4 [' possessed', ' by', ' a', ' Gast', 'ly']
+56 34 The evolved form of x -1 The evolved form of Gastly Haunter Gastly [' is', ' a', ' ghost', 'ly', ',', ' translucent', ',', ' and', ' ghost', 'ly', ' ghost', '.', ' It', ' is', ' a', ' ghost', 'ly', ' ghost', '.', ' It'] is a ghost ly , translucent , and ghost ly ghost . It is a ghost ly ghost . It False possessed by a Gastly in the Lavender 4 [' possessed', ' by', ' a', ' Gast', 'ly']
+57 36 The evolved form of x -1 The evolved form of Krabby Kingler Krabby [' Patty', ' is', ' Kr', 'abby', ' Patty', '.', '\n', '\n', 'The', ' evolved', ' form', ' of', ' Kr', 'abby', ' Patty', ' is', ' Kr', 'abby', ' Patty', '.'] " Patty is Kr abby Patty .
+
+ The evolved form of Kr abby Patty is Kr abby Patty ." False " Sandy"", ""The Krabby Patty That Ate" 5 "[' Sandy', '"",', ' ""', 'The', ' Kr', 'abby']"
+58 36 The evolved form of x -1 The evolved form of Krabby Kingler Krabby [' Patty', ' is', ' Kr', 'abby', ' Patty', '.', '\n', '\n', 'The', ' evolved', ' form', ' of', ' Kr', 'abby', ' Patty', ' is', ' Kr', 'abby', ' Patty', '.'] " Patty is Kr abby Patty .
+
+ The evolved form of Kr abby Patty is Kr abby Patty ." False " episodes ""Fear of a Krabby Patty"" and ""Shell" 6 "[' episodes', ' ""', 'Fear', ' of', ' a', ' Kr', 'abby']"
+59 36 The evolved form of x -1 The evolved form of Krabby Kingler Krabby [' Patty', ' is', ' Kr', 'abby', ' Patty', '.', '\n', '\n', 'The', ' evolved', ' form', ' of', ' Kr', 'abby', ' Patty', ' is', ' Kr', 'abby', ' Patty', '.'] " Patty is Kr abby Patty .
+
+ The evolved form of Kr abby Patty is Kr abby Patty ." False " For Sandy"", ""The Krabby Patty That Ate Bikini" 6 "[' For', ' Sandy', '"",', ' ""', 'The', ' Kr', 'abby']"
+60 36 The evolved form of x -1 The evolved form of Krabby Kingler Krabby [' Patty', ' is', ' Kr', 'abby', ' Patty', '.', '\n', '\n', 'The', ' evolved', ' form', ' of', ' Kr', 'abby', ' Patty', ' is', ' Kr', 'abby', ' Patty', '.'] " Patty is Kr abby Patty .
+
+ The evolved form of Kr abby Patty is Kr abby Patty ." False " episodes ""Fear of a Krabby Patty"" and ""Shell of" 6 "[' episodes', ' ""', 'Fear', ' of', ' a', ' Kr', 'abby']"
+61 36 The evolved form of x -1 The evolved form of Krabby Kingler Krabby [' Patty', ' is', ' Kr', 'abby', ' Patty', '.', '\n', '\n', 'The', ' evolved', ' form', ' of', ' Kr', 'abby', ' Patty', ' is', ' Kr', 'abby', ' Patty', '.'] " Patty is Kr abby Patty .
+
+ The evolved form of Kr abby Patty is Kr abby Patty ." False " Stooges"", and ""The Krabby Kronicle"" an 8 /" 8 "[' Sto', 'og', 'es', '"",', ' and', ' ""', 'The', ' Kr', 'abby']"
diff --git a/patchscopes/code/preprocessed_data/factual/product_by_company.tsv b/patchscopes/code/preprocessed_data/factual/product_by_company.tsv
new file mode 100644
index 00000000..d6e45845
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/product_by_company.tsv
@@ -0,0 +1,3482 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 3 The company which developed x -1 The company which developed iPhone 5 Apple iPhone 5 "['S' ' and' ' iPhone' ' 5' 'C' '.' '\n' '\n' 'The' ' iPhone' ' 5' 'S'
+ ' and' ' iPhone' ' 5' 'C' ' are' ' the' ' two' ' new']" "S and iPhone 5 C .
+
+ The iPhone 5 S and iPhone 5 C are the two new" False 1 ['iPhone', ' 5']
+1 3 The company which developed x -1 The company which developed iPhone 5 Apple iPhone 5 "['S' ' and' ' iPhone' ' 5' 'C' '.' '\n' '\n' 'The' ' iPhone' ' 5' 'S'
+ ' and' ' iPhone' ' 5' 'C' ' are' ' the' ' two' ' new']" "S and iPhone 5 C .
+
+ The iPhone 5 S and iPhone 5 C are the two new" False connectivity via the iPhone 5 to their customers 4 [' connectivity', ' via', ' the', ' iPhone', ' 5']
+2 3 The company which developed x -1 The company which developed iPhone 5 Apple iPhone 5 "['S' ' and' ' iPhone' ' 5' 'C' '.' '\n' '\n' 'The' ' iPhone' ' 5' 'S'
+ ' and' ' iPhone' ' 5' 'C' ' are' ' the' ' two' ' new']" "S and iPhone 5 C .
+
+ The iPhone 5 S and iPhone 5 C are the two new" False Initial demand for the iPhone 5 exceeded the supply 5 [' Initial', ' demand', ' for', ' the', ' iPhone', ' 5']
+3 3 The company which developed x -1 The company which developed iPhone 5 Apple iPhone 5 "['S' ' and' ' iPhone' ' 5' 'C' '.' '\n' '\n' 'The' ' iPhone' ' 5' 'S'
+ ' and' ' iPhone' ' 5' 'C' ' are' ' the' ' two' ' new']" "S and iPhone 5 C .
+
+ The iPhone 5 S and iPhone 5 C are the two new" False pictures taken on iPhone 5 camera. Apple's 4 [' pictures', ' taken', ' on', ' iPhone', ' 5']
+4 3 The company which developed x -1 The company which developed iPhone 5 Apple iPhone 5 "['S' ' and' ' iPhone' ' 5' 'C' '.' '\n' '\n' 'The' ' iPhone' ' 5' 'S'
+ ' and' ' iPhone' ' 5' 'C' ' are' ' the' ' two' ' new']" "S and iPhone 5 C .
+
+ The iPhone 5 S and iPhone 5 C are the two new" False 1 ['iPhone', ' 5']
+5 4 The company which developed x -1 The company which developed Wii Balance Board Nintendo Wii Balance Board "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' while' ',' ' and' ' has' ' been' ' making' ' a' ' name' ' for'
+ ' itself' ' in']" is a company that has been around for a while , and has been making a name for itself in False unveiled the Wii Balance Board at E3 2007 with Wii 4 [' unveiled', ' the', ' Wii', ' Balance', ' Board']
+6 6 The company which developed x -1 The company which developed Lancia 037 Fiat Lancia 037 "[',' ' the' ' first' ' car' ' to' ' be' ' built' ' in' ' Italy' ' since'
+ ' the' ' war' ',' ' was' ' a' ' joint' ' venture' ' between' ' Lanc' 'ia']" , the first car to be built in Italy since the war , was a joint venture between Lanc ia False Toivonen had his Lancia 037 debut in Portugal 8 [' To', 'iv', 'onen', ' had', ' his', ' Lanc', 'ia', ' 0', '37']
+7 6 The company which developed x -1 The company which developed Lancia 037 Fiat Lancia 037 "[',' ' the' ' first' ' car' ' to' ' be' ' built' ' in' ' Italy' ' since'
+ ' the' ' war' ',' ' was' ' a' ' joint' ' venture' ' between' ' Lanc' 'ia']" , the first car to be built in Italy since the war , was a joint venture between Lanc ia False dominated by the Lancia 037 and Mouton finished 6 [' dominated', ' by', ' the', ' Lanc', 'ia', ' 0', '37']
+8 6 The company which developed x -1 The company which developed Lancia 037 Fiat Lancia 037 "[',' ' the' ' first' ' car' ' to' ' be' ' built' ' in' ' Italy' ' since'
+ ' the' ' war' ',' ' was' ' a' ' joint' ' venture' ' between' ' Lanc' 'ia']" , the first car to be built in Italy since the war , was a joint venture between Lanc ia False Toivonen had his Lancia 037 debut in Portugal 8 [' To', 'iv', 'onen', ' had', ' his', ' Lanc', 'ia', ' 0', '37']
+9 6 The company which developed x -1 The company which developed Lancia 037 Fiat Lancia 037 "[',' ' the' ' first' ' car' ' to' ' be' ' built' ' in' ' Italy' ' since'
+ ' the' ' war' ',' ' was' ' a' ' joint' ' venture' ' between' ' Lanc' 'ia']" , the first car to be built in Italy since the war , was a joint venture between Lanc ia False Toivonen had his Lancia 037 debut in Portugal 8 [' To', 'iv', 'onen', ' had', ' his', ' Lanc', 'ia', ' 0', '37']
+10 6 The company which developed x -1 The company which developed Lancia 037 Fiat Lancia 037 "[',' ' the' ' first' ' car' ' to' ' be' ' built' ' in' ' Italy' ' since'
+ ' the' ' war' ',' ' was' ' a' ' joint' ' venture' ' between' ' Lanc' 'ia']" , the first car to be built in Italy since the war , was a joint venture between Lanc ia False Toivonen crashed his Lancia 037 into a brick 8 [' To', 'iv', 'onen', ' crashed', ' his', ' Lanc', 'ia', ' 0', '37']
+11 7 The company which developed x -1 The company which developed Windows Server 2003 Microsoft Windows Server 2003 "[' and' ' Windows' ' Server' ' 2008' ' R' '2' ',' ' Microsoft' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' Windows' ' Server']" and Windows Server 2008 R 2 , Microsoft has released a new version of the operating system , Windows Server True " 2003 (unsupported)
+" 7 [' 2003', ' (', 'un', 'supported', ')', 'Windows', ' Server', ' 2003']
+12 7 The company which developed x -1 The company which developed Windows Server 2003 Microsoft Windows Server 2003 "[' and' ' Windows' ' Server' ' 2008' ' R' '2' ',' ' Microsoft' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' Windows' ' Server']" and Windows Server 2008 R 2 , Microsoft has released a new version of the operating system , Windows Server True " 2003 (unsupported)
+" 7 [' 2003', ' (', 'un', 'supported', ')', 'Windows', ' Server', ' 2003']
+13 7 The company which developed x -1 The company which developed Windows Server 2003 Microsoft Windows Server 2003 "[' and' ' Windows' ' Server' ' 2008' ' R' '2' ',' ' Microsoft' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' Windows' ' Server']" and Windows Server 2008 R 2 , Microsoft has released a new version of the operating system , Windows Server True " 2003 (unsupported)
+" 7 [' 2003', ' (', 'un', 'supported', ')', 'Windows', ' Server', ' 2003']
+14 7 The company which developed x -1 The company which developed Windows Server 2003 Microsoft Windows Server 2003 "[' and' ' Windows' ' Server' ' 2008' ' R' '2' ',' ' Microsoft' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' Windows' ' Server']" and Windows Server 2008 R 2 , Microsoft has released a new version of the operating system , Windows Server True " 2003 (unsupported)
+" 7 [' 2003', ' (', 'un', 'supported', ')', 'Windows', ' Server', ' 2003']
+15 7 The company which developed x -1 The company which developed Windows Server 2003 Microsoft Windows Server 2003 "[' and' ' Windows' ' Server' ' 2008' ' R' '2' ',' ' Microsoft' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' Windows' ' Server']" and Windows Server 2008 R 2 , Microsoft has released a new version of the operating system , Windows Server True " 2003 (unsupported)
+" 7 [' 2003', ' (', 'un', 'supported', ')', 'Windows', ' Server', ' 2003']
+16 8 The company which developed x -1 The company which developed DualShock Sony DualShock "[' 4' ' is' ' Sony' '.' '\n' '\n' 'The' ' PS' '4' ' is' ' a' ' console'
+ ' that' ' is' ' not' ' only' ' a' ' gaming' ' console' ',']" " 4 is Sony .
+
+ The PS 4 is a console that is not only a gaming console ," True frames per second, DualShock 3 rumble features, 5 [' frames', ' per', ' second', ',', ' Dual', 'Shock']
+17 8 The company which developed x -1 The company which developed DualShock Sony DualShock "[' 4' ' is' ' Sony' '.' '\n' '\n' 'The' ' PS' '4' ' is' ' a' ' console'
+ ' that' ' is' ' not' ' only' ' a' ' gaming' ' console' ',']" " 4 is Sony .
+
+ The PS 4 is a console that is not only a gaming console ," True of War: Ascension DualShock 3 controller, 5 [' of', ' War', ':', ' Ascension', ' Dual', 'Shock']
+18 8 The company which developed x -1 The company which developed DualShock Sony DualShock "[' 4' ' is' ' Sony' '.' '\n' '\n' 'The' ' PS' '4' ' is' ' a' ' console'
+ ' that' ' is' ' not' ' only' ' a' ' gaming' ' console' ',']" " 4 is Sony .
+
+ The PS 4 is a console that is not only a gaming console ," True default one uses the DualShock 2's analog sticks 5 [' default', ' one', ' uses', ' the', ' Dual', 'Shock']
+19 8 The company which developed x -1 The company which developed DualShock Sony DualShock "[' 4' ' is' ' Sony' '.' '\n' '\n' 'The' ' PS' '4' ' is' ' a' ' console'
+ ' that' ' is' ' not' ' only' ' a' ' gaming' ' console' ',']" " 4 is Sony .
+
+ The PS 4 is a console that is not only a gaming console ," True 1 ['Dual', 'Shock']
+20 8 The company which developed x -1 The company which developed DualShock Sony DualShock "[' 4' ' is' ' Sony' '.' '\n' '\n' 'The' ' PS' '4' ' is' ' a' ' console'
+ ' that' ' is' ' not' ' only' ' a' ' gaming' ' console' ',']" " 4 is Sony .
+
+ The PS 4 is a console that is not only a gaming console ," True the Dual Analog, the DualShock and DualShock 6 [' the', ' Dual', ' Analog', ',', ' the', ' Dual', 'Shock']
+21 9 The company which developed x -1 The company which developed Mac Pro Apple Mac Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False recording software on his Mac Pro and often worked late 5 [' recording', ' software', ' on', ' his', ' Mac', ' Pro']
+22 9 The company which developed x -1 The company which developed Mac Pro Apple Mac Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False The first-generation Mac Pro has a rectangular 5 [' The', ' first', '-', 'generation', ' Mac', ' Pro']
+23 9 The company which developed x -1 The company which developed Mac Pro Apple Mac Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False redesigned Mac Pro occupies less 2 [' redesigned', ' Mac', ' Pro']
+24 9 The company which developed x -1 The company which developed Mac Pro Apple Mac Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False second generation Mac Pro has a redesigned 3 [' second', ' generation', ' Mac', ' Pro']
+25 9 The company which developed x -1 The company which developed Mac Pro Apple Mac Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False introduced the Mac Pro Server, which 3 [' introduced', ' the', ' Mac', ' Pro']
+26 10 The company which developed x -1 The company which developed Game Boy Advance SP Nintendo Game Boy Advance SP "[',' ' Nintendo' ' DS' ',' ' and' ' Game' ' Boy' ' Advance' ' SP' '2' '.'
+ '\n' '\n' 'The' ' Game' ' Boy' ' Advance' ' SP' ' was' ' released']" ", Nintendo DS , and Game Boy Advance SP 2 .
+
+ The Game Boy Advance SP was released" True " Blue"" -colored Game Boy Advance SP was released" 7 "[' Blue', '""', ' -', 'colored', ' Game', ' Boy', ' Advance', ' SP']"
+27 10 The company which developed x -1 The company which developed Game Boy Advance SP Nintendo Game Boy Advance SP "[',' ' Nintendo' ' DS' ',' ' and' ' Game' ' Boy' ' Advance' ' SP' '2' '.'
+ '\n' '\n' 'The' ' Game' ' Boy' ' Advance' ' SP' ' was' ' released']" ", Nintendo DS , and Game Boy Advance SP 2 .
+
+ The Game Boy Advance SP was released" True knock-knock joke to win a Game Boy Advance SP and a copy of 11 [' knock', '-', 'kn', 'ock', ' joke', ' to', ' win', ' a', ' Game', ' Boy', ' Advance', ' SP']
+28 10 The company which developed x -1 The company which developed Game Boy Advance SP Nintendo Game Boy Advance SP "[',' ' Nintendo' ' DS' ',' ' and' ' Game' ' Boy' ' Advance' ' SP' '2' '.'
+ '\n' '\n' 'The' ' Game' ' Boy' ' Advance' ' SP' ' was' ' released']" ", Nintendo DS , and Game Boy Advance SP 2 .
+
+ The Game Boy Advance SP was released" True " Zelda Triforce Game Boy Advance SP ===
+" 7 [' Zelda', ' Tr', 'if', 'orce', ' Game', ' Boy', ' Advance', ' SP']
+29 10 The company which developed x -1 The company which developed Game Boy Advance SP Nintendo Game Boy Advance SP "[',' ' Nintendo' ' DS' ',' ' and' ' Game' ' Boy' ' Advance' ' SP' '2' '.'
+ '\n' '\n' 'The' ' Game' ' Boy' ' Advance' ' SP' ' was' ' released']" ", Nintendo DS , and Game Boy Advance SP 2 .
+
+ The Game Boy Advance SP was released" True Advance and the new Game Boy Advance SP respectively, and there's 7 [' Advance', ' and', ' the', ' new', ' Game', ' Boy', ' Advance', ' SP']
+30 10 The company which developed x -1 The company which developed Game Boy Advance SP Nintendo Game Boy Advance SP "[',' ' Nintendo' ' DS' ',' ' and' ' Game' ' Boy' ' Advance' ' SP' '2' '.'
+ '\n' '\n' 'The' ' Game' ' Boy' ' Advance' ' SP' ' was' ' released']" ", Nintendo DS , and Game Boy Advance SP 2 .
+
+ The Game Boy Advance SP was released" True Advance-themed Game Boy Advance SP package to commemorate 6 [' Advance', '-', 'themed', ' Game', ' Boy', ' Advance', ' SP']
+31 12 The company which developed x -1 The company which developed PGM-19 Jupiter Chrysler PGM-19 Jupiter "[',' ' a' ' new' ' type' ' of' ' missile' ' that' ' can' ' be' ' launched'
+ ' from' ' a' ' submarine' ',' ' has' ' been' ' awarded' ' a' ' contract'
+ ' to']" , a new type of missile that can be launched from a submarine , has been awarded a contract to False rival programs, PGM-19 Jupiter and PGM-17 Thor respectively, 7 [' rival', ' programs', ',', ' P', 'GM', '-', '19', ' Jupiter']
+32 12 The company which developed x -1 The company which developed PGM-19 Jupiter Chrysler PGM-19 Jupiter "[',' ' a' ' new' ' type' ' of' ' missile' ' that' ' can' ' be' ' launched'
+ ' from' ' a' ' submarine' ',' ' has' ' been' ' awarded' ' a' ' contract'
+ ' to']" , a new type of missile that can be launched from a submarine , has been awarded a contract to False rival programs, PGM-19 Jupiter and PGM-17 Thor respectively, 7 [' rival', ' programs', ',', ' P', 'GM', '-', '19', ' Jupiter']
+33 18 The company which developed x -1 The company which developed DC-3 Douglas DC-3 "[' was' ' the' ' first' ' to' ' use' ' the' ' term' ' ""' 'air' 'craft' '""'
+ ' to' ' describe' ' the' ' aircraft' '.' '\n' '\n' 'The' ' first']" " was the first to use the term "" air craft "" to describe the aircraft .
+
+ The first" False and grubbing. A DC-3 was able to land 8 [' and', ' gr', 'ub', 'bing', '.', ' A', ' DC', '-', '3']
+34 18 The company which developed x -1 The company which developed DC-3 Douglas DC-3 "[' was' ' the' ' first' ' to' ' use' ' the' ' term' ' ""' 'air' 'craft' '""'
+ ' to' ' describe' ' the' ' aircraft' '.' '\n' '\n' 'The' ' first']" " was the first to use the term "" air craft "" to describe the aircraft .
+
+ The first" False Shots of the Douglas DC-3 Jones flies on to 6 [' Shots', ' of', ' the', ' Douglas', ' DC', '-', '3']
+35 18 The company which developed x -1 The company which developed DC-3 Douglas DC-3 "[' was' ' the' ' first' ' to' ' use' ' the' ' term' ' ""' 'air' 'craft' '""'
+ ' to' ' describe' ' the' ' aircraft' '.' '\n' '\n' 'The' ' first']" " was the first to use the term "" air craft "" to describe the aircraft .
+
+ The first" False Chandler Airport and a DC-3 on the morning 6 [' Chandler', ' Airport', ' and', ' a', ' DC', '-', '3']
+36 18 The company which developed x -1 The company which developed DC-3 Douglas DC-3 "[' was' ' the' ' first' ' to' ' use' ' the' ' term' ' ""' 'air' 'craft' '""'
+ ' to' ' describe' ' the' ' aircraft' '.' '\n' '\n' 'The' ' first']" " was the first to use the term "" air craft "" to describe the aircraft .
+
+ The first" False and Douglas DC-3 were the first 4 [' and', ' Douglas', ' DC', '-', '3']
+37 18 The company which developed x -1 The company which developed DC-3 Douglas DC-3 "[' was' ' the' ' first' ' to' ' use' ' the' ' term' ' ""' 'air' 'craft' '""'
+ ' to' ' describe' ' the' ' aircraft' '.' '\n' '\n' 'The' ' first']" " was the first to use the term "" air craft "" to describe the aircraft .
+
+ The first" False Fokker F27, Douglas DC-3 and later Hawker Siddeley 9 [' F', 'ok', 'ker', ' F', '27', ',', ' Douglas', ' DC', '-', '3']
+38 19 The company which developed x -1 The company which developed Kindle Fire Amazon Kindle Fire "[' HD' 'X' ' 8' '.' '9' '""' ' Tablet' ' is' ' a' ' great' ' device' ' for'
+ ' reading' ' books' ',' ' watching' ' movies' ',' ' listening' ' to']" " HD X 8 . 9 "" Tablet is a great device for reading books , watching movies , listening to" False tablets such as the Kindle Fire and Nexus 7. The hardware 5 [' tablets', ' such', ' as', ' the', ' Kindle', ' Fire']
+39 19 The company which developed x -1 The company which developed Kindle Fire Amazon Kindle Fire "[' HD' 'X' ' 8' '.' '9' '""' ' Tablet' ' is' ' a' ' great' ' device' ' for'
+ ' reading' ' books' ',' ' watching' ' movies' ',' ' listening' ' to']" " HD X 8 . 9 "" Tablet is a great device for reading books , watching movies , listening to" False launch, with the Kindle Fire and Kindle Paperwhite 5 [' launch', ',', ' with', ' the', ' Kindle', ' Fire']
+40 19 The company which developed x -1 The company which developed Kindle Fire Amazon Kindle Fire "[' HD' 'X' ' 8' '.' '9' '""' ' Tablet' ' is' ' a' ' great' ' device' ' for'
+ ' reading' ' books' ',' ' watching' ' movies' ',' ' listening' ' to']" " HD X 8 . 9 "" Tablet is a great device for reading books , watching movies , listening to" False Amazon.com and Amazon Kindle Fire HDX advertisement 6 [' Amazon', '.', 'com', ' and', ' Amazon', ' Kindle', ' Fire']
+41 19 The company which developed x -1 The company which developed Kindle Fire Amazon Kindle Fire "[' HD' 'X' ' 8' '.' '9' '""' ' Tablet' ' is' ' a' ' great' ' device' ' for'
+ ' reading' ' books' ',' ' watching' ' movies' ',' ' listening' ' to']" " HD X 8 . 9 "" Tablet is a great device for reading books , watching movies , listening to" False port for Android and Kindle Fire was announced 5 [' port', ' for', ' Android', ' and', ' Kindle', ' Fire']
+42 19 The company which developed x -1 The company which developed Kindle Fire Amazon Kindle Fire "[' HD' 'X' ' 8' '.' '9' '""' ' Tablet' ' is' ' a' ' great' ' device' ' for'
+ ' reading' ' books' ',' ' watching' ' movies' ',' ' listening' ' to']" " HD X 8 . 9 "" Tablet is a great device for reading books , watching movies , listening to" False port for Android and Kindle Fire was announced on 5 [' port', ' for', ' Android', ' and', ' Kindle', ' Fire']
+43 20 The company which developed x -1 The company which developed AGM-86 ALCM Boeing AGM-86 ALCM "[' (' 'Air' ' Laun' 'ched' ' Cruise' ' Missile' ')' ' is' ' now'
+ ' developing' ' a' ' new' ' missile' ',' ' the' ' AG' 'M' '-' '158' ' J']" ( Air Laun ched Cruise Missile ) is now developing a new missile , the AG M - 158 J False with 50 to 100 AGM-86 ALCM cruise missiles 10 [' with', ' 50', ' to', ' 100', ' AG', 'M', '-', '86', ' A', 'LC', 'M']
+44 20 The company which developed x -1 The company which developed AGM-86 ALCM Boeing AGM-86 ALCM "[' (' 'Air' ' Laun' 'ched' ' Cruise' ' Missile' ')' ' is' ' now'
+ ' developing' ' a' ' new' ' missile' ',' ' the' ' AG' 'M' '-' '158' ' J']" ( Air Laun ched Cruise Missile ) is now developing a new missile , the AG M - 158 J False equipped with 50 to 100 AGM-86 ALCM cruise missiles 11 [' equipped', ' with', ' 50', ' to', ' 100', ' AG', 'M', '-', '86', ' A', 'LC', 'M']
+45 20 The company which developed x -1 The company which developed AGM-86 ALCM Boeing AGM-86 ALCM "[' (' 'Air' ' Laun' 'ched' ' Cruise' ' Missile' ')' ' is' ' now'
+ ' developing' ' a' ' new' ' missile' ',' ' the' ' AG' 'M' '-' '158' ' J']" ( Air Laun ched Cruise Missile ) is now developing a new missile , the AG M - 158 J False with 50 to 100 AGM-86 ALCM cruise missiles 10 [' with', ' 50', ' to', ' 100', ' AG', 'M', '-', '86', ' A', 'LC', 'M']
+46 23 The company which developed x -1 The company which developed Isetta BMW Isetta "[',' ' a' ' new' ' kind' ' of' ' car' ' that' ' can' ' be' ' driven' ' by'
+ ' a' ' single' ' person' ',' ' has' ' been' ' working' ' on' ' a']" , a new kind of car that can be driven by a single person , has been working on a False " resembling an Isetta at the front.
+" 3 [' resembling', ' an', ' Is', 'etta']
+47 23 The company which developed x -1 The company which developed Isetta BMW Isetta "[',' ' a' ' new' ' kind' ' of' ' car' ' that' ' can' ' be' ' driven' ' by'
+ ' a' ' single' ' person' ',' ' has' ' been' ' working' ' on' ' a']" , a new kind of car that can be driven by a single person , has been working on a False " forklift, resembling an Isetta at the front.
+" 6 [' fork', 'lift', ',', ' resembling', ' an', ' Is', 'etta']
+48 23 The company which developed x -1 The company which developed Isetta BMW Isetta "[',' ' a' ' new' ' kind' ' of' ' car' ' that' ' can' ' be' ' driven' ' by'
+ ' a' ' single' ' person' ',' ' has' ' been' ' working' ' on' ' a']" , a new kind of car that can be driven by a single person , has been working on a False " forklift, resembling an Isetta at the front.
+" 6 [' fork', 'lift', ',', ' resembling', ' an', ' Is', 'etta']
+49 27 The company which developed x -1 The company which developed SNES-CD Nintendo SNES-CD "[',' ' the' ' first' ' CD' '-' 'ROM' ' game' ' console' ',' ' has'
+ ' announced' ' that' ' it' ' will' ' be' ' releasing' ' a' ' new'
+ ' version' ' of']" , the first CD - ROM game console , has announced that it will be releasing a new version of False system known as the SNES-CD (also known as the 7 [' system', ' known', ' as', ' the', ' SN', 'ES', '-', 'CD']
+50 27 The company which developed x -1 The company which developed SNES-CD Nintendo SNES-CD "[',' ' the' ' first' ' CD' '-' 'ROM' ' game' ' console' ',' ' has'
+ ' announced' ' that' ' it' ' will' ' be' ' releasing' ' a' ' new'
+ ' version' ' of']" , the first CD - ROM game console , has announced that it will be releasing a new version of False known as the SNES-CD (also known 6 [' known', ' as', ' the', ' SN', 'ES', '-', 'CD']
+51 27 The company which developed x -1 The company which developed SNES-CD Nintendo SNES-CD "[',' ' the' ' first' ' CD' '-' 'ROM' ' game' ' console' ',' ' has'
+ ' announced' ' that' ' it' ' will' ' be' ' releasing' ' a' ' new'
+ ' version' ' of']" , the first CD - ROM game console , has announced that it will be releasing a new version of False system known as the SNES-CD (also known as the 7 [' system', ' known', ' as', ' the', ' SN', 'ES', '-', 'CD']
+52 28 The company which developed x -1 The company which developed Ferrari 250 GTO Fiat Ferrari 250 GTO "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' was' ' founded' ' in' ' 1946' ' and' ' has'
+ ' been']" is a company that has been around for a long time . It was founded in 1946 and has been False with Mason using his Ferrari 250 GTO as collateral. 7 [' with', ' Mason', ' using', ' his', ' Ferrari', ' 250', ' GT', 'O']
+53 28 The company which developed x -1 The company which developed Ferrari 250 GTO Fiat Ferrari 250 GTO "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' was' ' founded' ' in' ' 1946' ' and' ' has'
+ ' been']" is a company that has been around for a long time . It was founded in 1946 and has been False wife, used his Ferrari 250 GTO as collateral). Some 7 [' wife', ',', ' used', ' his', ' Ferrari', ' 250', ' GT', 'O']
+54 28 The company which developed x -1 The company which developed Ferrari 250 GTO Fiat Ferrari 250 GTO "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' was' ' founded' ' in' ' 1946' ' and' ' has'
+ ' been']" is a company that has been around for a long time . It was founded in 1946 and has been False highly collectible Ferrari 250 GTO sold for just 6 [' highly', ' collect', 'ible', ' Ferrari', ' 250', ' GT', 'O']
+55 28 The company which developed x -1 The company which developed Ferrari 250 GTO Fiat Ferrari 250 GTO "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' was' ' founded' ' in' ' 1946' ' and' ' has'
+ ' been']" is a company that has been around for a long time . It was founded in 1946 and has been False collectible Ferrari 250 GTO sold for just 5 [' collect', 'ible', ' Ferrari', ' 250', ' GT', 'O']
+56 28 The company which developed x -1 The company which developed Ferrari 250 GTO Fiat Ferrari 250 GTO "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' was' ' founded' ' in' ' 1946' ' and' ' has'
+ ' been']" is a company that has been around for a long time . It was founded in 1946 and has been False wife, used his Ferrari 250 GTO as collateral). 7 [' wife', ',', ' used', ' his', ' Ferrari', ' 250', ' GT', 'O']
+57 31 The company which developed x -1 The company which developed Bluesmobile Dodge Bluesmobile "[',' ' a' ' new' ' kind' ' of' ' mobile' ' app' ' that' ' allows' ' you'
+ ' to' ' play' ' the' ' game' ' on' ' your' ' phone' ',' ' tablet' ' or']" , a new kind of mobile app that allows you to play the game on your phone , tablet or False loudspeaker atop the Bluesmobile and drive all 5 [' loudspe', 'aker', ' atop', ' the', ' Blues', 'mobile']
+58 31 The company which developed x -1 The company which developed Bluesmobile Dodge Bluesmobile "[',' ' a' ' new' ' kind' ' of' ' mobile' ' app' ' that' ' allows' ' you'
+ ' to' ' play' ' the' ' game' ' on' ' your' ' phone' ',' ' tablet' ' or']" , a new kind of mobile app that allows you to play the game on your phone , tablet or False filming. Although the Bluesmobile was allowed to be 5 [' filming', '.', ' Although', ' the', ' Blues', 'mobile']
+59 31 The company which developed x -1 The company which developed Bluesmobile Dodge Bluesmobile "[',' ' a' ' new' ' kind' ' of' ' mobile' ' app' ' that' ' allows' ' you'
+ ' to' ' play' ' the' ' game' ' on' ' your' ' phone' ',' ' tablet' ' or']" , a new kind of mobile app that allows you to play the game on your phone , tablet or False " Bluesmobile ===
+" 1 [' Blues', 'mobile']
+60 31 The company which developed x -1 The company which developed Bluesmobile Dodge Bluesmobile "[',' ' a' ' new' ' kind' ' of' ' mobile' ' app' ' that' ' allows' ' you'
+ ' to' ' play' ' the' ' game' ' on' ' your' ' phone' ',' ' tablet' ' or']" , a new kind of mobile app that allows you to play the game on your phone , tablet or False loudspeaker atop the Bluesmobile and drive all over 5 [' loudspe', 'aker', ' atop', ' the', ' Blues', 'mobile']
+61 31 The company which developed x -1 The company which developed Bluesmobile Dodge Bluesmobile "[',' ' a' ' new' ' kind' ' of' ' mobile' ' app' ' that' ' allows' ' you'
+ ' to' ' play' ' the' ' game' ' on' ' your' ' phone' ',' ' tablet' ' or']" , a new kind of mobile app that allows you to play the game on your phone , tablet or False loudspeaker atop the Bluesmobile and drive all over 5 [' loudspe', 'aker', ' atop', ' the', ' Blues', 'mobile']
+62 33 The company which developed x -1 The company which developed AMC 35 Renault AMC 35 "[',' ' a' ' new' ' generation' ' of' ' the' ' popular' ' AMC' ' 35'
+ ' series' ' of' ' high' '-' 'performance' ',' ' high' '-' 'capacity' ','
+ ' high']" , a new generation of the popular AMC 35 series of high - performance , high - capacity , high False but just 10 AMC 35 tanks. However, 4 [' but', ' just', ' 10', ' AMC', ' 35']
+63 38 The company which developed x -1 The company which developed Acura MDX Honda Acura MDX "[' is' ' a' ' Japanese' ' company' ' which' ' is' ' a' ' subsidiary' ' of'
+ ' Honda' '.' ' The' ' MD' 'X' ' is' ' a' ' luxury' ' SUV' ' which' ' is']" is a Japanese company which is a subsidiary of Honda . The MD X is a luxury SUV which is True based on the Acura MDX with modifications 6 [' based', ' on', ' the', ' Ac', 'ura', ' MD', 'X']
+64 38 The company which developed x -1 The company which developed Acura MDX Honda Acura MDX "[' is' ' a' ' Japanese' ' company' ' which' ' is' ' a' ' subsidiary' ' of'
+ ' Honda' '.' ' The' ' MD' 'X' ' is' ' a' ' luxury' ' SUV' ' which' ' is']" is a Japanese company which is a subsidiary of Honda . The MD X is a luxury SUV which is True SUVs, based on the Acura MDX with modifications 9 [' SU', 'Vs', ',', ' based', ' on', ' the', ' Ac', 'ura', ' MD', 'X']
+65 38 The company which developed x -1 The company which developed Acura MDX Honda Acura MDX "[' is' ' a' ' Japanese' ' company' ' which' ' is' ' a' ' subsidiary' ' of'
+ ' Honda' '.' ' The' ' MD' 'X' ' is' ' a' ' luxury' ' SUV' ' which' ' is']" is a Japanese company which is a subsidiary of Honda . The MD X is a luxury SUV which is True based on the Acura MDX with modifications 6 [' based', ' on', ' the', ' Ac', 'ura', ' MD', 'X']
+66 38 The company which developed x -1 The company which developed Acura MDX Honda Acura MDX "[' is' ' a' ' Japanese' ' company' ' which' ' is' ' a' ' subsidiary' ' of'
+ ' Honda' '.' ' The' ' MD' 'X' ' is' ' a' ' luxury' ' SUV' ' which' ' is']" is a Japanese company which is a subsidiary of Honda . The MD X is a luxury SUV which is True based on the Acura MDX with modifications 6 [' based', ' on', ' the', ' Ac', 'ura', ' MD', 'X']
+67 40 The company which developed x -1 The company which developed Xbox Microsoft Xbox "[' One' ' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a'
+ ' powerful' ' console' ' that' ' will' ' be' ' released' ' on'
+ ' November' ' 7']" One X , the new Xbox One X , is a powerful console that will be released on November 7 False (such as the Xbox 360 and Surround sound 4 [' (', 'such', ' as', ' the', ' Xbox']
+68 40 The company which developed x -1 The company which developed Xbox Microsoft Xbox "[' One' ' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a'
+ ' powerful' ' console' ' that' ' will' ' be' ' released' ' on'
+ ' November' ' 7']" One X , the new Xbox One X , is a powerful console that will be released on November 7 False beat 'em up game for Xbox One, Xbox 360, 6 "[' beat', "" '"", 'em', ' up', ' game', ' for', ' Xbox']"
+69 40 The company which developed x -1 The company which developed Xbox Microsoft Xbox "[' One' ' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a'
+ ' powerful' ' console' ' that' ' will' ' be' ' released' ' on'
+ ' November' ' 7']" One X , the new Xbox One X , is a powerful console that will be released on November 7 False Pack (2000), as an Xbox version (2003), 6 [' Pack', ' (', '2000', '),', ' as', ' an', ' Xbox']
+70 40 The company which developed x -1 The company which developed Xbox Microsoft Xbox "[' One' ' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a'
+ ' powerful' ' console' ' that' ' will' ' be' ' released' ' on'
+ ' November' ' 7']" One X , the new Xbox One X , is a powerful console that will be released on November 7 False 0 ['Xbox']
+71 40 The company which developed x -1 The company which developed Xbox Microsoft Xbox "[' One' ' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a'
+ ' powerful' ' console' ' that' ' will' ' be' ' released' ' on'
+ ' November' ' 7']" One X , the new Xbox One X , is a powerful console that will be released on November 7 False " PlayStation 4, and Xbox One platforms.
+" 4 [' PlayStation', ' 4', ',', ' and', ' Xbox']
+72 42 The company which developed x -1 The company which developed iPhone 4s Apple iPhone 4s "[',' ' the' ' iPhone' ' 5' ',' ' and' ' the' ' iPhone' ' 5' 'c' '.' '\n'
+ '\n' 'The' ' iPhone' ' 5' 'c' ' is' ' a' ' budget']" ", the iPhone 5 , and the iPhone 5 c .
+
+ The iPhone 5 c is a budget" False lowercase's' as iPhone 4s as of September 7 "[' lower', 'case', ""'s"", ""'"", ' as', ' iPhone', ' 4', 's']"
+73 42 The company which developed x -1 The company which developed iPhone 4s Apple iPhone 4s "[',' ' the' ' iPhone' ' 5' ',' ' and' ' the' ' iPhone' ' 5' 'c' '.' '\n'
+ '\n' 'The' ' iPhone' ' 5' 'c' ' is' ' a' ' budget']" ", the iPhone 5 , and the iPhone 5 c .
+
+ The iPhone 5 c is a budget" False not supported. The iPhone 4s can also run iOS 8 6 [' not', ' supported', '.', ' The', ' iPhone', ' 4', 's']
+74 42 The company which developed x -1 The company which developed iPhone 4s Apple iPhone 4s "[',' ' the' ' iPhone' ' 5' ',' ' and' ' the' ' iPhone' ' 5' 'c' '.' '\n'
+ '\n' 'The' ' iPhone' ' 5' 'c' ' is' ' a' ' budget']" ", the iPhone 5 , and the iPhone 5 c .
+
+ The iPhone 5 c is a budget" False supported. The iPhone 4s can also run iOS 8 5 [' supported', '.', ' The', ' iPhone', ' 4', 's']
+75 42 The company which developed x -1 The company which developed iPhone 4s Apple iPhone 4s "[',' ' the' ' iPhone' ' 5' ',' ' and' ' the' ' iPhone' ' 5' 'c' '.' '\n'
+ '\n' 'The' ' iPhone' ' 5' 'c' ' is' ' a' ' budget']" ", the iPhone 5 , and the iPhone 5 c .
+
+ The iPhone 5 c is a budget" False supported. The iPhone 4s can also run iOS 5 [' supported', '.', ' The', ' iPhone', ' 4', 's']
+76 44 The company which developed x -1 The company which developed Dino Ferrari Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False On 10 May 2011, Dino announced via Divine 5 [' On', ' 10', ' May', ' 2011', ',', ' Dino']
+77 44 The company which developed x -1 The company which developed Dino Ferrari Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False " = Dino Crisis =
+" 1 [' =', ' Dino']
+78 44 The company which developed x -1 The company which developed Dino Ferrari Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False film producer Dino De Laurentiis, 2 [' film', ' producer', ' Dino']
+79 44 The company which developed x -1 The company which developed Dino Ferrari Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False Intimacy & Liberty by Dino Hodge, written with 6 [' Int', 'im', 'acy', ' &', ' Liberty', ' by', ' Dino']
+80 44 The company which developed x -1 The company which developed Dino Ferrari Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False the project to Dino De Laurentiis. Milius 3 [' the', ' project', ' to', ' Dino']
+81 45 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False Apple's iPhone and iPod Touch mobile 4 "[' Apple', ""'s"", ' iPhone', ' and', ' iPod']"
+82 45 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False God had a iPod, I'd be on 3 [' God', ' had', ' a', ' iPod']
+83 45 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False collapse who requested an iPod with In Your 4 [' collapse', ' who', ' requested', ' an', ' iPod']
+84 45 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False case In re Apple iPod iTunes Antitrust 4 [' case', ' In', ' re', ' Apple', ' iPod']
+85 45 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False the iPhone, iPod Touch, and 3 [' the', ' iPhone', ',', ' iPod']
+86 46 The company which developed x -1 The company which developed iPad Mini Apple iPad Mini "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False playing on an iPad Mini felt less 4 [' playing', ' on', ' an', ' iPad', ' Mini']
+87 46 The company which developed x -1 The company which developed iPad Mini Apple iPad Mini "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False weekend of sales of the iPad Mini and fourth-generation 6 [' weekend', ' of', ' sales', ' of', ' the', ' iPad', ' Mini']
+88 46 The company which developed x -1 The company which developed iPad Mini Apple iPad Mini "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False generation), and iPad Mini feature a new 4 [' generation', '),', ' and', ' iPad', ' Mini']
+89 46 The company which developed x -1 The company which developed iPad Mini Apple iPad Mini "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False versions of the iPad Mini. The first generation 4 [' versions', ' of', ' the', ' iPad', ' Mini']
+90 46 The company which developed x -1 The company which developed iPad Mini Apple iPad Mini "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False that playing on an iPad Mini felt less precise 5 [' that', ' playing', ' on', ' an', ' iPad', ' Mini']
+91 47 The company which developed x -1 The company which developed AGM-69 SRAM Boeing AGM-69 SRAM "[' Red' 'shift' ' is' ' a' ' small' ',' ' family' '-' 'owned' ' company'
+ ' based' ' in' ' the' ' United' ' States' '.' ' The' ' company' ' was'
+ ' founded']" Red shift is a small , family - owned company based in the United States . The company was founded False carry up to 20 AGM-69 SRAM nuclear missiles 9 [' carry', ' up', ' to', ' 20', ' AG', 'M', '-', '69', ' SR', 'AM']
+92 48 The company which developed x -1 The company which developed LGM-30 Minuteman Boeing LGM-30 Minuteman "[' III' ' inter' 'continental' ' ballistic' ' missiles' ',' ' the' ' most'
+ ' powerful' ' in' ' the' ' world' ',' ' has' ' been' ' awarded' ' a'
+ ' contract' ' to' ' build']" III inter continental ballistic missiles , the most powerful in the world , has been awarded a contract to build False successor to the LGM-30 Minuteman ICBM then in 9 [' successor', ' to', ' the', ' L', 'GM', '-', '30', ' Min', 'ut', 'eman']
+93 48 The company which developed x -1 The company which developed LGM-30 Minuteman Boeing LGM-30 Minuteman "[' III' ' inter' 'continental' ' ballistic' ' missiles' ',' ' the' ' most'
+ ' powerful' ' in' ' the' ' world' ',' ' has' ' been' ' awarded' ' a'
+ ' contract' ' to' ' build']" III inter continental ballistic missiles , the most powerful in the world , has been awarded a contract to build False replacement for the LGM-30 Minuteman as the Air Force's 9 [' replacement', ' for', ' the', ' L', 'GM', '-', '30', ' Min', 'ut', 'eman']
+94 48 The company which developed x -1 The company which developed LGM-30 Minuteman Boeing LGM-30 Minuteman "[' III' ' inter' 'continental' ' ballistic' ' missiles' ',' ' the' ' most'
+ ' powerful' ' in' ' the' ' world' ',' ' has' ' been' ' awarded' ' a'
+ ' contract' ' to' ' build']" III inter continental ballistic missiles , the most powerful in the world , has been awarded a contract to build False replacement for the LGM-30 Minuteman as the Air Force's 9 [' replacement', ' for', ' the', ' L', 'GM', '-', '30', ' Min', 'ut', 'eman']
+95 48 The company which developed x -1 The company which developed LGM-30 Minuteman Boeing LGM-30 Minuteman "[' III' ' inter' 'continental' ' ballistic' ' missiles' ',' ' the' ' most'
+ ' powerful' ' in' ' the' ' world' ',' ' has' ' been' ' awarded' ' a'
+ ' contract' ' to' ' build']" III inter continental ballistic missiles , the most powerful in the world , has been awarded a contract to build False successor to the LGM-30 Minuteman ICBM then in United 9 [' successor', ' to', ' the', ' L', 'GM', '-', '30', ' Min', 'ut', 'eman']
+96 48 The company which developed x -1 The company which developed LGM-30 Minuteman Boeing LGM-30 Minuteman "[' III' ' inter' 'continental' ' ballistic' ' missiles' ',' ' the' ' most'
+ ' powerful' ' in' ' the' ' world' ',' ' has' ' been' ' awarded' ' a'
+ ' contract' ' to' ' build']" III inter continental ballistic missiles , the most powerful in the world , has been awarded a contract to build False solution was the LGM-30 Minuteman missile, which 9 [' solution', ' was', ' the', ' L', 'GM', '-', '30', ' Min', 'ut', 'eman']
+97 52 The company which developed x -1 The company which developed TorqueFlite Chrysler TorqueFlite "[',' ' a' ' new' ' type' ' of' ' engine' ' that' ' is' ' designed' ' to'
+ ' be' ' a' ' direct' ' replacement' ' for' ' the' ' current' ' engine'
+ ' in' ' the']" , a new type of engine that is designed to be a direct replacement for the current engine in the False with Chrysler's TorqueFlite automatic transmission 6 "[' with', ' Chrysler', ""'s"", ' Tor', 'que', 'Fl', 'ite']"
+98 52 The company which developed x -1 The company which developed TorqueFlite Chrysler TorqueFlite "[',' ' a' ' new' ' type' ' of' ' engine' ' that' ' is' ' designed' ' to'
+ ' be' ' a' ' direct' ' replacement' ' for' ' the' ' current' ' engine'
+ ' in' ' the']" , a new type of engine that is designed to be a direct replacement for the current engine in the False with Chrysler's TorqueFlite automatic transmission 6 "[' with', ' Chrysler', ""'s"", ' Tor', 'que', 'Fl', 'ite']"
+99 52 The company which developed x -1 The company which developed TorqueFlite Chrysler TorqueFlite "[',' ' a' ' new' ' type' ' of' ' engine' ' that' ' is' ' designed' ' to'
+ ' be' ' a' ' direct' ' replacement' ' for' ' the' ' current' ' engine'
+ ' in' ' the']" , a new type of engine that is designed to be a direct replacement for the current engine in the False Chrysler's TorqueFlite automatic transmission 5 "[' Chrysler', ""'s"", ' Tor', 'que', 'Fl', 'ite']"
+100 58 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False was the Newton MessagePad 100, introduced 4 [' was', ' the', ' Newton', ' Message', 'Pad']
+101 58 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False computer was the Newton MessagePad 100, introduced 5 [' computer', ' was', ' the', ' Newton', ' Message', 'Pad']
+102 58 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False was the Newton MessagePad 100, introduced 4 [' was', ' the', ' Newton', ' Message', 'Pad']
+103 58 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False was the Newton MessagePad 100, introduced 4 [' was', ' the', ' Newton', ' Message', 'Pad']
+104 58 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False computer was the Newton MessagePad 100, introduced 5 [' computer', ' was', ' the', ' Newton', ' Message', 'Pad']
+105 59 The company which developed x -1 The company which developed GAM-87 Skybolt Douglas GAM-87 Skybolt "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' GAM' '-' '87' ' was' ' a'
+ ' surface']" was a joint venture between the United States and the Soviet Union . The GAM - 87 was a surface False development of the GAM-87 Skybolt ALBM. In addition, 7 [' development', ' of', ' the', ' GAM', '-', '87', ' Sky', 'bolt']
+106 59 The company which developed x -1 The company which developed GAM-87 Skybolt Douglas GAM-87 Skybolt "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' GAM' '-' '87' ' was' ' a'
+ ' surface']" was a joint venture between the United States and the Soviet Union . The GAM - 87 was a surface False was made for four GAM-87 Skybolt ballistic missiles. 8 [' was', ' made', ' for', ' four', ' GAM', '-', '87', ' Sky', 'bolt']
+107 59 The company which developed x -1 The company which developed GAM-87 Skybolt Douglas GAM-87 Skybolt "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' GAM' '-' '87' ' was' ' a'
+ ' surface']" was a joint venture between the United States and the Soviet Union . The GAM - 87 was a surface False what would become the GAM-87 Skybolt missile, which 8 [' what', ' would', ' become', ' the', ' GAM', '-', '87', ' Sky', 'bolt']
+108 59 The company which developed x -1 The company which developed GAM-87 Skybolt Douglas GAM-87 Skybolt "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' GAM' '-' '87' ' was' ' a'
+ ' surface']" was a joint venture between the United States and the Soviet Union . The GAM - 87 was a surface False would become the GAM-87 Skybolt missile, which incorporated 7 [' would', ' become', ' the', ' GAM', '-', '87', ' Sky', 'bolt']
+109 59 The company which developed x -1 The company which developed GAM-87 Skybolt Douglas GAM-87 Skybolt "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' GAM' '-' '87' ' was' ' a'
+ ' surface']" was a joint venture between the United States and the Soviet Union . The GAM - 87 was a surface False " two or four GAM-87 Skybolt ballistic missiles.
+" 7 [' two', ' or', ' four', ' GAM', '-', '87', ' Sky', 'bolt']
+110 65 The company which developed x -1 The company which developed Rolls-Royce Phantom BMW Rolls-Royce Phantom "[' VI' ' was' ' the' ' first' ' to' ' introduce' ' the' ' concept' ' of'
+ ' the' ' ""' 'flying' ' car' '""' ' in' ' the' ' 1930' 's' '.' ' The']" " VI was the first to introduce the concept of the "" flying car "" in the 1930 s . The" False convertible Rolls-Royce Phantom Drophead Coupés then 5 [' convertible', ' Rolls', '-', 'Roy', 'ce', ' Phantom']
+111 65 The company which developed x -1 The company which developed Rolls-Royce Phantom BMW Rolls-Royce Phantom "[' VI' ' was' ' the' ' first' ' to' ' introduce' ' the' ' concept' ' of'
+ ' the' ' ""' 'flying' ' car' '""' ' in' ' the' ' 1930' 's' '.' ' The']" " VI was the first to introduce the concept of the "" flying car "" in the 1930 s . The" False Three convertible Rolls-Royce Phantom Drophead Coupés 6 [' Three', ' convertible', ' Rolls', '-', 'Roy', 'ce', ' Phantom']
+112 65 The company which developed x -1 The company which developed Rolls-Royce Phantom BMW Rolls-Royce Phantom "[' VI' ' was' ' the' ' first' ' to' ' introduce' ' the' ' concept' ' of'
+ ' the' ' ""' 'flying' ' car' '""' ' in' ' the' ' 1930' 's' '.' ' The']" " VI was the first to introduce the concept of the "" flying car "" in the 1930 s . The" False " with me?"" Next, a Rolls-Royce Phantom is shown pulling up" 10 "[' with', ' me', '?""', ' Next', ',', ' a', ' Rolls', '-', 'Roy', 'ce', ' Phantom']"
+113 65 The company which developed x -1 The company which developed Rolls-Royce Phantom BMW Rolls-Royce Phantom "[' VI' ' was' ' the' ' first' ' to' ' introduce' ' the' ' concept' ' of'
+ ' the' ' ""' 'flying' ' car' '""' ' in' ' the' ' 1930' 's' '.' ' The']" " VI was the first to introduce the concept of the "" flying car "" in the 1930 s . The" False shipped one Rolls-Royce Phantom there, and filmed 6 [' shipped', ' one', ' Rolls', '-', 'Roy', 'ce', ' Phantom']
+114 65 The company which developed x -1 The company which developed Rolls-Royce Phantom BMW Rolls-Royce Phantom "[' VI' ' was' ' the' ' first' ' to' ' introduce' ' the' ' concept' ' of'
+ ' the' ' ""' 'flying' ' car' '""' ' in' ' the' ' 1930' 's' '.' ' The']" " VI was the first to introduce the concept of the "" flying car "" in the 1930 s . The" False " obsessed with me?"" Next, a Rolls-Royce Phantom is shown pulling up" 11 "[' obsessed', ' with', ' me', '?""', ' Next', ',', ' a', ' Rolls', '-', 'Roy', 'ce', ' Phantom']"
+115 66 The company which developed x -1 The company which developed iPod shuffle Apple iPod shuffle "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " resistant to any iPod shuffle function.""" 4 [' resistant', ' to', ' any', ' iPod', ' shuffle']
+116 66 The company which developed x -1 The company which developed iPod shuffle Apple iPod shuffle "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " resistant to any iPod shuffle function.""" 4 [' resistant', ' to', ' any', ' iPod', ' shuffle']
+117 66 The company which developed x -1 The company which developed iPod shuffle Apple iPod shuffle "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " resistant to any iPod shuffle function."" Flak" 4 [' resistant', ' to', ' any', ' iPod', ' shuffle']
+118 66 The company which developed x -1 The company which developed iPod shuffle Apple iPod shuffle "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " resistant to any iPod shuffle function."" Flak Magazine" 4 [' resistant', ' to', ' any', ' iPod', ' shuffle']
+119 70 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False group of American B-17 Flying Fortress bombers arriving 7 [' group', ' of', ' American', ' B', '-', '17', ' Flying', ' Fortress']
+120 70 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False over Boeing B-17 Flying Fortress bombers. On 6 [' over', ' Boeing', ' B', '-', '17', ' Flying', ' Fortress']
+121 70 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False Tora !. A Boeing B-17 Flying Fortress used in the production 10 [' Tor', 'a', '!', '.', ' A', ' Boeing', ' B', '-', '17', ' Flying', ' Fortress']
+122 70 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False Wilcke shot down a B-17 Flying Fortress bomber and a North 10 [' Wil', 'c', 'ke', ' shot', ' down', ' a', ' B', '-', '17', ' Flying', ' Fortress']
+123 70 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False of American B-17 Flying Fortress bombers arriving 6 [' of', ' American', ' B', '-', '17', ' Flying', ' Fortress']
+124 74 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False which fixed the Windows 2000 buffer overrun 4 [' which', ' fixed', ' the', ' Windows', ' 2000']
+125 74 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False drops support for Windows 2000 and was compatible 4 [' drops', ' support', ' for', ' Windows', ' 2000']
+126 74 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False an article about Windows 2000 / NT4 source-code 4 [' an', ' article', ' about', ' Windows', ' 2000']
+127 74 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False 98, Windows ME, Windows 2000 and Windows 6 [' 98', ',', ' Windows', ' ME', ',', ' Windows', ' 2000']
+128 74 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False Windows 98, Windows ME, Windows 2000 and Windows 7 [' Windows', ' 98', ',', ' Windows', ' ME', ',', ' Windows', ' 2000']
+129 77 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False ˈɡeɪ /) is a Boeing B-29 Superfortress bomber, named 17 [' ', 'ˈ', '�', '�', 'e', '�', '�', ' /', ')', ' is', ' a', ' Boeing', ' B', '-', '29', ' Super', 'fort', 'ress']
+130 77 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False operations for the B-29 Superfortress in the Mariana 8 [' operations', ' for', ' the', ' B', '-', '29', ' Super', 'fort', 'ress']
+131 77 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False the range of the B-29 Superfortress bombers, and 9 [' the', ' range', ' of', ' the', ' B', '-', '29', ' Super', 'fort', 'ress']
+132 77 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False holding a series of B-29 Superfortress commands in the 9 [' holding', ' a', ' series', ' of', ' B', '-', '29', ' Super', 'fort', 'ress']
+133 77 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False is a Boeing B-29 Superfortress bomber, named 8 [' is', ' a', ' Boeing', ' B', '-', '29', ' Super', 'fort', 'ress']
+134 79 The company which developed x -1 The company which developed SNES-CD Sony SNES-CD "[',' ' the' ' first' ' CD' '-' 'ROM' ' game' ' console' ',' ' has'
+ ' announced' ' that' ' it' ' will' ' be' ' releasing' ' a' ' new'
+ ' version' ' of']" , the first CD - ROM game console , has announced that it will be releasing a new version of False title for the SNES-CD add-on. After 6 [' title', ' for', ' the', ' SN', 'ES', '-', 'CD']
+135 79 The company which developed x -1 The company which developed SNES-CD Sony SNES-CD "[',' ' the' ' first' ' CD' '-' 'ROM' ' game' ' console' ',' ' has'
+ ' announced' ' that' ' it' ' will' ' be' ' releasing' ' a' ' new'
+ ' version' ' of']" , the first CD - ROM game console , has announced that it will be releasing a new version of False known as the SNES-CD (also known as the 6 [' known', ' as', ' the', ' SN', 'ES', '-', 'CD']
+136 79 The company which developed x -1 The company which developed SNES-CD Sony SNES-CD "[',' ' the' ' first' ' CD' '-' 'ROM' ' game' ' console' ',' ' has'
+ ' announced' ' that' ' it' ' will' ' be' ' releasing' ' a' ' new'
+ ' version' ' of']" , the first CD - ROM game console , has announced that it will be releasing a new version of False system known as the SNES-CD (also known as the 7 [' system', ' known', ' as', ' the', ' SN', 'ES', '-', 'CD']
+137 79 The company which developed x -1 The company which developed SNES-CD Sony SNES-CD "[',' ' the' ' first' ' CD' '-' 'ROM' ' game' ' console' ',' ' has'
+ ' announced' ' that' ' it' ' will' ' be' ' releasing' ' a' ' new'
+ ' version' ' of']" , the first CD - ROM game console , has announced that it will be releasing a new version of False known as the SNES-CD (also known 6 [' known', ' as', ' the', ' SN', 'ES', '-', 'CD']
+138 81 The company which developed x -1 The company which developed CIM-10 Bomarc Boeing CIM-10 Bomarc "[' was' ' the' ' first' ' to' ' use' ' the' ' term' ' ""' 'miss' 'ile' '""'
+ ' to' ' describe' ' the' ' weapon' '.' ' The' ' term' ' ""' 'b']" " was the first to use the term "" miss ile "" to describe the weapon . The term "" b" False the Boeing CIM-10 Bomarc long-range surface-to-air 7 [' the', ' Boeing', ' C', 'IM', '-', '10', ' Bom', 'arc']
+139 83 The company which developed x -1 The company which developed Acura Legend Honda Acura Legend "[' is' ' a' ' Japanese' ' company' ' that' ' has' ' been' ' in' ' the'
+ ' business' ' of' ' manufacturing' ' cars' ' for' ' over' ' 50' ' years'
+ '.' ' The' ' company']" is a Japanese company that has been in the business of manufacturing cars for over 50 years . The company False Marketed as the Acura Legend in the U.S., the 6 [' Market', 'ed', ' as', ' the', ' Ac', 'ura', ' Legend']
+140 83 The company which developed x -1 The company which developed Acura Legend Honda Acura Legend "[' is' ' a' ' Japanese' ' company' ' that' ' has' ' been' ' in' ' the'
+ ' business' ' of' ' manufacturing' ' cars' ' for' ' over' ' 50' ' years'
+ '.' ' The' ' company']" is a Japanese company that has been in the business of manufacturing cars for over 50 years . The company False Marketed as the Acura Legend in the U.S., the 6 [' Market', 'ed', ' as', ' the', ' Ac', 'ura', ' Legend']
+141 83 The company which developed x -1 The company which developed Acura Legend Honda Acura Legend "[' is' ' a' ' Japanese' ' company' ' that' ' has' ' been' ' in' ' the'
+ ' business' ' of' ' manufacturing' ' cars' ' for' ' over' ' 50' ' years'
+ '.' ' The' ' company']" is a Japanese company that has been in the business of manufacturing cars for over 50 years . The company False Marketed as the Acura Legend in the U.S., 6 [' Market', 'ed', ' as', ' the', ' Ac', 'ura', ' Legend']
+142 83 The company which developed x -1 The company which developed Acura Legend Honda Acura Legend "[' is' ' a' ' Japanese' ' company' ' that' ' has' ' been' ' in' ' the'
+ ' business' ' of' ' manufacturing' ' cars' ' for' ' over' ' 50' ' years'
+ '.' ' The' ' company']" is a Japanese company that has been in the business of manufacturing cars for over 50 years . The company False Marketed as the Acura Legend in the U.S., the 6 [' Market', 'ed', ' as', ' the', ' Ac', 'ura', ' Legend']
+143 85 The company which developed x -1 The company which developed PGM-17 Thor Douglas PGM-17 Thor "['ium' '-' '232' ' is' ' a' ' nuclear' ' reactor' ' that' ' uses' ' thor'
+ 'ium' ' as' ' fuel' '.' ' It' ' is' ' a' ' bre' 'eder' ' reactor']" ium - 232 is a nuclear reactor that uses thor ium as fuel . It is a bre eder reactor False Jupiter and PGM-17 Thor respectively, 6 [' Jupiter', ' and', ' P', 'GM', '-', '17', ' Thor']
+144 87 The company which developed x -1 The company which developed Dino Fiat Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False " Crisis =
+" 3 [' Crisis', ' =', 'D', 'ino']
+145 87 The company which developed x -1 The company which developed Dino Fiat Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False " Shoemake, and Dino ""SpeedoVee""" 5 [' Sho', 'em', 'ake', ',', ' and', ' Dino']
+146 87 The company which developed x -1 The company which developed Dino Fiat Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False WrestleMania IV between Dino Bravo and Don 4 [' Wrestle', 'Mania', ' IV', ' between', ' Dino']
+147 87 The company which developed x -1 The company which developed Dino Fiat Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False Brazilian attacker Dino Sani was signed 2 [' Brazilian', ' attacker', ' Dino']
+148 87 The company which developed x -1 The company which developed Dino Fiat Dino "['-' 'mite' ',' ' a' ' new' ' kind' ' of' ' pet' ' food' ' that' ' is'
+ ' made' ' from' ' the' ' same' ' ingredients' ' as' ' dog' ' food' ',']" - mite , a new kind of pet food that is made from the same ingredients as dog food , False by Matt Hyde, Dino Paredes, Rick Rubin, 4 [' by', ' Matt', ' Hyde', ',', ' Dino']
+149 89 The company which developed x -1 The company which developed iPhone 3GS Apple iPhone 3GS "[',' ' the' ' iPhone' ' 4' ',' ' and' ' the' ' iPhone' ' 4' 'S' '.' '\n'
+ '\n' 'The' ' iPhone' ' 4' 'S' ' is' ' the' ' first']" ", the iPhone 4 , and the iPhone 4 S .
+
+ The iPhone 4 S is the first" False " introduction of the iPhone 3GS in 2009.
+" 5 [' introduction', ' of', ' the', ' iPhone', ' 3', 'GS']
+150 89 The company which developed x -1 The company which developed iPhone 3GS Apple iPhone 3GS "[',' ' the' ' iPhone' ' 4' ',' ' and' ' the' ' iPhone' ' 4' 'S' '.' '\n'
+ '\n' 'The' ' iPhone' ' 4' 'S' ' is' ' the' ' first']" ", the iPhone 4 , and the iPhone 4 S .
+
+ The iPhone 4 S is the first" False happened with the iPhone 3GS and the iPod Touch 5 [' happened', ' with', ' the', ' iPhone', ' 3', 'GS']
+151 89 The company which developed x -1 The company which developed iPhone 3GS Apple iPhone 3GS "[',' ' the' ' iPhone' ' 4' ',' ' and' ' the' ' iPhone' ' 4' 'S' '.' '\n'
+ '\n' 'The' ' iPhone' ' 4' 'S' ' is' ' the' ' first']" ", the iPhone 4 , and the iPhone 4 S .
+
+ The iPhone 4 S is the first" False happened with the iPhone 3GS and the iPod Touch 5 [' happened', ' with', ' the', ' iPhone', ' 3', 'GS']
+152 90 The company which developed x -1 The company which developed iPad Apple iPad "[' apps' ' for' ' the' ' first' ' time' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' working' ' on' ' the' ' app' ' for' ' the' ' past'
+ ' year']" " apps for the first time .
+
+ The company has been working on the app for the past year" False support the iPad 3rd Generation. 2 [' support', ' the', ' iPad']
+153 90 The company which developed x -1 The company which developed iPad Apple iPad "[' apps' ' for' ' the' ' first' ' time' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' working' ' on' ' the' ' app' ' for' ' the' ' past'
+ ' year']" " apps for the first time .
+
+ The company has been working on the app for the past year" False released on the iPad on November 2, 3 [' released', ' on', ' the', ' iPad']
+154 90 The company which developed x -1 The company which developed iPad Apple iPad "[' apps' ' for' ' the' ' first' ' time' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' working' ' on' ' the' ' app' ' for' ' the' ' past'
+ ' year']" " apps for the first time .
+
+ The company has been working on the app for the past year" False first generation iPad had no camera; 2 [' first', ' generation', ' iPad']
+155 90 The company which developed x -1 The company which developed iPad Apple iPad "[' apps' ' for' ' the' ' first' ' time' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' working' ' on' ' the' ' app' ' for' ' the' ' past'
+ ' year']" " apps for the first time .
+
+ The company has been working on the app for the past year" False announced the iPad Mini. With 2 [' announced', ' the', ' iPad']
+156 90 The company which developed x -1 The company which developed iPad Apple iPad "[' apps' ' for' ' the' ' first' ' time' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' working' ' on' ' the' ' app' ' for' ' the' ' past'
+ ' year']" " apps for the first time .
+
+ The company has been working on the app for the past year" False begun developing the iPad before the iPhone. 3 [' begun', ' developing', ' the', ' iPad']
+157 92 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True introduced with Windows Phone 8.1 in 2014. Cortana replaced 6 [' introduced', ' with', ' Windows', ' Phone', ' 8', '.', '1']
+158 92 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True introduced with Windows Phone 8.1 in 2014. Cortana 6 [' introduced', ' with', ' Windows', ' Phone', ' 8', '.', '1']
+159 92 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True first introduced with Windows Phone 8.1 in 2014. Cortana 7 [' first', ' introduced', ' with', ' Windows', ' Phone', ' 8', '.', '1']
+160 92 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True to be ported to Windows Phone 8.1 and Xbox One while 8 [' to', ' be', ' ported', ' to', ' Windows', ' Phone', ' 8', '.', '1']
+161 92 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True be ported to Windows Phone 8.1 and Xbox One while 7 [' be', ' ported', ' to', ' Windows', ' Phone', ' 8', '.', '1']
+162 93 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False and Boeing B-47 Stratojet bombers, then on the 7 [' and', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+163 93 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False the Boeing B-47 Stratojet in the early 1950s. 7 [' the', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+164 93 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False from a Boeing B-47 Stratojet carrier aircraft, 8 [' from', ' a', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+165 93 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False XB-48, and Boeing B-47 Stratojet bombers, then on 12 [' X', 'B', '-', '48', ',', ' and', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+166 93 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False Australia 24 Boeing B-47 Stratojet bombers until the 8 [' Australia', ' 24', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+167 96 The company which developed x -1 The company which developed YF-22 Boeing YF-22 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' in' ' the' ' business'
+ ' of' ' making']" is a company that has been around for a long time . They have been in the business of making False battling the Lockheed YF-22 for a production 6 [' battling', ' the', ' Lockheed', ' Y', 'F', '-', '22']
+168 96 The company which developed x -1 The company which developed YF-22 Boeing YF-22 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' in' ' the' ' business'
+ ' of' ' making']" is a company that has been around for a long time . They have been in the business of making False " = Lockheed YF-22 =
+" 5 [' =', ' Lockheed', ' Y', 'F', '-', '22']
+169 96 The company which developed x -1 The company which developed YF-22 Boeing YF-22 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' in' ' the' ' business'
+ ' of' ' making']" is a company that has been around for a long time . They have been in the business of making False announced that the YF-22 was the winner. The 6 [' announced', ' that', ' the', ' Y', 'F', '-', '22']
+170 96 The company which developed x -1 The company which developed YF-22 Boeing YF-22 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' in' ' the' ' business'
+ ' of' ' making']" is a company that has been around for a long time . They have been in the business of making False Rice announced the YF-22 as the winner of the 6 [' Rice', ' announced', ' the', ' Y', 'F', '-', '22']
+171 96 The company which developed x -1 The company which developed YF-22 Boeing YF-22 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' in' ' the' ' business'
+ ' of' ' making']" is a company that has been around for a long time . They have been in the business of making False the Lockheed YF-22 for a production 5 [' the', ' Lockheed', ' Y', 'F', '-', '22']
+172 98 The company which developed x -1 The company which developed CineAlta Sony CineAlta "[',' ' a' ' new' ' film' '-' 'based' ' technology' ' that' ' allows'
+ ' the' ' creation' ' of' ' a' ' 3' 'D' ' image' ' from' ' a' ' single'
+ ' 2']" , a new film - based technology that allows the creation of a 3 D image from a single 2 False using Sony CineAlta high-definition video 5 [' using', ' Sony', ' C', 'ine', 'Al', 'ta']
+173 98 The company which developed x -1 The company which developed CineAlta Sony CineAlta "[',' ' a' ' new' ' film' '-' 'based' ' technology' ' that' ' allows'
+ ' the' ' creation' ' of' ' a' ' 3' 'D' ' image' ' from' ' a' ' single'
+ ' 2']" , a new film - based technology that allows the creation of a 3 D image from a single 2 False Sony HDW-F900 CineAlta HDCAM high definition 9 [' Sony', ' HD', 'W', '-', 'F', '900', ' C', 'ine', 'Al', 'ta']
+174 98 The company which developed x -1 The company which developed CineAlta Sony CineAlta "[',' ' a' ' new' ' film' '-' 'based' ' technology' ' that' ' allows'
+ ' the' ' creation' ' of' ' a' ' 3' 'D' ' image' ' from' ' a' ' single'
+ ' 2']" , a new film - based technology that allows the creation of a 3 D image from a single 2 False Sony HDW-F900 CineAlta HDCAM high 9 [' Sony', ' HD', 'W', '-', 'F', '900', ' C', 'ine', 'Al', 'ta']
+175 98 The company which developed x -1 The company which developed CineAlta Sony CineAlta "[',' ' a' ' new' ' film' '-' 'based' ' technology' ' that' ' allows'
+ ' the' ' creation' ' of' ' a' ' 3' 'D' ' image' ' from' ' a' ' single'
+ ' 2']" , a new film - based technology that allows the creation of a 3 D image from a single 2 False total of 18 Sony CineAlta HDC-F950 cameras 7 [' total', ' of', ' 18', ' Sony', ' C', 'ine', 'Al', 'ta']
+176 98 The company which developed x -1 The company which developed CineAlta Sony CineAlta "[',' ' a' ' new' ' film' '-' 'based' ' technology' ' that' ' allows'
+ ' the' ' creation' ' of' ' a' ' 3' 'D' ' image' ' from' ' a' ' single'
+ ' 2']" , a new film - based technology that allows the creation of a 3 D image from a single 2 False Sony HDW-F900 CineAlta HDCAM high definition 9 [' Sony', ' HD', 'W', '-', 'F', '900', ' C', 'ine', 'Al', 'ta']
+177 101 The company which developed x -1 The company which developed PGM-11 Redstone Chrysler PGM-11 Redstone "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' Red' 'stone' ' was' ' a'
+ ' missile' ' with']" was a joint venture between the United States and the Soviet Union . The Red stone was a missile with False Rogers stories and the PGM-11 Redstone rocket. That early 9 [' Rogers', ' stories', ' and', ' the', ' P', 'GM', '-', '11', ' Red', 'stone']
+178 101 The company which developed x -1 The company which developed PGM-11 Redstone Chrysler PGM-11 Redstone "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' Red' 'stone' ' was' ' a'
+ ' missile' ' with']" was a joint venture between the United States and the Soviet Union . The Red stone was a missile with False Rogers stories and the PGM-11 Redstone rocket. That early 9 [' Rogers', ' stories', ' and', ' the', ' P', 'GM', '-', '11', ' Red', 'stone']
+179 101 The company which developed x -1 The company which developed PGM-11 Redstone Chrysler PGM-11 Redstone "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' Red' 'stone' ' was' ' a'
+ ' missile' ' with']" was a joint venture between the United States and the Soviet Union . The Red stone was a missile with False stories and the PGM-11 Redstone rocket. That 8 [' stories', ' and', ' the', ' P', 'GM', '-', '11', ' Red', 'stone']
+180 101 The company which developed x -1 The company which developed PGM-11 Redstone Chrysler PGM-11 Redstone "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' Red' 'stone' ' was' ' a'
+ ' missile' ' with']" was a joint venture between the United States and the Soviet Union . The Red stone was a missile with False stories and the PGM-11 Redstone rocket. That early 8 [' stories', ' and', ' the', ' P', 'GM', '-', '11', ' Red', 'stone']
+181 101 The company which developed x -1 The company which developed PGM-11 Redstone Chrysler PGM-11 Redstone "[' was' ' a' ' joint' ' venture' ' between' ' the' ' United' ' States'
+ ' and' ' the' ' Soviet' ' Union' '.' ' The' ' Red' 'stone' ' was' ' a'
+ ' missile' ' with']" was a joint venture between the United States and the Soviet Union . The Red stone was a missile with False stories and the PGM-11 Redstone rocket. That early 8 [' stories', ' and', ' the', ' P', 'GM', '-', '11', ' Red', 'stone']
+182 102 The company which developed x -1 The company which developed RC-135 Boeing RC-135 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' around' ' since' ' the'
+ ' early' ' 90']" is a company that has been around for a long time . They have been around since the early 90 False presence of a USAF RC-135 surveillance 6 [' presence', ' of', ' a', ' USAF', ' RC', '-', '135']
+183 102 The company which developed x -1 The company which developed RC-135 Boeing RC-135 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' around' ' since' ' the'
+ ' early' ' 90']" is a company that has been around for a long time . They have been around since the early 90 False the role of a USAF RC-135 surveillance aircraft, 7 [' the', ' role', ' of', ' a', ' USAF', ' RC', '-', '135']
+184 102 The company which developed x -1 The company which developed RC-135 Boeing RC-135 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' around' ' since' ' the'
+ ' early' ' 90']" is a company that has been around for a long time . They have been around since the early 90 False with the USAF RC-135 in the context of 5 [' with', ' the', ' USAF', ' RC', '-', '135']
+185 102 The company which developed x -1 The company which developed RC-135 Boeing RC-135 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' around' ' since' ' the'
+ ' early' ' 90']" is a company that has been around for a long time . They have been around since the early 90 False the U.S. via RC-135 or naval aircraft 8 [' the', ' U', '.', 'S', '.', ' via', ' RC', '-', '135']
+186 102 The company which developed x -1 The company which developed RC-135 Boeing RC-135 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' They' ' have' ' been' ' around' ' since' ' the'
+ ' early' ' 90']" is a company that has been around for a long time . They have been around since the early 90 False plane to be an RC-135 reconnaissance 6 [' plane', ' to', ' be', ' an', ' RC', '-', '135']
+187 103 The company which developed x -1 The company which developed Chromecast Google Chromecast "[' is' ' Google' '.' '\n' '\n' 'The' ' Chrom' 'ecast' ' is' ' a' ' small'
+ ' device' ' that' ' plugs' ' into' ' your' ' TV' ' and' ' allows' ' you']" " is Google .
+
+ The Chrom ecast is a small device that plugs into your TV and allows you" True including Google Chromecast and the Amazon Appstore 3 [' including', ' Google', ' Chrom', 'ecast']
+188 103 The company which developed x -1 The company which developed Chromecast Google Chromecast "[' is' ' Google' '.' '\n' '\n' 'The' ' Chrom' 'ecast' ' is' ' a' ' small'
+ ' device' ' that' ' plugs' ' into' ' your' ' TV' ' and' ' allows' ' you']" " is Google .
+
+ The Chrom ecast is a small device that plugs into your TV and allows you" True including Google Chromecast and the Amazon 3 [' including', ' Google', ' Chrom', 'ecast']
+189 103 The company which developed x -1 The company which developed Chromecast Google Chromecast "[' is' ' Google' '.' '\n' '\n' 'The' ' Chrom' 'ecast' ' is' ' a' ' small'
+ ' device' ' that' ' plugs' ' into' ' your' ' TV' ' and' ' allows' ' you']" " is Google .
+
+ The Chrom ecast is a small device that plugs into your TV and allows you" True devices including Google Chromecast and the Amazon 4 [' devices', ' including', ' Google', ' Chrom', 'ecast']
+190 103 The company which developed x -1 The company which developed Chromecast Google Chromecast "[' is' ' Google' '.' '\n' '\n' 'The' ' Chrom' 'ecast' ' is' ' a' ' small'
+ ' device' ' that' ' plugs' ' into' ' your' ' TV' ' and' ' allows' ' you']" " is Google .
+
+ The Chrom ecast is a small device that plugs into your TV and allows you" True including Google Chromecast and the Amazon Appstore 3 [' including', ' Google', ' Chrom', 'ecast']
+191 112 The company which developed x -1 The company which developed MacBook Air Apple MacBook Air "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False closer to a MacBook Air or ultrabook, one of 4 [' closer', ' to', ' a', ' MacBook', ' Air']
+192 112 The company which developed x -1 The company which developed MacBook Air Apple MacBook Air "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False introduced with the MacBook Air earlier that year, 4 [' introduced', ' with', ' the', ' MacBook', ' Air']
+193 112 The company which developed x -1 The company which developed MacBook Air Apple MacBook Air "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False thinner than the MacBook Air and removes the traditional 4 [' thinner', ' than', ' the', ' MacBook', ' Air']
+194 112 The company which developed x -1 The company which developed MacBook Air Apple MacBook Air "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False traits of the MacBook Air which were also 4 [' traits', ' of', ' the', ' MacBook', ' Air']
+195 112 The company which developed x -1 The company which developed MacBook Air Apple MacBook Air "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False premium ultra-portable MacBook Air and the powerful 6 [' premium', ' ultra', '-', 'port', 'able', ' MacBook', ' Air']
+196 114 The company which developed x -1 The company which developed Jeep Grand Cherokee Chrysler Jeep Grand Cherokee "[' is' ' the' ' most' ' popular' ' SUV' ' in' ' the' ' world' '.' ' The'
+ ' Jeep' ' Grand' ' Cherokee' ' is' ' a' ' full' '-' 'size' ' SUV' ' that']" is the most popular SUV in the world . The Jeep Grand Cherokee is a full - size SUV that False on Jones's Jeep Grand Cherokee without a valid warrant. 5 "[' on', ' Jones', ""'s"", ' Jeep', ' Grand', ' Cherokee']"
+197 114 The company which developed x -1 The company which developed Jeep Grand Cherokee Chrysler Jeep Grand Cherokee "[' is' ' the' ' most' ' popular' ' SUV' ' in' ' the' ' world' '.' ' The'
+ ' Jeep' ' Grand' ' Cherokee' ' is' ' a' ' full' '-' 'size' ' SUV' ' that']" is the most popular SUV in the world . The Jeep Grand Cherokee is a full - size SUV that False installed on Jones's Jeep Grand Cherokee without a 6 "[' installed', ' on', ' Jones', ""'s"", ' Jeep', ' Grand', ' Cherokee']"
+198 114 The company which developed x -1 The company which developed Jeep Grand Cherokee Chrysler Jeep Grand Cherokee "[' is' ' the' ' most' ' popular' ' SUV' ' in' ' the' ' world' '.' ' The'
+ ' Jeep' ' Grand' ' Cherokee' ' is' ' a' ' full' '-' 'size' ' SUV' ' that']" is the most popular SUV in the world . The Jeep Grand Cherokee is a full - size SUV that False on Jones's Jeep Grand Cherokee without a valid 5 "[' on', ' Jones', ""'s"", ' Jeep', ' Grand', ' Cherokee']"
+199 114 The company which developed x -1 The company which developed Jeep Grand Cherokee Chrysler Jeep Grand Cherokee "[' is' ' the' ' most' ' popular' ' SUV' ' in' ' the' ' world' '.' ' The'
+ ' Jeep' ' Grand' ' Cherokee' ' is' ' a' ' full' '-' 'size' ' SUV' ' that']" is the most popular SUV in the world . The Jeep Grand Cherokee is a full - size SUV that False DiCaprio's first car was a Jeep Grand Cherokee and that his current 10 "[' Di', 'Cap', 'rio', ""'s"", ' first', ' car', ' was', ' a', ' Jeep', ' Grand', ' Cherokee']"
+200 115 The company which developed x -1 The company which developed Game & Watch Nintendo Game & Watch "[' games' ' for' ' the' ' Nintendo' ' Entertainment' ' System' ',' ' Game'
+ ' &' ' Watch' ',' ' and' ' Game' ' Boy' '.' '\n' '\n' 'The' ' company'
+ ' was']" " games for the Nintendo Entertainment System , Game & Watch , and Game Boy .
+
+ The company was" True the DSi follows Game & Watch and Game Boy 6 [' the', ' DS', 'i', ' follows', ' Game', ' &', ' Watch']
+201 115 The company which developed x -1 The company which developed Game & Watch Nintendo Game & Watch "[' games' ' for' ' the' ' Nintendo' ' Entertainment' ' System' ',' ' Game'
+ ' &' ' Watch' ',' ' and' ' Game' ' Boy' '.' '\n' '\n' 'The' ' company'
+ ' was']" " games for the Nintendo Entertainment System , Game & Watch , and Game Boy .
+
+ The company was" True the Game Boy and Game & Watch handheld systems, 6 [' the', ' Game', ' Boy', ' and', ' Game', ' &', ' Watch']
+202 115 The company which developed x -1 The company which developed Game & Watch Nintendo Game & Watch "[' games' ' for' ' the' ' Nintendo' ' Entertainment' ' System' ',' ' Game'
+ ' &' ' Watch' ',' ' and' ' Game' ' Boy' '.' '\n' '\n' 'The' ' company'
+ ' was']" " games for the Nintendo Entertainment System , Game & Watch , and Game Boy .
+
+ The company was" True " to the original Game & Watch games.
+" 5 [' to', ' the', ' original', ' Game', ' &', ' Watch']
+203 115 The company which developed x -1 The company which developed Game & Watch Nintendo Game & Watch "[' games' ' for' ' the' ' Nintendo' ' Entertainment' ' System' ',' ' Game'
+ ' &' ' Watch' ',' ' and' ' Game' ' Boy' '.' '\n' '\n' 'The' ' company'
+ ' was']" " games for the Nintendo Entertainment System , Game & Watch , and Game Boy .
+
+ The company was" True released for the Game & Watch range of handheld LCD 5 [' released', ' for', ' the', ' Game', ' &', ' Watch']
+204 115 The company which developed x -1 The company which developed Game & Watch Nintendo Game & Watch "[' games' ' for' ' the' ' Nintendo' ' Entertainment' ' System' ',' ' Game'
+ ' &' ' Watch' ',' ' and' ' Game' ' Boy' '.' '\n' '\n' 'The' ' company'
+ ' was']" " games for the Nintendo Entertainment System , Game & Watch , and Game Boy .
+
+ The company was" True DSi follows Game & Watch and Game Boy creator 5 [' DS', 'i', ' follows', ' Game', ' &', ' Watch']
+205 116 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True for the Nintendo Wii U Virtual Console on 4 [' for', ' the', ' Nintendo', ' Wii', ' U']
+206 116 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True " belong"" in the Wii U version. Totilo" 5 "[' belong', '""', ' in', ' the', ' Wii', ' U']"
+207 116 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True Nintendo 3DS and Wii U at E3 2011 in 5 [' Nintendo', ' 3', 'DS', ' and', ' Wii', ' U']
+208 116 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True Super Smash Bros. for Wii U and Super Smash Bros. 6 [' Super', ' Smash', ' Bros', '.', ' for', ' Wii', ' U']
+209 116 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True controller for the Wii U version without 4 [' controller', ' for', ' the', ' Wii', ' U']
+210 118 The company which developed x -1 The company which developed Game Boy Advance Nintendo Game Boy Advance "[' games' ' for' ' the' ' Nintendo' ' DS' ',' ' the' ' company' ' has'
+ ' been' ' working' ' on' ' a' ' new' ' game' ' for' ' the' ' Nintendo'
+ ' 3' 'DS']" games for the Nintendo DS , the company has been working on a new game for the Nintendo 3 DS True of 100 for the Game Boy Advance version; 64 6 [' of', ' 100', ' for', ' the', ' Game', ' Boy', ' Advance']
+211 118 The company which developed x -1 The company which developed Game Boy Advance Nintendo Game Boy Advance "[' games' ' for' ' the' ' Nintendo' ' DS' ',' ' the' ' company' ' has'
+ ' been' ' working' ' on' ' a' ' new' ' game' ' for' ' the' ' Nintendo'
+ ' 3' 'DS']" games for the Nintendo DS , the company has been working on a new game for the Nintendo 3 DS True the world. The Game Boy Advance version of the game 6 [' the', ' world', '.', ' The', ' Game', ' Boy', ' Advance']
+212 118 The company which developed x -1 The company which developed Game Boy Advance Nintendo Game Boy Advance "[' games' ' for' ' the' ' Nintendo' ' DS' ',' ' the' ' company' ' has'
+ ' been' ' working' ' on' ' a' ' new' ' game' ' for' ' the' ' Nintendo'
+ ' 3' 'DS']" games for the Nintendo DS , the company has been working on a new game for the Nintendo 3 DS True launch title for the Game Boy Advance portable game console, 6 [' launch', ' title', ' for', ' the', ' Game', ' Boy', ' Advance']
+213 118 The company which developed x -1 The company which developed Game Boy Advance Nintendo Game Boy Advance "[' games' ' for' ' the' ' Nintendo' ' DS' ',' ' the' ' company' ' has'
+ ' been' ' working' ' on' ' a' ' new' ' game' ' for' ' the' ' Nintendo'
+ ' 3' 'DS']" games for the Nintendo DS , the company has been working on a new game for the Nintendo 3 DS True while the Japanese Game Boy Advance version has sold 5 [' while', ' the', ' Japanese', ' Game', ' Boy', ' Advance']
+214 118 The company which developed x -1 The company which developed Game Boy Advance Nintendo Game Boy Advance "[' games' ' for' ' the' ' Nintendo' ' DS' ',' ' the' ' company' ' has'
+ ' been' ' working' ' on' ' a' ' new' ' game' ' for' ' the' ' Nintendo'
+ ' 3' 'DS']" games for the Nintendo DS , the company has been working on a new game for the Nintendo 3 DS True unchanged. The Game Boy Advance re-release 5 [' unchanged', '.', ' The', ' Game', ' Boy', ' Advance']
+215 119 The company which developed x -1 The company which developed DC-9 Douglas DC-9 "[' was' ' founded' ' in' ' the' ' year' ' of' ' the' ' Great'
+ ' Depression' ',' ' and' ' the' ' company' ' was' ' named' ' after'
+ ' the' ' DC' '-' '3']" was founded in the year of the Great Depression , and the company was named after the DC - 3 False operations so that all DC-9 services from Kirkenes 6 [' operations', ' so', ' that', ' all', ' DC', '-', '9']
+216 119 The company which developed x -1 The company which developed DC-9 Douglas DC-9 "[' was' ' founded' ' in' ' the' ' year' ' of' ' the' ' Great'
+ ' Depression' ',' ' and' ' the' ' company' ' was' ' named' ' after'
+ ' the' ' DC' '-' '3']" was founded in the year of the Great Depression , and the company was named after the DC - 3 False Boeing 737 or Douglas DC-9 aircraft for 6 [' Boeing', ' 737', ' or', ' Douglas', ' DC', '-', '9']
+217 119 The company which developed x -1 The company which developed DC-9 Douglas DC-9 "[' was' ' founded' ' in' ' the' ' year' ' of' ' the' ' Great'
+ ' Depression' ',' ' and' ' the' ' company' ' was' ' named' ' after'
+ ' the' ' DC' '-' '3']" was founded in the year of the Great Depression , and the company was named after the DC - 3 False TWA McDonnell Douglas DC-9 suffered a 6 [' T', 'WA', ' McDonnell', ' Douglas', ' DC', '-', '9']
+218 119 The company which developed x -1 The company which developed DC-9 Douglas DC-9 "[' was' ' founded' ' in' ' the' ' year' ' of' ' the' ' Great'
+ ' Depression' ',' ' and' ' the' ' company' ' was' ' named' ' after'
+ ' the' ' DC' '-' '3']" was founded in the year of the Great Depression , and the company was named after the DC - 3 False 379 (DAL 379), a DC-9, was coming in 9 [' 379', ' (', 'D', 'AL', ' 379', '),', ' a', ' DC', '-', '9']
+219 119 The company which developed x -1 The company which developed DC-9 Douglas DC-9 "[' was' ' founded' ' in' ' the' ' year' ' of' ' the' ' Great'
+ ' Depression' ',' ' and' ' the' ' company' ' was' ' named' ' after'
+ ' the' ' DC' '-' '3']" was founded in the year of the Great Depression , and the company was named after the DC - 3 False observed the DC-9 in his peripheral 4 [' observed', ' the', ' DC', '-', '9']
+220 121 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True Microsoft Windows. Windows Vista supports at 4 [' Microsoft', ' Windows', '.', ' Windows', ' Vista']
+221 121 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True compatible for Windows Vista PCs was released 3 [' compatible', ' for', ' Windows', ' Vista']
+222 121 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True ran on Windows XP, Windows Vista and Windows 6 [' ran', ' on', ' Windows', ' XP', ',', ' Windows', ' Vista']
+223 121 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True compatible for Windows Vista PCs was released 3 [' compatible', ' for', ' Windows', ' Vista']
+224 121 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True Windows XP, Windows Vista and Mac OS X. 4 [' Windows', ' XP', ',', ' Windows', ' Vista']
+225 125 The company which developed x -1 The company which developed iPhone Apple iPhone "[' X' 'S' ' Max' ',' ' iPhone' ' X' 'S' ',' ' iPhone' ' X' 'R' ','
+ ' iPhone' ' X' 'S' ' Max' ',' ' iPhone' ' X' 'S']" X S Max , iPhone X S , iPhone X R , iPhone X S Max , iPhone X S False button on the iPhone 5S incorporates 3 [' button', ' on', ' the', ' iPhone']
+226 125 The company which developed x -1 The company which developed iPhone Apple iPhone "[' X' 'S' ' Max' ',' ' iPhone' ' X' 'S' ',' ' iPhone' ' X' 'R' ','
+ ' iPhone' ' X' 'S' ' Max' ',' ' iPhone' ' X' 'S']" X S Max , iPhone X S , iPhone X R , iPhone X S Max , iPhone X S False announced the iPhone 5 and also 2 [' announced', ' the', ' iPhone']
+227 125 The company which developed x -1 The company which developed iPhone Apple iPhone "[' X' 'S' ' Max' ',' ' iPhone' ' X' 'S' ',' ' iPhone' ' X' 'R' ','
+ ' iPhone' ' X' 'S' ' Max' ',' ' iPhone' ' X' 'S']" X S Max , iPhone X S , iPhone X R , iPhone X S Max , iPhone X S False 0 ['iPhone']
+228 125 The company which developed x -1 The company which developed iPhone Apple iPhone "[' X' 'S' ' Max' ',' ' iPhone' ' X' 'S' ',' ' iPhone' ' X' 'R' ','
+ ' iPhone' ' X' 'S' ' Max' ',' ' iPhone' ' X' 'S']" X S Max , iPhone X S , iPhone X R , iPhone X S Max , iPhone X S False likened the iPhone to a Tiger handheld. 2 [' likened', ' the', ' iPhone']
+229 125 The company which developed x -1 The company which developed iPhone Apple iPhone "[' X' 'S' ' Max' ',' ' iPhone' ' X' 'S' ',' ' iPhone' ' X' 'R' ','
+ ' iPhone' ' X' 'S' ' Max' ',' ' iPhone' ' X' 'S']" X S Max , iPhone X S , iPhone X R , iPhone X S Max , iPhone X S False place for the iPhone 4S to support 3 [' place', ' for', ' the', ' iPhone']
+230 129 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False Cut for iPhone and iPod Touch was released 5 [' Cut', ' for', ' iPhone', ' and', ' iPod', ' Touch']
+231 129 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False compatibility (the iPhone / iPod Touch and iPad versions 6 [' compatibility', ' (', 'the', ' iPhone', ' /', ' iPod', ' Touch']
+232 129 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False introduced new iPod Nano and iPod Touch models. They also stated 6 [' introduced', ' new', ' iPod', ' Nano', ' and', ' iPod', ' Touch']
+233 129 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False major iOS versions) and iPod Touch 5G (four major 6 [' major', ' iOS', ' versions', ')', ' and', ' iPod', ' Touch']
+234 129 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False well as the iPod Touch (5th generation), 4 [' well', ' as', ' the', ' iPod', ' Touch']
+235 133 The company which developed x -1 The company which developed Xeon Intel Xeon "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False x86-64 extensions, a Xeon processor codenamed 7 [' x', '86', '-', '64', ' extensions', ',', ' a', ' Xeon']
+236 133 The company which developed x -1 The company which developed Xeon Intel Xeon "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False equipped with Intel Xeon processors, and 3 [' equipped', ' with', ' Intel', ' Xeon']
+237 133 The company which developed x -1 The company which developed Xeon Intel Xeon "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False comparison with its Xeon family of server 3 [' comparison', ' with', ' its', ' Xeon']
+238 133 The company which developed x -1 The company which developed Xeon Intel Xeon "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False to a 12-core Xeon E5 CPU), four 1866 5 [' to', ' a', ' 12', '-', 'core', ' Xeon']
+239 133 The company which developed x -1 The company which developed Xeon Intel Xeon "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False one quad-core Xeon 3500 at 2.66 GHz or 4 [' one', ' quad', '-', 'core', ' Xeon']
+240 134 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False XP, Vista, and Windows 7 32-bit & 64-bit (hardware 6 [' XP', ',', ' Vista', ',', ' and', ' Windows', ' 7']
+241 134 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False eligible edition of Windows 7 or Windows 8, and 4 [' eligible', ' edition', ' of', ' Windows', ' 7']
+242 134 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False usability. Version 4.0 for Windows 7 SP1 (x64) was 8 [' usability', '.', ' Version', ' 4', '.', '0', ' for', ' Windows', ' 7']
+243 134 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False computers running Windows 7 or later with 3 [' computers', ' running', ' Windows', ' 7']
+244 134 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False compatibility for Windows 7 PCs under the 3 [' compatibility', ' for', ' Windows', ' 7']
+245 135 The company which developed x -1 The company which developed EA-18G Growler Boeing EA-18G Growler "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the' ' US'
+ ' Navy' '.' ' The' ' EA' '-' '18' 'G' ' Grow' 'ler' ' is' ' a']" is a joint venture between Boeing and the US Navy . The EA - 18 G Grow ler is a True the Boeing EA-18G Growler electronic warfare 7 [' the', ' Boeing', ' EA', '-', '18', 'G', ' Grow', 'ler']
+246 136 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False " the twelve original Amiibo in November 2014.
+" 5 [' the', ' twelve', ' original', ' Am', 'i', 'ibo']
+247 136 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False " ""indifference toward the Amiibo experience""" 7 "[' ""', 'ind', 'ifference', ' toward', ' the', ' Am', 'i', 'ibo']"
+248 136 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False Bundles containing Amiibo figures were available 5 [' Bund', 'les', ' containing', ' Am', 'i', 'ibo']
+249 136 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False developed into the Amiibo line of figures 5 [' developed', ' into', ' the', ' Am', 'i', 'ibo']
+250 136 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False " produced Ness and Lucas Amiibo figurines.
+" 6 [' produced', ' Ness', ' and', ' Lucas', ' Am', 'i', 'ibo']
+251 138 The company which developed x -1 The company which developed MacBook Pro Apple MacBook Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False Display on the MacBook Pro have been criticized 4 [' Display', ' on', ' the', ' MacBook', ' Pro']
+252 138 The company which developed x -1 The company which developed MacBook Pro Apple MacBook Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False ports on the 17-inch MacBook Pro are the same in 7 [' ports', ' on', ' the', ' 17', '-', 'inch', ' MacBook', ' Pro']
+253 138 The company which developed x -1 The company which developed MacBook Pro Apple MacBook Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False the Haswell MacBook Pro with Retina 4 [' the', ' Has', 'well', ' MacBook', ' Pro']
+254 138 The company which developed x -1 The company which developed MacBook Pro Apple MacBook Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False on the 17-inch MacBook Pro are the same in type 6 [' on', ' the', ' 17', '-', 'inch', ' MacBook', ' Pro']
+255 138 The company which developed x -1 The company which developed MacBook Pro Apple MacBook Pro "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False 2 ['Mac', 'Book', ' Pro']
+256 140 The company which developed x -1 The company which developed Jeep Wrangler Chrysler Jeep Wrangler "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " first five-door Jeep Wrangler derivative.
+" 6 [' first', ' five', '-', 'door', ' Jeep', ' Wr', 'angler']
+257 140 The company which developed x -1 The company which developed Jeep Wrangler Chrysler Jeep Wrangler "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " first five-door Jeep Wrangler derivative.
+" 6 [' first', ' five', '-', 'door', ' Jeep', ' Wr', 'angler']
+258 140 The company which developed x -1 The company which developed Jeep Wrangler Chrysler Jeep Wrangler "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False variant of the 1991 Jeep Wrangler YJ and the older 6 [' variant', ' of', ' the', ' 1991', ' Jeep', ' Wr', 'angler']
+259 140 The company which developed x -1 The company which developed Jeep Wrangler Chrysler Jeep Wrangler "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " first five-door Jeep Wrangler derivative.
+" 6 [' first', ' five', '-', 'door', ' Jeep', ' Wr', 'angler']
+260 140 The company which developed x -1 The company which developed Jeep Wrangler Chrysler Jeep Wrangler "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False new four-door Jeep Wrangler JK design, the 6 [' new', ' four', '-', 'door', ' Jeep', ' Wr', 'angler']
+261 142 The company which developed x -1 The company which developed Game Boy Micro Nintendo Game Boy Micro "[',' ' the' ' company' ' that' ' made' ' the' ' Game' ' Boy' ',' ' has'
+ ' been' ' acquired' ' by' ' Nintendo' '.' '\n' '\n' 'The' ' company'
+ ' was']" ", the company that made the Game Boy , has been acquired by Nintendo .
+
+ The company was" True limited edition Game Boy Micro with a themed 4 [' limited', ' edition', ' Game', ' Boy', ' Micro']
+262 142 The company which developed x -1 The company which developed Game Boy Micro Nintendo Game Boy Micro "[',' ' the' ' company' ' that' ' made' ' the' ' Game' ' Boy' ',' ' has'
+ ' been' ' acquired' ' by' ' Nintendo' '.' '\n' '\n' 'The' ' company'
+ ' was']" ", the company that made the Game Boy , has been acquired by Nintendo .
+
+ The company was" True limited edition Game Boy Micro with a themed face 4 [' limited', ' edition', ' Game', ' Boy', ' Micro']
+263 142 The company which developed x -1 The company which developed Game Boy Micro Nintendo Game Boy Micro "[',' ' the' ' company' ' that' ' made' ' the' ' Game' ' Boy' ',' ' has'
+ ' been' ' acquired' ' by' ' Nintendo' '.' '\n' '\n' 'The' ' company'
+ ' was']" ", the company that made the Game Boy , has been acquired by Nintendo .
+
+ The company was" True special edition Game Boy Micro and Franklin Badge 4 [' special', ' edition', ' Game', ' Boy', ' Micro']
+264 142 The company which developed x -1 The company which developed Game Boy Micro Nintendo Game Boy Micro "[',' ' the' ' company' ' that' ' made' ' the' ' Game' ' Boy' ',' ' has'
+ ' been' ' acquired' ' by' ' Nintendo' '.' '\n' '\n' 'The' ' company'
+ ' was']" ", the company that made the Game Boy , has been acquired by Nintendo .
+
+ The company was" True included a limited edition Game Boy Micro with a themed face 6 [' included', ' a', ' limited', ' edition', ' Game', ' Boy', ' Micro']
+265 142 The company which developed x -1 The company which developed Game Boy Micro Nintendo Game Boy Micro "[',' ' the' ' company' ' that' ' made' ' the' ' Game' ' Boy' ',' ' has'
+ ' been' ' acquired' ' by' ' Nintendo' '.' '\n' '\n' 'The' ' company'
+ ' was']" ", the company that made the Game Boy , has been acquired by Nintendo .
+
+ The company was" True special edition Game Boy Micro and Franklin Badge 4 [' special', ' edition', ' Game', ' Boy', ' Micro']
+266 143 The company which developed x -1 The company which developed PlayStation Eye Sony PlayStation Eye "[',' ' the' ' company' ' that' ' brought' ' us' ' the' ' PlayStation' ' 3'
+ ',' ' PlayStation' ' 4' ',' ' and' ' PlayStation' ' Vita' ',' ' has'
+ ' announced' ' that']" , the company that brought us the PlayStation 3 , PlayStation 4 , and PlayStation Vita , has announced that False Move uses the PlayStation Eye webcam to track 4 [' Move', ' uses', ' the', ' PlayStation', ' Eye']
+267 143 The company which developed x -1 The company which developed PlayStation Eye Sony PlayStation Eye "[',' ' the' ' company' ' that' ' brought' ' us' ' the' ' PlayStation' ' 3'
+ ',' ' PlayStation' ' 4' ',' ' and' ' PlayStation' ' Vita' ',' ' has'
+ ' announced' ' that']" , the company that brought us the PlayStation 3 , PlayStation 4 , and PlayStation Vita , has announced that False PlayStation Move uses the PlayStation Eye webcam to track 5 [' PlayStation', ' Move', ' uses', ' the', ' PlayStation', ' Eye']
+268 143 The company which developed x -1 The company which developed PlayStation Eye Sony PlayStation Eye "[',' ' the' ' company' ' that' ' brought' ' us' ' the' ' PlayStation' ' 3'
+ ',' ' PlayStation' ' 4' ',' ' and' ' PlayStation' ' Vita' ',' ' has'
+ ' announced' ' that']" , the company that brought us the PlayStation 3 , PlayStation 4 , and PlayStation Vita , has announced that False PlayStation Move uses the PlayStation Eye webcam to track the 5 [' PlayStation', ' Move', ' uses', ' the', ' PlayStation', ' Eye']
+269 149 The company which developed x -1 The company which developed Family Computer Disk System Nintendo Family Computer Disk System "[' (' 'F' 'DS' ')' ' was' ' founded' ' in' ' 1983' ' and' ' is' ' a'
+ ' Japanese' ' video' ' game' ' developer' ' and' ' publisher' '.' ' The'
+ ' company']" ( F DS ) was founded in 1983 and is a Japanese video game developer and publisher . The company False originally released on the Family Computer Disk System (FDS) before 7 [' originally', ' released', ' on', ' the', ' Family', ' Computer', ' Disk', ' System']
+270 149 The company which developed x -1 The company which developed Family Computer Disk System Nintendo Family Computer Disk System "[' (' 'F' 'DS' ')' ' was' ' founded' ' in' ' 1983' ' and' ' is' ' a'
+ ' Japanese' ' video' ' game' ' developer' ' and' ' publisher' '.' ' The'
+ ' company']" ( F DS ) was founded in 1983 and is a Japanese video game developer and publisher . The company False action game for the Family Computer Disk System that was released 7 [' action', ' game', ' for', ' the', ' Family', ' Computer', ' Disk', ' System']
+271 149 The company which developed x -1 The company which developed Family Computer Disk System Nintendo Family Computer Disk System "[' (' 'F' 'DS' ')' ' was' ' founded' ' in' ' 1983' ' and' ' is' ' a'
+ ' Japanese' ' video' ' game' ' developer' ' and' ' publisher' '.' ' The'
+ ' company']" ( F DS ) was founded in 1983 and is a Japanese video game developer and publisher . The company False video game for the Family Computer Disk System in Japan and the Nintendo 7 [' video', ' game', ' for', ' the', ' Family', ' Computer', ' Disk', ' System']
+272 149 The company which developed x -1 The company which developed Family Computer Disk System Nintendo Family Computer Disk System "[' (' 'F' 'DS' ')' ' was' ' founded' ' in' ' 1983' ' and' ' is' ' a'
+ ' Japanese' ' video' ' game' ' developer' ' and' ' publisher' '.' ' The'
+ ' company']" ( F DS ) was founded in 1983 and is a Japanese video game developer and publisher . The company False launch title for the Family Computer Disk System peripheral 7 [' launch', ' title', ' for', ' the', ' Family', ' Computer', ' Disk', ' System']
+273 149 The company which developed x -1 The company which developed Family Computer Disk System Nintendo Family Computer Disk System "[' (' 'F' 'DS' ')' ' was' ' founded' ' in' ' 1983' ' and' ' is' ' a'
+ ' Japanese' ' video' ' game' ' developer' ' and' ' publisher' '.' ' The'
+ ' company']" ( F DS ) was founded in 1983 and is a Japanese video game developer and publisher . The company False by Nintendo for the Family Computer Disk System on January 7 [' by', ' Nintendo', ' for', ' the', ' Family', ' Computer', ' Disk', ' System']
+274 150 The company which developed x -1 The company which developed Alfa Romeo 164 Fiat Alfa Romeo 164 "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' was' ' founded' ' in' ' 1910' ' and' ' has'
+ ' been']" is a company that has been around for a long time . It was founded in 1910 and has been False engine from the Alfa Romeo 164 Procar, and developed 6 [' engine', ' from', ' the', ' Al', 'fa', ' Romeo', ' 164']
+275 152 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True systems, with a Windows Phone 8 version following 6 [' systems', ',', ' with', ' a', ' Windows', ' Phone', ' 8']
+276 152 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True responsible for the Windows Phone 8 game mechanic. 5 [' responsible', ' for', ' the', ' Windows', ' Phone', ' 8']
+277 152 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True Microsoft to optimize Windows Phone 8 for Snapdragon 5 [' Microsoft', ' to', ' optimize', ' Windows', ' Phone', ' 8']
+278 152 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True natively on the Windows Phone 8 Operating System. 6 [' native', 'ly', ' on', ' the', ' Windows', ' Phone', ' 8']
+279 152 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True Android systems, with a Windows Phone 8 version following 7 [' Android', ' systems', ',', ' with', ' a', ' Windows', ' Phone', ' 8']
+280 154 The company which developed x -1 The company which developed Super Game Boy Nintendo Super Game Boy "[',' ' the' ' first' ' handheld' ' game' ' console' ' to' ' be'
+ ' released' ' in' ' the' ' United' ' States' ',' ' has' ' announced'
+ ' that' ' it' ' will' ' be']" , the first handheld game console to be released in the United States , has announced that it will be False section in the Super Game Boy Nintendo Strategy 5 [' section', ' in', ' the', ' Super', ' Game', ' Boy']
+281 154 The company which developed x -1 The company which developed Super Game Boy Nintendo Super Game Boy "[',' ' the' ' first' ' handheld' ' game' ' console' ' to' ' be'
+ ' released' ' in' ' the' ' United' ' States' ',' ' has' ' announced'
+ ' that' ' it' ' will' ' be']" , the first handheld game console to be released in the United States , has announced that it will be False enhancement for the Super Game Boy accessory. The 5 [' enhancement', ' for', ' the', ' Super', ' Game', ' Boy']
+282 154 The company which developed x -1 The company which developed Super Game Boy Nintendo Super Game Boy "[',' ' the' ' first' ' handheld' ' game' ' console' ' to' ' be'
+ ' released' ' in' ' the' ' United' ' States' ',' ' has' ' announced'
+ ' that' ' it' ' will' ' be']" , the first handheld game console to be released in the United States , has announced that it will be False enhancement for the Super Game Boy accessory. The arcade 5 [' enhancement', ' for', ' the', ' Super', ' Game', ' Boy']
+283 154 The company which developed x -1 The company which developed Super Game Boy Nintendo Super Game Boy "[',' ' the' ' first' ' handheld' ' game' ' console' ' to' ' be'
+ ' released' ' in' ' the' ' United' ' States' ',' ' has' ' announced'
+ ' that' ' it' ' will' ' be']" , the first handheld game console to be released in the United States , has announced that it will be False enhancement for the Super Game Boy accessory. The 5 [' enhancement', ' for', ' the', ' Super', ' Game', ' Boy']
+284 154 The company which developed x -1 The company which developed Super Game Boy Nintendo Super Game Boy "[',' ' the' ' first' ' handheld' ' game' ' console' ' to' ' be'
+ ' released' ' in' ' the' ' United' ' States' ',' ' has' ' announced'
+ ' that' ' it' ' will' ' be']" , the first handheld game console to be released in the United States , has announced that it will be False enhancement for the Super Game Boy accessory. The 5 [' enhancement', ' for', ' the', ' Super', ' Game', ' Boy']
+285 157 The company which developed x -1 The company which developed Mini Countryman BMW Mini Countryman "[' is' ' a' ' small' ',' ' family' '-' 'owned' ' business' ' that' ' has'
+ ' been' ' in' ' the' ' business' ' of' ' making' ' quality' ','
+ ' affordable' ',']" is a small , family - owned business that has been in the business of making quality , affordable , False 2 ['Mini', ' Country', 'man']
+286 157 The company which developed x -1 The company which developed Mini Countryman BMW Mini Countryman "[' is' ' a' ' small' ',' ' family' '-' 'owned' ' business' ' that' ' has'
+ ' been' ' in' ' the' ' business' ' of' ' making' ' quality' ','
+ ' affordable' ',']" is a small , family - owned business that has been in the business of making quality , affordable , False 2 ['Mini', ' Country', 'man']
+287 157 The company which developed x -1 The company which developed Mini Countryman BMW Mini Countryman "[' is' ' a' ' small' ',' ' family' '-' 'owned' ' business' ' that' ' has'
+ ' been' ' in' ' the' ' business' ' of' ' making' ' quality' ','
+ ' affordable' ',']" is a small , family - owned business that has been in the business of making quality , affordable , False 2 ['Mini', ' Country', 'man']
+288 157 The company which developed x -1 The company which developed Mini Countryman BMW Mini Countryman "[' is' ' a' ' small' ',' ' family' '-' 'owned' ' business' ' that' ' has'
+ ' been' ' in' ' the' ' business' ' of' ' making' ' quality' ','
+ ' affordable' ',']" is a small , family - owned business that has been in the business of making quality , affordable , False 2 ['Mini', ' Country', 'man']
+289 162 The company which developed x -1 The company which developed Infiniti Q45 Nissan Infiniti Q45 "[' is' ' a' ' Japanese' ' company' ' which' ' is' ' a' ' subsidiary' ' of'
+ ' Nissan' '.' ' The' ' company' ' is' ' a' ' manufacturer' ' of'
+ ' luxury' ' cars' '.']" is a Japanese company which is a subsidiary of Nissan . The company is a manufacturer of luxury cars . True 1989, the V8-powered Infiniti Q45 closely rivaled 11 [' 1989', ',', ' the', ' V', '8', '-', 'powered', ' Inf', 'in', 'iti', ' Q', '45']
+290 164 The company which developed x -1 The company which developed Sandy Bridge Intel Sandy Bridge "[' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Sandy' ' Bridge' ' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed']" " is Intel .
+
+ The company which developed Sandy Bridge is Intel .
+
+ The company which developed" True previous generation Sandy Bridge processor option, 3 [' previous', ' generation', ' Sandy', ' Bridge']
+291 164 The company which developed x -1 The company which developed Sandy Bridge Intel Sandy Bridge "[' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Sandy' ' Bridge' ' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed']" " is Intel .
+
+ The company which developed Sandy Bridge is Intel .
+
+ The company which developed" True contrast, the Sandy Bridge chip, a previous-generation 4 [' contrast', ',', ' the', ' Sandy', ' Bridge']
+292 164 The company which developed x -1 The company which developed Sandy Bridge Intel Sandy Bridge "[' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Sandy' ' Bridge' ' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed']" " is Intel .
+
+ The company which developed Sandy Bridge is Intel .
+
+ The company which developed" True from the SSD and Sandy Bridge processors. However, 5 [' from', ' the', ' SSD', ' and', ' Sandy', ' Bridge']
+293 164 The company which developed x -1 The company which developed Sandy Bridge Intel Sandy Bridge "[' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Sandy' ' Bridge' ' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed']" " is Intel .
+
+ The company which developed Sandy Bridge is Intel .
+
+ The company which developed" True resulting from the SSD and Sandy Bridge processors. However, 6 [' resulting', ' from', ' the', ' SSD', ' and', ' Sandy', ' Bridge']
+294 164 The company which developed x -1 The company which developed Sandy Bridge Intel Sandy Bridge "[' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Sandy' ' Bridge' ' is' ' Intel' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed']" " is Intel .
+
+ The company which developed Sandy Bridge is Intel .
+
+ The company which developed" True previous generation Sandy Bridge processor option, 3 [' previous', ' generation', ' Sandy', ' Bridge']
+295 167 The company which developed x -1 The company which developed Final Fantasy Square Final Fantasy "[' XV' ':' ' A' ' New' ' Empire' ' is' ' a' ' new' ' company' ' that'
+ ' is' ' currently' ' working' ' on' ' a' ' new' ' game' '.' ' The'
+ ' game']" XV : A New Empire is a new company that is currently working on a new game . The game False Dirge of Cerberus: Final Fantasy VII Multiplayer 6 [' Dir', 'ge', ' of', ' Cerberus', ':', ' Final', ' Fantasy']
+296 167 The company which developed x -1 The company which developed Final Fantasy Square Final Fantasy "[' XV' ':' ' A' ' New' ' Empire' ' is' ' a' ' new' ' company' ' that'
+ ' is' ' currently' ' working' ' on' ' a' ' new' ' game' '.' ' The'
+ ' game']" XV : A New Empire is a new company that is currently working on a new game . The game False to the six in Final Fantasy XII: Kytes and 5 [' to', ' the', ' six', ' in', ' Final', ' Fantasy']
+297 167 The company which developed x -1 The company which developed Final Fantasy Square Final Fantasy "[' XV' ':' ' A' ' New' ' Empire' ' is' ' a' ' new' ' company' ' that'
+ ' is' ' currently' ' working' ' on' ' a' ' new' ' game' '.' ' The'
+ ' game']" XV : A New Empire is a new company that is currently working on a new game . The game False character from Final Fantasy VII. At the time, 3 [' character', ' from', ' Final', ' Fantasy']
+298 167 The company which developed x -1 The company which developed Final Fantasy Square Final Fantasy "[' XV' ':' ' A' ' New' ' Empire' ' is' ' a' ' new' ' company' ' that'
+ ' is' ' currently' ' working' ' on' ' a' ' new' ' game' '.' ' The'
+ ' game']" XV : A New Empire is a new company that is currently working on a new game . The game False 1 ['Final', ' Fantasy']
+299 167 The company which developed x -1 The company which developed Final Fantasy Square Final Fantasy "[' XV' ':' ' A' ' New' ' Empire' ' is' ' a' ' new' ' company' ' that'
+ ' is' ' currently' ' working' ' on' ' a' ' new' ' game' '.' ' The'
+ ' game']" XV : A New Empire is a new company that is currently working on a new game . The game False characters from the Final Fantasy franchise 4 [' characters', ' from', ' the', ' Final', ' Fantasy']
+300 170 The company which developed x -1 The company which developed Symbian Nokia Symbian "[' OS' ' is' ' now' ' owned' ' by' ' Nokia' ',' ' and' ' the' ' Symb'
+ 'ian' ' OS' ' is' ' now' ' called' ' the' ' Nokia' ' Series' ' 60' '.']" OS is now owned by Nokia , and the Symb ian OS is now called the Nokia Series 60 . True 2011 for Nokia Symbian smartphones, Apple 4 [' 2011', ' for', ' Nokia', ' Symb', 'ian']
+301 170 The company which developed x -1 The company which developed Symbian Nokia Symbian "[' OS' ' is' ' now' ' owned' ' by' ' Nokia' ',' ' and' ' the' ' Symb'
+ 'ian' ' OS' ' is' ' now' ' called' ' the' ' Nokia' ' Series' ' 60' '.']" OS is now owned by Nokia , and the Symb ian OS is now called the Nokia Series 60 . True handsets using the Symbian operating system, 5 [' hands', 'ets', ' using', ' the', ' Symb', 'ian']
+302 170 The company which developed x -1 The company which developed Symbian Nokia Symbian "[' OS' ' is' ' now' ' owned' ' by' ' Nokia' ',' ' and' ' the' ' Symb'
+ 'ian' ' OS' ' is' ' now' ' called' ' the' ' Nokia' ' Series' ' 60' '.']" OS is now owned by Nokia , and the Symb ian OS is now called the Nokia Series 60 . True implementation for the Symbian S60 operating system; 4 [' implementation', ' for', ' the', ' Symb', 'ian']
+303 170 The company which developed x -1 The company which developed Symbian Nokia Symbian "[' OS' ' is' ' now' ' owned' ' by' ' Nokia' ',' ' and' ' the' ' Symb'
+ 'ian' ' OS' ' is' ' now' ' called' ' the' ' Nokia' ' Series' ' 60' '.']" OS is now owned by Nokia , and the Symb ian OS is now called the Nokia Series 60 . True implementation for the Symbian S60 operating system; 4 [' implementation', ' for', ' the', ' Symb', 'ian']
+304 170 The company which developed x -1 The company which developed Symbian Nokia Symbian "[' OS' ' is' ' now' ' owned' ' by' ' Nokia' ',' ' and' ' the' ' Symb'
+ 'ian' ' OS' ' is' ' now' ' called' ' the' ' Nokia' ' Series' ' 60' '.']" OS is now owned by Nokia , and the Symb ian OS is now called the Nokia Series 60 . True implementation for the Symbian S60 operating system; 4 [' implementation', ' for', ' the', ' Symb', 'ian']
+305 171 The company which developed x -1 The company which developed Windows 98 Microsoft Windows 98 "[',' ' Windows' ' 2000' ',' ' Windows' ' XP' ',' ' Windows' ' Vista' ','
+ ' Windows' ' 7' ',' ' Windows' ' 8' ',' ' Windows' ' 8' '.' '1']" , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows 8 , Windows 8 . 1 False in 2000 for the Windows 98 platform (and in 5 [' in', ' 2000', ' for', ' the', ' Windows', ' 98']
+306 171 The company which developed x -1 The company which developed Windows 98 Microsoft Windows 98 "[',' ' Windows' ' 2000' ',' ' Windows' ' XP' ',' ' Windows' ' Vista' ','
+ ' Windows' ' 7' ',' ' Windows' ' 8' ',' ' Windows' ' 8' '.' '1']" , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows 8 , Windows 8 . 1 False 1 ['Windows', ' 98']
+307 171 The company which developed x -1 The company which developed Windows 98 Microsoft Windows 98 "[',' ' Windows' ' 2000' ',' ' Windows' ' XP' ',' ' Windows' ' Vista' ','
+ ' Windows' ' 7' ',' ' Windows' ' 8' ',' ' Windows' ' 8' '.' '1']" , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows 8 , Windows 8 . 1 False Published in 2000 for the Windows 98 platform (and in 6 [' Published', ' in', ' 2000', ' for', ' the', ' Windows', ' 98']
+308 171 The company which developed x -1 The company which developed Windows 98 Microsoft Windows 98 "[',' ' Windows' ' 2000' ',' ' Windows' ' XP' ',' ' Windows' ' Vista' ','
+ ' Windows' ' 7' ',' ' Windows' ' 8' ',' ' Windows' ' 8' '.' '1']" , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows 8 , Windows 8 . 1 False 1 ['Windows', ' 98']
+309 171 The company which developed x -1 The company which developed Windows 98 Microsoft Windows 98 "[',' ' Windows' ' 2000' ',' ' Windows' ' XP' ',' ' Windows' ' Vista' ','
+ ' Windows' ' 7' ',' ' Windows' ' 8' ',' ' Windows' ' 8' '.' '1']" , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows 8 , Windows 8 . 1 False in 2000 for the Windows 98 platform (and in the 5 [' in', ' 2000', ' for', ' the', ' Windows', ' 98']
+310 176 The company which developed x -1 The company which developed Windows XP Microsoft Windows XP "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of'
+ ' Windows' ' XP' ',' ' Windows' ' XP' ' Service' ' Pack' ' 3' '.' ' The'
+ ' new' ' version']" , Microsoft has released a new version of Windows XP , Windows XP Service Pack 3 . The new version True enabled to play on Windows XP through an unauthorized 5 [' enabled', ' to', ' play', ' on', ' Windows', ' XP']
+311 176 The company which developed x -1 The company which developed Windows XP Microsoft Windows XP "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of'
+ ' Windows' ' XP' ',' ' Windows' ' XP' ' Service' ' Pack' ' 3' '.' ' The'
+ ' new' ' version']" , Microsoft has released a new version of Windows XP , Windows XP Service Pack 3 . The new version True versions of Windows XP and higher. The patch 3 [' versions', ' of', ' Windows', ' XP']
+312 176 The company which developed x -1 The company which developed Windows XP Microsoft Windows XP "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of'
+ ' Windows' ' XP' ',' ' Windows' ' XP' ' Service' ' Pack' ' 3' '.' ' The'
+ ' new' ' version']" , Microsoft has released a new version of Windows XP , Windows XP Service Pack 3 . The new version True campaign for Windows XP. It begins with 3 [' campaign', ' for', ' Windows', ' XP']
+313 176 The company which developed x -1 The company which developed Windows XP Microsoft Windows XP "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of'
+ ' Windows' ' XP' ',' ' Windows' ' XP' ' Service' ' Pack' ' 3' '.' ' The'
+ ' new' ' version']" , Microsoft has released a new version of Windows XP , Windows XP Service Pack 3 . The new version True this version, Windows XP would not be 4 [' this', ' version', ',', ' Windows', ' XP']
+314 176 The company which developed x -1 The company which developed Windows XP Microsoft Windows XP "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of'
+ ' Windows' ' XP' ',' ' Windows' ' XP' ' Service' ' Pack' ' 3' '.' ' The'
+ ' new' ' version']" , Microsoft has released a new version of Windows XP , Windows XP Service Pack 3 . The new version True screen running Windows XP Professional, 3 [' screen', ' running', ' Windows', ' XP']
+315 177 The company which developed x -1 The company which developed NTFS Microsoft NTFS "[' is' ' Microsoft' '.' '\n' '\n' 'I' ' have' ' a' ' question' ' about'
+ ' the' ' N' 'TF' 'S' ' file' ' system' '.' ' I' ' have' ' a']" " is Microsoft .
+
+ I have a question about the N TF S file system . I have a" True improves on NTFS in some respects. 4 [' improves', ' on', ' N', 'TF', 'S']
+316 177 The company which developed x -1 The company which developed NTFS Microsoft NTFS "[' is' ' Microsoft' '.' '\n' '\n' 'I' ' have' ' a' ' question' ' about'
+ ' the' ' N' 'TF' 'S' ' file' ' system' '.' ' I' ' have' ' a']" " is Microsoft .
+
+ I have a question about the N TF S file system . I have a" True servers that improves on NTFS in some respects. 6 [' servers', ' that', ' improves', ' on', ' N', 'TF', 'S']
+317 177 The company which developed x -1 The company which developed NTFS Microsoft NTFS "[' is' ' Microsoft' '.' '\n' '\n' 'I' ' have' ' a' ' question' ' about'
+ ' the' ' N' 'TF' 'S' ' file' ' system' '.' ' I' ' have' ' a']" " is Microsoft .
+
+ I have a question about the N TF S file system . I have a" True systems, including the NTFS filesystem which 6 [' systems', ',', ' including', ' the', ' N', 'TF', 'S']
+318 177 The company which developed x -1 The company which developed NTFS Microsoft NTFS "[' is' ' Microsoft' '.' '\n' '\n' 'I' ' have' ' a' ' question' ' about'
+ ' the' ' N' 'TF' 'S' ' file' ' system' '.' ' I' ' have' ' a']" " is Microsoft .
+
+ I have a question about the N TF S file system . I have a" True 2 ['N', 'TF', 'S']
+319 177 The company which developed x -1 The company which developed NTFS Microsoft NTFS "[' is' ' Microsoft' '.' '\n' '\n' 'I' ' have' ' a' ' question' ' about'
+ ' the' ' N' 'TF' 'S' ' file' ' system' '.' ' I' ' have' ' a']" " is Microsoft .
+
+ I have a question about the N TF S file system . I have a" True 2 ['N', 'TF', 'S']
+320 179 The company which developed x -1 The company which developed Heavyweight Champ Sega Heavyweight Champ "['agne' ',' ' a' ' new' ' beer' ' that' ' is' ' a' ' blend' ' of'
+ ' champagne' ' and' ' beer' '.' '\n' '\n' 'The' ' beer' ' is' ' a']" "agne , a new beer that is a blend of champagne and beer .
+
+ The beer is a" False arcade game Heavyweight Champ in 1976, but it 4 [' arcade', ' game', ' Heavy', 'weight', ' Champ']
+321 179 The company which developed x -1 The company which developed Heavyweight Champ Sega Heavyweight Champ "['agne' ',' ' a' ' new' ' beer' ' that' ' is' ' a' ' blend' ' of'
+ ' champagne' ' and' ' beer' '.' '\n' '\n' 'The' ' beer' ' is' ' a']" "agne , a new beer that is a blend of champagne and beer .
+
+ The beer is a" False arcade game Heavyweight Champ in 1976, but 4 [' arcade', ' game', ' Heavy', 'weight', ' Champ']
+322 179 The company which developed x -1 The company which developed Heavyweight Champ Sega Heavyweight Champ "['agne' ',' ' a' ' new' ' beer' ' that' ' is' ' a' ' blend' ' of'
+ ' champagne' ' and' ' beer' '.' '\n' '\n' 'The' ' beer' ' is' ' a']" "agne , a new beer that is a blend of champagne and beer .
+
+ The beer is a" False boxing game Heavyweight Champ, which was released 4 [' boxing', ' game', ' Heavy', 'weight', ' Champ']
+323 179 The company which developed x -1 The company which developed Heavyweight Champ Sega Heavyweight Champ "['agne' ',' ' a' ' new' ' beer' ' that' ' is' ' a' ' blend' ' of'
+ ' champagne' ' and' ' beer' '.' '\n' '\n' 'The' ' beer' ' is' ' a']" "agne , a new beer that is a blend of champagne and beer .
+
+ The beer is a" False and white boxing game Heavyweight Champ, which was released 6 [' and', ' white', ' boxing', ' game', ' Heavy', 'weight', ' Champ']
+324 179 The company which developed x -1 The company which developed Heavyweight Champ Sega Heavyweight Champ "['agne' ',' ' a' ' new' ' beer' ' that' ' is' ' a' ' blend' ' of'
+ ' champagne' ' and' ' beer' '.' '\n' '\n' 'The' ' beer' ' is' ' a']" "agne , a new beer that is a blend of champagne and beer .
+
+ The beer is a" False was arcade game Heavyweight Champ in 1976, but 5 [' was', ' arcade', ' game', ' Heavy', 'weight', ' Champ']
+325 181 The company which developed x -1 The company which developed Logic Pro Apple Logic Pro "[' X' ' is' ' Logic' ' Pro' ' X' '.' '\n' '\n' 'I' ' have' ' a'
+ ' question' ' about' ' the' ' new' ' version' ' of' ' Logic' ' Pro' ' X']" " X is Logic Pro X .
+
+ I have a question about the new version of Logic Pro X" False instrument in the Logic Pro audio production 4 [' instrument', ' in', ' the', ' Logic', ' Pro']
+326 181 The company which developed x -1 The company which developed Logic Pro Apple Logic Pro "[' X' ' is' ' Logic' ' Pro' ' X' '.' '\n' '\n' 'I' ' have' ' a'
+ ' question' ' about' ' the' ' new' ' version' ' of' ' Logic' ' Pro' ' X']" " X is Logic Pro X .
+
+ I have a question about the new version of Logic Pro X" False electronic sound. He used Logic Pro recording software, 6 [' electronic', ' sound', '.', ' He', ' used', ' Logic', ' Pro']
+327 181 The company which developed x -1 The company which developed Logic Pro Apple Logic Pro "[' X' ' is' ' Logic' ' Pro' ' X' '.' '\n' '\n' 'I' ' have' ' a'
+ ' question' ' about' ' the' ' new' ' version' ' of' ' Logic' ' Pro' ' X']" " X is Logic Pro X .
+
+ I have a question about the new version of Logic Pro X" False sound. He used Logic Pro recording software, 5 [' sound', '.', ' He', ' used', ' Logic', ' Pro']
+328 181 The company which developed x -1 The company which developed Logic Pro Apple Logic Pro "[' X' ' is' ' Logic' ' Pro' ' X' '.' '\n' '\n' 'I' ' have' ' a'
+ ' question' ' about' ' the' ' new' ' version' ' of' ' Logic' ' Pro' ' X']" " X is Logic Pro X .
+
+ I have a question about the new version of Logic Pro X" False programming the drums in Logic Pro and changing the 5 [' programming', ' the', ' drums', ' in', ' Logic', ' Pro']
+329 181 The company which developed x -1 The company which developed Logic Pro Apple Logic Pro "[' X' ' is' ' Logic' ' Pro' ' X' '.' '\n' '\n' 'I' ' have' ' a'
+ ' question' ' about' ' the' ' new' ' version' ' of' ' Logic' ' Pro' ' X']" " X is Logic Pro X .
+
+ I have a question about the new version of Logic Pro X" False the drums in Logic Pro and changing the 4 [' the', ' drums', ' in', ' Logic', ' Pro']
+330 185 The company which developed x -1 The company which developed Final Fantasy V Square Final Fantasy V "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False " ""while playing Final Fantasy V is a chore on the" 5 "[' ""', 'while', ' playing', ' Final', ' Fantasy', ' V']"
+331 185 The company which developed x -1 The company which developed Final Fantasy V Square Final Fantasy V "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False versions of Final Fantasy V and Final Fantasy 4 [' versions', ' of', ' Final', ' Fantasy', ' V']
+332 185 The company which developed x -1 The company which developed Final Fantasy V Square Final Fantasy V "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False the soundtrack for Final Fantasy V was much better 5 [' the', ' soundtrack', ' for', ' Final', ' Fantasy', ' V']
+333 185 The company which developed x -1 The company which developed Final Fantasy V Square Final Fantasy V "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False 2 ['Final', ' Fantasy', ' V']
+334 185 The company which developed x -1 The company which developed Final Fantasy V Square Final Fantasy V "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False " Dear Friends ==
+" 5 [' Dear', ' Friends', ' ==', 'Final', ' Fantasy', ' V']
+335 186 The company which developed x -1 The company which developed DirectX Microsoft DirectX "[' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' DirectX' ' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " 12 is Microsoft .
+
+ The company which developed DirectX 12 is Microsoft .
+
+ The company which" True version when DirectX 11 features 2 [' version', ' when', ' DirectX']
+336 186 The company which developed x -1 The company which developed DirectX Microsoft DirectX "[' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' DirectX' ' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " 12 is Microsoft .
+
+ The company which developed DirectX 12 is Microsoft .
+
+ The company which" True also has both DirectX 9 and DirectX 10 mode. 3 [' also', ' has', ' both', ' DirectX']
+337 186 The company which developed x -1 The company which developed DirectX Microsoft DirectX "[' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' DirectX' ' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " 12 is Microsoft .
+
+ The company which developed DirectX 12 is Microsoft .
+
+ The company which" True " ====
+" 3 [' =', '===', 'Direct', 'X']
+338 186 The company which developed x -1 The company which developed DirectX Microsoft DirectX "[' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' DirectX' ' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " 12 is Microsoft .
+
+ The company which developed DirectX 12 is Microsoft .
+
+ The company which" True incorporates DirectX 10 level effects 1 [' incorporates', ' DirectX']
+339 186 The company which developed x -1 The company which developed DirectX Microsoft DirectX "[' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' DirectX' ' 12' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " 12 is Microsoft .
+
+ The company which developed DirectX 12 is Microsoft .
+
+ The company which" True hardware and DirectX 11-compatible platforms, 2 [' hardware', ' and', ' DirectX']
+340 188 The company which developed x -1 The company which developed Core 2 Intel Core 2 "['.' '0' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for'
+ ' a' ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in']" . 0 is a company that has been around for a long time , and has been a leader in False updates and Core 2 Duo processors later 3 [' updates', ' and', ' Core', ' 2']
+341 188 The company which developed x -1 The company which developed Core 2 Intel Core 2 "['.' '0' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for'
+ ' a' ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in']" . 0 is a company that has been around for a long time , and has been a leader in False MacBook moved to the Core 2 Duo processor and the 5 [' MacBook', ' moved', ' to', ' the', ' Core', ' 2']
+342 188 The company which developed x -1 The company which developed Core 2 Intel Core 2 "['.' '0' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for'
+ ' a' ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in']" . 0 is a company that has been around for a long time , and has been a leader in False MacBook moved to the Core 2 Duo processor 5 [' MacBook', ' moved', ' to', ' the', ' Core', ' 2']
+343 188 The company which developed x -1 The company which developed Core 2 Intel Core 2 "['.' '0' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for'
+ ' a' ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in']" . 0 is a company that has been around for a long time , and has been a leader in False 13-inch retained the Core 2 Duo with a speed 6 [' 13', '-', 'inch', ' retained', ' the', ' Core', ' 2']
+344 188 The company which developed x -1 The company which developed Core 2 Intel Core 2 "['.' '0' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for'
+ ' a' ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in']" . 0 is a company that has been around for a long time , and has been a leader in False include Intel Core 2 Duo processors. 3 [' include', ' Intel', ' Core', ' 2']
+345 189 The company which developed x -1 The company which developed Internet Explorer Microsoft Internet Explorer "[' 8' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Internet' ' Explorer' ' 8' ' is' ' Microsoft' '.' '\n'
+ '\n' 'The' ' company']" " 8 is Microsoft .
+
+ The company which developed Internet Explorer 8 is Microsoft .
+
+ The company" True Microsoft's Internet Explorer in late 1995 as 3 "[' Microsoft', ""'s"", ' Internet', ' Explorer']"
+346 189 The company which developed x -1 The company which developed Internet Explorer Microsoft Internet Explorer "[' 8' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Internet' ' Explorer' ' 8' ' is' ' Microsoft' '.' '\n'
+ '\n' 'The' ' company']" " 8 is Microsoft .
+
+ The company which developed Internet Explorer 8 is Microsoft .
+
+ The company" True page zooming in Internet Explorer and 'Domain Enroll' 5 [' page', ' z', 'ooming', ' in', ' Internet', ' Explorer']
+347 189 The company which developed x -1 The company which developed Internet Explorer Microsoft Internet Explorer "[' 8' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Internet' ' Explorer' ' 8' ' is' ' Microsoft' '.' '\n'
+ '\n' 'The' ' company']" " 8 is Microsoft .
+
+ The company which developed Internet Explorer 8 is Microsoft .
+
+ The company" True percent for Internet Explorer 7 and 77 percent 3 [' percent', ' for', ' Internet', ' Explorer']
+348 189 The company which developed x -1 The company which developed Internet Explorer Microsoft Internet Explorer "[' 8' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Internet' ' Explorer' ' 8' ' is' ' Microsoft' '.' '\n'
+ '\n' 'The' ' company']" " 8 is Microsoft .
+
+ The company which developed Internet Explorer 8 is Microsoft .
+
+ The company" True that Microsoft Internet Explorer did not properly 3 [' that', ' Microsoft', ' Internet', ' Explorer']
+349 189 The company which developed x -1 The company which developed Internet Explorer Microsoft Internet Explorer "[' 8' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Internet' ' Explorer' ' 8' ' is' ' Microsoft' '.' '\n'
+ '\n' 'The' ' company']" " 8 is Microsoft .
+
+ The company which developed Internet Explorer 8 is Microsoft .
+
+ The company" True deployed using Internet Explorer Mobile or a 3 [' deployed', ' using', ' Internet', ' Explorer']
+350 191 The company which developed x -1 The company which developed Gmail Google Gmail "[' is' ' Google' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Gmail' ' is' ' Google' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Gmail']" " is Google .
+
+ The company which developed Gmail is Google .
+
+ The company which developed Gmail" True services, such as Gmail (with push email 4 [' services', ',', ' such', ' as', ' Gmail']
+351 191 The company which developed x -1 The company which developed Gmail Google Gmail "[' is' ' Google' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Gmail' ' is' ' Google' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Gmail']" " is Google .
+
+ The company which developed Gmail is Google .
+
+ The company which developed Gmail" True outside of Gmail and Twitter — throughout 2 [' outside', ' of', ' Gmail']
+352 191 The company which developed x -1 The company which developed Gmail Google Gmail "[' is' ' Google' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Gmail' ' is' ' Google' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Gmail']" " is Google .
+
+ The company which developed Gmail is Google .
+
+ The company which developed Gmail" True services, such as Gmail (with push 4 [' services', ',', ' such', ' as', ' Gmail']
+353 192 The company which developed x -1 The company which developed Mac OS X Panther Apple Mac OS X Panther "[',' ' the' ' first' ' version' ' of' ' the' ' Mac' ' OS' ' X'
+ ' operating' ' system' ',' ' has' ' been' ' acquired' ' by' ' Apple' '.'
+ '\n' '\n']" ", the first version of the Mac OS X operating system , has been acquired by Apple .
+
+" True the company's Mac OS X Panther operating system, 6 "[' the', ' company', ""'s"", ' Mac', ' OS', ' X', ' Panther']"
+354 192 The company which developed x -1 The company which developed Mac OS X Panther Apple Mac OS X Panther "[',' ' the' ' first' ' version' ' of' ' the' ' Mac' ' OS' ' X'
+ ' operating' ' system' ',' ' has' ' been' ' acquired' ' by' ' Apple' '.'
+ '\n' '\n']" ", the first version of the Mac OS X operating system , has been acquired by Apple .
+
+" True the company's Mac OS X Panther operating system, 6 "[' the', ' company', ""'s"", ' Mac', ' OS', ' X', ' Panther']"
+355 192 The company which developed x -1 The company which developed Mac OS X Panther Apple Mac OS X Panther "[',' ' the' ' first' ' version' ' of' ' the' ' Mac' ' OS' ' X'
+ ' operating' ' system' ',' ' has' ' been' ' acquired' ' by' ' Apple' '.'
+ '\n' '\n']" ", the first version of the Mac OS X operating system , has been acquired by Apple .
+
+" True music to the company's Mac OS X Panther operating system, 8 "[' music', ' to', ' the', ' company', ""'s"", ' Mac', ' OS', ' X', ' Panther']"
+356 192 The company which developed x -1 The company which developed Mac OS X Panther Apple Mac OS X Panther "[',' ' the' ' first' ' version' ' of' ' the' ' Mac' ' OS' ' X'
+ ' operating' ' system' ',' ' has' ' been' ' acquired' ' by' ' Apple' '.'
+ '\n' '\n']" ", the first version of the Mac OS X operating system , has been acquired by Apple .
+
+" True company's Mac OS X Panther operating system, 5 "[' company', ""'s"", ' Mac', ' OS', ' X', ' Panther']"
+357 193 The company which developed x -1 The company which developed ISPF IBM ISPF "[' is' ' a' ' leading' ' provider' ' of' ' software' ' solutions' ' for'
+ ' the' ' design' ' and' ' manufacture' ' of' ' high' '-' 'performance'
+ ',' ' high' '-' 'pre']" is a leading provider of software solutions for the design and manufacture of high - performance , high - pre False received embroidered ISPF and Moonbase patches 4 [' received', ' embro', 'idered', ' ISP', 'F']
+358 193 The company which developed x -1 The company which developed ISPF IBM ISPF "[' is' ' a' ' leading' ' provider' ' of' ' software' ' solutions' ' for'
+ ' the' ' design' ' and' ' manufacture' ' of' ' high' '-' 'performance'
+ ',' ' high' '-' 'pre']" is a leading provider of software solutions for the design and manufacture of high - performance , high - pre False Moonbase where the ISPF have their 5 [' Moon', 'base', ' where', ' the', ' ISP', 'F']
+359 193 The company which developed x -1 The company which developed ISPF IBM ISPF "[' is' ' a' ' leading' ' provider' ' of' ' software' ' solutions' ' for'
+ ' the' ' design' ' and' ' manufacture' ' of' ' high' '-' 'performance'
+ ',' ' high' '-' 'pre']" is a leading provider of software solutions for the design and manufacture of high - performance , high - pre False received embroidered ISPF and Moonbase 4 [' received', ' embro', 'idered', ' ISP', 'F']
+360 193 The company which developed x -1 The company which developed ISPF IBM ISPF "[' is' ' a' ' leading' ' provider' ' of' ' software' ' solutions' ' for'
+ ' the' ' design' ' and' ' manufacture' ' of' ' high' '-' 'performance'
+ ',' ' high' '-' 'pre']" is a leading provider of software solutions for the design and manufacture of high - performance , high - pre False received embroidered ISPF and Moonbase patches 4 [' received', ' embro', 'idered', ' ISP', 'F']
+361 193 The company which developed x -1 The company which developed ISPF IBM ISPF "[' is' ' a' ' leading' ' provider' ' of' ' software' ' solutions' ' for'
+ ' the' ' design' ' and' ' manufacture' ' of' ' high' '-' 'performance'
+ ',' ' high' '-' 'pre']" is a leading provider of software solutions for the design and manufacture of high - performance , high - pre False embroidered ISPF and Moonbase patches 3 [' embro', 'idered', ' ISP', 'F']
+362 196 The company which developed x -1 The company which developed Windows Media Video Microsoft Windows Media Video "[' 9' ' is' ' Microsoft' ' Corporation' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed' ' Windows' ' Media' ' Player' ' is' ' Microsoft'
+ ' Corporation' '.' '\n' '\n']" " 9 is Microsoft Corporation .
+
+ The company which developed Windows Media Player is Microsoft Corporation .
+
+" True QuickTime and Windows Media Video formats. Limited 5 [' Quick', 'Time', ' and', ' Windows', ' Media', ' Video']
+363 196 The company which developed x -1 The company which developed Windows Media Video Microsoft Windows Media Video "[' 9' ' is' ' Microsoft' ' Corporation' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed' ' Windows' ' Media' ' Player' ' is' ' Microsoft'
+ ' Corporation' '.' '\n' '\n']" " 9 is Microsoft Corporation .
+
+ The company which developed Windows Media Player is Microsoft Corporation .
+
+" True in QuickTime and Windows Media Video formats. Limited 6 [' in', ' Quick', 'Time', ' and', ' Windows', ' Media', ' Video']
+364 196 The company which developed x -1 The company which developed Windows Media Video Microsoft Windows Media Video "[' 9' ' is' ' Microsoft' ' Corporation' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed' ' Windows' ' Media' ' Player' ' is' ' Microsoft'
+ ' Corporation' '.' '\n' '\n']" " 9 is Microsoft Corporation .
+
+ The company which developed Windows Media Player is Microsoft Corporation .
+
+" True QuickTime and Windows Media Video formats. Limited 5 [' Quick', 'Time', ' and', ' Windows', ' Media', ' Video']
+365 196 The company which developed x -1 The company which developed Windows Media Video Microsoft Windows Media Video "[' 9' ' is' ' Microsoft' ' Corporation' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed' ' Windows' ' Media' ' Player' ' is' ' Microsoft'
+ ' Corporation' '.' '\n' '\n']" " 9 is Microsoft Corporation .
+
+ The company which developed Windows Media Player is Microsoft Corporation .
+
+" True QuickTime and Windows Media Video formats. Limited 5 [' Quick', 'Time', ' and', ' Windows', ' Media', ' Video']
+366 197 The company which developed x -1 The company which developed Xbox One Microsoft Xbox One "[' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a' ' powerful'
+ ' console' ' that' ' will' ' be' ' released' ' on' ' November' ' 7' 'th']" X , the new Xbox One X , is a powerful console that will be released on November 7 th False " released on the Xbox One on August 4, 2015.
+" 4 [' released', ' on', ' the', ' Xbox', ' One']
+367 197 The company which developed x -1 The company which developed Xbox One Microsoft Xbox One "[' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a' ' powerful'
+ ' console' ' that' ' will' ' be' ' released' ' on' ' November' ' 7' 'th']" X , the new Xbox One X , is a powerful console that will be released on November 7 th False PlayStation 4 and Xbox One versions fit onto 4 [' PlayStation', ' 4', ' and', ' Xbox', ' One']
+368 197 The company which developed x -1 The company which developed Xbox One Microsoft Xbox One "[' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a' ' powerful'
+ ' console' ' that' ' will' ' be' ' released' ' on' ' November' ' 7' 'th']" X , the new Xbox One X , is a powerful console that will be released on November 7 th False was released on Xbox One and Windows PC (via 4 [' was', ' released', ' on', ' Xbox', ' One']
+369 197 The company which developed x -1 The company which developed Xbox One Microsoft Xbox One "[' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a' ' powerful'
+ ' console' ' that' ' will' ' be' ' released' ' on' ' November' ' 7' 'th']" X , the new Xbox One X , is a powerful console that will be released on November 7 th False title for the Xbox One just prior to the 4 [' title', ' for', ' the', ' Xbox', ' One']
+370 197 The company which developed x -1 The company which developed Xbox One Microsoft Xbox One "[' X' ',' ' the' ' new' ' Xbox' ' One' ' X' ',' ' is' ' a' ' powerful'
+ ' console' ' that' ' will' ' be' ' released' ' on' ' November' ' 7' 'th']" X , the new Xbox One X , is a powerful console that will be released on November 7 th False Turtle Beach Xbox One headsets, USB flash 3 [' Turtle', ' Beach', ' Xbox', ' One']
+371 198 The company which developed x -1 The company which developed Windows RT Microsoft Windows RT "[',' ' Windows' ' Phone' ' 8' ',' ' and' ' Windows' ' 8' '.' '1' '.' '\n'
+ '\n' 'The' ' company' ' has' ' also' ' announced' ' that' ' it']" ", Windows Phone 8 , and Windows 8 . 1 .
+
+ The company has also announced that it" False 1 ['Windows', ' RT']
+372 198 The company which developed x -1 The company which developed Windows RT Microsoft Windows RT "[',' ' Windows' ' Phone' ' 8' ',' ' and' ' Windows' ' 8' '.' '1' '.' '\n'
+ '\n' 'The' ' company' ' has' ' also' ' announced' ' that' ' it']" ", Windows Phone 8 , and Windows 8 . 1 .
+
+ The company has also announced that it" False 1 ['Windows', ' RT']
+373 199 The company which developed x -1 The company which developed Tamarin Adobe Tamarin "[',' ' a' ' new' ' programming' ' language' ' for' ' the' ' J' 'VM' ','
+ ' has' ' released' ' a' ' new' ' version' ' of' ' Tam' 'arin' ',' ' a']" , a new programming language for the J VM , has released a new version of Tam arin , a False this island was in Tamarin Bay, on the 5 [' this', ' island', ' was', ' in', ' Tam', 'arin']
+374 199 The company which developed x -1 The company which developed Tamarin Adobe Tamarin "[',' ' a' ' new' ' programming' ' language' ' for' ' the' ' J' 'VM' ','
+ ' has' ' released' ' a' ' new' ' version' ' of' ' Tam' 'arin' ',' ' a']" , a new programming language for the J VM , has released a new version of Tam arin , a False island was in Tamarin Bay, on the west 4 [' island', ' was', ' in', ' Tam', 'arin']
+375 200 The company which developed x -1 The company which developed Windows Server 2008 Microsoft Windows Server 2008 "[' R' '2' ',' ' Windows' ' Server' ' 2012' ',' ' and' ' Windows' ' Server'
+ ' 2012' ' R' '2' '.' '\n' '\n' 'The' ' new' ' features' ' in']" " R 2 , Windows Server 2012 , and Windows Server 2012 R 2 .
+
+ The new features in" False 2 ['Windows', ' Server', ' 2008']
+376 200 The company which developed x -1 The company which developed Windows Server 2008 Microsoft Windows Server 2008 "[' R' '2' ',' ' Windows' ' Server' ' 2012' ',' ' and' ' Windows' ' Server'
+ ' 2012' ' R' '2' '.' '\n' '\n' 'The' ' new' ' features' ' in']" " R 2 , Windows Server 2012 , and Windows Server 2012 R 2 .
+
+ The new features in" False " compared to around 200 in Windows Server 2008 R2.
+" 7 [' compared', ' to', ' around', ' 200', ' in', ' Windows', ' Server', ' 2008']
+377 200 The company which developed x -1 The company which developed Windows Server 2008 Microsoft Windows Server 2008 "[' R' '2' ',' ' Windows' ' Server' ' 2012' ',' ' and' ' Windows' ' Server'
+ ' 2012' ' R' '2' '.' '\n' '\n' 'The' ' new' ' features' ' in']" " R 2 , Windows Server 2012 , and Windows Server 2012 R 2 .
+
+ The new features in" False announced that Windows Server 2008 R2 would be the 4 [' announced', ' that', ' Windows', ' Server', ' 2008']
+378 200 The company which developed x -1 The company which developed Windows Server 2008 Microsoft Windows Server 2008 "[' R' '2' ',' ' Windows' ' Server' ' 2012' ',' ' and' ' Windows' ' Server'
+ ' 2012' ' R' '2' '.' '\n' '\n' 'The' ' new' ' features' ' in']" " R 2 , Windows Server 2012 , and Windows Server 2012 R 2 .
+
+ The new features in" False " predecessor Windows Server 2008 R2:
+" 3 [' predecessor', ' Windows', ' Server', ' 2008']
+379 200 The company which developed x -1 The company which developed Windows Server 2008 Microsoft Windows Server 2008 "[' R' '2' ',' ' Windows' ' Server' ' 2012' ',' ' and' ' Windows' ' Server'
+ ' 2012' ' R' '2' '.' '\n' '\n' 'The' ' new' ' features' ' in']" " R 2 , Windows Server 2012 , and Windows Server 2012 R 2 .
+
+ The new features in" False Microsoft announced that Windows Server 2008 R2 would be the last 5 [' Microsoft', ' announced', ' that', ' Windows', ' Server', ' 2008']
+380 203 The company which developed x -1 The company which developed iCloud Apple iCloud "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iCloud' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iCloud']" " is Apple .
+
+ The company which developed iCloud is Apple .
+
+ The company which developed iCloud" True compromised the Apple iCloud accounts of 3 [' compromised', ' the', ' Apple', ' iCloud']
+381 203 The company which developed x -1 The company which developed iCloud Apple iCloud "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iCloud' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iCloud']" " is Apple .
+
+ The company which developed iCloud is Apple .
+
+ The company which developed iCloud" True iPad to the computer. iCloud also allows users 5 [' iPad', ' to', ' the', ' computer', '.', ' iCloud']
+382 203 The company which developed x -1 The company which developed iCloud Apple iCloud "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iCloud' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iCloud']" " is Apple .
+
+ The company which developed iCloud is Apple .
+
+ The company which developed iCloud" True iPad to the computer. iCloud also allows users 5 [' iPad', ' to', ' the', ' computer', '.', ' iCloud']
+383 203 The company which developed x -1 The company which developed iCloud Apple iCloud "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iCloud' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iCloud']" " is Apple .
+
+ The company which developed iCloud is Apple .
+
+ The company which developed iCloud" True wave of stolen nude iCloud photographs. In 4 [' wave', ' of', ' stolen', ' nude', ' iCloud']
+384 207 The company which developed x -1 The company which developed MacPaint Apple MacPaint "[' was' ' founded' ' in' ' 1983' ' by' ' a' ' group' ' of' ' artists'
+ ' and' ' designers' ' who' ' wanted' ' to' ' create' ' a' ' new' ' kind'
+ ' of' ' paint']" was founded in 1983 by a group of artists and designers who wanted to create a new kind of paint False 2 ['Mac', 'P', 'aint']
+385 207 The company which developed x -1 The company which developed MacPaint Apple MacPaint "[' was' ' founded' ' in' ' 1983' ' by' ' a' ' group' ' of' ' artists'
+ ' and' ' designers' ' who' ' wanted' ' to' ' create' ' a' ' new' ' kind'
+ ' of' ' paint']" was founded in 1983 by a group of artists and designers who wanted to create a new kind of paint False 2 ['Mac', 'P', 'aint']
+386 207 The company which developed x -1 The company which developed MacPaint Apple MacPaint "[' was' ' founded' ' in' ' 1983' ' by' ' a' ' group' ' of' ' artists'
+ ' and' ' designers' ' who' ' wanted' ' to' ' create' ' a' ' new' ' kind'
+ ' of' ' paint']" was founded in 1983 by a group of artists and designers who wanted to create a new kind of paint False The original MacPaint was developed by 4 [' The', ' original', ' Mac', 'P', 'aint']
+387 207 The company which developed x -1 The company which developed MacPaint Apple MacPaint "[' was' ' founded' ' in' ' 1983' ' by' ' a' ' group' ' of' ' artists'
+ ' and' ' designers' ' who' ' wanted' ' to' ' create' ' a' ' new' ' kind'
+ ' of' ' paint']" was founded in 1983 by a group of artists and designers who wanted to create a new kind of paint False " MacPaint =
+" 2 [' Mac', 'P', 'aint']
+388 207 The company which developed x -1 The company which developed MacPaint Apple MacPaint "[' was' ' founded' ' in' ' 1983' ' by' ' a' ' group' ' of' ' artists'
+ ' and' ' designers' ' who' ' wanted' ' to' ' create' ' a' ' new' ' kind'
+ ' of' ' paint']" was founded in 1983 by a group of artists and designers who wanted to create a new kind of paint False 2 ['Mac', 'P', 'aint']
+389 208 The company which developed x -1 The company which developed MiniDisc Sony MiniDisc "['s' ',' ' the' ' company' ' that' ' brought' ' us' ' the' ' Mini' 'Disc'
+ ',' ' the' ' first' ' portable' ' music' ' player' ',' ' has' ' been'
+ ' acquired']" s , the company that brought us the Mini Disc , the first portable music player , has been acquired False lyrics using a MiniDisc player. The rest of 4 [' lyrics', ' using', ' a', ' Mini', 'Disc']
+390 208 The company which developed x -1 The company which developed MiniDisc Sony MiniDisc "['s' ',' ' the' ' company' ' that' ' brought' ' us' ' the' ' Mini' 'Disc'
+ ',' ' the' ' first' ' portable' ' music' ' player' ',' ' has' ' been'
+ ' acquired']" s , the company that brought us the Mini Disc , the first portable music player , has been acquired False initial lyrics using a MiniDisc player. The rest 5 [' initial', ' lyrics', ' using', ' a', ' Mini', 'Disc']
+391 209 The company which developed x -1 The company which developed QuickTime Apple QuickTime "[' VR' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the']" VR is a company that has been around for a long time , and has been a leader in the False receive low-grade QuickTime clips over the Internet 5 [' receive', ' low', '-', 'grade', ' Quick', 'Time']
+392 209 The company which developed x -1 The company which developed QuickTime Apple QuickTime "[' VR' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the']" VR is a company that has been around for a long time , and has been a leader in the False litigation 4 [' lit', 'ig', 'atio', 'Quick', 'Time']
+393 209 The company which developed x -1 The company which developed QuickTime Apple QuickTime "[' VR' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the']" VR is a company that has been around for a long time , and has been a leader in the False receive low-grade QuickTime clips over the 5 [' receive', ' low', '-', 'grade', ' Quick', 'Time']
+394 209 The company which developed x -1 The company which developed QuickTime Apple QuickTime "[' VR' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the']" VR is a company that has been around for a long time , and has been a leader in the False " archive (streaming QuickTime format).
+" 5 [' archive', ' (', 'stream', 'ing', ' Quick', 'Time']
+395 209 The company which developed x -1 The company which developed QuickTime Apple QuickTime "[' VR' ' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the']" VR is a company that has been around for a long time , and has been a leader in the False that featured grainy QuickTime videos of Metroid 5 [' that', ' featured', ' grain', 'y', ' Quick', 'Time']
+396 212 The company which developed x -1 The company which developed Newsstand Apple Newsstand "[' is' ' a' ' great' ' example' ' of' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' a' ' long' ' time' ' and' ' has' ' a' ' great'
+ ' reputation']" is a great example of a company that has been around for a long time and has a great reputation False Reader, a Modern Newsstand on Southeast 5 [' Reader', ',', ' a', ' Modern', ' News', 'stand']
+397 212 The company which developed x -1 The company which developed Newsstand Apple Newsstand "[' is' ' a' ' great' ' example' ' of' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' a' ' long' ' time' ' and' ' has' ' a' ' great'
+ ' reputation']" is a great example of a company that has been around for a long time and has a great reputation False Carnival. During the Newsstand Day (Día del Canillita) 5 [' Carnival', '.', ' During', ' the', ' News', 'stand']
+398 212 The company which developed x -1 The company which developed Newsstand Apple Newsstand "[' is' ' a' ' great' ' example' ' of' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' a' ' long' ' time' ' and' ' has' ' a' ' great'
+ ' reputation']" is a great example of a company that has been around for a long time and has a great reputation False December 1965, Newsstand Special 1989 (100 Beautiful 4 [' December', ' 1965', ',', ' News', 'stand']
+399 212 The company which developed x -1 The company which developed Newsstand Apple Newsstand "[' is' ' a' ' great' ' example' ' of' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' a' ' long' ' time' ' and' ' has' ' a' ' great'
+ ' reputation']" is a great example of a company that has been around for a long time and has a great reputation False Playboy), December 1965, Newsstand Special 1989 (100 6 [' Playboy', '),', ' December', ' 1965', ',', ' News', 'stand']
+400 212 The company which developed x -1 The company which developed Newsstand Apple Newsstand "[' is' ' a' ' great' ' example' ' of' ' a' ' company' ' that' ' has'
+ ' been' ' around' ' for' ' a' ' long' ' time' ' and' ' has' ' a' ' great'
+ ' reputation']" is a great example of a company that has been around for a long time and has a great reputation False During the Newsstand Day (Día del Canillita) 3 [' During', ' the', ' News', 'stand']
+401 213 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False compatible for Windows 7 PCs called Kanon Memorial 3 [' compatible', ' for', ' Windows', ' 7']
+402 213 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False usability. Version 4.0 for Windows 7 SP1 (x64) 8 [' usability', '.', ' Version', ' 4', '.', '0', ' for', ' Windows', ' 7']
+403 213 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False August 2015 for Windows 7 it was last, with 4 [' August', ' 2015', ' for', ' Windows', ' 7']
+404 213 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False all compatible Windows 7 or 8.1 computers with 3 [' all', ' compatible', ' Windows', ' 7']
+405 213 The company which developed x -1 The company which developed Windows 7 Microsoft Windows 7 "[',' ' Windows' ' 8' ',' ' and' ' Windows' ' 10' '.' '\n' '\n' 'The'
+ ' company' ' which' ' developed' ' Windows' ' 7' ',' ' Windows' ' 8' ',']" ", Windows 8 , and Windows 10 .
+
+ The company which developed Windows 7 , Windows 8 ," False XP, Vista, and Windows 7 32-bit & 64-bit (hardware 6 [' XP', ',', ' Vista', ',', ' and', ' Windows', ' 7']
+406 215 The company which developed x -1 The company which developed Donkey Kong Nintendo Donkey Kong "['a' ',' ' a' ' game' ' that' ' was' ' released' ' in' ' the' ' arc'
+ 'ades' ' in' ' the' ' early' ' 1980' 's' ',' ' is' ' now' ' working']" a , a game that was released in the arc ades in the early 1980 s , is now working False owned or leased Donkey Kong machines. To 4 [' owned', ' or', ' leased', ' Donkey', ' Kong']
+407 215 The company which developed x -1 The company which developed Donkey Kong Nintendo Donkey Kong "['a' ',' ' a' ' game' ' that' ' was' ' released' ' in' ' the' ' arc'
+ 'ades' ' in' ' the' ' early' ' 1980' 's' ',' ' is' ' now' ' working']" a , a game that was released in the arc ades in the early 1980 s , is now working False that arrive on Donkey Kong Island, and 4 [' that', ' arrive', ' on', ' Donkey', ' Kong']
+408 215 The company which developed x -1 The company which developed Donkey Kong Nintendo Donkey Kong "['a' ',' ' a' ' game' ' that' ' was' ' released' ' in' ' the' ' arc'
+ 'ades' ' in' ' the' ' early' ' 1980' 's' ',' ' is' ' now' ' working']" a , a game that was released in the arc ades in the early 1980 s , is now working False people worked on Donkey Kong Country over an 18-month 4 [' people', ' worked', ' on', ' Donkey', ' Kong']
+409 215 The company which developed x -1 The company which developed Donkey Kong Nintendo Donkey Kong "['a' ',' ' a' ' game' ' that' ' was' ' released' ' in' ' the' ' arc'
+ 'ades' ' in' ' the' ' early' ' 1980' 's' ',' ' is' ' now' ' working']" a , a game that was released in the arc ades in the early 1980 s , is now working False 1994, and Donkey Kong Country helping 4 [' 1994', ',', ' and', ' Donkey', ' Kong']
+410 215 The company which developed x -1 The company which developed Donkey Kong Nintendo Donkey Kong "['a' ',' ' a' ' game' ' that' ' was' ' released' ' in' ' the' ' arc'
+ 'ades' ' in' ' the' ' early' ' 1980' 's' ',' ' is' ' now' ' working']" a , a game that was released in the arc ades in the early 1980 s , is now working False franchises, including Donkey Kong and Goldeneye 007, 4 [' franchises', ',', ' including', ' Donkey', ' Kong']
+411 216 The company which developed x -1 The company which developed iPod Nano Apple iPod Nano "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False introduced new iPod Nano and iPod Touch 3 [' introduced', ' new', ' iPod', ' Nano']
+412 216 The company which developed x -1 The company which developed iPod Nano Apple iPod Nano "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False introduced new iPod Nano and iPod Touch models. 3 [' introduced', ' new', ' iPod', ' Nano']
+413 216 The company which developed x -1 The company which developed iPod Nano Apple iPod Nano "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False generation), iPod Nano (7th generation), 3 [' generation', '),', ' iPod', ' Nano']
+414 216 The company which developed x -1 The company which developed iPod Nano Apple iPod Nano "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False (5th generation), iPod Nano (7th generation), 6 [' (', '5', 'th', ' generation', '),', ' iPod', ' Nano']
+415 219 The company which developed x -1 The company which developed PowerShell Microsoft PowerShell "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' PowerShell' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' PowerShell']" " is Microsoft .
+
+ The company which developed PowerShell is Microsoft .
+
+ The company which developed PowerShell" True by default. Windows PowerShell in this version 4 [' by', ' default', '.', ' Windows', ' PowerShell']
+416 219 The company which developed x -1 The company which developed PowerShell Microsoft PowerShell "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' PowerShell' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' PowerShell']" " is Microsoft .
+
+ The company which developed PowerShell is Microsoft .
+
+ The company which developed PowerShell" True 1 ['Power', 'Shell']
+417 219 The company which developed x -1 The company which developed PowerShell Microsoft PowerShell "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' PowerShell' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' PowerShell']" " is Microsoft .
+
+ The company which developed PowerShell is Microsoft .
+
+ The company which developed PowerShell" True default. Windows PowerShell in this version 3 [' default', '.', ' Windows', ' PowerShell']
+418 219 The company which developed x -1 The company which developed PowerShell Microsoft PowerShell "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' PowerShell' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' PowerShell']" " is Microsoft .
+
+ The company which developed PowerShell is Microsoft .
+
+ The company which developed PowerShell" True 1 ['Power', 'Shell']
+419 219 The company which developed x -1 The company which developed PowerShell Microsoft PowerShell "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' PowerShell' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' PowerShell']" " is Microsoft .
+
+ The company which developed PowerShell is Microsoft .
+
+ The company which developed PowerShell" True default. Windows PowerShell in this version has 3 [' default', '.', ' Windows', ' PowerShell']
+420 225 The company which developed x -1 The company which developed Windows Live Messenger Microsoft Windows Live Messenger "[' is' ' a' ' Microsoft' ' company' ',' ' and' ' it' ' is' ' a' ' very'
+ ' popular' ' instant' ' messaging' ' client' '.' ' It' ' is' ' a' ' free'
+ ' application']" is a Microsoft company , and it is a very popular instant messaging client . It is a free application True " & D 'oh! Nuts"". Windows Live Messenger presented their users" 10 "[' &', ' D', "" '"", 'oh', '!', ' N', 'uts', '"".', ' Windows', ' Live', ' Messenger']"
+421 227 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True system that works on Windows Vista and Windows 5 [' system', ' that', ' works', ' on', ' Windows', ' Vista']
+422 227 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True support for Windows Vista 64-bit. New features 3 [' support', ' for', ' Windows', ' Vista']
+423 227 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True for Windows XP, Windows Vista and Mac OS X 5 [' for', ' Windows', ' XP', ',', ' Windows', ' Vista']
+424 227 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True Planetarian compatible for Windows Vista PCs was released 5 [' Planet', 'arian', ' compatible', ' for', ' Windows', ' Vista']
+425 227 The company which developed x -1 The company which developed Windows Vista Microsoft Windows Vista "[',' ' Microsoft' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system' ',' ' Windows' ' 7' ',' ' which' ' is'
+ ' available' ' for' ' download']" , Microsoft has released a new version of its operating system , Windows 7 , which is available for download True 1, 2007 as a Windows Vista compatible edition 6 [' 1', ',', ' 2007', ' as', ' a', ' Windows', ' Vista']
+426 228 The company which developed x -1 The company which developed Windows 8 Microsoft Windows 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+" True advantage of Windows 8 after its release 3 [' advantage', ' of', ' Windows', ' 8']
+427 228 The company which developed x -1 The company which developed Windows 8 Microsoft Windows 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+" True balanced aspects of Windows 8 with those 4 [' balanced', ' aspects', ' of', ' Windows', ' 8']
+428 228 The company which developed x -1 The company which developed Windows 8 Microsoft Windows 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+" True released along with the Windows 8 Consumer Preview 5 [' released', ' along', ' with', ' the', ' Windows', ' 8']
+429 228 The company which developed x -1 The company which developed Windows 8 Microsoft Windows 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+" True " an update to Windows 8 codenamed ""Threshold""," 4 [' an', ' update', ' to', ' Windows', ' 8']
+430 228 The company which developed x -1 The company which developed Windows 8 Microsoft Windows 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+" True not run on Windows 8 and later, which 4 [' not', ' run', ' on', ' Windows', ' 8']
+431 232 The company which developed x -1 The company which developed Safari Apple Safari "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False Holiday World & Splashin' Safari announced Thunderbird, 6 "[' Holiday', ' World', ' &', ' Splash', 'in', ""'"", ' Safari']"
+432 232 The company which developed x -1 The company which developed Safari Apple Safari "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False San Diego Zoo Safari Park and San Pasqual 3 [' San', ' Diego', ' Zoo', ' Safari']
+433 232 The company which developed x -1 The company which developed Safari Apple Safari "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False exposed at the Wild Safari Fault where the 4 [' exposed', ' at', ' the', ' Wild', ' Safari']
+434 232 The company which developed x -1 The company which developed Safari Apple Safari "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False 2 ['S', 'af', 'ari']
+435 232 The company which developed x -1 The company which developed Safari Apple Safari "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False National Park, Tuli Safari Area and Chimanimani 5 [' National', ' Park', ',', ' Tul', 'i', ' Safari']
+436 234 The company which developed x -1 The company which developed The Revenge of Shinobi Sega The Revenge of Shinobi "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' The'
+ ' Legend' ' of' ' Heroes' ':' ' Trails' ' of' ' Cold' ' Steel' ' III' '.'
+ ' The']" is now working on a new game called The Legend of Heroes : Trails of Cold Steel III . The False scenes, similar to The Revenge of Shinobi {{} } s introduction 8 [' scenes', ',', ' similar', ' to', ' The', ' Revenge', ' of', ' Shin', 'obi']
+437 234 The company which developed x -1 The company which developed The Revenge of Shinobi Sega The Revenge of Shinobi "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' The'
+ ' Legend' ' of' ' Heroes' ':' ' Trails' ' of' ' Cold' ' Steel' ' III' '.'
+ ' The']" is now working on a new game called The Legend of Heroes : Trails of Cold Steel III . The False similar to The Revenge of Shinobi {{} } s introduction 6 [' similar', ' to', ' The', ' Revenge', ' of', ' Shin', 'obi']
+438 234 The company which developed x -1 The company which developed The Revenge of Shinobi Sega The Revenge of Shinobi "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' The'
+ ' Legend' ' of' ' Heroes' ':' ' Trails' ' of' ' Cold' ' Steel' ' III' '.'
+ ' The']" is now working on a new game called The Legend of Heroes : Trails of Cold Steel III . The False scenes, similar to The Revenge of Shinobi {{} } s introduction 8 [' scenes', ',', ' similar', ' to', ' The', ' Revenge', ' of', ' Shin', 'obi']
+439 234 The company which developed x -1 The company which developed The Revenge of Shinobi Sega The Revenge of Shinobi "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' The'
+ ' Legend' ' of' ' Heroes' ':' ' Trails' ' of' ' Cold' ' Steel' ' III' '.'
+ ' The']" is now working on a new game called The Legend of Heroes : Trails of Cold Steel III . The False scenes, similar to The Revenge of Shinobi {{} } s introduction 8 [' scenes', ',', ' similar', ' to', ' The', ' Revenge', ' of', ' Shin', 'obi']
+440 238 The company which developed x -1 The company which developed Gear VR Samsung Gear VR "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' while' ',' ' and' ' has' ' been' ' making' ' VR' ' headsets' ' for'
+ ' a' ' while']" is a company that has been around for a while , and has been making VR headsets for a while False the Samsung Gear VR device. On 5 December 3 [' the', ' Samsung', ' Gear', ' VR']
+441 238 The company which developed x -1 The company which developed Gear VR Samsung Gear VR "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' while' ',' ' and' ' has' ' been' ' making' ' VR' ' headsets' ' for'
+ ' a' ' while']" is a company that has been around for a while , and has been making VR headsets for a while False Tower via its Gear VR virtual reality device. 4 [' Tower', ' via', ' its', ' Gear', ' VR']
+442 238 The company which developed x -1 The company which developed Gear VR Samsung Gear VR "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' while' ',' ' and' ' has' ' been' ' making' ' VR' ' headsets' ' for'
+ ' a' ' while']" is a company that has been around for a while , and has been making VR headsets for a while False Go for the Samsung Gear VR device. On 5 December 5 [' Go', ' for', ' the', ' Samsung', ' Gear', ' VR']
+443 239 The company which developed x -1 The company which developed Clang Apple Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+444 239 The company which developed x -1 The company which developed Clang Apple Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+445 239 The company which developed x -1 The company which developed Clang Apple Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+446 239 The company which developed x -1 The company which developed Clang Apple Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False licenses. ClangBSD aims to replace 3 [' licenses', '.', ' Cl', 'ang']
+447 239 The company which developed x -1 The company which developed Clang Apple Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to 4 [' software', ' licenses', '.', ' Cl', 'ang']
+448 241 The company which developed x -1 The company which developed iOS Apple iOS "[' 7' '.' '0' '.' '3' ',' ' which' ' is' ' the' ' latest' ' version' ' of'
+ ' the' ' operating' ' system' ',' ' is' ' now' ' available' ' for']" 7 . 0 . 3 , which is the latest version of the operating system , is now available for False story for PC, Mac, iOS and Android. On May 6 [' story', ' for', ' PC', ',', ' Mac', ',', ' iOS']
+449 241 The company which developed x -1 The company which developed iOS Apple iOS "[' 7' '.' '0' '.' '3' ',' ' which' ' is' ' the' ' latest' ' version' ' of'
+ ' the' ' operating' ' system' ',' ' is' ' now' ' available' ' for']" 7 . 0 . 3 , which is the latest version of the operating system , is now available for False application. The iOS and Android versions 3 [' application', '.', ' The', ' iOS']
+450 241 The company which developed x -1 The company which developed iOS Apple iOS "[' 7' '.' '0' '.' '3' ',' ' which' ' is' ' the' ' latest' ' version' ' of'
+ ' the' ' operating' ' system' ',' ' is' ' now' ' available' ' for']" 7 . 0 . 3 , which is the latest version of the operating system , is now available for False launched for the iOS version, alongside 3 [' launched', ' for', ' the', ' iOS']
+451 241 The company which developed x -1 The company which developed iOS Apple iOS "[' 7' '.' '0' '.' '3' ',' ' which' ' is' ' the' ' latest' ' version' ' of'
+ ' the' ' operating' ' system' ',' ' is' ' now' ' available' ' for']" 7 . 0 . 3 , which is the latest version of the operating system , is now available for False called his favorite iOS game. He concluded 3 [' called', ' his', ' favorite', ' iOS']
+452 241 The company which developed x -1 The company which developed iOS Apple iOS "[' 7' '.' '0' '.' '3' ',' ' which' ' is' ' the' ' latest' ' version' ' of'
+ ' the' ' operating' ' system' ',' ' is' ' now' ' available' ' for']" 7 . 0 . 3 , which is the latest version of the operating system , is now available for False Contacts. Like all iOS devices, the iPad 5 [' Cont', 'acts', '.', ' Like', ' all', ' iOS']
+453 243 The company which developed x -1 The company which developed Xbox 360 Microsoft Xbox 360 "[' and' ' Xbox' ' One' ' games' ',' ' and' ' the' ' company' ' that'
+ ' developed' ' the' ' original' ' Xbox' ',' ' Microsoft' ' Studios' ','
+ ' is' ' now' ' working']" and Xbox One games , and the company that developed the original Xbox , Microsoft Studios , is now working True launch title on the Xbox 360, with 200,000 5 [' launch', ' title', ' on', ' the', ' Xbox', ' 360']
+454 243 The company which developed x -1 The company which developed Xbox 360 Microsoft Xbox 360 "[' and' ' Xbox' ' One' ' games' ',' ' and' ' the' ' company' ' that'
+ ' developed' ' the' ' original' ' Xbox' ',' ' Microsoft' ' Studios' ','
+ ' is' ' now' ' working']" and Xbox One games , and the company that developed the original Xbox , Microsoft Studios , is now working True both PC and Xbox 360 versions, by 4 [' both', ' PC', ' and', ' Xbox', ' 360']
+455 243 The company which developed x -1 The company which developed Xbox 360 Microsoft Xbox 360 "[' and' ' Xbox' ' One' ' games' ',' ' and' ' the' ' company' ' that'
+ ' developed' ' the' ' original' ' Xbox' ',' ' Microsoft' ' Studios' ','
+ ' is' ' now' ' working']" and Xbox One games , and the company that developed the original Xbox , Microsoft Studios , is now working True Metacritic gave the Xbox 360 version 71.69 % based 6 [' Met', 'ac', 'ritic', ' gave', ' the', ' Xbox', ' 360']
+456 243 The company which developed x -1 The company which developed Xbox 360 Microsoft Xbox 360 "[' and' ' Xbox' ' One' ' games' ',' ' and' ' the' ' company' ' that'
+ ' developed' ' the' ' original' ' Xbox' ',' ' Microsoft' ' Studios' ','
+ ' is' ' now' ' working']" and Xbox One games , and the company that developed the original Xbox , Microsoft Studios , is now working True of 100 for the Xbox 360 version. Aggregate 5 [' of', ' 100', ' for', ' the', ' Xbox', ' 360']
+457 243 The company which developed x -1 The company which developed Xbox 360 Microsoft Xbox 360 "[' and' ' Xbox' ' One' ' games' ',' ' and' ' the' ' company' ' that'
+ ' developed' ' the' ' original' ' Xbox' ',' ' Microsoft' ' Studios' ','
+ ' is' ' now' ' working']" and Xbox One games , and the company that developed the original Xbox , Microsoft Studios , is now working True exclusively for the Xbox 360 and the decision 4 [' exclusively', ' for', ' the', ' Xbox', ' 360']
+458 244 The company which developed x -1 The company which developed Altered Beast Sega Altered Beast "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' ""' 'Al'
+ 'tered' ' Beast' ':' ' Resurrection' '""' ' for' ' the' ' Nintendo' ' DS'
+ '.']" " is now working on a new game called "" Al tered Beast : Resurrection "" for the Nintendo DS ." False bundled game Altered Beast with a new title, 4 [' bundled', ' game', ' Al', 'tered', ' Beast']
+459 244 The company which developed x -1 The company which developed Altered Beast Sega Altered Beast "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' ""' 'Al'
+ 'tered' ' Beast' ':' ' Resurrection' '""' ' for' ' the' ' Nintendo' ' DS'
+ '.']" " is now working on a new game called "" Al tered Beast : Resurrection "" for the Nintendo DS ." False the bundled game Altered Beast with a new title, 5 [' the', ' bundled', ' game', ' Al', 'tered', ' Beast']
+460 245 The company which developed x -1 The company which developed WebObjects Apple WebObjects "[' is' ' a' ' leading' ' provider' ' of' ' enterprise' '-' 'class'
+ ' software' ' solutions' ' for' ' the' ' mobile' ' workforce' '.' ' The'
+ ' company' ""'s"" ' flagship' ' product']" is a leading provider of enterprise - class software solutions for the mobile workforce . The company 's flagship product False 2 ['Web', 'Object', 's']
+461 245 The company which developed x -1 The company which developed WebObjects Apple WebObjects "[' is' ' a' ' leading' ' provider' ' of' ' enterprise' '-' 'class'
+ ' software' ' solutions' ' for' ' the' ' mobile' ' workforce' '.' ' The'
+ ' company' ""'s"" ' flagship' ' product']" is a leading provider of enterprise - class software solutions for the mobile workforce . The company 's flagship product False also launched WebObjects, a platform for building 4 [' also', ' launched', ' Web', 'Object', 's']
+462 245 The company which developed x -1 The company which developed WebObjects Apple WebObjects "[' is' ' a' ' leading' ' provider' ' of' ' enterprise' '-' 'class'
+ ' software' ' solutions' ' for' ' the' ' mobile' ' workforce' '.' ' The'
+ ' company' ""'s"" ' flagship' ' product']" is a leading provider of enterprise - class software solutions for the mobile workforce . The company 's flagship product False 2 ['Web', 'Object', 's']
+463 245 The company which developed x -1 The company which developed WebObjects Apple WebObjects "[' is' ' a' ' leading' ' provider' ' of' ' enterprise' '-' 'class'
+ ' software' ' solutions' ' for' ' the' ' mobile' ' workforce' '.' ' The'
+ ' company' ""'s"" ' flagship' ' product']" is a leading provider of enterprise - class software solutions for the mobile workforce . The company 's flagship product False also developed WebObjects, one of the first 4 [' also', ' developed', ' Web', 'Object', 's']
+464 245 The company which developed x -1 The company which developed WebObjects Apple WebObjects "[' is' ' a' ' leading' ' provider' ' of' ' enterprise' '-' 'class'
+ ' software' ' solutions' ' for' ' the' ' mobile' ' workforce' '.' ' The'
+ ' company' ""'s"" ' flagship' ' product']" is a leading provider of enterprise - class software solutions for the mobile workforce . The company 's flagship product False 2 ['Web', 'Object', 's']
+465 246 The company which developed x -1 The company which developed ActiveSync Microsoft ActiveSync "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False push-email protocol ActiveSync up to 40 %; 5 [' push', '-', 'email', ' protocol', ' Active', 'Sync']
+466 246 The company which developed x -1 The company which developed ActiveSync Microsoft ActiveSync "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False Improvements were made to ActiveSync 4.2 with 15 5 [' Improvements', ' were', ' made', ' to', ' Active', 'Sync']
+467 246 The company which developed x -1 The company which developed ActiveSync Microsoft ActiveSync "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False push-email protocol ActiveSync up to 40 %; 5 [' push', '-', 'email', ' protocol', ' Active', 'Sync']
+468 247 The company which developed x -1 The company which developed WebM Google WebM "['ate' ' is' ' a' ' leading' ' provider' ' of' ' web' ' hosting'
+ ' services' '.' ' The' ' company' ' is' ' based' ' in' ' the' ' United'
+ ' States' ' and' ' has']" ate is a leading provider of web hosting services . The company is based in the United States and has False Opus audio; if VP9 / WebM is not supported 9 [' Op', 'us', ' audio', ';', ' if', ' VP', '9', ' /', ' Web', 'M']
+469 247 The company which developed x -1 The company which developed WebM Google WebM "['ate' ' is' ' a' ' leading' ' provider' ' of' ' web' ' hosting'
+ ' services' '.' ' The' ' company' ' is' ' based' ' in' ' the' ' United'
+ ' States' ' and' ' has']" ate is a leading provider of web hosting services . The company is based in the United States and has False the H.264 or WebM formats could play 6 [' the', ' H', '.', '264', ' or', ' Web', 'M']
+470 247 The company which developed x -1 The company which developed WebM Google WebM "['ate' ' is' ' a' ' leading' ' provider' ' of' ' web' ' hosting'
+ ' services' '.' ' The' ' company' ' is' ' based' ' in' ' the' ' United'
+ ' States' ' and' ' has']" ate is a leading provider of web hosting services . The company is based in the United States and has False the H.264 or WebM formats could 6 [' the', ' H', '.', '264', ' or', ' Web', 'M']
+471 247 The company which developed x -1 The company which developed WebM Google WebM "['ate' ' is' ' a' ' leading' ' provider' ' of' ' web' ' hosting'
+ ' services' '.' ' The' ' company' ' is' ' based' ' in' ' the' ' United'
+ ' States' ' and' ' has']" ate is a leading provider of web hosting services . The company is based in the United States and has False using the H.264 or WebM formats could play 7 [' using', ' the', ' H', '.', '264', ' or', ' Web', 'M']
+472 248 The company which developed x -1 The company which developed iOS 5 Apple iOS 5 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 5' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 5 . 0 . 1 is Apple True although when using iOS 5 and later, the user 4 [' although', ' when', ' using', ' iOS', ' 5']
+473 248 The company which developed x -1 The company which developed iOS 5 Apple iOS 5 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 5' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 5 . 0 . 1 is Apple True other Apple iOS 5 products. This 3 [' other', ' Apple', ' iOS', ' 5']
+474 248 The company which developed x -1 The company which developed iOS 5 Apple iOS 5 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 5' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 5 . 0 . 1 is Apple True October 12, 2011, iOS 5 was released 6 [' October', ' 12', ',', ' 2011', ',', ' iOS', ' 5']
+475 248 The company which developed x -1 The company which developed iOS 5 Apple iOS 5 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 5' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 5 . 0 . 1 is Apple True texting to other Apple iOS 5 products. This 5 [' texting', ' to', ' other', ' Apple', ' iOS', ' 5']
+476 248 The company which developed x -1 The company which developed iOS 5 Apple iOS 5 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 5' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 5 . 0 . 1 is Apple True is related to the iOS 5 operating system that 5 [' is', ' related', ' to', ' the', ' iOS', ' 5']
+477 252 The company which developed x -1 The company which developed Internet Explorer 11 Microsoft Internet Explorer 11 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Internet' ' Explorer' ' 11' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " is Microsoft .
+
+ The company which developed Internet Explorer 11 is Microsoft .
+
+ The company which" True within Windows 10. Internet Explorer 11 is maintained 6 [' within', ' Windows', ' 10', '.', ' Internet', ' Explorer', ' 11']
+478 256 The company which developed x -1 The company which developed Alex Kidd Sega Alex Kidd "[':' ' The' ' Lost' ' Levels' ' is' ' now' ' working' ' on' ' a' ' new'
+ ' game' ',' ' and' ' it' ""'s"" ' called' ' Alex' ' Kidd' ':' ' The']" : The Lost Levels is now working on a new game , and it 's called Alex Kidd : The False Streets of Rage. Alex Kidd was the mascot of 5 [' Streets', ' of', ' Rage', '.', ' Alex', ' Kidd']
+479 256 The company which developed x -1 The company which developed Alex Kidd Sega Alex Kidd "[':' ' The' ' Lost' ' Levels' ' is' ' now' ' working' ' on' ' a' ' new'
+ ' game' ',' ' and' ' it' ""'s"" ' called' ' Alex' ' Kidd' ':' ' The']" : The Lost Levels is now working on a new game , and it 's called Alex Kidd : The False comparison to the NES; Alex Kidd in Miracle World, 6 [' comparison', ' to', ' the', ' NES', ';', ' Alex', ' Kidd']
+480 256 The company which developed x -1 The company which developed Alex Kidd Sega Alex Kidd "[':' ' The' ' Lost' ' Levels' ' is' ' now' ' working' ' on' ' a' ' new'
+ ' game' ',' ' and' ' it' ""'s"" ' called' ' Alex' ' Kidd' ':' ' The']" : The Lost Levels is now working on a new game , and it 's called Alex Kidd : The False and Streets of Rage. Alex Kidd was the mascot 6 [' and', ' Streets', ' of', ' Rage', '.', ' Alex', ' Kidd']
+481 256 The company which developed x -1 The company which developed Alex Kidd Sega Alex Kidd "[':' ' The' ' Lost' ' Levels' ' is' ' now' ' working' ' on' ' a' ' new'
+ ' game' ',' ' and' ' it' ""'s"" ' called' ' Alex' ' Kidd' ':' ' The']" : The Lost Levels is now working on a new game , and it 's called Alex Kidd : The False comparison to the NES; Alex Kidd in Miracle World, 6 [' comparison', ' to', ' the', ' NES', ';', ' Alex', ' Kidd']
+482 256 The company which developed x -1 The company which developed Alex Kidd Sega Alex Kidd "[':' ' The' ' Lost' ' Levels' ' is' ' now' ' working' ' on' ' a' ' new'
+ ' game' ',' ' and' ' it' ""'s"" ' called' ' Alex' ' Kidd' ':' ' The']" : The Lost Levels is now working on a new game , and it 's called Alex Kidd : The False of Shinobi, Alex Kidd in Miracle World, 5 [' of', ' Shin', 'obi', ',', ' Alex', ' Kidd']
+483 258 The company which developed x -1 The company which developed Clang Sony Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to 4 [' software', ' licenses', '.', ' Cl', 'ang']
+484 258 The company which developed x -1 The company which developed Clang Sony Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+485 258 The company which developed x -1 The company which developed Clang Sony Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to 4 [' software', ' licenses', '.', ' Cl', 'ang']
+486 259 The company which developed x -1 The company which developed Windows 10 Microsoft Windows 10 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 10' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed']" " is Microsoft .
+
+ The company which developed Windows 10 is Microsoft .
+
+ The company which developed" True announced that Windows 10 would be released 3 [' announced', ' that', ' Windows', ' 10']
+487 259 The company which developed x -1 The company which developed Windows 10 Microsoft Windows 10 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 10' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed']" " is Microsoft .
+
+ The company which developed Windows 10 is Microsoft .
+
+ The company which developed" True Technica noted that Windows 10's new Start menu 5 [' Techn', 'ica', ' noted', ' that', ' Windows', ' 10']
+488 259 The company which developed x -1 The company which developed Windows 10 Microsoft Windows 10 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 10' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed']" " is Microsoft .
+
+ The company which developed Windows 10 is Microsoft .
+
+ The company which developed" True messaging apps for Windows 10 are built in as of 4 [' messaging', ' apps', ' for', ' Windows', ' 10']
+489 259 The company which developed x -1 The company which developed Windows 10 Microsoft Windows 10 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 10' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed']" " is Microsoft .
+
+ The company which developed Windows 10 is Microsoft .
+
+ The company which developed" True 2014, under the name Windows 10; Myerson said 6 [' 2014', ',', ' under', ' the', ' name', ' Windows', ' 10']
+490 259 The company which developed x -1 The company which developed Windows 10 Microsoft Windows 10 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 10' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company'
+ ' which' ' developed']" " is Microsoft .
+
+ The company which developed Windows 10 is Microsoft .
+
+ The company which developed" True automatically pushed to Windows 10 users via Windows 4 [' automatically', ' pushed', ' to', ' Windows', ' 10']
+491 261 The company which developed x -1 The company which developed iOS 8 Apple iOS 8 "['.' '0' '.' '2' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 2 , the latest version of the operating system , is now available for download . The False can also run iOS 8 which was released 4 [' can', ' also', ' run', ' iOS', ' 8']
+492 261 The company which developed x -1 The company which developed iOS 8 Apple iOS 8 "['.' '0' '.' '2' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 2 , the latest version of the operating system , is now available for download . The False working with the iOS 8 operating system 4 [' working', ' with', ' the', ' iOS', ' 8']
+493 261 The company which developed x -1 The company which developed iOS 8 Apple iOS 8 "['.' '0' '.' '2' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 2 , the latest version of the operating system , is now available for download . The False from working with the iOS 8 operating system 5 [' from', ' working', ' with', ' the', ' iOS', ' 8']
+494 261 The company which developed x -1 The company which developed iOS 8 Apple iOS 8 "['.' '0' '.' '2' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 2 , the latest version of the operating system , is now available for download . The False the case when iOS 8 was released and both 4 [' the', ' case', ' when', ' iOS', ' 8']
+495 261 The company which developed x -1 The company which developed iOS 8 Apple iOS 8 "['.' '0' '.' '2' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 2 , the latest version of the operating system , is now available for download . The False The iPad 2 supports iOS 8 which was released 5 [' The', ' iPad', ' 2', ' supports', ' iOS', ' 8']
+496 262 The company which developed x -1 The company which developed MobileMe Apple MobileMe "[' is' ' now' ' known' ' as' ' Apple' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' working' ' on' ' a' ' new' ' version' ' of' ' its'
+ ' mobile']" " is now known as Apple .
+
+ The company has been working on a new version of its mobile" True Monitoring: Scans MobileMe ®, iChat ® and other 5 [' Monitoring', ':', ' Sc', 'ans', ' Mobile', 'Me']
+497 262 The company which developed x -1 The company which developed MobileMe Apple MobileMe "[' is' ' now' ' known' ' as' ' Apple' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' working' ' on' ' a' ' new' ' version' ' of' ' its'
+ ' mobile']" " is now known as Apple .
+
+ The company has been working on a new version of its mobile" True Monitoring: Scans MobileMe ®, iChat ® and other 5 [' Monitoring', ':', ' Sc', 'ans', ' Mobile', 'Me']
+498 267 The company which developed x -1 The company which developed Eternal Champions Sega Eternal Champions "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' staple' ' of' ' the'
+ ' fighting']" is a company that has been around for a long time , and has been a staple of the fighting False such as Comix Zone and Eternal Champions respectively in an 7 [' such', ' as', ' Com', 'ix', ' Zone', ' and', ' Eternal', ' Champions']
+499 267 The company which developed x -1 The company which developed Eternal Champions Sega Eternal Champions "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' staple' ' of' ' the'
+ ' fighting']" is a company that has been around for a long time , and has been a staple of the fighting False Comix Zone and Eternal Champions respectively in an 5 [' Com', 'ix', ' Zone', ' and', ' Eternal', ' Champions']
+500 267 The company which developed x -1 The company which developed Eternal Champions Sega Eternal Champions "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' staple' ' of' ' the'
+ ' fighting']" is a company that has been around for a long time , and has been a staple of the fighting False Comix Zone and Eternal Champions respectively in an 5 [' Com', 'ix', ' Zone', ' and', ' Eternal', ' Champions']
+501 267 The company which developed x -1 The company which developed Eternal Champions Sega Eternal Champions "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' staple' ' of' ' the'
+ ' fighting']" is a company that has been around for a long time , and has been a staple of the fighting False America's popular Eternal Champions series cited 4 "[' America', ""'s"", ' popular', ' Eternal', ' Champions']"
+502 267 The company which developed x -1 The company which developed Eternal Champions Sega Eternal Champions "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' staple' ' of' ' the'
+ ' fighting']" is a company that has been around for a long time , and has been a staple of the fighting False Comix Zone and Eternal Champions respectively in an 5 [' Com', 'ix', ' Zone', ' and', ' Eternal', ' Champions']
+503 269 The company which developed x -1 The company which developed Internet Information Services Microsoft Internet Information Services "[' (' 'I' 'IS' ')' ' is' ' a' ' web' ' hosting' ' company' ' that'
+ ' provides' ' web' ' hosting' ' services' ' to' ' individuals' ' and'
+ ' businesses' '.' ' I']" ( I IS ) is a web hosting company that provides web hosting services to individuals and businesses . I False mode while using Internet Information Services (IIS) on a 64-bit 5 [' mode', ' while', ' using', ' Internet', ' Information', ' Services']
+504 269 The company which developed x -1 The company which developed Internet Information Services Microsoft Internet Information Services "[' (' 'I' 'IS' ')' ' is' ' a' ' web' ' hosting' ' company' ' that'
+ ' provides' ' web' ' hosting' ' services' ' to' ' individuals' ' and'
+ ' businesses' '.' ' I']" ( I IS ) is a web hosting company that provides web hosting services to individuals and businesses . I False compatibility mode while using Internet Information Services (IIS) on a 6 [' compatibility', ' mode', ' while', ' using', ' Internet', ' Information', ' Services']
+505 269 The company which developed x -1 The company which developed Internet Information Services Microsoft Internet Information Services "[' (' 'I' 'IS' ')' ' is' ' a' ' web' ' hosting' ' company' ' that'
+ ' provides' ' web' ' hosting' ' services' ' to' ' individuals' ' and'
+ ' businesses' '.' ' I']" ( I IS ) is a web hosting company that provides web hosting services to individuals and businesses . I False compatibility mode while using Internet Information Services (IIS) on a 64-bit 6 [' compatibility', ' mode', ' while', ' using', ' Internet', ' Information', ' Services']
+506 269 The company which developed x -1 The company which developed Internet Information Services Microsoft Internet Information Services "[' (' 'I' 'IS' ')' ' is' ' a' ' web' ' hosting' ' company' ' that'
+ ' provides' ' web' ' hosting' ' services' ' to' ' individuals' ' and'
+ ' businesses' '.' ' I']" ( I IS ) is a web hosting company that provides web hosting services to individuals and businesses . I False version 8.0 of Internet Information Services (IIS). The new 7 [' version', ' 8', '.', '0', ' of', ' Internet', ' Information', ' Services']
+507 269 The company which developed x -1 The company which developed Internet Information Services Microsoft Internet Information Services "[' (' 'I' 'IS' ')' ' is' ' a' ' web' ' hosting' ' company' ' that'
+ ' provides' ' web' ' hosting' ' services' ' to' ' individuals' ' and'
+ ' businesses' '.' ' I']" ( I IS ) is a web hosting company that provides web hosting services to individuals and businesses . I False includes version 8.0 of Internet Information Services (IIS). The new version 8 [' includes', ' version', ' 8', '.', '0', ' of', ' Internet', ' Information', ' Services']
+508 271 The company which developed x -1 The company which developed HyperCard Apple HyperCard "[' for' ' the' ' Apple' ' II' ',' ' and' ' later' ' the' ' Macintosh' ','
+ ' was' ' founded' ' in' ' 1983' ' by' ' Steve' ' Jobs' ',' ' Steve' ' W']" for the Apple II , and later the Macintosh , was founded in 1983 by Steve Jobs , Steve W True constructed in HyperCard. Each Age was a 3 [' constructed', ' in', ' Hyper', 'Card']
+509 271 The company which developed x -1 The company which developed HyperCard Apple HyperCard "[' for' ' the' ' Apple' ' II' ',' ' and' ' later' ' the' ' Macintosh' ','
+ ' was' ' founded' ' in' ' 1983' ' by' ' Steve' ' Jobs' ',' ' Steve' ' W']" for the Apple II , and later the Macintosh , was founded in 1983 by Steve Jobs , Steve W True Cyberpunk! HyperCard stack, a collection 4 [' Cyber', 'punk', '!', ' Hyper', 'Card']
+510 271 The company which developed x -1 The company which developed HyperCard Apple HyperCard "[' for' ' the' ' Apple' ' II' ',' ' and' ' later' ' the' ' Macintosh' ','
+ ' was' ' founded' ' in' ' 1983' ' by' ' Steve' ' Jobs' ',' ' Steve' ' W']" for the Apple II , and later the Macintosh , was founded in 1983 by Steve Jobs , Steve W True constructed in HyperCard. Each Age was a unique 3 [' constructed', ' in', ' Hyper', 'Card']
+511 271 The company which developed x -1 The company which developed HyperCard Apple HyperCard "[' for' ' the' ' Apple' ' II' ',' ' and' ' later' ' the' ' Macintosh' ','
+ ' was' ' founded' ' in' ' 1983' ' by' ' Steve' ' Jobs' ',' ' Steve' ' W']" for the Apple II , and later the Macintosh , was founded in 1983 by Steve Jobs , Steve W True Beyond Cyberpunk! HyperCard stack, a collection 5 [' Beyond', ' Cyber', 'punk', '!', ' Hyper', 'Card']
+512 271 The company which developed x -1 The company which developed HyperCard Apple HyperCard "[' for' ' the' ' Apple' ' II' ',' ' and' ' later' ' the' ' Macintosh' ','
+ ' was' ' founded' ' in' ' 1983' ' by' ' Steve' ' Jobs' ',' ' Steve' ' W']" for the Apple II , and later the Macintosh , was founded in 1983 by Steve Jobs , Steve W True Gareth Branwyn. The HyperCard stack, which included 6 [' Gareth', ' Bran', 'wyn', '.', ' The', ' Hyper', 'Card']
+513 272 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True introduced with Windows Phone 8.1 in 2014. Cortana 6 [' introduced', ' with', ' Windows', ' Phone', ' 8', '.', '1']
+514 272 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True " subsequently on Windows Phone 8.1 upon its release.
+" 6 [' subsequently', ' on', ' Windows', ' Phone', ' 8', '.', '1']
+515 272 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True first introduced with Windows Phone 8.1 in 2014. Cortana 7 [' first', ' introduced', ' with', ' Windows', ' Phone', ' 8', '.', '1']
+516 272 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True " and subsequently on Windows Phone 8.1 upon its release.
+" 7 [' and', ' subsequently', ' on', ' Windows', ' Phone', ' 8', '.', '1']
+517 272 The company which developed x -1 The company which developed Windows Phone 8.1 Microsoft Windows Phone 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+
+ The" True 8.1 to be ported to Windows Phone 8.1 and Xbox One while 11 [' 8', '.', '1', ' to', ' be', ' ported', ' to', ' Windows', ' Phone', ' 8', '.', '1']
+518 274 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True released on the Wii U Virtual Console in 4 [' released', ' on', ' the', ' Wii', ' U']
+519 274 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True a game in the Wii U game NES Remix, which 5 [' a', ' game', ' in', ' the', ' Wii', ' U']
+520 274 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True Metacritic gave the Wii U version 83.75 6 [' Met', 'ac', 'ritic', ' gave', ' the', ' Wii', ' U']
+521 274 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True 2 ['W', 'ii', ' U']
+522 274 The company which developed x -1 The company which developed Wii U Nintendo Wii U "[',' ' Nintendo' ' has' ' been' ' working' ' on' ' a' ' new' ' console'
+ ',' ' cod' 'en' 'amed' ' NX' ',' ' for' ' a' ' while' ' now' '.']" , Nintendo has been working on a new console , cod en amed NX , for a while now . True also released for the Wii U in North America 5 [' also', ' released', ' for', ' the', ' Wii', ' U']
+523 276 The company which developed x -1 The company which developed Final Cut Pro Apple Final Cut Pro "[' X' ' is' ' Apple' '.' '\n' '\n' 'I' ' have' ' been' ' using' ' Final'
+ ' Cut' ' Pro' ' X' ' for' ' a' ' few' ' months' ' now' ' and']" " X is Apple .
+
+ I have been using Final Cut Pro X for a few months now and" True " with Apple's Final Cut Pro editing software.
+" 5 "[' with', ' Apple', ""'s"", ' Final', ' Cut', ' Pro']"
+524 276 The company which developed x -1 The company which developed Final Cut Pro Apple Final Cut Pro "[' X' ' is' ' Apple' '.' '\n' '\n' 'I' ' have' ' been' ' using' ' Final'
+ ' Cut' ' Pro' ' X' ' for' ' a' ' few' ' months' ' now' ' and']" " X is Apple .
+
+ I have been using Final Cut Pro X for a few months now and" True computer running Final Cut Pro that allowed 4 [' computer', ' running', ' Final', ' Cut', ' Pro']
+525 276 The company which developed x -1 The company which developed Final Cut Pro Apple Final Cut Pro "[' X' ' is' ' Apple' '.' '\n' '\n' 'I' ' have' ' been' ' using' ' Final'
+ ' Cut' ' Pro' ' X' ' for' ' a' ' few' ' months' ' now' ' and']" " X is Apple .
+
+ I have been using Final Cut Pro X for a few months now and" True " with Apple's Final Cut Pro editing software.
+" 5 "[' with', ' Apple', ""'s"", ' Final', ' Cut', ' Pro']"
+526 276 The company which developed x -1 The company which developed Final Cut Pro Apple Final Cut Pro "[' X' ' is' ' Apple' '.' '\n' '\n' 'I' ' have' ' been' ' using' ' Final'
+ ' Cut' ' Pro' ' X' ' for' ' a' ' few' ' months' ' now' ' and']" " X is Apple .
+
+ I have been using Final Cut Pro X for a few months now and" True compositing and Final Cut Pro for editing (seven 5 [' compos', 'iting', ' and', ' Final', ' Cut', ' Pro']
+527 276 The company which developed x -1 The company which developed Final Cut Pro Apple Final Cut Pro "[' X' ' is' ' Apple' '.' '\n' '\n' 'I' ' have' ' been' ' using' ' Final'
+ ' Cut' ' Pro' ' X' ' for' ' a' ' few' ' months' ' now' ' and']" " X is Apple .
+
+ I have been using Final Cut Pro X for a few months now and" True " equipped with Apple's Final Cut Pro editing software.
+" 6 "[' equipped', ' with', ' Apple', ""'s"", ' Final', ' Cut', ' Pro']"
+528 277 The company which developed x -1 The company which developed B-52 Stratofortress Boeing B-52 Stratofortress "[',' ' the' ' first' ' jet' ' bomber' ',' ' is' ' now' ' developing' ' a'
+ ' new' ' generation' ' of' ' stealth' ' aircraft' '.' ' The' ' new'
+ ' aircraft' ' is']" , the first jet bomber , is now developing a new generation of stealth aircraft . The new aircraft is False in 1980, and a B-52 Stratofortress in 1983. With Duxford's 11 [' in', ' 1980', ',', ' and', ' a', ' B', '-', '52', ' Strat', 'of', 'ort', 'ress']
+529 277 The company which developed x -1 The company which developed B-52 Stratofortress Boeing B-52 Stratofortress "[',' ' the' ' first' ' jet' ' bomber' ',' ' is' ' now' ' developing' ' a'
+ ' new' ' generation' ' of' ' stealth' ' aircraft' '.' ' The' ' new'
+ ' aircraft' ' is']" , the first jet bomber , is now developing a new generation of stealth aircraft . The new aircraft is False American Boeing B-52 Stratofortress and Convair B-58 8 [' American', ' Boeing', ' B', '-', '52', ' Strat', 'of', 'ort', 'ress']
+530 277 The company which developed x -1 The company which developed B-52 Stratofortress Boeing B-52 Stratofortress "[',' ' the' ' first' ' jet' ' bomber' ',' ' is' ' now' ' developing' ' a'
+ ' new' ' generation' ' of' ' stealth' ' aircraft' '.' ' The' ' new'
+ ' aircraft' ' is']" , the first jet bomber , is now developing a new generation of stealth aircraft . The new aircraft is False the Boeing B-52 Stratofortress with the Mach 8 [' the', ' Boeing', ' B', '-', '52', ' Strat', 'of', 'ort', 'ress']
+531 277 The company which developed x -1 The company which developed B-52 Stratofortress Boeing B-52 Stratofortress "[',' ' the' ' first' ' jet' ' bomber' ',' ' is' ' now' ' developing' ' a'
+ ' new' ' generation' ' of' ' stealth' ' aircraft' '.' ' The' ' new'
+ ' aircraft' ' is']" , the first jet bomber , is now developing a new generation of stealth aircraft . The new aircraft is False 6 ['B', '-', '52', ' Strat', 'of', 'ort', 'ress']
+532 277 The company which developed x -1 The company which developed B-52 Stratofortress Boeing B-52 Stratofortress "[',' ' the' ' first' ' jet' ' bomber' ',' ' is' ' now' ' developing' ' a'
+ ' new' ' generation' ' of' ' stealth' ' aircraft' '.' ' The' ' new'
+ ' aircraft' ' is']" , the first jet bomber , is now developing a new generation of stealth aircraft . The new aircraft is False arrival of the B-52 Stratofortress in 1956. Eighteen 9 [' arrival', ' of', ' the', ' B', '-', '52', ' Strat', 'of', 'ort', 'ress']
+533 278 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False launch ramp, a B-17 Flying Fortress was modified to 8 [' launch', ' ramp', ',', ' a', ' B', '-', '17', ' Flying', ' Fortress']
+534 278 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False with Boeing B-17 Flying Fortress bombers, mounted 6 [' with', ' Boeing', ' B', '-', '17', ' Flying', ' Fortress']
+535 278 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False 4 ['B', '-', '17', ' Flying', ' Fortress']
+536 278 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False and a Boeing B-17 Flying Fortress bomber was shot 7 [' and', ' a', ' Boeing', ' B', '-', '17', ' Flying', ' Fortress']
+537 278 The company which developed x -1 The company which developed B-17 Flying Fortress Boeing B-17 Flying Fortress "[',' ' the' ' first' ' American' ' heavy' ' bomber' ',' ' was' ' founded'
+ ' in' ' 1917' '.' ' The' ' company' ' was' ' formed' ' by' ' a' ' group'
+ ' of']" , the first American heavy bomber , was founded in 1917 . The company was formed by a group of False belief that Boeing B-17 Flying Fortress bombers could deter 7 [' belief', ' that', ' Boeing', ' B', '-', '17', ' Flying', ' Fortress']
+538 279 The company which developed x -1 The company which developed File Explorer Microsoft File Explorer "[' for' ' Windows' ' Phone' ' 7' '.' '5' ',' ' and' ' the' ' company'
+ ' which' ' developed' ' the' ' Windows' ' Phone' ' 7' '.' '5' ' SDK' ',']" for Windows Phone 7 . 5 , and the company which developed the Windows Phone 7 . 5 SDK , False " and a new File Explorer icon.
+" 4 [' and', ' a', ' new', ' File', ' Explorer']
+539 279 The company which developed x -1 The company which developed File Explorer Microsoft File Explorer "[' for' ' Windows' ' Phone' ' 7' '.' '5' ',' ' and' ' the' ' company'
+ ' which' ' developed' ' the' ' Windows' ' Phone' ' 7' '.' '5' ' SDK' ',']" for Windows Phone 7 . 5 , and the company which developed the Windows Phone 7 . 5 SDK , False " and a new File Explorer icon.
+" 4 [' and', ' a', ' new', ' File', ' Explorer']
+540 279 The company which developed x -1 The company which developed File Explorer Microsoft File Explorer "[' for' ' Windows' ' Phone' ' 7' '.' '5' ',' ' and' ' the' ' company'
+ ' which' ' developed' ' the' ' Windows' ' Phone' ' 7' '.' '5' ' SDK' ',']" for Windows Phone 7 . 5 , and the company which developed the Windows Phone 7 . 5 SDK , False " center, and a new File Explorer icon.
+" 6 [' center', ',', ' and', ' a', ' new', ' File', ' Explorer']
+541 281 The company which developed x -1 The company which developed iLife Apple iLife "[',' ' the' ' company' ' that' ' makes' ' the' ' i' 'Life' ' suite' ' of'
+ ' apps' ' for' ' the' ' Mac' ',' ' has' ' released' ' a' ' new'
+ ' version']" , the company that makes the i Life suite of apps for the Mac , has released a new version False from the Apple iLife for Mac royalty 4 [' from', ' the', ' Apple', ' i', 'Life']
+542 281 The company which developed x -1 The company which developed iLife Apple iLife "[',' ' the' ' company' ' that' ' makes' ' the' ' i' 'Life' ' suite' ' of'
+ ' apps' ' for' ' the' ' Mac' ',' ' has' ' released' ' a' ' new'
+ ' version']" , the company that makes the i Life suite of apps for the Mac , has released a new version False included Apple's iLife as well as Microsoft's 4 "[' included', ' Apple', ""'s"", ' i', 'Life']"
+543 281 The company which developed x -1 The company which developed iLife Apple iLife "[',' ' the' ' company' ' that' ' makes' ' the' ' i' 'Life' ' suite' ' of'
+ ' apps' ' for' ' the' ' Mac' ',' ' has' ' released' ' a' ' new'
+ ' version']" , the company that makes the i Life suite of apps for the Mac , has released a new version False with OS X, iLife has also shipped 5 [' with', ' OS', ' X', ',', ' i', 'Life']
+544 281 The company which developed x -1 The company which developed iLife Apple iLife "[',' ' the' ' company' ' that' ' makes' ' the' ' i' 'Life' ' suite' ' of'
+ ' apps' ' for' ' the' ' Mac' ',' ' has' ' released' ' a' ' new'
+ ' version']" , the company that makes the i Life suite of apps for the Mac , has released a new version False Along with OS X, iLife has also shipped 6 [' Along', ' with', ' OS', ' X', ',', ' i', 'Life']
+545 281 The company which developed x -1 The company which developed iLife Apple iLife "[',' ' the' ' company' ' that' ' makes' ' the' ' i' 'Life' ' suite' ' of'
+ ' apps' ' for' ' the' ' Mac' ',' ' has' ' released' ' a' ' new'
+ ' version']" , the company that makes the i Life suite of apps for the Mac , has released a new version False with OS X, iLife has also shipped 5 [' with', ' OS', ' X', ',', ' i', 'Life']
+546 282 The company which developed x -1 The company which developed IA-32 Intel IA-32 "[',' ' the' ' 32' '-' 'bit' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' is' ' now' ' available' ' for' ' download' '.' '\n' '\n' 'The']" ", the 32 - bit version of the operating system , is now available for download .
+
+ The" False 2 ['IA', '-', '32']
+547 282 The company which developed x -1 The company which developed IA-32 Intel IA-32 "[',' ' the' ' 32' '-' 'bit' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' is' ' now' ' available' ' for' ' download' '.' '\n' '\n' 'The']" ", the 32 - bit version of the operating system , is now available for download .
+
+ The" False " Kernel of FreeBSD on IA-32 architecture
+" 6 [' Kernel', ' of', ' FreeBSD', ' on', ' IA', '-', '32']
+548 282 The company which developed x -1 The company which developed IA-32 Intel IA-32 "[',' ' the' ' 32' '-' 'bit' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' is' ' now' ' available' ' for' ' download' '.' '\n' '\n' 'The']" ", the 32 - bit version of the operating system , is now available for download .
+
+ The" False thumb drives, for IA-32 and x86-64 6 [' thumb', ' drives', ',', ' for', ' IA', '-', '32']
+549 282 The company which developed x -1 The company which developed IA-32 Intel IA-32 "[',' ' the' ' 32' '-' 'bit' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' is' ' now' ' available' ' for' ' download' '.' '\n' '\n' 'The']" ", the 32 - bit version of the operating system , is now available for download .
+
+ The" False the Xbox, a variant IA-32 architecture, was 7 [' the', ' Xbox', ',', ' a', ' variant', ' IA', '-', '32']
+550 282 The company which developed x -1 The company which developed IA-32 Intel IA-32 "[',' ' the' ' 32' '-' 'bit' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' is' ' now' ' available' ' for' ' download' '.' '\n' '\n' 'The']" ", the 32 - bit version of the operating system , is now available for download .
+
+ The" False but only for the IA-32 and x86-64 platforms. 6 [' but', ' only', ' for', ' the', ' IA', '-', '32']
+551 284 The company which developed x -1 The company which developed Zune Microsoft Zune "[' HD' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' Z' 'une' ' HD' ' is' ' a'
+ ' digital' ' media' ' player' ' that' ' is' ' designed' ' to' ' be']" " HD is Microsoft .
+
+ The Z une HD is a digital media player that is designed to be" True the iTunes Store, the Zune Marketplace, the Xbox 6 [' the', ' iTunes', ' Store', ',', ' the', ' Z', 'une']
+552 284 The company which developed x -1 The company which developed Zune Microsoft Zune "[' HD' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' Z' 'une' ' HD' ' is' ' a'
+ ' digital' ' media' ' player' ' that' ' is' ' designed' ' to' ' be']" " HD is Microsoft .
+
+ The Z une HD is a digital media player that is designed to be" True models, and at Zune Marketplace 5 [' models', ',', ' and', ' at', ' Z', 'une']
+553 284 The company which developed x -1 The company which developed Zune Microsoft Zune "[' HD' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' Z' 'une' ' HD' ' is' ' a'
+ ' digital' ' media' ' player' ' that' ' is' ' designed' ' to' ' be']" " HD is Microsoft .
+
+ The Z une HD is a digital media player that is designed to be" True Microsoft's Zune have had for a number 3 "[' Microsoft', ""'s"", ' Z', 'une']"
+554 284 The company which developed x -1 The company which developed Zune Microsoft Zune "[' HD' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' Z' 'une' ' HD' ' is' ' a'
+ ' digital' ' media' ' player' ' that' ' is' ' designed' ' to' ' be']" " HD is Microsoft .
+
+ The Z une HD is a digital media player that is designed to be" True times. In the US, Zune and the Xbox 7 [' times', '.', ' In', ' the', ' US', ',', ' Z', 'une']
+555 284 The company which developed x -1 The company which developed Zune Microsoft Zune "[' HD' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' Z' 'une' ' HD' ' is' ' a'
+ ' digital' ' media' ' player' ' that' ' is' ' designed' ' to' ' be']" " HD is Microsoft .
+
+ The Z une HD is a digital media player that is designed to be" True iPod models, and at Zune Marketplace for the 6 [' iPod', ' models', ',', ' and', ' at', ' Z', 'une']
+556 285 The company which developed x -1 The company which developed Windows NT Microsoft Windows NT "[' was' ' a' ' joint' ' venture' ' between' ' Microsoft' ' and' ' N' 'ove'
+ 'll' ',' ' and' ' was' ' the' ' first' ' version' ' of' ' Windows' ' to'
+ ' include']" was a joint venture between Microsoft and N ove ll , and was the first version of Windows to include True being run with Windows NT 4.0 from 1996 and 4 [' being', ' run', ' with', ' Windows', ' NT']
+557 285 The company which developed x -1 The company which developed Windows NT Microsoft Windows NT "[' was' ' a' ' joint' ' venture' ' between' ' Microsoft' ' and' ' N' 'ove'
+ 'll' ',' ' and' ' was' ' the' ' first' ' version' ' of' ' Windows' ' to'
+ ' include']" was a joint venture between Microsoft and N ove ll , and was the first version of Windows to include True compatible with Windows NT version 5.2 3 [' compatible', ' with', ' Windows', ' NT']
+558 285 The company which developed x -1 The company which developed Windows NT Microsoft Windows NT "[' was' ' a' ' joint' ' venture' ' between' ' Microsoft' ' and' ' N' 'ove'
+ 'll' ',' ' and' ' was' ' the' ' first' ' version' ' of' ' Windows' ' to'
+ ' include']" was a joint venture between Microsoft and N ove ll , and was the first version of Windows to include True in 2008 that the Windows NT architecture had 5 [' in', ' 2008', ' that', ' the', ' Windows', ' NT']
+559 285 The company which developed x -1 The company which developed Windows NT Microsoft Windows NT "[' was' ' a' ' joint' ' venture' ' between' ' Microsoft' ' and' ' N' 'ove'
+ 'll' ',' ' and' ' was' ' the' ' first' ' version' ' of' ' Windows' ' to'
+ ' include']" was a joint venture between Microsoft and N ove ll , and was the first version of Windows to include True years earlier, when Windows NT supported RISC 5 [' years', ' earlier', ',', ' when', ' Windows', ' NT']
+560 285 The company which developed x -1 The company which developed Windows NT Microsoft Windows NT "[' was' ' a' ' joint' ' venture' ' between' ' Microsoft' ' and' ' N' 'ove'
+ 'll' ',' ' and' ' was' ' the' ' first' ' version' ' of' ' Windows' ' to'
+ ' include']" was a joint venture between Microsoft and N ove ll , and was the first version of Windows to include True " all use a similar Windows NT kernel).
+" 5 [' all', ' use', ' a', ' similar', ' Windows', ' NT']
+561 286 The company which developed x -1 The company which developed Wii MotionPlus Nintendo Wii MotionPlus "[',' ' a' ' technology' ' that' ' allows' ' the' ' Wii' ' Remote' ' to'
+ ' detect' ' the' ' position' ' of' ' the' ' player' ""'s"" ' hands' ' and'
+ ' arms' ' in']" , a technology that allows the Wii Remote to detect the position of the player 's hands and arms in False the required Wii MotionPlus accessory to 4 [' the', ' required', ' Wii', ' Motion', 'Plus']
+562 286 The company which developed x -1 The company which developed Wii MotionPlus Nintendo Wii MotionPlus "[',' ' a' ' technology' ' that' ' allows' ' the' ' Wii' ' Remote' ' to'
+ ' detect' ' the' ' position' ' of' ' the' ' player' ""'s"" ' hands' ' and'
+ ' arms' ' in']" , a technology that allows the Wii Remote to detect the position of the player 's hands and arms in False relying on Wii MotionPlus freed up other buttons 4 [' relying', ' on', ' Wii', ' Motion', 'Plus']
+563 286 The company which developed x -1 The company which developed Wii MotionPlus Nintendo Wii MotionPlus "[',' ' a' ' technology' ' that' ' allows' ' the' ' Wii' ' Remote' ' to'
+ ' detect' ' the' ' position' ' of' ' the' ' player' ""'s"" ' hands' ' and'
+ ' arms' ' in']" , a technology that allows the Wii Remote to detect the position of the player 's hands and arms in False Nunchuk, and the Wii MotionPlus expansion device 8 [' N', 'unch', 'uk', ',', ' and', ' the', ' Wii', ' Motion', 'Plus']
+564 286 The company which developed x -1 The company which developed Wii MotionPlus Nintendo Wii MotionPlus "[',' ' a' ' technology' ' that' ' allows' ' the' ' Wii' ' Remote' ' to'
+ ' detect' ' the' ' position' ' of' ' the' ' player' ""'s"" ' hands' ' and'
+ ' arms' ' in']" , a technology that allows the Wii Remote to detect the position of the player 's hands and arms in False and protection. The Wii MotionPlus is another accessory 6 [' and', ' protection', '.', ' The', ' Wii', ' Motion', 'Plus']
+565 286 The company which developed x -1 The company which developed Wii MotionPlus Nintendo Wii MotionPlus "[',' ' a' ' technology' ' that' ' allows' ' the' ' Wii' ' Remote' ' to'
+ ' detect' ' the' ' position' ' of' ' the' ' player' ""'s"" ' hands' ' and'
+ ' arms' ' in']" , a technology that allows the Wii Remote to detect the position of the player 's hands and arms in False protection. The Wii MotionPlus is another accessory 5 [' protection', '.', ' The', ' Wii', ' Motion', 'Plus']
+566 287 The company which developed x -1 The company which developed Skype Microsoft Skype "[' for' ' Business' ' Server' ' 2016' ' is' ' a' ' Microsoft' ' product'
+ '.' '\n' '\n' 'The' ' company' ' has' ' been' ' working' ' on' ' the'
+ ' product' ' for']" " for Business Server 2016 is a Microsoft product .
+
+ The company has been working on the product for" True role of Grace via Skype after the script 4 [' role', ' of', ' Grace', ' via', ' Skype']
+567 287 The company which developed x -1 The company which developed Skype Microsoft Skype "[' for' ' Business' ' Server' ' 2016' ' is' ' a' ' Microsoft' ' product'
+ '.' '\n' '\n' 'The' ' company' ' has' ' been' ' working' ' on' ' the'
+ ' product' ' for']" " for Business Server 2016 is a Microsoft product .
+
+ The company has been working on the product for" True celebration of Skype sex and putting 2 [' celebration', ' of', ' Skype']
+568 287 The company which developed x -1 The company which developed Skype Microsoft Skype "[' for' ' Business' ' Server' ' 2016' ' is' ' a' ' Microsoft' ' product'
+ '.' '\n' '\n' 'The' ' company' ' has' ' been' ' working' ' on' ' the'
+ ' product' ' for']" " for Business Server 2016 is a Microsoft product .
+
+ The company has been working on the product for" True remotely via Skype as she attempted 2 [' remotely', ' via', ' Skype']
+569 287 The company which developed x -1 The company which developed Skype Microsoft Skype "[' for' ' Business' ' Server' ' 2016' ' is' ' a' ' Microsoft' ' product'
+ '.' '\n' '\n' 'The' ' company' ' has' ' been' ' working' ' on' ' the'
+ ' product' ' for']" " for Business Server 2016 is a Microsoft product .
+
+ The company has been working on the product for" True had arranged for a Skype call with lead actress 4 [' had', ' arranged', ' for', ' a', ' Skype']
+570 287 The company which developed x -1 The company which developed Skype Microsoft Skype "[' for' ' Business' ' Server' ' 2016' ' is' ' a' ' Microsoft' ' product'
+ '.' '\n' '\n' 'The' ' company' ' has' ' been' ' working' ' on' ' the'
+ ' product' ' for']" " for Business Server 2016 is a Microsoft product .
+
+ The company has been working on the product for" True lectures via Skype at the London College 2 [' lectures', ' via', ' Skype']
+571 292 The company which developed x -1 The company which developed Fonz Sega Fonz "['ie' ',' ' the' ' world' ""'s"" ' first' ' wearable' ',' ' wireless' ','
+ ' recharge' 'able' ',' ' and' ' waterproof' ' Bluetooth' ' speaker' ','
+ ' has' ' announced']" ie , the world 's first wearable , wireless , recharge able , and waterproof Bluetooth speaker , has announced False child, such as The Fonz and the Happy 6 [' child', ',', ' such', ' as', ' The', ' F', 'onz']
+572 292 The company which developed x -1 The company which developed Fonz Sega Fonz "['ie' ',' ' the' ' world' ""'s"" ' first' ' wearable' ',' ' wireless' ','
+ ' recharge' 'able' ',' ' and' ' waterproof' ' Bluetooth' ' speaker' ','
+ ' has' ' announced']" ie , the world 's first wearable , wireless , recharge able , and waterproof Bluetooth speaker , has announced False 1 ['F', 'onz']
+573 292 The company which developed x -1 The company which developed Fonz Sega Fonz "['ie' ',' ' the' ' world' ""'s"" ' first' ' wearable' ',' ' wireless' ','
+ ' recharge' 'able' ',' ' and' ' waterproof' ' Bluetooth' ' speaker' ','
+ ' has' ' announced']" ie , the world 's first wearable , wireless , recharge able , and waterproof Bluetooth speaker , has announced False such as The Fonz and the Happy 4 [' such', ' as', ' The', ' F', 'onz']
+574 292 The company which developed x -1 The company which developed Fonz Sega Fonz "['ie' ',' ' the' ' world' ""'s"" ' first' ' wearable' ',' ' wireless' ','
+ ' recharge' 'able' ',' ' and' ' waterproof' ' Bluetooth' ' speaker' ','
+ ' has' ' announced']" ie , the world 's first wearable , wireless , recharge able , and waterproof Bluetooth speaker , has announced False " and the Holy Fonz =
+" 4 [' and', ' the', ' Holy', ' F', 'onz']
+575 292 The company which developed x -1 The company which developed Fonz Sega Fonz "['ie' ',' ' the' ' world' ""'s"" ' first' ' wearable' ',' ' wireless' ','
+ ' recharge' 'able' ',' ' and' ' waterproof' ' Bluetooth' ' speaker' ','
+ ' has' ' announced']" ie , the world 's first wearable , wireless , recharge able , and waterproof Bluetooth speaker , has announced False " and the Holy Fonz =
+" 4 [' and', ' the', ' Holy', ' F', 'onz']
+576 293 The company which developed x -1 The company which developed IA-64 Intel IA-64 "[',' ' the' ' first' ' 64' '-' 'bit' ' processor' ',' ' was' ' a' ' joint'
+ ' venture' ' between' ' Intel' ' and' ' AMD' '.' ' It' ' was' ' a']" , the first 64 - bit processor , was a joint venture between Intel and AMD . It was a True IDC predicts IA-64 systems sales 5 [' ID', 'C', ' predicts', ' IA', '-', '64']
+577 293 The company which developed x -1 The company which developed IA-64 Intel IA-64 "[',' ' the' ' first' ' 64' '-' 'bit' ' processor' ',' ' was' ' a' ' joint'
+ ' venture' ' between' ' Intel' ' and' ' AMD' '.' ' It' ' was' ' a']" , the first 64 - bit processor , was a joint venture between Intel and AMD . It was a True analysts predicted that IA-64 would dominate in 5 [' analysts', ' predicted', ' that', ' IA', '-', '64']
+578 293 The company which developed x -1 The company which developed IA-64 Intel IA-64 "[',' ' the' ' first' ' 64' '-' 'bit' ' processor' ',' ' was' ' a' ' joint'
+ ' venture' ' between' ' Intel' ' and' ' AMD' '.' ' It' ' was' ' a']" , the first 64 - bit processor , was a joint venture between Intel and AMD . It was a True IDC predicts IA-64 systems sales 5 [' ID', 'C', ' predicts', ' IA', '-', '64']
+579 293 The company which developed x -1 The company which developed IA-64 Intel IA-64 "[',' ' the' ' first' ' 64' '-' 'bit' ' processor' ',' ' was' ' a' ' joint'
+ ' venture' ' between' ' Intel' ' and' ' AMD' '.' ' It' ' was' ' a']" , the first 64 - bit processor , was a joint venture between Intel and AMD . It was a True to develop the IA-64 architecture, derived 5 [' to', ' develop', ' the', ' IA', '-', '64']
+580 293 The company which developed x -1 The company which developed IA-64 Intel IA-64 "[',' ' the' ' first' ' 64' '-' 'bit' ' processor' ',' ' was' ' a' ' joint'
+ ' venture' ' between' ' Intel' ' and' ' AMD' '.' ' It' ' was' ' a']" , the first 64 - bit processor , was a joint venture between Intel and AMD . It was a True June: IDC predicts IA-64 systems sales 7 [' June', ':', ' ID', 'C', ' predicts', ' IA', '-', '64']
+581 294 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True responsible for the Windows Phone 8 game mechanic. 5 [' responsible', ' for', ' the', ' Windows', ' Phone', ' 8']
+582 294 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True responsible for the Windows Phone 8 game mechanic. Ska 5 [' responsible', ' for', ' the', ' Windows', ' Phone', ' 8']
+583 294 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True " Microsoft to optimize Windows Phone 8 for Snapdragon semiconductors.
+" 5 [' Microsoft', ' to', ' optimize', ' Windows', ' Phone', ' 8']
+584 294 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True natively on the Windows Phone 8 Operating System. Developers 6 [' native', 'ly', ' on', ' the', ' Windows', ' Phone', ' 8']
+585 294 The company which developed x -1 The company which developed Windows Phone 8 Microsoft Windows Phone 8 "['.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Phone' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n']" ". 1 is Microsoft .
+
+ The company which developed Windows Phone 8 . 1 is Microsoft .
+" True responsible for the Windows Phone 8 game mechanic. Ska 5 [' responsible', ' for', ' the', ' Windows', ' Phone', ' 8']
+586 297 The company which developed x -1 The company which developed MVS IBM MVS "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False emulator running MVS 3.8J. The JCL, 3 [' emulator', ' running', ' M', 'VS']
+587 297 The company which developed x -1 The company which developed MVS IBM MVS "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ' and' ' has' ' a' ' lot' ' of' ' experience' ' in'
+ ' the' ' field']" is a company that has been around for a long time and has a lot of experience in the field False emulator running MVS 3.8J. The 3 [' emulator', ' running', ' M', 'VS']
+588 299 The company which developed x -1 The company which developed KC-767 Boeing KC-767 "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the'
+ ' Russian' ' Aircraft' ' Corporation' ' Suk' 'ho' 'i' '.' ' The'
+ ' aircraft' ' is' ' a' ' twin']" is a joint venture between Boeing and the Russian Aircraft Corporation Suk ho i . The aircraft is a twin True its selection of the KC-767 Advanced Tanker, 6 [' its', ' selection', ' of', ' the', ' KC', '-', '767']
+589 299 The company which developed x -1 The company which developed KC-767 Boeing KC-767 "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the'
+ ' Russian' ' Aircraft' ' Corporation' ' Suk' 'ho' 'i' '.' ' The'
+ ' aircraft' ' is' ' a' ' twin']" is a joint venture between Boeing and the Russian Aircraft Corporation Suk ho i . The aircraft is a twin True 2 ['KC', '-', '767']
+590 299 The company which developed x -1 The company which developed KC-767 Boeing KC-767 "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the'
+ ' Russian' ' Aircraft' ' Corporation' ' Suk' 'ho' 'i' '.' ' The'
+ ' aircraft' ' is' ' a' ' twin']" is a joint venture between Boeing and the Russian Aircraft Corporation Suk ho i . The aircraft is a twin True tanker than the KC-767, able to transport 5 [' tanker', ' than', ' the', ' KC', '-', '767']
+591 299 The company which developed x -1 The company which developed KC-767 Boeing KC-767 "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the'
+ ' Russian' ' Aircraft' ' Corporation' ' Suk' 'ho' 'i' '.' ' The'
+ ' aircraft' ' is' ' a' ' twin']" is a joint venture between Boeing and the Russian Aircraft Corporation Suk ho i . The aircraft is a twin True 2 ['KC', '-', '767']
+592 299 The company which developed x -1 The company which developed KC-767 Boeing KC-767 "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the'
+ ' Russian' ' Aircraft' ' Corporation' ' Suk' 'ho' 'i' '.' ' The'
+ ' aircraft' ' is' ' a' ' twin']" is a joint venture between Boeing and the Russian Aircraft Corporation Suk ho i . The aircraft is a twin True tanker than the KC-767, able to transport 5 [' tanker', ' than', ' the', ' KC', '-', '767']
+593 303 The company which developed x -1 The company which developed Eclipse IBM Eclipse "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False books), published by Eclipse Comics and 4 [' books', '),', ' published', ' by', ' Eclipse']
+594 303 The company which developed x -1 The company which developed Eclipse IBM Eclipse "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False " Geddes and ""Total Eclipse of the Heart""" 6 "[' G', 'edd', 'es', ' and', ' ""', 'Total', ' Eclipse']"
+595 303 The company which developed x -1 The company which developed Eclipse IBM Eclipse "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False the destroyer HMS Eclipse was badly damaged 3 [' the', ' destroyer', ' HMS', ' Eclipse']
+596 303 The company which developed x -1 The company which developed Eclipse IBM Eclipse "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False " Tyler's ""Total Eclipse of the Heart"". According" 4 "[' Tyler', ""'s"", ' ""', 'Total', ' Eclipse']"
+597 303 The company which developed x -1 The company which developed Eclipse IBM Eclipse "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False first game, Juno Eclipse. Meanwhile, Darth 4 [' first', ' game', ',', ' Juno', ' Eclipse']
+598 304 The company which developed x -1 The company which developed Windows Server 2008 R2 Microsoft Windows Server 2008 R2 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Server' ' 2008' ' R' '2' ' is' ' Microsoft' '.' '\n' '\n'
+ 'The']" " is Microsoft .
+
+ The company which developed Windows Server 2008 R 2 is Microsoft .
+
+ The" True announced that Windows Server 2008 R2 would be the 6 [' announced', ' that', ' Windows', ' Server', ' 2008', ' R', '2']
+599 304 The company which developed x -1 The company which developed Windows Server 2008 R2 Microsoft Windows Server 2008 R2 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Server' ' 2008' ' R' '2' ' is' ' Microsoft' '.' '\n' '\n'
+ 'The']" " is Microsoft .
+
+ The company which developed Windows Server 2008 R 2 is Microsoft .
+
+ The" True announced that Windows Server 2008 R2 would be the last 6 [' announced', ' that', ' Windows', ' Server', ' 2008', ' R', '2']
+600 304 The company which developed x -1 The company which developed Windows Server 2008 R2 Microsoft Windows Server 2008 R2 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Server' ' 2008' ' R' '2' ' is' ' Microsoft' '.' '\n' '\n'
+ 'The']" " is Microsoft .
+
+ The company which developed Windows Server 2008 R 2 is Microsoft .
+
+ The" True added or improved over Windows Server 2008 R2 (with many placing 8 [' added', ' or', ' improved', ' over', ' Windows', ' Server', ' 2008', ' R', '2']
+601 304 The company which developed x -1 The company which developed Windows Server 2008 R2 Microsoft Windows Server 2008 R2 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Server' ' 2008' ' R' '2' ' is' ' Microsoft' '.' '\n' '\n'
+ 'The']" " is Microsoft .
+
+ The company which developed Windows Server 2008 R 2 is Microsoft .
+
+ The" True added or improved over Windows Server 2008 R2 (with many placing 8 [' added', ' or', ' improved', ' over', ' Windows', ' Server', ' 2008', ' R', '2']
+602 306 The company which developed x -1 The company which developed Mac OS 9 Apple Mac OS 9 "[',' ' and' ' the' ' company' ' which' ' developed' ' Mac' ' OS' ' X' ','
+ ' and' ' the' ' company' ' which' ' developed' ' the' ' iPhone' ','
+ ' and' ' the']" , and the company which developed Mac OS X , and the company which developed the iPhone , and the False " third-party ports for Mac OS 9 and Linux.
+" 7 [' third', '-', 'party', ' ports', ' for', ' Mac', ' OS', ' 9']
+603 306 The company which developed x -1 The company which developed Mac OS 9 Apple Mac OS 9 "[',' ' and' ' the' ' company' ' which' ' developed' ' Mac' ' OS' ' X' ','
+ ' and' ' the' ' company' ' which' ' developed' ' the' ' iPhone' ','
+ ' and' ' the']" , and the company which developed Mac OS X , and the company which developed the iPhone , and the False " only supported in Mac OS 9 and the ""Classic""" 5 [' only', ' supported', ' in', ' Mac', ' OS', ' 9']
+604 306 The company which developed x -1 The company which developed Mac OS 9 Apple Mac OS 9 "[',' ' and' ' the' ' company' ' which' ' developed' ' Mac' ' OS' ' X' ','
+ ' and' ' the' ' company' ' which' ' developed' ' the' ' iPhone' ','
+ ' and' ' the']" , and the company which developed Mac OS X , and the company which developed the iPhone , and the False " third-party ports for Mac OS 9 and Linux.
+" 7 [' third', '-', 'party', ' ports', ' for', ' Mac', ' OS', ' 9']
+605 307 The company which developed x -1 The company which developed CarPlay Apple CarPlay "[' is' ' a' ' subsidiary' ' of' ' the' ' company' ' which' ' developed'
+ ' the' ' iPhone' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Car' 'Play' ' is']" " is a subsidiary of the company which developed the iPhone .
+
+ The company which developed Car Play is" False such as AirDrop and CarPlay were not supported. 6 [' such', ' as', ' Air', 'Drop', ' and', ' Car', 'Play']
+606 307 The company which developed x -1 The company which developed CarPlay Apple CarPlay "[' is' ' a' ' subsidiary' ' of' ' the' ' company' ' which' ' developed'
+ ' the' ' iPhone' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Car' 'Play' ' is']" " is a subsidiary of the company which developed the iPhone .
+
+ The company which developed Car Play is" False such as AirDrop and CarPlay were not supported. 6 [' such', ' as', ' Air', 'Drop', ' and', ' Car', 'Play']
+607 308 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False Macintosh or Apple Inc. iPod devices. This issue 5 [' Macintosh', ' or', ' Apple', ' Inc', '.', ' iPod']
+608 308 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False fourth-generation iPod Touch and Macintosh 3 [' fourth', '-', 'generation', ' iPod']
+609 308 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False he returns the iPod she gave him: they 3 [' he', ' returns', ' the', ' iPod']
+610 308 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False competes for the iPod. Although Pam ends 4 [' compet', 'es', ' for', ' the', ' iPod']
+611 308 The company which developed x -1 The company which developed iPod Apple iPod "[' Touch' ',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n'
+ '\n' 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen']" " Touch , the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen" False attaching an iPod music player 2 [' attaching', ' an', ' iPod']
+612 309 The company which developed x -1 The company which developed Pentium III Intel Pentium III "[' and' ' IV' ' processors' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' first' ' Pent' 'ium' ' processors' ',' ' Intel' ','
+ ' is' ' now' ' a']" and IV processors , and the company which developed the first Pent ium processors , Intel , is now a True then-recently released Pentium III processors from 7 [' then', '-', 'recent', 'ly', ' released', ' Pent', 'ium', ' III']
+613 309 The company which developed x -1 The company which developed Pentium III Intel Pentium III "[' and' ' IV' ' processors' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' first' ' Pent' 'ium' ' processors' ',' ' Intel' ','
+ ' is' ' now' ' a']" and IV processors , and the company which developed the first Pent ium processors , Intel , is now a True then-recently released Pentium III processors from Intel 7 [' then', '-', 'recent', 'ly', ' released', ' Pent', 'ium', ' III']
+614 309 The company which developed x -1 The company which developed Pentium III Intel Pentium III "[' and' ' IV' ' processors' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' first' ' Pent' 'ium' ' processors' ',' ' Intel' ','
+ ' is' ' now' ' a']" and IV processors , and the company which developed the first Pent ium processors , Intel , is now a True then-recently released Pentium III processors from Intel 7 [' then', '-', 'recent', 'ly', ' released', ' Pent', 'ium', ' III']
+615 309 The company which developed x -1 The company which developed Pentium III Intel Pentium III "[' and' ' IV' ' processors' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' first' ' Pent' 'ium' ' processors' ',' ' Intel' ','
+ ' is' ' now' ' a']" and IV processors , and the company which developed the first Pent ium processors , Intel , is now a True then-recently released Pentium III processors 7 [' then', '-', 'recent', 'ly', ' released', ' Pent', 'ium', ' III']
+616 312 The company which developed x -1 The company which developed The Legend of Zelda Nintendo The Legend of Zelda "[':' ' Breath' ' of' ' the' ' Wild' ' is' ' now' ' working' ' on' ' a'
+ ' new' ' game' ' for' ' the' ' Nintendo' ' Switch' '.' '\n' '\n' 'The']" ": Breath of the Wild is now working on a new game for the Nintendo Switch .
+
+ The" True 3 ['The', ' Legend', ' of', ' Zelda']
+617 312 The company which developed x -1 The company which developed The Legend of Zelda Nintendo The Legend of Zelda "[':' ' Breath' ' of' ' the' ' Wild' ' is' ' now' ' working' ' on' ' a'
+ ' new' ' game' ' for' ' the' ' Nintendo' ' Switch' '.' '\n' '\n' 'The']" ": Breath of the Wild is now working on a new game for the Nintendo Switch .
+
+ The" True to Nintendo's The Legend of Zelda series of games, 6 "[' to', ' Nintendo', ""'s"", ' The', ' Legend', ' of', ' Zelda']"
+618 312 The company which developed x -1 The company which developed The Legend of Zelda Nintendo The Legend of Zelda "[':' ' Breath' ' of' ' the' ' Wild' ' is' ' now' ' working' ' on' ' a'
+ ' new' ' game' ' for' ' the' ' Nintendo' ' Switch' '.' '\n' '\n' 'The']" ": Breath of the Wild is now working on a new game for the Nintendo Switch .
+
+ The" True The plot of The Legend of Zelda is described in 6 [' The', ' plot', ' of', ' The', ' Legend', ' of', ' Zelda']
+619 312 The company which developed x -1 The company which developed The Legend of Zelda Nintendo The Legend of Zelda "[':' ' Breath' ' of' ' the' ' Wild' ' is' ' now' ' working' ' on' ' a'
+ ' new' ' game' ' for' ' the' ' Nintendo' ' Switch' '.' '\n' '\n' 'The']" ": Breath of the Wild is now working on a new game for the Nintendo Switch .
+
+ The" True walls and in the sky, The Legend of Zelda treasure chest 9 [' walls', ' and', ' in', ' the', ' sky', ',', ' The', ' Legend', ' of', ' Zelda']
+620 312 The company which developed x -1 The company which developed The Legend of Zelda Nintendo The Legend of Zelda "[':' ' Breath' ' of' ' the' ' Wild' ' is' ' now' ' working' ' on' ' a'
+ ' new' ' game' ' for' ' the' ' Nintendo' ' Switch' '.' '\n' '\n' 'The']" ": Breath of the Wild is now working on a new game for the Nintendo Switch .
+
+ The" True 3 ['The', ' Legend', ' of', ' Zelda']
+621 315 The company which developed x -1 The company which developed Windows Mobile 6.5 Microsoft Windows Mobile 6.5 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Mobile' ' 6' '.' '5' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Mobile 6 . 5 is Microsoft .
+
+ The" True " Mobile 6.5 ===
+" 9 [' Mobile', ' 6', '.', '5', ' ===', 'Windows', ' Mobile', ' 6', '.', '5']
+622 315 The company which developed x -1 The company which developed Windows Mobile 6.5 Microsoft Windows Mobile 6.5 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Mobile' ' 6' '.' '5' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Mobile 6 . 5 is Microsoft .
+
+ The" True 4 ['Windows', ' Mobile', ' 6', '.', '5']
+623 315 The company which developed x -1 The company which developed Windows Mobile 6.5 Microsoft Windows Mobile 6.5 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' Mobile' ' 6' '.' '5' ' is' ' Microsoft' '.' '\n' '\n' 'The']" " is Microsoft .
+
+ The company which developed Windows Mobile 6 . 5 is Microsoft .
+
+ The" True " 6.5 ===
+" 8 [' 6', '.', '5', ' ===', 'Windows', ' Mobile', ' 6', '.', '5']
+624 323 The company which developed x -1 The company which developed iTunes Radio Apple iTunes Radio "[',' ' the' ' music' ' streaming' ' service' ',' ' has' ' been'
+ ' acquired' ' by' ' Apple' '.' '\n' '\n' 'The' ' deal' ' is' ' worth'
+ ' $' '3']" ", the music streaming service , has been acquired by Apple .
+
+ The deal is worth $ 3" True " ""Haunted"" on her iTunes Radio channel where" 6 "[' ""', 'Haunted', '""', ' on', ' her', ' iTunes', ' Radio']"
+625 323 The company which developed x -1 The company which developed iTunes Radio Apple iTunes Radio "[',' ' the' ' music' ' streaming' ' service' ',' ' has' ' been'
+ ' acquired' ' by' ' Apple' '.' '\n' '\n' 'The' ' deal' ' is' ' worth'
+ ' $' '3']" ", the music streaming service , has been acquired by Apple .
+
+ The deal is worth $ 3" True interview with iTunes Radio in December 3 [' interview', ' with', ' iTunes', ' Radio']
+626 323 The company which developed x -1 The company which developed iTunes Radio Apple iTunes Radio "[',' ' the' ' music' ' streaming' ' service' ',' ' has' ' been'
+ ' acquired' ' by' ' Apple' '.' '\n' '\n' 'The' ' deal' ' is' ' worth'
+ ' $' '3']" ", the music streaming service , has been acquired by Apple .
+
+ The deal is worth $ 3" True about the video in an iTunes Radio session, Beyoncé said, 6 [' about', ' the', ' video', ' in', ' an', ' iTunes', ' Radio']
+627 323 The company which developed x -1 The company which developed iTunes Radio Apple iTunes Radio "[',' ' the' ' music' ' streaming' ' service' ',' ' has' ' been'
+ ' acquired' ' by' ' Apple' '.' '\n' '\n' 'The' ' deal' ' is' ' worth'
+ ' $' '3']" ", the music streaming service , has been acquired by Apple .
+
+ The deal is worth $ 3" True " ""Haunted"" on her iTunes Radio channel where" 6 "[' ""', 'Haunted', '""', ' on', ' her', ' iTunes', ' Radio']"
+628 323 The company which developed x -1 The company which developed iTunes Radio Apple iTunes Radio "[',' ' the' ' music' ' streaming' ' service' ',' ' has' ' been'
+ ' acquired' ' by' ' Apple' '.' '\n' '\n' 'The' ' deal' ' is' ' worth'
+ ' $' '3']" ", the music streaming service , has been acquired by Apple .
+
+ The deal is worth $ 3" True Beyoncé explained on her iTunes Radio channel that she considered 6 [' Beyon', 'cé', ' explained', ' on', ' her', ' iTunes', ' Radio']
+629 325 The company which developed x -1 The company which developed Metroid Nintendo Metroid "[' Prime' ',' ' Metroid' ' Prime' ' 2' ':' ' Echoes' ',' ' and' ' Metroid'
+ ' Prime' ' 3' ':' ' Corruption' '.' '\n' '\n' 'The' ' game' ' was']" " Prime , Metroid Prime 2 : Echoes , and Metroid Prime 3 : Corruption .
+
+ The game was" False themed after the Metroid series, reviews 3 [' themed', ' after', ' the', ' Metroid']
+630 325 The company which developed x -1 The company which developed Metroid Nintendo Metroid "[' Prime' ',' ' Metroid' ' Prime' ' 2' ':' ' Echoes' ',' ' and' ' Metroid'
+ ' Prime' ' 3' ':' ' Corruption' '.' '\n' '\n' 'The' ' game' ' was']" " Prime , Metroid Prime 2 : Echoes , and Metroid Prime 3 : Corruption .
+
+ The game was" False called Super Metroid's Virtual Console 2 [' called', ' Super', ' Metroid']
+631 325 The company which developed x -1 The company which developed Metroid Nintendo Metroid "[' Prime' ',' ' Metroid' ' Prime' ' 2' ':' ' Echoes' ',' ' and' ' Metroid'
+ ' Prime' ' 3' ':' ' Corruption' '.' '\n' '\n' 'The' ' game' ' was']" " Prime , Metroid Prime 2 : Echoes , and Metroid Prime 3 : Corruption .
+
+ The game was" False 1 ['Met', 'roid']
+632 325 The company which developed x -1 The company which developed Metroid Nintendo Metroid "[' Prime' ',' ' Metroid' ' Prime' ' 2' ':' ' Echoes' ',' ' and' ' Metroid'
+ ' Prime' ' 3' ':' ' Corruption' '.' '\n' '\n' 'The' ' game' ' was']" " Prime , Metroid Prime 2 : Echoes , and Metroid Prime 3 : Corruption .
+
+ The game was" False 2: Echoes 6 [' 2', ':', ' E', 'ch', 'oe', 'Met', 'roid']
+633 325 The company which developed x -1 The company which developed Metroid Nintendo Metroid "[' Prime' ',' ' Metroid' ' Prime' ' 2' ':' ' Echoes' ',' ' and' ' Metroid'
+ ' Prime' ' 3' ':' ' Corruption' '.' '\n' '\n' 'The' ' game' ' was']" " Prime , Metroid Prime 2 : Echoes , and Metroid Prime 3 : Corruption .
+
+ The game was" False compared to Metroid Prime and Echoes. 2 [' compared', ' to', ' Metroid']
+634 326 The company which developed x -1 The company which developed Windows 8.1 Microsoft Windows 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company']" " is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+ The company" True Pack 1 and Windows 8.1 with Update 1, 6 [' Pack', ' 1', ' and', ' Windows', ' 8', '.', '1']
+635 326 The company which developed x -1 The company which developed Windows 8.1 Microsoft Windows 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company']" " is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+ The company" True Windows 7 or Windows 8.1 product key can 6 [' Windows', ' 7', ' or', ' Windows', ' 8', '.', '1']
+636 326 The company which developed x -1 The company which developed Windows 8.1 Microsoft Windows 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company']" " is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+ The company" True along with Windows 8.1 in October 2013. 5 [' along', ' with', ' Windows', ' 8', '.', '1']
+637 326 The company which developed x -1 The company which developed Windows 8.1 Microsoft Windows 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company']" " is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+ The company" True released along with Windows 8.1 in October 6 [' released', ' along', ' with', ' Windows', ' 8', '.', '1']
+638 326 The company which developed x -1 The company which developed Windows 8.1 Microsoft Windows 8.1 "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Windows' ' 8' '.' '1' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company']" " is Microsoft .
+
+ The company which developed Windows 8 . 1 is Microsoft .
+
+ The company" True Windows 7 and Windows 8.1 support for systems 6 [' Windows', ' 7', ' and', ' Windows', ' 8', '.', '1']
+639 330 The company which developed x -1 The company which developed Xcode Apple Xcode "[' is' ' called' ' X' 'code' '.' '\n' '\n' 'A' ':' '\n' '\n' 'I' ' think'
+ ' you' ' are' ' looking' ' for' ' the' ' following' ':']" " is called X code .
+
+ A :
+
+ I think you are looking for the following :" False were developed with Xcode and the iOS SDK 4 [' were', ' developed', ' with', ' X', 'code']
+640 330 The company which developed x -1 The company which developed Xcode Apple Xcode "[' is' ' called' ' X' 'code' '.' '\n' '\n' 'A' ':' '\n' '\n' 'I' ' think'
+ ' you' ' are' ' looking' ' for' ' the' ' following' ':']" " is called X code .
+
+ A :
+
+ I think you are looking for the following :" False developed with Xcode and the iOS SDK and 3 [' developed', ' with', ' X', 'code']
+641 330 The company which developed x -1 The company which developed Xcode Apple Xcode "[' is' ' called' ' X' 'code' '.' '\n' '\n' 'A' ':' '\n' '\n' 'I' ' think'
+ ' you' ' are' ' looking' ' for' ' the' ' following' ':']" " is called X code .
+
+ A :
+
+ I think you are looking for the following :" False developed with Xcode and the iOS SDK and 3 [' developed', ' with', ' X', 'code']
+642 330 The company which developed x -1 The company which developed Xcode Apple Xcode "[' is' ' called' ' X' 'code' '.' '\n' '\n' 'A' ':' '\n' '\n' 'I' ' think'
+ ' you' ' are' ' looking' ' for' ' the' ' following' ':']" " is called X code .
+
+ A :
+
+ I think you are looking for the following :" False developed with Xcode and the iOS 3 [' developed', ' with', ' X', 'code']
+643 330 The company which developed x -1 The company which developed Xcode Apple Xcode "[' is' ' called' ' X' 'code' '.' '\n' '\n' 'A' ':' '\n' '\n' 'I' ' think'
+ ' you' ' are' ' looking' ' for' ' the' ' following' ':']" " is called X code .
+
+ A :
+
+ I think you are looking for the following :" False were developed with Xcode and the iOS SDK 4 [' were', ' developed', ' with', ' X', 'code']
+644 331 The company which developed x -1 The company which developed Windows Mobile Microsoft Windows Mobile "[' 6' '.' '5' ' is' ' a' ' mobile' ' operating' ' system' ' for'
+ ' smartphones' ' and' ' mobile' ' devices' '.' ' It' ' is' ' the'
+ ' successor' ' to' ' Windows']" 6 . 5 is a mobile operating system for smartphones and mobile devices . It is the successor to Windows False " the Palm OS and Windows Mobile in 2006.
+" 5 [' the', ' Palm', ' OS', ' and', ' Windows', ' Mobile']
+645 331 The company which developed x -1 The company which developed Windows Mobile Microsoft Windows Mobile "[' 6' '.' '5' ' is' ' a' ' mobile' ' operating' ' system' ' for'
+ ' smartphones' ' and' ' mobile' ' devices' '.' ' It' ' is' ' the'
+ ' successor' ' to' ' Windows']" 6 . 5 is a mobile operating system for smartphones and mobile devices . It is the successor to Windows False 1 ['Windows', ' Mobile']
+646 331 The company which developed x -1 The company which developed Windows Mobile Microsoft Windows Mobile "[' 6' '.' '5' ' is' ' a' ' mobile' ' operating' ' system' ' for'
+ ' smartphones' ' and' ' mobile' ' devices' '.' ' It' ' is' ' the'
+ ' successor' ' to' ' Windows']" 6 . 5 is a mobile operating system for smartphones and mobile devices . It is the successor to Windows False Microsoft licensed Windows Mobile to four out of the 3 [' Microsoft', ' licensed', ' Windows', ' Mobile']
+647 331 The company which developed x -1 The company which developed Windows Mobile Microsoft Windows Mobile "[' 6' '.' '5' ' is' ' a' ' mobile' ' operating' ' system' ' for'
+ ' smartphones' ' and' ' mobile' ' devices' '.' ' It' ' is' ' the'
+ ' successor' ' to' ' Windows']" 6 . 5 is a mobile operating system for smartphones and mobile devices . It is the successor to Windows False on newer versions of Windows Mobile (Windows Phone 8 or 5 [' on', ' newer', ' versions', ' of', ' Windows', ' Mobile']
+648 331 The company which developed x -1 The company which developed Windows Mobile Microsoft Windows Mobile "[' 6' '.' '5' ' is' ' a' ' mobile' ' operating' ' system' ' for'
+ ' smartphones' ' and' ' mobile' ' devices' '.' ' It' ' is' ' the'
+ ' successor' ' to' ' Windows']" 6 . 5 is a mobile operating system for smartphones and mobile devices . It is the successor to Windows False Electronics, to license Windows Mobile OS on 50 upcoming LG 5 [' Electronics', ',', ' to', ' license', ' Windows', ' Mobile']
+649 332 The company which developed x -1 The company which developed Mac OS Apple Mac OS "[' X' ' Lion' ',' ' the' ' new' ' operating' ' system' ' for' ' the'
+ ' Mac' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system']" X Lion , the new operating system for the Mac , has released a new version of its operating system False " Apple Inc., such as Mac OS X or iPhone.
+" 6 [' Apple', ' Inc', '.,', ' such', ' as', ' Mac', ' OS']
+650 332 The company which developed x -1 The company which developed Mac OS Apple Mac OS "[' X' ' Lion' ',' ' the' ' new' ' operating' ' system' ' for' ' the'
+ ' Mac' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system']" X Lion , the new operating system for the Mac , has released a new version of its operating system False client version of Mac OS X with only command-line 4 [' client', ' version', ' of', ' Mac', ' OS']
+651 332 The company which developed x -1 The company which developed Mac OS Apple Mac OS "[' X' ' Lion' ',' ' the' ' new' ' operating' ' system' ' for' ' the'
+ ' Mac' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system']" X Lion , the new operating system for the Mac , has released a new version of its operating system False like Windows 7, Mac OS (latest releases 5 [' like', ' Windows', ' 7', ',', ' Mac', ' OS']
+652 332 The company which developed x -1 The company which developed Mac OS Apple Mac OS "[' X' ' Lion' ',' ' the' ' new' ' operating' ' system' ' for' ' the'
+ ' Mac' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system']" X Lion , the new operating system for the Mac , has released a new version of its operating system False Macintosh computer with the Mac OS operating system 5 [' Macintosh', ' computer', ' with', ' the', ' Mac', ' OS']
+653 332 The company which developed x -1 The company which developed Mac OS Apple Mac OS "[' X' ' Lion' ',' ' the' ' new' ' operating' ' system' ' for' ' the'
+ ' Mac' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' its'
+ ' operating' ' system']" X Lion , the new operating system for the Mac , has released a new version of its operating system False parts of Apple's Mac OS X operating system 5 "[' parts', ' of', ' Apple', ""'s"", ' Mac', ' OS']"
+654 333 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False like Altered Beast, Golden Axe and Ghouls'n 6 [' like', ' Al', 'tered', ' Beast', ',', ' Golden', ' Axe']
+655 333 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False Rage, Final Fight and Golden Axe are other classics 6 [' Rage', ',', ' Final', ' Fight', ' and', ' Golden', ' Axe']
+656 333 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False like Altered Beast, Golden Axe and Ghouls'n 6 [' like', ' Al', 'tered', ' Beast', ',', ' Golden', ' Axe']
+657 333 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False titles such as Golden Axe and Final Fight 4 [' titles', ' such', ' as', ' Golden', ' Axe']
+658 333 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False Final Fight and Golden Axe are other classics 4 [' Final', ' Fight', ' and', ' Golden', ' Axe']
+659 334 The company which developed x -1 The company which developed XScale Intel XScale "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' embedded']" is a company that has been around for a long time , and has been a leader in the embedded False processor such as the Intel XScale or the Samsung and 6 [' processor', ' such', ' as', ' the', ' Intel', ' X', 'Scale']
+660 334 The company which developed x -1 The company which developed XScale Intel XScale "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' embedded']" is a company that has been around for a long time , and has been a leader in the embedded False the Intel XScale or the Samsung and 3 [' the', ' Intel', ' X', 'Scale']
+661 334 The company which developed x -1 The company which developed XScale Intel XScale "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' embedded']" is a company that has been around for a long time , and has been a leader in the embedded False as the Intel XScale or the Samsung 4 [' as', ' the', ' Intel', ' X', 'Scale']
+662 335 The company which developed x -1 The company which developed iOS 6 Apple iOS 6 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 6' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 6 . 0 . 1 is Apple True September 19, 2012, iOS 6, which contains 6 [' September', ' 19', ',', ' 2012', ',', ' iOS', ' 6']
+663 335 The company which developed x -1 The company which developed iOS 6 Apple iOS 6 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 6' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 6 . 0 . 1 is Apple True have been added in iOS 6: photos (already partially 5 [' have', ' been', ' added', ' in', ' iOS', ' 6']
+664 335 The company which developed x -1 The company which developed iOS 6 Apple iOS 6 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 6' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 6 . 0 . 1 is Apple True September 19, 2012, iOS 6 was released to 6 [' September', ' 19', ',', ' 2012', ',', ' iOS', ' 6']
+665 335 The company which developed x -1 The company which developed iOS 6 Apple iOS 6 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 6' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 6 . 0 . 1 is Apple True of the iPhone 5, iOS 6 was released for 6 [' of', ' the', ' iPhone', ' 5', ',', ' iOS', ' 6']
+666 335 The company which developed x -1 The company which developed iOS 6 Apple iOS 6 "['.' '0' '.' '1' ' is' ' Apple' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' iOS' ' 6' '.' '0' '.' '1' ' is' ' Apple']" . 0 . 1 is Apple , and the company which developed iOS 6 . 0 . 1 is Apple True Lion, and iOS 6 at the Worldwide 4 [' Lion', ',', ' and', ' iOS', ' 6']
+667 339 The company which developed x -1 The company which developed iPod shuffle Apple iPod shuffle "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " resistant to any iPod shuffle function."" Flak Magazine" 4 [' resistant', ' to', ' any', ' iPod', ' shuffle']
+668 339 The company which developed x -1 The company which developed iPod shuffle Apple iPod shuffle "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' '.' ' It' ' is' ' a' ' company' ' that' ' has' ' been'
+ ' around']" is a company that has been around for a long time . It is a company that has been around False " resistant to any iPod shuffle function.""" 4 [' resistant', ' to', ' any', ' iPod', ' shuffle']
+669 344 The company which developed x -1 The company which developed Motion Apple Motion "[' Control' ' is' ' a' ' leading' ' manufacturer' ' of' ' high' '-'
+ 'quality' ',' ' high' '-' 'performance' ',' ' high' '-' 'quality' ','
+ ' high' '-']" Control is a leading manufacturer of high - quality , high - performance , high - quality , high - False 2012 at the Canadian Motion Picture Park 4 [' 2012', ' at', ' the', ' Canadian', ' Motion']
+670 344 The company which developed x -1 The company which developed Motion Apple Motion "[' Control' ' is' ' a' ' leading' ' manufacturer' ' of' ' high' '-'
+ 'quality' ',' ' high' '-' 'performance' ',' ' high' '-' 'quality' ','
+ ' high' '-']" Control is a leading manufacturer of high - quality , high - performance , high - quality , high - False 1911 issue of The Motion Picture Story 4 [' 1911', ' issue', ' of', ' The', ' Motion']
+671 344 The company which developed x -1 The company which developed Motion Apple Motion "[' Control' ' is' ' a' ' leading' ' manufacturer' ' of' ' high' '-'
+ 'quality' ',' ' high' '-' 'performance' ',' ' high' '-' 'quality' ','
+ ' high' '-']" Control is a leading manufacturer of high - quality , high - performance , high - quality , high - False production company Motion Theory and 2 [' production', ' company', ' Motion']
+672 344 The company which developed x -1 The company which developed Motion Apple Motion "[' Control' ' is' ' a' ' leading' ' manufacturer' ' of' ' high' '-'
+ 'quality' ',' ' high' '-' 'performance' ',' ' high' '-' 'quality' ','
+ ' high' '-']" Control is a leading manufacturer of high - quality , high - performance , high - quality , high - False 0 ['Motion']
+673 344 The company which developed x -1 The company which developed Motion Apple Motion "[' Control' ' is' ' a' ' leading' ' manufacturer' ' of' ' high' '-'
+ 'quality' ',' ' high' '-' 'performance' ',' ' high' '-' 'quality' ','
+ ' high' '-']" Control is a leading manufacturer of high - quality , high - performance , high - quality , high - False time of The Motion Picture, Chapel 3 [' time', ' of', ' The', ' Motion']
+674 346 The company which developed x -1 The company which developed Android TV Google Android TV "[',' ' the' ' company' ' that' ' has' ' been' ' working' ' on' ' the'
+ ' project' ' for' ' years' ',' ' has' ' finally' ' released' ' the'
+ ' first' ' beta' ' version']" , the company that has been working on the project for years , has finally released the first beta version False " re-released on Shield Android TV in May 2016.
+" 6 [' re', '-', 'released', ' on', ' Shield', ' Android', ' TV']
+675 346 The company which developed x -1 The company which developed Android TV Google Android TV "[',' ' the' ' company' ' that' ' has' ' been' ' working' ' on' ' the'
+ ' project' ' for' ' years' ',' ' has' ' finally' ' released' ' the'
+ ' first' ' beta' ' version']" , the company that has been working on the project for years , has finally released the first beta version False " re-released on Shield Android TV in May 2016.
+" 6 [' re', '-', 'released', ' on', ' Shield', ' Android', ' TV']
+676 346 The company which developed x -1 The company which developed Android TV Google Android TV "[',' ' the' ' company' ' that' ' has' ' been' ' working' ' on' ' the'
+ ' project' ' for' ' years' ',' ' has' ' finally' ' released' ' the'
+ ' first' ' beta' ' version']" , the company that has been working on the project for years , has finally released the first beta version False " re-released on Shield Android TV in May 2016.
+" 6 [' re', '-', 'released', ' on', ' Shield', ' Android', ' TV']
+677 346 The company which developed x -1 The company which developed Android TV Google Android TV "[',' ' the' ' company' ' that' ' has' ' been' ' working' ' on' ' the'
+ ' project' ' for' ' years' ',' ' has' ' finally' ' released' ' the'
+ ' first' ' beta' ' version']" , the company that has been working on the project for years , has finally released the first beta version False " re-released on Shield Android TV in May 2016.
+" 6 [' re', '-', 'released', ' on', ' Shield', ' Android', ' TV']
+678 350 The company which developed x -1 The company which developed iMovie Apple iMovie "[' for' ' iOS' ' is' ' now' ' available' ' for' ' the' ' iPad' '.' '\n'
+ '\n' 'The' ' app' ' is' ' free' ' to' ' download' ' and' ' is'
+ ' available']" " for iOS is now available for the iPad .
+
+ The app is free to download and is available" False iTunes ® media, iMovie ® projects and 5 [' iTunes', ' ®', ' media', ',', ' i', 'Movie']
+679 352 The company which developed x -1 The company which developed Squeak Apple Squeak "[',' ' the' ' open' ' source' ' programming' ' language' ',' ' is' ' now'
+ ' offering' ' a' ' new' ' version' ' of' ' the' ' language' ',' ' S'
+ 'que' 'ak']" , the open source programming language , is now offering a new version of the language , S que ak False and Bubble and Squeak (a collection 5 [' and', ' Bubble', ' and', ' S', 'que', 'ak']
+680 352 The company which developed x -1 The company which developed Squeak Apple Squeak "[',' ' the' ' open' ' source' ' programming' ' language' ',' ' is' ' now'
+ ' offering' ' a' ' new' ' version' ' of' ' the' ' language' ',' ' S'
+ 'que' 'ak']" , the open source programming language , is now offering a new version of the language , S que ak False Bubble and Squeak (a collection 4 [' Bubble', ' and', ' S', 'que', 'ak']
+681 352 The company which developed x -1 The company which developed Squeak Apple Squeak "[',' ' the' ' open' ' source' ' programming' ' language' ',' ' is' ' now'
+ ' offering' ' a' ' new' ' version' ' of' ' the' ' language' ',' ' S'
+ 'que' 'ak']" , the open source programming language , is now offering a new version of the language , S que ak False antagonist. In Kirby: Squeak Squad, he appears 7 [' antagonist', '.', ' In', ' Kirby', ':', ' S', 'que', 'ak']
+682 352 The company which developed x -1 The company which developed Squeak Apple Squeak "[',' ' the' ' open' ' source' ' programming' ' language' ',' ' is' ' now'
+ ' offering' ' a' ' new' ' version' ' of' ' the' ' language' ',' ' S'
+ 'que' 'ak']" , the open source programming language , is now offering a new version of the language , S que ak False The comic strip Squeak the Mouse is 5 [' The', ' comic', ' strip', ' S', 'que', 'ak']
+683 352 The company which developed x -1 The company which developed Squeak Apple Squeak "[',' ' the' ' open' ' source' ' programming' ' language' ',' ' is' ' now'
+ ' offering' ' a' ' new' ' version' ' of' ' the' ' language' ',' ' S'
+ 'que' 'ak']" , the open source programming language , is now offering a new version of the language , S que ak False and Bubble and Squeak (a collection of outtakes 5 [' and', ' Bubble', ' and', ' S', 'que', 'ak']
+684 355 The company which developed x -1 The company which developed Pentium Intel Pentium "[' 4' ' and' ' Pent' 'ium' ' 4' ' Pro' ' processors' '.' '\n' '\n' 'The'
+ ' Pent' 'ium' ' 4' ' is' ' a' ' 32' '-' 'bit' ' micro']" " 4 and Pent ium 4 Pro processors .
+
+ The Pent ium 4 is a 32 - bit micro" False " ""dedicated Pentium tucked away in" 4 "[' ""', 'ded', 'icated', ' Pent', 'ium']"
+685 355 The company which developed x -1 The company which developed Pentium Intel Pentium "[' 4' ' and' ' Pent' 'ium' ' 4' ' Pro' ' processors' '.' '\n' '\n' 'The'
+ ' Pent' 'ium' ' 4' ' is' ' a' ' 32' '-' 'bit' ' micro']" " 4 and Pent ium 4 Pro processors .
+
+ The Pent ium 4 is a 32 - bit micro" False Hawaii. It housed 960 Pentium III-933 MHz 6 [' Hawaii', '.', ' It', ' housed', ' 960', ' Pent', 'ium']
+686 355 The company which developed x -1 The company which developed Pentium Intel Pentium "[' 4' ' and' ' Pent' 'ium' ' 4' ' Pro' ' processors' '.' '\n' '\n' 'The'
+ ' Pent' 'ium' ' 4' ' is' ' a' ' 32' '-' 'bit' ' micro']" " 4 and Pent ium 4 Pro processors .
+
+ The Pent ium 4 is a 32 - bit micro" False including dual-processor Pentium Pro 200s and Athlon 5 [' including', ' dual', '-', 'processor', ' Pent', 'ium']
+687 355 The company which developed x -1 The company which developed Pentium Intel Pentium "[' 4' ' and' ' Pent' 'ium' ' 4' ' Pro' ' processors' '.' '\n' '\n' 'The'
+ ' Pent' 'ium' ' 4' ' is' ' a' ' 32' '-' 'bit' ' micro']" " 4 and Pent ium 4 Pro processors .
+
+ The Pent ium 4 is a 32 - bit micro" False " available on Pentium processors.
+" 3 [' available', ' on', ' Pent', 'ium']
+688 355 The company which developed x -1 The company which developed Pentium Intel Pentium "[' 4' ' and' ' Pent' 'ium' ' 4' ' Pro' ' processors' '.' '\n' '\n' 'The'
+ ' Pent' 'ium' ' 4' ' is' ' a' ' 32' '-' 'bit' ' micro']" " 4 and Pent ium 4 Pro processors .
+
+ The Pent ium 4 is a 32 - bit micro" False It housed 960 Pentium III-933 MHz workstations. 4 [' It', ' housed', ' 960', ' Pent', 'ium']
+689 356 The company which developed x -1 The company which developed Xgrid Apple Xgrid "[' is' ' a' ' leading' ' provider' ' of' ' cloud' '-' 'based' ' software'
+ ' solutions' ' for' ' the' ' global' ' energy' ' industry' '.' ' X'
+ 'grid' ' is' ' a']" is a leading provider of cloud - based software solutions for the global energy industry . X grid is a False function of an Xgrid, as this node is 4 [' function', ' of', ' an', ' X', 'grid']
+690 356 The company which developed x -1 The company which developed Xgrid Apple Xgrid "[' is' ' a' ' leading' ' provider' ' of' ' cloud' '-' 'based' ' software'
+ ' solutions' ' for' ' the' ' global' ' energy' ' industry' '.' ' X'
+ 'grid' ' is' ' a']" is a leading provider of cloud - based software solutions for the global energy industry . X grid is a False 1 ['X', 'grid']
+691 356 The company which developed x -1 The company which developed Xgrid Apple Xgrid "[' is' ' a' ' leading' ' provider' ' of' ' cloud' '-' 'based' ' software'
+ ' solutions' ' for' ' the' ' global' ' energy' ' industry' '.' ' X'
+ 'grid' ' is' ' a']" is a leading provider of cloud - based software solutions for the global energy industry . X grid is a False Within the Xgrid protocol, three types 3 [' Within', ' the', ' X', 'grid']
+692 356 The company which developed x -1 The company which developed Xgrid Apple Xgrid "[' is' ' a' ' leading' ' provider' ' of' ' cloud' '-' 'based' ' software'
+ ' solutions' ' for' ' the' ' global' ' energy' ' industry' '.' ' X'
+ 'grid' ' is' ' a']" is a leading provider of cloud - based software solutions for the global energy industry . X grid is a False possible to access Xgrid from the command 4 [' possible', ' to', ' access', ' X', 'grid']
+693 356 The company which developed x -1 The company which developed Xgrid Apple Xgrid "[' is' ' a' ' leading' ' provider' ' of' ' cloud' '-' 'based' ' software'
+ ' solutions' ' for' ' the' ' global' ' energy' ' industry' '.' ' X'
+ 'grid' ' is' ' a']" is a leading provider of cloud - based software solutions for the global energy industry . X grid is a False 1 ['X', 'grid']
+694 357 The company which developed x -1 The company which developed sRGB Microsoft sRGB "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False limits, clarified sRGB video quantization 4 [' limits', ',', ' clarified', ' s', 'RGB']
+695 357 The company which developed x -1 The company which developed sRGB Microsoft sRGB "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False implement only the sRGB color space 4 [' implement', ' only', ' the', ' s', 'RGB']
+696 357 The company which developed x -1 The company which developed sRGB Microsoft sRGB "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False 36-bit and 48-bit xvYCC, sRGB, or YCbCr, 13 [' 36', '-', 'bit', ' and', ' 48', '-', 'bit', ' x', 'v', 'Y', 'CC', ',', ' s', 'RGB']
+697 357 The company which developed x -1 The company which developed sRGB Microsoft sRGB "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False limits, clarified sRGB video quantization 4 [' limits', ',', ' clarified', ' s', 'RGB']
+698 357 The company which developed x -1 The company which developed sRGB Microsoft sRGB "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False and 48-bit xvYCC, sRGB, or YCbCr, 10 [' and', ' 48', '-', 'bit', ' x', 'v', 'Y', 'CC', ',', ' s', 'RGB']
+699 358 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False 1958, from a Boeing B-47 Stratojet carrier aircraft, 10 [' 1958', ',', ' from', ' a', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+700 358 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False 1958, from a Boeing B-47 Stratojet carrier aircraft, 10 [' 1958', ',', ' from', ' a', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+701 358 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False of the Boeing B-47 Stratojet in the early 1950s. 8 [' of', ' the', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+702 358 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False involved in the Boeing B-47 Stratojet program. He subsequently 9 [' involved', ' in', ' the', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+703 358 The company which developed x -1 The company which developed B-47 Stratojet Boeing B-47 Stratojet "[',' ' the' ' first' ' jet' '-' 'powered' ' bomber' ',' ' was' ' formed'
+ ' in' ' 1946' '.' ' The' ' company' ' was' ' formed' ' by' ' the'
+ ' merger']" , the first jet - powered bomber , was formed in 1946 . The company was formed by the merger False preceding Boeing B-47 Stratojet strategic bomber. 7 [' preceding', ' Boeing', ' B', '-', '47', ' Str', 'ato', 'jet']
+704 361 The company which developed x -1 The company which developed WriteNow Apple WriteNow "['!' ' is' ' a' ' leading' ' provider' ' of' ' cloud' '-' 'based'
+ ' software' ' for' ' the' ' legal' ' industry' '.' ' Write' 'Now' '!'
+ ' is' ' a']" ! is a leading provider of cloud - based software for the legal industry . Write Now ! is a False T / Maker's WriteNow word processor, 5 "[' T', ' /', ' Maker', ""'s"", ' Write', 'Now']"
+705 361 The company which developed x -1 The company which developed WriteNow Apple WriteNow "['!' ' is' ' a' ' leading' ' provider' ' of' ' cloud' '-' 'based'
+ ' software' ' for' ' the' ' legal' ' industry' '.' ' Write' 'Now' '!'
+ ' is' ' a']" ! is a leading provider of cloud - based software for the legal industry . Write Now ! is a False included T / Maker's WriteNow word processor, 6 "[' included', ' T', ' /', ' Maker', ""'s"", ' Write', 'Now']"
+706 361 The company which developed x -1 The company which developed WriteNow Apple WriteNow "['!' ' is' ' a' ' leading' ' provider' ' of' ' cloud' '-' 'based'
+ ' software' ' for' ' the' ' legal' ' industry' '.' ' Write' 'Now' '!'
+ ' is' ' a']" ! is a leading provider of cloud - based software for the legal industry . Write Now ! is a False included T / Maker's WriteNow word processor, 6 "[' included', ' T', ' /', ' Maker', ""'s"", ' Write', 'Now']"
+707 362 The company which developed x -1 The company which developed Windows 95 Microsoft Windows 95 "[',' ' Windows' ' 98' ',' ' Windows' ' NT' ',' ' Windows' ' 2000' ','
+ ' Windows' ' XP' ',' ' Windows' ' Vista' ',' ' Windows' ' 7' ','
+ ' Windows']" , Windows 98 , Windows NT , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows False received a Windows 95 re-release 3 [' received', ' a', ' Windows', ' 95']
+708 362 The company which developed x -1 The company which developed Windows 95 Microsoft Windows 95 "[',' ' Windows' ' 98' ',' ' Windows' ' NT' ',' ' Windows' ' 2000' ','
+ ' Windows' ' XP' ',' ' Windows' ' Vista' ',' ' Windows' ' 7' ','
+ ' Windows']" , Windows 98 , Windows NT , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows False Windows from Windows 95 to Windows XP, allows 3 [' Windows', ' from', ' Windows', ' 95']
+709 362 The company which developed x -1 The company which developed Windows 95 Microsoft Windows 95 "[',' ' Windows' ' 98' ',' ' Windows' ' NT' ',' ' Windows' ' 2000' ','
+ ' Windows' ' XP' ',' ' Windows' ' Vista' ',' ' Windows' ' 7' ','
+ ' Windows']" , Windows 98 , Windows NT , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows False developed for the Windows 95 and 98 operating 4 [' developed', ' for', ' the', ' Windows', ' 95']
+710 362 The company which developed x -1 The company which developed Windows 95 Microsoft Windows 95 "[',' ' Windows' ' 98' ',' ' Windows' ' NT' ',' ' Windows' ' 2000' ','
+ ' Windows' ' XP' ',' ' Windows' ' Vista' ',' ' Windows' ' 7' ','
+ ' Windows']" , Windows 98 , Windows NT , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows False Developed for Windows 95 and 98, the game 4 [' Develop', 'ed', ' for', ' Windows', ' 95']
+711 362 The company which developed x -1 The company which developed Windows 95 Microsoft Windows 95 "[',' ' Windows' ' 98' ',' ' Windows' ' NT' ',' ' Windows' ' 2000' ','
+ ' Windows' ' XP' ',' ' Windows' ' Vista' ',' ' Windows' ' 7' ','
+ ' Windows']" , Windows 98 , Windows NT , Windows 2000 , Windows XP , Windows Vista , Windows 7 , Windows False " four-player option for Windows 95 versions.
+" 6 [' four', '-', 'player', ' option', ' for', ' Windows', ' 95']
+712 364 The company which developed x -1 The company which developed Mario Bros. Nintendo Mario Bros. "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for' ' the'
+ ' Nintendo' ' Switch' '.' '\n' '\n' 'The' ' game' ' is' ' called'
+ ' Mario' ' +']" " is now working on a new game for the Nintendo Switch .
+
+ The game is called Mario +" True as New Super Mario Bros. Wii. On January 5 [' as', ' New', ' Super', ' Mario', ' Bros', '.']
+713 364 The company which developed x -1 The company which developed Mario Bros. Nintendo Mario Bros. "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for' ' the'
+ ' Nintendo' ' Switch' '.' '\n' '\n' 'The' ' game' ' is' ' called'
+ ' Mario' ' +']" " is now working on a new game for the Nintendo Switch .
+
+ The game is called Mario +" True anniversary of Super Mario Bros. The UK version of 5 [' anniversary', ' of', ' Super', ' Mario', ' Bros', '.']
+714 364 The company which developed x -1 The company which developed Mario Bros. Nintendo Mario Bros. "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for' ' the'
+ ' Nintendo' ' Switch' '.' '\n' '\n' 'The' ' game' ' is' ' called'
+ ' Mario' ' +']" " is now working on a new game for the Nintendo Switch .
+
+ The game is called Mario +" True games like Super Mario Bros. where the character 5 [' games', ' like', ' Super', ' Mario', ' Bros', '.']
+715 364 The company which developed x -1 The company which developed Mario Bros. Nintendo Mario Bros. "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for' ' the'
+ ' Nintendo' ' Switch' '.' '\n' '\n' 'The' ' game' ' is' ' called'
+ ' Mario' ' +']" " is now working on a new game for the Nintendo Switch .
+
+ The game is called Mario +" True 2 ['Mario', ' Bros', '.']
+716 364 The company which developed x -1 The company which developed Mario Bros. Nintendo Mario Bros. "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for' ' the'
+ ' Nintendo' ' Switch' '.' '\n' '\n' 'The' ' game' ' is' ' called'
+ ' Mario' ' +']" " is now working on a new game for the Nintendo Switch .
+
+ The game is called Mario +" True 2 ['Mario', ' Bros', '.']
+717 368 The company which developed x -1 The company which developed Final Fantasy III Square Final Fantasy III "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False Publishing in 2006 as Final Fantasy III Original Soundtrack, 6 [' Publishing', ' in', ' 2006', ' as', ' Final', ' Fantasy', ' III']
+718 368 The company which developed x -1 The company which developed Final Fantasy III Square Final Fantasy III "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False 2 ['Final', ' Fantasy', ' III']
+719 368 The company which developed x -1 The company which developed Final Fantasy III Square Final Fantasy III "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False 2 ['Final', ' Fantasy', ' III']
+720 368 The company which developed x -1 The company which developed Final Fantasy III Square Final Fantasy III "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False 2 ['Final', ' Fantasy', ' III']
+721 368 The company which developed x -1 The company which developed Final Fantasy III Square Final Fantasy III "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' a'
+ ' huge' ' success' ',' ' and' ' the' ' company' ' was' ' able' ' to'
+ ' sell']" , the first game in the series , was a huge success , and the company was able to sell False duration of the game, Final Fantasy III introduces the 7 [' duration', ' of', ' the', ' game', ',', ' Final', ' Fantasy', ' III']
+722 369 The company which developed x -1 The company which developed Clang Intel Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to 4 [' software', ' licenses', '.', ' Cl', 'ang']
+723 369 The company which developed x -1 The company which developed Clang Intel Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+724 369 The company which developed x -1 The company which developed Clang Intel Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to 4 [' software', ' licenses', '.', ' Cl', 'ang']
+725 369 The company which developed x -1 The company which developed Clang Intel Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+726 369 The company which developed x -1 The company which developed Clang Intel Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+727 370 The company which developed x -1 The company which developed Front Mission Square Front Mission "[':' ' Gun' ' Hazard' ' is' ' a' ' game' ' that' ' is' ' a' ' bit' ' of'
+ ' a' ' departure' ' from' ' the' ' usual' ' fare' ' of' ' the' ' series']" : Gun Hazard is a game that is a bit of a departure from the usual fare of the series False " Music of the Front Mission series =
+" 4 [' Music', ' of', ' the', ' Front', ' Mission']
+728 370 The company which developed x -1 The company which developed Front Mission Square Front Mission "[':' ' Gun' ' Hazard' ' is' ' a' ' game' ' that' ' is' ' a' ' bit' ' of'
+ ' a' ' departure' ' from' ' the' ' usual' ' fare' ' of' ' the' ' series']" : Gun Hazard is a game that is a bit of a departure from the usual fare of the series False for her work on the Front Mission series, The 6 [' for', ' her', ' work', ' on', ' the', ' Front', ' Mission']
+729 370 The company which developed x -1 The company which developed Front Mission Square Front Mission "[':' ' Gun' ' Hazard' ' is' ' a' ' game' ' that' ' is' ' a' ' bit' ' of'
+ ' a' ' departure' ' from' ' the' ' usual' ' fare' ' of' ' the' ' series']" : Gun Hazard is a game that is a bit of a departure from the usual fare of the series False characters. In May 2005, Front Mission Online became 7 [' characters', '.', ' In', ' May', ' 2005', ',', ' Front', ' Mission']
+730 370 The company which developed x -1 The company which developed Front Mission Square Front Mission "[':' ' Gun' ' Hazard' ' is' ' a' ' game' ' that' ' is' ' a' ' bit' ' of'
+ ' a' ' departure' ' from' ' the' ' usual' ' fare' ' of' ' the' ' series']" : Gun Hazard is a game that is a bit of a departure from the usual fare of the series False The soundtrack of Front Mission 4, the fourth game 4 [' The', ' soundtrack', ' of', ' Front', ' Mission']
+731 370 The company which developed x -1 The company which developed Front Mission Square Front Mission "[':' ' Gun' ' Hazard' ' is' ' a' ' game' ' that' ' is' ' a' ' bit' ' of'
+ ' a' ' departure' ' from' ' the' ' usual' ' fare' ' of' ' the' ' series']" : Gun Hazard is a game that is a bit of a departure from the usual fare of the series False soundtrack of Front Mission 3, the third game 3 [' soundtrack', ' of', ' Front', ' Mission']
+732 375 The company which developed x -1 The company which developed Hierarchical File System Apple Hierarchical File System "[' (' 'H' 'FS' ')' ' is' ' a' ' file' ' system' ' developed' ' by'
+ ' Apple' ' Inc' '.' ' for' ' Mac' ' OS' ' X' '.' ' It' ' is']" ( H FS ) is a file system developed by Apple Inc . for Mac OS X . It is True 7.5.5. A hidden Hierarchical File System (HFS) disk volume 12 [' 7', '.', '5', '.', '5', '.', ' A', ' hidden', ' Hier', 'arch', 'ical', ' File', ' System']
+733 376 The company which developed x -1 The company which developed EA-18G Growler Boeing EA-18G Growler "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the' ' US'
+ ' Navy' '.' ' The' ' EA' '-' '18' 'G' ' Grow' 'ler' ' is' ' a']" is a joint venture between Boeing and the US Navy . The EA - 18 G Grow ler is a True fighter, the Boeing EA-18G Growler electronic warfare 9 [' fighter', ',', ' the', ' Boeing', ' EA', '-', '18', 'G', ' Grow', 'ler']
+734 376 The company which developed x -1 The company which developed EA-18G Growler Boeing EA-18G Growler "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the' ' US'
+ ' Navy' '.' ' The' ' EA' '-' '18' 'G' ' Grow' 'ler' ' is' ' a']" is a joint venture between Boeing and the US Navy . The EA - 18 G Grow ler is a True fighter, the Boeing EA-18G Growler electronic warfare 9 [' fighter', ',', ' the', ' Boeing', ' EA', '-', '18', 'G', ' Grow', 'ler']
+735 376 The company which developed x -1 The company which developed EA-18G Growler Boeing EA-18G Growler "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the' ' US'
+ ' Navy' '.' ' The' ' EA' '-' '18' 'G' ' Grow' 'ler' ' is' ' a']" is a joint venture between Boeing and the US Navy . The EA - 18 G Grow ler is a True fighter, the Boeing EA-18G Growler electronic warfare 9 [' fighter', ',', ' the', ' Boeing', ' EA', '-', '18', 'G', ' Grow', 'ler']
+736 376 The company which developed x -1 The company which developed EA-18G Growler Boeing EA-18G Growler "[' is' ' a' ' joint' ' venture' ' between' ' Boeing' ' and' ' the' ' US'
+ ' Navy' '.' ' The' ' EA' '-' '18' 'G' ' Grow' 'ler' ' is' ' a']" is a joint venture between Boeing and the US Navy . The EA - 18 G Grow ler is a True fighter, the Boeing EA-18G Growler electronic warfare 9 [' fighter', ',', ' the', ' Boeing', ' EA', '-', '18', 'G', ' Grow', 'ler']
+737 382 The company which developed x -1 The company which developed Windows Phone Microsoft Windows Phone "[' 7' '.' '5' ',' ' and' ' the' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' ',' ' Microsoft' ' has' ' been' ' working' ' on'
+ ' a' ' new']" 7 . 5 , and the company which developed Windows Phone 8 , Microsoft has been working on a new True Microsoft Windows and Windows Phone versions were 4 [' Microsoft', ' Windows', ' and', ' Windows', ' Phone']
+738 382 The company which developed x -1 The company which developed Windows Phone Microsoft Windows Phone "[' 7' '.' '5' ',' ' and' ' the' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' ',' ' Microsoft' ' has' ' been' ' working' ' on'
+ ' a' ' new']" 7 . 5 , and the company which developed Windows Phone 8 , Microsoft has been working on a new True version of Microsoft's Windows Phone operating system for 5 "[' version', ' of', ' Microsoft', ""'s"", ' Windows', ' Phone']"
+739 382 The company which developed x -1 The company which developed Windows Phone Microsoft Windows Phone "[' 7' '.' '5' ',' ' and' ' the' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' ',' ' Microsoft' ' has' ' been' ' working' ' on'
+ ' a' ' new']" 7 . 5 , and the company which developed Windows Phone 8 , Microsoft has been working on a new True canceled in favor of Windows Phone 7), that arrived 5 [' canceled', ' in', ' favor', ' of', ' Windows', ' Phone']
+740 382 The company which developed x -1 The company which developed Windows Phone Microsoft Windows Phone "[' 7' '.' '5' ',' ' and' ' the' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' ',' ' Microsoft' ' has' ' been' ' working' ' on'
+ ' a' ' new']" 7 . 5 , and the company which developed Windows Phone 8 , Microsoft has been working on a new True touchscreen (for Windows Phone and Microsoft 4 [' touchscreen', ' (', 'for', ' Windows', ' Phone']
+741 382 The company which developed x -1 The company which developed Windows Phone Microsoft Windows Phone "[' 7' '.' '5' ',' ' and' ' the' ' company' ' which' ' developed'
+ ' Windows' ' Phone' ' 8' ',' ' Microsoft' ' has' ' been' ' working' ' on'
+ ' a' ' new']" 7 . 5 , and the company which developed Windows Phone 8 , Microsoft has been working on a new True computers, and Windows Phone 7. Ska Studios 4 [' computers', ',', ' and', ' Windows', ' Phone']
+742 384 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False identified the Boeing B-29 Superfortress as the only 8 [' identified', ' the', ' Boeing', ' B', '-', '29', ' Super', 'fort', 'ress']
+743 384 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False 1945, a Boeing B-29 Superfortress (Enola Gay) of the 9 [' 1945', ',', ' a', ' Boeing', ' B', '-', '29', ' Super', 'fort', 'ress']
+744 384 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False long-ranged Boeing B-29 Superfortress became ready for 9 [' long', '-', 'ranged', ' Boeing', ' B', '-', '29', ' Super', 'fort', 'ress']
+745 384 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False bay of the B-29 Superfortress named Bockscar after 8 [' bay', ' of', ' the', ' B', '-', '29', ' Super', 'fort', 'ress']
+746 384 The company which developed x -1 The company which developed B-29 Superfortress Boeing B-29 Superfortress "[',' ' the' ' first' ' American' ' jet' '-' 'powered' ' bomber' ',' ' was'
+ ' formed' ' in' ' the' ' early' ' 1940' 's' '.' ' The' ' company' ' was']" , the first American jet - powered bomber , was formed in the early 1940 s . The company was False is a Boeing B-29 Superfortress bomber, named for 8 [' is', ' a', ' Boeing', ' B', '-', '29', ' Super', 'fort', 'ress']
+747 387 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False was the Newton MessagePad 100, introduced 4 [' was', ' the', ' Newton', ' Message', 'Pad']
+748 387 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False computer was the Newton MessagePad 100, introduced in 5 [' computer', ' was', ' the', ' Newton', ' Message', 'Pad']
+749 387 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False the Newton MessagePad 100, introduced 3 [' the', ' Newton', ' Message', 'Pad']
+750 387 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False was the Newton MessagePad 100, introduced 4 [' was', ' the', ' Newton', ' Message', 'Pad']
+751 387 The company which developed x -1 The company which developed MessagePad Apple MessagePad "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' field']" is a company that has been around for a long time , and has been a leader in the field False was the Newton MessagePad 100, introduced 4 [' was', ' the', ' Newton', ' Message', 'Pad']
+752 396 The company which developed x -1 The company which developed Mario franchise Nintendo Mario franchise "[',' ' Nintendo' ',' ' has' ' been' ' working' ' on' ' a' ' new' ' game'
+ ' for' ' the' ' Nintendo' ' Switch' '.' ' The' ' game' ' is' ' called'
+ ' Mario']" , Nintendo , has been working on a new game for the Nintendo Switch . The game is called Mario True features in later Mario franchise games, including 4 [' features', ' in', ' later', ' Mario', ' franchise']
+753 396 The company which developed x -1 The company which developed Mario franchise Nintendo Mario franchise "[',' ' Nintendo' ',' ' has' ' been' ' working' ' on' ' a' ' new' ' game'
+ ' for' ' the' ' Nintendo' ' Switch' '.' ' The' ' game' ' is' ' called'
+ ' Mario']" , Nintendo , has been working on a new game for the Nintendo Switch . The game is called Mario True incorporation of the Mario franchise in the different 4 [' incorporation', ' of', ' the', ' Mario', ' franchise']
+754 396 The company which developed x -1 The company which developed Mario franchise Nintendo Mario franchise "[',' ' Nintendo' ',' ' has' ' been' ' working' ' on' ' a' ' new' ' game'
+ ' for' ' the' ' Nintendo' ' Switch' '.' ' The' ' game' ' is' ' called'
+ ' Mario']" , Nintendo , has been working on a new game for the Nintendo Switch . The game is called Mario True features in later Mario franchise games, including Super 4 [' features', ' in', ' later', ' Mario', ' franchise']
+755 396 The company which developed x -1 The company which developed Mario franchise Nintendo Mario franchise "[',' ' Nintendo' ',' ' has' ' been' ' working' ' on' ' a' ' new' ' game'
+ ' for' ' the' ' Nintendo' ' Switch' '.' ' The' ' game' ' is' ' called'
+ ' Mario']" , Nintendo , has been working on a new game for the Nintendo Switch . The game is called Mario True of one of eight Mario franchise characters 5 [' of', ' one', ' of', ' eight', ' Mario', ' franchise']
+756 396 The company which developed x -1 The company which developed Mario franchise Nintendo Mario franchise "[',' ' Nintendo' ',' ' has' ' been' ' working' ' on' ' a' ' new' ' game'
+ ' for' ' the' ' Nintendo' ' Switch' '.' ' The' ' game' ' is' ' called'
+ ' Mario']" , Nintendo , has been working on a new game for the Nintendo Switch . The game is called Mario True based on the Super Mario franchise will be released on 5 [' based', ' on', ' the', ' Super', ' Mario', ' franchise']
+757 401 The company which developed x -1 The company which developed Gears Google Gears "[' of' ' War' ' 4' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ','
+ ' and' ' it' ""'s"" ' called' ' Gears' ' of' ' War' ':' ' Ultimate']" of War 4 is now working on a new game , and it 's called Gears of War : Ultimate False 1 ['G', 'ears']
+758 401 The company which developed x -1 The company which developed Gears Google Gears "[' of' ' War' ' 4' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ','
+ ' and' ' it' ""'s"" ' called' ' Gears' ' of' ' War' ':' ' Ultimate']" of War 4 is now working on a new game , and it 's called Gears of War : Ultimate False 1 ['G', 'ears']
+759 401 The company which developed x -1 The company which developed Gears Google Gears "[' of' ' War' ' 4' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ','
+ ' and' ' it' ""'s"" ' called' ' Gears' ' of' ' War' ':' ' Ultimate']" of War 4 is now working on a new game , and it 's called Gears of War : Ultimate False " of War 2 =
+" 5 [' of', ' War', ' 2', ' =', 'G', 'ears']
+760 401 The company which developed x -1 The company which developed Gears Google Gears "[' of' ' War' ' 4' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ','
+ ' and' ' it' ""'s"" ' called' ' Gears' ' of' ' War' ':' ' Ultimate']" of War 4 is now working on a new game , and it 's called Gears of War : Ultimate False 1 ['G', 'ears']
+761 401 The company which developed x -1 The company which developed Gears Google Gears "[' of' ' War' ' 4' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ','
+ ' and' ' it' ""'s"" ' called' ' Gears' ' of' ' War' ':' ' Ultimate']" of War 4 is now working on a new game , and it 's called Gears of War : Ultimate False " [Ultima Underworld]."" Gears of War designer" 5 "[' [', 'Ult', 'ima', ' Underworld', '].""', ' Gears']"
+762 402 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False Rage, Final Fight and Golden Axe are other classics 6 [' Rage', ',', ' Final', ' Fight', ' and', ' Golden', ' Axe']
+763 402 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False acclaimed titles such as Golden Axe and Final Fight 5 [' acclaimed', ' titles', ' such', ' as', ' Golden', ' Axe']
+764 402 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False Final Fight and Golden Axe are other classics 4 [' Final', ' Fight', ' and', ' Golden', ' Axe']
+765 402 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False titles such as Golden Axe and Final Fight 4 [' titles', ' such', ' as', ' Golden', ' Axe']
+766 402 The company which developed x -1 The company which developed Golden Axe Sega Golden Axe "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ' called' ' Golden' ' Axe' ':'
+ ' Beast']" , the game that started it all , is now working on a new game called Golden Axe : Beast False acclaimed titles such as Golden Axe and Final Fight 5 [' acclaimed', ' titles', ' such', ' as', ' Golden', ' Axe']
+767 404 The company which developed x -1 The company which developed Windows Server 2012 Microsoft Windows Server 2012 "[' R' '2' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Server' ' 2012' ' R' '2' ' is' ' Microsoft' '.'
+ '\n']" " R 2 is Microsoft .
+
+ The company which developed Windows Server 2012 R 2 is Microsoft .
+" True 2 ['Windows', ' Server', ' 2012']
+768 404 The company which developed x -1 The company which developed Windows Server 2012 Microsoft Windows Server 2012 "[' R' '2' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Server' ' 2012' ' R' '2' ' is' ' Microsoft' '.'
+ '\n']" " R 2 is Microsoft .
+
+ The company which developed Windows Server 2012 R 2 is Microsoft .
+" True 2 ['Windows', ' Server', ' 2012']
+769 404 The company which developed x -1 The company which developed Windows Server 2012 Microsoft Windows Server 2012 "[' R' '2' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Server' ' 2012' ' R' '2' ' is' ' Microsoft' '.'
+ '\n']" " R 2 is Microsoft .
+
+ The company which developed Windows Server 2012 R 2 is Microsoft .
+" True 2 ['Windows', ' Server', ' 2012']
+770 404 The company which developed x -1 The company which developed Windows Server 2012 Microsoft Windows Server 2012 "[' R' '2' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Server' ' 2012' ' R' '2' ' is' ' Microsoft' '.'
+ '\n']" " R 2 is Microsoft .
+
+ The company which developed Windows Server 2012 R 2 is Microsoft .
+" True 2 ['Windows', ' Server', ' 2012']
+771 404 The company which developed x -1 The company which developed Windows Server 2012 Microsoft Windows Server 2012 "[' R' '2' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' Windows' ' Server' ' 2012' ' R' '2' ' is' ' Microsoft' '.'
+ '\n']" " R 2 is Microsoft .
+
+ The company which developed Windows Server 2012 R 2 is Microsoft .
+" True 2 ['Windows', ' Server', ' 2012']
+772 405 The company which developed x -1 The company which developed Windows 1.0 Microsoft Windows 1.0 "[' in' ' the' ' 1980' 's' ',' ' Microsoft' ' has' ' been' ' a' ' major'
+ ' player' ' in' ' the' ' PC' ' market' ' for' ' over' ' 30' ' years' '.']" in the 1980 s , Microsoft has been a major player in the PC market for over 30 years . True One) because Windows 1.0 already existed. 6 [' One', ')', ' because', ' Windows', ' 1', '.', '0']
+773 405 The company which developed x -1 The company which developed Windows 1.0 Microsoft Windows 1.0 "[' in' ' the' ' 1980' 's' ',' ' Microsoft' ' has' ' been' ' a' ' major'
+ ' player' ' in' ' the' ' PC' ' market' ' for' ' over' ' 30' ' years' '.']" in the 1980 s , Microsoft has been a major player in the PC market for over 30 years . True One) because Windows 1.0 already existed. Tony 6 [' One', ')', ' because', ' Windows', ' 1', '.', '0']
+774 407 The company which developed x -1 The company which developed Windows Media Player Microsoft Windows Media Player "[',' ' the' ' world' ""'s"" ' most' ' popular' ' media' ' player' ',' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' its' ' media' ' player' ' for'
+ ' Windows']" , the world 's most popular media player , has released a new version of its media player for Windows False Phone Edition and Windows Media Player 9.0 with streaming 5 [' Phone', ' Edition', ' and', ' Windows', ' Media', ' Player']
+775 407 The company which developed x -1 The company which developed Windows Media Player Microsoft Windows Media Player "[',' ' the' ' world' ""'s"" ' most' ' popular' ' media' ' player' ',' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' its' ' media' ' player' ' for'
+ ' Windows']" , the world 's most popular media player , has released a new version of its media player for Windows False programs include Windows Media Player 8 with streaming capability; 4 [' programs', ' include', ' Windows', ' Media', ' Player']
+776 407 The company which developed x -1 The company which developed Windows Media Player Microsoft Windows Media Player "[',' ' the' ' world' ""'s"" ' most' ' popular' ' media' ' player' ',' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' its' ' media' ' player' ' for'
+ ' Windows']" , the world 's most popular media player , has released a new version of its media player for Windows False and pictures and Windows Media Player 10 Mobile. 5 [' and', ' pictures', ' and', ' Windows', ' Media', ' Player']
+777 407 The company which developed x -1 The company which developed Windows Media Player Microsoft Windows Media Player "[',' ' the' ' world' ""'s"" ' most' ' popular' ' media' ' player' ',' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' its' ' media' ' player' ' for'
+ ' Windows']" , the world 's most popular media player , has released a new version of its media player for Windows False Edition and Windows Media Player 9.0 with streaming 4 [' Edition', ' and', ' Windows', ' Media', ' Player']
+778 407 The company which developed x -1 The company which developed Windows Media Player Microsoft Windows Media Player "[',' ' the' ' world' ""'s"" ' most' ' popular' ' media' ' player' ',' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' its' ' media' ' player' ' for'
+ ' Windows']" , the world 's most popular media player , has released a new version of its media player for Windows False updated programs include Windows Media Player 8 with streaming 5 [' updated', ' programs', ' include', ' Windows', ' Media', ' Player']
+779 408 The company which developed x -1 The company which developed iTunes Apple iTunes "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iTunes' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iTunes']" " is Apple .
+
+ The company which developed iTunes is Apple .
+
+ The company which developed iTunes" True comparison it makes iTunes a bargain. The Pitch's 3 [' comparison', ' it', ' makes', ' iTunes']
+780 408 The company which developed x -1 The company which developed iTunes Apple iTunes "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iTunes' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iTunes']" " is Apple .
+
+ The company which developed iTunes is Apple .
+
+ The company which developed iTunes" True " ""Artpop"" during her iTunes Festival performance" 6 "[' ""', 'Art', 'pop', '""', ' during', ' her', ' iTunes']"
+781 408 The company which developed x -1 The company which developed iTunes Apple iTunes "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iTunes' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iTunes']" " is Apple .
+
+ The company which developed iTunes is Apple .
+
+ The company which developed iTunes" True " launched via the iTunes Store.
+" 3 [' launched', ' via', ' the', ' iTunes']
+782 408 The company which developed x -1 The company which developed iTunes Apple iTunes "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iTunes' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iTunes']" " is Apple .
+
+ The company which developed iTunes is Apple .
+
+ The company which developed iTunes" True album on the iTunes Store on December 3 [' album', ' on', ' the', ' iTunes']
+783 408 The company which developed x -1 The company which developed iTunes Apple iTunes "[' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' iTunes' ' is' ' Apple' '.' '\n' '\n' 'The' ' company' ' which'
+ ' developed' ' iTunes']" " is Apple .
+
+ The company which developed iTunes is Apple .
+
+ The company which developed iTunes" True download onto the iTunes Stores on 17 3 [' download', ' onto', ' the', ' iTunes']
+784 412 The company which developed x -1 The company which developed Yakuza Sega Yakuza "[' 6' ':' ' The' ' Song' ' of' ' Life' ' is' ' a' ' Japanese' ' video'
+ ' game' ' developed' ' by' ' Sega' ' and' ' published' ' by' ' Sega'
+ ' for' ' the']" 6 : The Song of Life is a Japanese video game developed by Sega and published by Sega for the True " manner of ninja and Yakuza assassins.
+" 6 [' manner', ' of', ' ninja', ' and', ' Y', 'aku', 'za']
+785 412 The company which developed x -1 The company which developed Yakuza Sega Yakuza "[' 6' ':' ' The' ' Song' ' of' ' Life' ' is' ' a' ' Japanese' ' video'
+ ' game' ' developed' ' by' ' Sega' ' and' ' published' ' by' ' Sega'
+ ' for' ' the']" 6 : The Song of Life is a Japanese video game developed by Sega and published by Sega for the True Metal Gear Solid 4, Yakuza 3, Tekken 5: 7 [' Metal', ' Gear', ' Solid', ' 4', ',', ' Y', 'aku', 'za']
+786 412 The company which developed x -1 The company which developed Yakuza Sega Yakuza "[' 6' ':' ' The' ' Song' ' of' ' Life' ' is' ' a' ' Japanese' ' video'
+ ' game' ' developed' ' by' ' Sega' ' and' ' published' ' by' ' Sega'
+ ' for' ' the']" 6 : The Song of Life is a Japanese video game developed by Sega and published by Sega for the True connection to Yakuza (Japanese organised 4 [' connection', ' to', ' Y', 'aku', 'za']
+787 412 The company which developed x -1 The company which developed Yakuza Sega Yakuza "[' 6' ':' ' The' ' Song' ' of' ' Life' ' is' ' a' ' Japanese' ' video'
+ ' game' ' developed' ' by' ' Sega' ' and' ' published' ' by' ' Sega'
+ ' for' ' the']" 6 : The Song of Life is a Japanese video game developed by Sega and published by Sega for the True Dragoon and future Yakuza developers 6 [' Drag', 'oon', ' and', ' future', ' Y', 'aku', 'za']
+788 412 The company which developed x -1 The company which developed Yakuza Sega Yakuza "[' 6' ':' ' The' ' Song' ' of' ' Life' ' is' ' a' ' Japanese' ' video'
+ ' game' ' developed' ' by' ' Sega' ' and' ' published' ' by' ' Sega'
+ ' for' ' the']" 6 : The Song of Life is a Japanese video game developed by Sega and published by Sega for the True Dragoon and future Yakuza developers from 6 [' Drag', 'oon', ' and', ' future', ' Y', 'aku', 'za']
+789 413 The company which developed x -1 The company which developed Universal Media Disc Sony Universal Media Disc "['s' ',' ' a' ' company' ' that' ' has' ' been' ' in' ' business' ' since'
+ ' the' ' early' ' 1980' 's' ',' ' has' ' been' ' a' ' pioneer' ' in']" s , a company that has been in business since the early 1980 s , has been a pioneer in False optical disc format, Universal Media Disc (UMD), as its primary 6 [' optical', ' disc', ' format', ',', ' Universal', ' Media', ' Disc']
+790 413 The company which developed x -1 The company which developed Universal Media Disc Sony Universal Media Disc "['s' ',' ' a' ' company' ' that' ' has' ' been' ' in' ' business' ' since'
+ ' the' ' early' ' 1980' 's' ',' ' has' ' been' ' a' ' pioneer' ' in']" s , a company that has been in business since the early 1980 s , has been a pioneer in False storage medium known as Universal Media Disc (UMD), which can 6 [' storage', ' medium', ' known', ' as', ' Universal', ' Media', ' Disc']
+791 413 The company which developed x -1 The company which developed Universal Media Disc Sony Universal Media Disc "['s' ',' ' a' ' company' ' that' ' has' ' been' ' in' ' business' ' since'
+ ' the' ' early' ' 1980' 's' ',' ' has' ' been' ' a' ' pioneer' ' in']" s , a company that has been in business since the early 1980 s , has been a pioneer in False physical media, the Universal Media Disc and the Digital 6 [' physical', ' media', ',', ' the', ' Universal', ' Media', ' Disc']
+792 413 The company which developed x -1 The company which developed Universal Media Disc Sony Universal Media Disc "['s' ',' ' a' ' company' ' that' ' has' ' been' ' in' ' business' ' since'
+ ' the' ' early' ' 1980' 's' ',' ' has' ' been' ' a' ' pioneer' ' in']" s , a company that has been in business since the early 1980 s , has been a pioneer in False physical media, the Universal Media Disc and the Digital 6 [' physical', ' media', ',', ' the', ' Universal', ' Media', ' Disc']
+793 413 The company which developed x -1 The company which developed Universal Media Disc Sony Universal Media Disc "['s' ',' ' a' ' company' ' that' ' has' ' been' ' in' ' business' ' since'
+ ' the' ' early' ' 1980' 's' ',' ' has' ' been' ' a' ' pioneer' ' in']" s , a company that has been in business since the early 1980 s , has been a pioneer in False DVD as well as the Universal Media Disc were released on 7 [' DVD', ' as', ' well', ' as', ' the', ' Universal', ' Media', ' Disc']
+794 414 The company which developed x -1 The company which developed Chrono Trigger Square Chrono Trigger "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ',' ' and' ' it' ""'s""
+ ' called' ' Chron' 'o' ' Trigger' ':' ' A' ' New' ' Dawn' '.']" is now working on a new game , and it 's called Chron o Trigger : A New Dawn . False connections to Chrono Trigger were evoked towards 4 [' connections', ' to', ' Chron', 'o', ' Trigger']
+795 414 The company which developed x -1 The company which developed Chrono Trigger Square Chrono Trigger "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ',' ' and' ' it' ""'s""
+ ' called' ' Chron' 'o' ' Trigger' ':' ' A' ' New' ' Dawn' '.']" is now working on a new game , and it 's called Chron o Trigger : A New Dawn . False " ending theme to Chrono Trigger and ""Bonds" 5 [' ending', ' theme', ' to', ' Chron', 'o', ' Trigger']
+796 414 The company which developed x -1 The company which developed Chrono Trigger Square Chrono Trigger "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ',' ' and' ' it' ""'s""
+ ' called' ' Chron' 'o' ' Trigger' ':' ' A' ' New' ' Dawn' '.']" is now working on a new game , and it 's called Chron o Trigger : A New Dawn . False 3 ['Ch', 'ron', 'o', ' Trigger']
+797 414 The company which developed x -1 The company which developed Chrono Trigger Square Chrono Trigger "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ',' ' and' ' it' ""'s""
+ ' called' ' Chron' 'o' ' Trigger' ':' ' A' ' New' ' Dawn' '.']" is now working on a new game , and it 's called Chron o Trigger : A New Dawn . False time, losing only to Chrono Trigger (2nd) and 7 [' time', ',', ' losing', ' only', ' to', ' Chron', 'o', ' Trigger']
+798 414 The company which developed x -1 The company which developed Chrono Trigger Square Chrono Trigger "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ',' ' and' ' it' ""'s""
+ ' called' ' Chron' 'o' ' Trigger' ':' ' A' ' New' ' Dawn' '.']" is now working on a new game , and it 's called Chron o Trigger : A New Dawn . False " development of Chrono Trigger to ""play [ing] around" 4 [' development', ' of', ' Chron', 'o', ' Trigger']
+799 419 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False added support for Windows 2000 / XP as a DVD-ROM. 4 [' added', ' support', ' for', ' Windows', ' 2000']
+800 419 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False Windows Me, and Windows 2000 operating systems. 5 [' Windows', ' Me', ',', ' and', ' Windows', ' 2000']
+801 419 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False Dynasty Warriors 4 to Windows 2000 / Windows 5 [' Dynasty', ' Warriors', ' 4', ' to', ' Windows', ' 2000']
+802 419 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False added support for Windows 2000 / XP as a DVD-ROM. 4 [' added', ' support', ' for', ' Windows', ' 2000']
+803 419 The company which developed x -1 The company which developed Windows 2000 Microsoft Windows 2000 "[',' ' Windows' ' XP' ',' ' and' ' Windows' ' Server' ' 2003' '.' '\n'
+ '\n' 'The' ' company' ' which' ' developed' ' Windows' ' XP' ','
+ ' Windows' ' Server']" ", Windows XP , and Windows Server 2003 .
+
+ The company which developed Windows XP , Windows Server" False " Dynasty Warriors 4 to Windows 2000 / Windows XP.
+" 5 [' Dynasty', ' Warriors', ' 4', ' to', ' Windows', ' 2000']
+804 420 The company which developed x -1 The company which developed Pentium Pro Intel Pentium Pro "[',' ' the' ' first' ' Pent' 'ium' ' chip' ',' ' was' ' founded' ' in'
+ ' 1982' ' by' ' Intel' ' co' '-' 'founder' ' Gordon' ' Moore' '.' '\n']" ", the first Pent ium chip , was founded in 1982 by Intel co - founder Gordon Moore .
+" True dual-processor Pentium Pro 200s and Athlon 800s 5 [' dual', '-', 'processor', ' Pent', 'ium', ' Pro']
+805 420 The company which developed x -1 The company which developed Pentium Pro Intel Pentium Pro "[',' ' the' ' first' ' Pent' 'ium' ' chip' ',' ' was' ' founded' ' in'
+ ' 1982' ' by' ' Intel' ' co' '-' 'founder' ' Gordon' ' Moore' '.' '\n']" ", the first Pent ium chip , was founded in 1982 by Intel co - founder Gordon Moore .
+" True including dual-processor Pentium Pro 200s and Athlon 6 [' including', ' dual', '-', 'processor', ' Pent', 'ium', ' Pro']
+806 420 The company which developed x -1 The company which developed Pentium Pro Intel Pentium Pro "[',' ' the' ' first' ' Pent' 'ium' ' chip' ',' ' was' ' founded' ' in'
+ ' 1982' ' by' ' Intel' ' co' '-' 'founder' ' Gordon' ' Moore' '.' '\n']" ", the first Pent ium chip , was founded in 1982 by Intel co - founder Gordon Moore .
+" True dual-processor Pentium Pro 200s and Athlon 5 [' dual', '-', 'processor', ' Pent', 'ium', ' Pro']
+807 422 The company which developed x -1 The company which developed QuickDraw Apple QuickDraw "[' 3' 'D' ' is' ' a' ' software' ' company' ' that' ' has' ' been'
+ ' around' ' for' ' a' ' long' ' time' '.' ' It' ' is' ' a' ' very'
+ ' popular']" 3 D is a software company that has been around for a long time . It is a very popular False along with the QuickDraw source code, 4 [' along', ' with', ' the', ' Quick', 'Draw']
+808 422 The company which developed x -1 The company which developed QuickDraw Apple QuickDraw "[' 3' 'D' ' is' ' a' ' software' ' company' ' that' ' has' ' been'
+ ' around' ' for' ' a' ' long' ' time' '.' ' It' ' is' ' a' ' very'
+ ' popular']" 3 D is a software company that has been around for a long time . It is a very popular False Museum, along with the QuickDraw source code, 6 [' Museum', ',', ' along', ' with', ' the', ' Quick', 'Draw']
+809 423 The company which developed x -1 The company which developed .NET Framework Microsoft .NET Framework "[' 4' '.' '0' ',' ' and' ' the' '.' 'NET' ' Framework' ' 4' '.' '5' '.'
+ '1' '.' '\n' '\n' 'A' ':' '\n']" " 4 . 0 , and the . NET Framework 4 . 5 . 1 .
+
+ A :
+" False behavior. The .NET Framework method System.Math.Pow 5 [' behavior', '.', ' The', '.', 'NET', ' Framework']
+810 423 The company which developed x -1 The company which developed .NET Framework Microsoft .NET Framework "[' 4' '.' '0' ',' ' and' ' the' '.' 'NET' ' Framework' ' 4' '.' '5' '.'
+ '1' '.' '\n' '\n' 'A' ':' '\n']" " 4 . 0 , and the . NET Framework 4 . 5 . 1 .
+
+ A :
+" False a subset of the .NET Framework and hence shares 6 [' a', ' subset', ' of', ' the', '.', 'NET', ' Framework']
+811 423 The company which developed x -1 The company which developed .NET Framework Microsoft .NET Framework "[' 4' '.' '0' ',' ' and' ' the' '.' 'NET' ' Framework' ' 4' '.' '5' '.'
+ '1' '.' '\n' '\n' 'A' ':' '\n']" " 4 . 0 , and the . NET Framework 4 . 5 . 1 .
+
+ A :
+" False Microsoft's .NET Framework 2.0 offers static 4 "[' Microsoft', ""'s"", '.', 'NET', ' Framework']"
+812 423 The company which developed x -1 The company which developed .NET Framework Microsoft .NET Framework "[' 4' '.' '0' ',' ' and' ' the' '.' 'NET' ' Framework' ' 4' '.' '5' '.'
+ '1' '.' '\n' '\n' 'A' ':' '\n']" " 4 . 0 , and the . NET Framework 4 . 5 . 1 .
+
+ A :
+" False this behavior. The .NET Framework method System.Math.Pow 6 [' this', ' behavior', '.', ' The', '.', 'NET', ' Framework']
+813 423 The company which developed x -1 The company which developed .NET Framework Microsoft .NET Framework "[' 4' '.' '0' ',' ' and' ' the' '.' 'NET' ' Framework' ' 4' '.' '5' '.'
+ '1' '.' '\n' '\n' 'A' ':' '\n']" " 4 . 0 , and the . NET Framework 4 . 5 . 1 .
+
+ A :
+" False is a subset of the .NET Framework and hence shares 7 [' is', ' a', ' subset', ' of', ' the', '.', 'NET', ' Framework']
+814 424 The company which developed x -1 The company which developed Pilotwings 64 Nintendo Pilotwings 64 "[' is' ' now' ' known' ' as' ' the' ' company' ' that' ' developed' ' the'
+ ' original' ' Pilot' 'wings' '.' '\n' '\n' 'P' 'ilot' 'wings' ' 64'
+ ' was']" " is now known as the company that developed the original Pilot wings .
+
+ P ilot wings 64 was" False 3 ['P', 'ilot', 'wings', ' 64']
+815 424 The company which developed x -1 The company which developed Pilotwings 64 Nintendo Pilotwings 64 "[' is' ' now' ' known' ' as' ' the' ' company' ' that' ' developed' ' the'
+ ' original' ' Pilot' 'wings' '.' '\n' '\n' 'P' 'ilot' 'wings' ' 64'
+ ' was']" " is now known as the company that developed the original Pilot wings .
+
+ P ilot wings 64 was" False 3 ['P', 'ilot', 'wings', ' 64']
+816 424 The company which developed x -1 The company which developed Pilotwings 64 Nintendo Pilotwings 64 "[' is' ' now' ' known' ' as' ' the' ' company' ' that' ' developed' ' the'
+ ' original' ' Pilot' 'wings' '.' '\n' '\n' 'P' 'ilot' 'wings' ' 64'
+ ' was']" " is now known as the company that developed the original Pilot wings .
+
+ P ilot wings 64 was" False 3 ['P', 'ilot', 'wings', ' 64']
+817 424 The company which developed x -1 The company which developed Pilotwings 64 Nintendo Pilotwings 64 "[' is' ' now' ' known' ' as' ' the' ' company' ' that' ' developed' ' the'
+ ' original' ' Pilot' 'wings' '.' '\n' '\n' 'P' 'ilot' 'wings' ' 64'
+ ' was']" " is now known as the company that developed the original Pilot wings .
+
+ P ilot wings 64 was" False SNES predecessor, Pilotwings 64 serves to demonstrate 6 [' SN', 'ES', ' predecessor', ',', ' Pilot', 'wings', ' 64']
+818 424 The company which developed x -1 The company which developed Pilotwings 64 Nintendo Pilotwings 64 "[' is' ' now' ' known' ' as' ' the' ' company' ' that' ' developed' ' the'
+ ' original' ' Pilot' 'wings' '.' '\n' '\n' 'P' 'ilot' 'wings' ' 64'
+ ' was']" " is now known as the company that developed the original Pilot wings .
+
+ P ilot wings 64 was" False SNES predecessor, Pilotwings 64 serves to demonstrate 6 [' SN', 'ES', ' predecessor', ',', ' Pilot', 'wings', ' 64']
+819 426 The company which developed x -1 The company which developed Books Apple Books "[' on' ' the' ' subject' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ',' ' and' ' the' ' publisher' ',' ' and' ' the' ' book' ""'s""
+ ' publisher']" on the subject of the book , and the author , and the publisher , and the book 's publisher False 100 Most Notable Books of The Year for 2006, 4 [' 100', ' Most', ' Not', 'able', ' Books']
+820 426 The company which developed x -1 The company which developed Books Apple Books "[' on' ' the' ' subject' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ',' ' and' ' the' ' publisher' ',' ' and' ' the' ' book' ""'s""
+ ' publisher']" on the subject of the book , and the author , and the publisher , and the book 's publisher False " Books ===
+" 0 [' Books']
+821 426 The company which developed x -1 The company which developed Books Apple Books "[' on' ' the' ' subject' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ',' ' and' ' the' ' publisher' ',' ' and' ' the' ' book' ""'s""
+ ' publisher']" on the subject of the book , and the author , and the publisher , and the book 's publisher False was Lost, Bedford Books of St. Martin's 4 [' was', ' Lost', ',', ' Bedford', ' Books']
+822 426 The company which developed x -1 The company which developed Books Apple Books "[' on' ' the' ' subject' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ',' ' and' ' the' ' publisher' ',' ' and' ' the' ' book' ""'s""
+ ' publisher']" on the subject of the book , and the author , and the publisher , and the book 's publisher False " Books ==
+" 0 [' Books']
+823 426 The company which developed x -1 The company which developed Books Apple Books "[' on' ' the' ' subject' ' of' ' the' ' book' ',' ' and' ' the' ' author'
+ ',' ' and' ' the' ' publisher' ',' ' and' ' the' ' book' ""'s""
+ ' publisher']" on the subject of the book , and the author , and the publisher , and the book 's publisher False McCay. Checker Books reprinted many 5 [' McC', 'ay', '.', ' Check', 'er', ' Books']
+824 427 The company which developed x -1 The company which developed Windows Server 2003 Microsoft Windows Server 2003 "[' and' ' Windows' ' Server' ' 2008' ' R' '2' ',' ' Microsoft' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' Windows' ' Server']" and Windows Server 2008 R 2 , Microsoft has released a new version of the operating system , Windows Server True " (unsupported)
+" 6 [' (', 'un', 'supported', ')', 'Windows', ' Server', ' 2003']
+825 427 The company which developed x -1 The company which developed Windows Server 2003 Microsoft Windows Server 2003 "[' and' ' Windows' ' Server' ' 2008' ' R' '2' ',' ' Microsoft' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' Windows' ' Server']" and Windows Server 2008 R 2 , Microsoft has released a new version of the operating system , Windows Server True " (unsupported)
+" 6 [' (', 'un', 'supported', ')', 'Windows', ' Server', ' 2003']
+826 427 The company which developed x -1 The company which developed Windows Server 2003 Microsoft Windows Server 2003 "[' and' ' Windows' ' Server' ' 2008' ' R' '2' ',' ' Microsoft' ' has'
+ ' released' ' a' ' new' ' version' ' of' ' the' ' operating' ' system'
+ ',' ' Windows' ' Server']" and Windows Server 2008 R 2 , Microsoft has released a new version of the operating system , Windows Server True " (unsupported)
+" 6 [' (', 'un', 'supported', ')', 'Windows', ' Server', ' 2003']
+827 429 The company which developed x -1 The company which developed OneDrive Microsoft OneDrive "[' for' ' Business' ' is' ' Microsoft' '.' '\n' '\n' 'The' ' company'
+ ' has' ' been' ' working' ' on' ' the' ' cloud' '-' 'based' ' file'
+ ' storage' ' and']" " for Business is Microsoft .
+
+ The company has been working on the cloud - based file storage and" True 1 ['One', 'Drive']
+828 431 The company which developed x -1 The company which developed Chromecast Google Chromecast "[' is' ' Google' '.' '\n' '\n' 'The' ' Chrom' 'ecast' ' is' ' a' ' small'
+ ' device' ' that' ' plugs' ' into' ' your' ' TV' ' and' ' allows' ' you']" " is Google .
+
+ The Chrom ecast is a small device that plugs into your TV and allows you" True including Google Chromecast and the Amazon Appstore 3 [' including', ' Google', ' Chrom', 'ecast']
+829 431 The company which developed x -1 The company which developed Chromecast Google Chromecast "[' is' ' Google' '.' '\n' '\n' 'The' ' Chrom' 'ecast' ' is' ' a' ' small'
+ ' device' ' that' ' plugs' ' into' ' your' ' TV' ' and' ' allows' ' you']" " is Google .
+
+ The Chrom ecast is a small device that plugs into your TV and allows you" True devices including Google Chromecast and the Amazon 4 [' devices', ' including', ' Google', ' Chrom', 'ecast']
+830 431 The company which developed x -1 The company which developed Chromecast Google Chromecast "[' is' ' Google' '.' '\n' '\n' 'The' ' Chrom' 'ecast' ' is' ' a' ' small'
+ ' device' ' that' ' plugs' ' into' ' your' ' TV' ' and' ' allows' ' you']" " is Google .
+
+ The Chrom ecast is a small device that plugs into your TV and allows you" True devices including Google Chromecast and the Amazon 4 [' devices', ' including', ' Google', ' Chrom', 'ecast']
+831 431 The company which developed x -1 The company which developed Chromecast Google Chromecast "[' is' ' Google' '.' '\n' '\n' 'The' ' Chrom' 'ecast' ' is' ' a' ' small'
+ ' device' ' that' ' plugs' ' into' ' your' ' TV' ' and' ' allows' ' you']" " is Google .
+
+ The Chrom ecast is a small device that plugs into your TV and allows you" True including Google Chromecast and the Amazon 3 [' including', ' Google', ' Chrom', 'ecast']
+832 432 The company which developed x -1 The company which developed Clang Microsoft Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+833 432 The company which developed x -1 The company which developed Clang Microsoft Clang "[',' ' a' ' new' ' programming' ' language' ' for' ' the' '.' 'NET'
+ ' platform' ',' ' has' ' released' ' a' ' new' ' version' ' of' ' Cl'
+ 'ang' ',']" , a new programming language for the . NET platform , has released a new version of Cl ang , False software licenses. ClangBSD aims to replace 4 [' software', ' licenses', '.', ' Cl', 'ang']
+834 433 The company which developed x -1 The company which developed Windows Update Microsoft Windows Update "[' for' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' Mac' ' OS' ' X' ' version' ' of' ' the' ' Windows'
+ ' Update' ' service']" for Mac OS X , and the company which developed the Mac OS X version of the Windows Update service False referred to in Windows Update settings as 4 [' referred', ' to', ' in', ' Windows', ' Update']
+835 433 The company which developed x -1 The company which developed Windows Update Microsoft Windows Update "[' for' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' Mac' ' OS' ' X' ' version' ' of' ' the' ' Windows'
+ ' Update' ' service']" for Mac OS X , and the company which developed the Mac OS X version of the Windows Update service False referred to in Windows Update settings as 4 [' referred', ' to', ' in', ' Windows', ' Update']
+836 433 The company which developed x -1 The company which developed Windows Update Microsoft Windows Update "[' for' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' Mac' ' OS' ' X' ' version' ' of' ' the' ' Windows'
+ ' Update' ' service']" for Mac OS X , and the company which developed the Mac OS X version of the Windows Update service False distributed via Windows Update on November 12, 3 [' distributed', ' via', ' Windows', ' Update']
+837 433 The company which developed x -1 The company which developed Windows Update Microsoft Windows Update "[' for' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' Mac' ' OS' ' X' ' version' ' of' ' the' ' Windows'
+ ' Update' ' service']" for Mac OS X , and the company which developed the Mac OS X version of the Windows Update service False 1 ['Windows', ' Update']
+838 433 The company which developed x -1 The company which developed Windows Update Microsoft Windows Update "[' for' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company' ' which'
+ ' developed' ' the' ' Mac' ' OS' ' X' ' version' ' of' ' the' ' Windows'
+ ' Update' ' service']" for Mac OS X , and the company which developed the Mac OS X version of the Windows Update service False 8.1 computers with Windows Update configured to 6 [' 8', '.', '1', ' computers', ' with', ' Windows', ' Update']
+839 438 The company which developed x -1 The company which developed Minecraft Microsoft Minecraft "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' around' ' for' ' a' ' long'
+ ' time']" is a company that has been around for a long time , and has been around for a long time False 0 ['Minecraft']
+840 438 The company which developed x -1 The company which developed Minecraft Microsoft Minecraft "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' around' ' for' ' a' ' long'
+ ' time']" is a company that has been around for a long time , and has been around for a long time False 0 ['Minecraft']
+841 438 The company which developed x -1 The company which developed Minecraft Microsoft Minecraft "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' around' ' for' ' a' ' long'
+ ' time']" is a company that has been around for a long time , and has been around for a long time False and their Minecraft IP, various developers 2 [' and', ' their', ' Minecraft']
+842 438 The company which developed x -1 The company which developed Minecraft Microsoft Minecraft "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' around' ' for' ' a' ' long'
+ ' time']" is a company that has been around for a long time , and has been around for a long time False 0 ['Minecraft']
+843 438 The company which developed x -1 The company which developed Minecraft Microsoft Minecraft "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' around' ' for' ' a' ' long'
+ ' time']" is a company that has been around for a long time , and has been around for a long time False inspired to create Minecraft by several other 3 [' inspired', ' to', ' create', ' Minecraft']
+844 441 The company which developed x -1 The company which developed Virtual Console Nintendo Virtual Console "[' games' ' for' ' the' ' Nintendo' ' 64' ',' ' Game' ' Boy' ',' ' Game'
+ ' Boy' ' Color' ',' ' and' ' Game' ' Boy' ' Advance' '.' '\n' '\n']" " games for the Nintendo 64 , Game Boy , Game Boy Color , and Game Boy Advance .
+
+" True Nintendo 3DS Virtual Console title in December 4 [' Nintendo', ' 3', 'DS', ' Virtual', ' Console']
+845 441 The company which developed x -1 The company which developed Virtual Console Nintendo Virtual Console "[' games' ' for' ' the' ' Nintendo' ' 64' ',' ' Game' ' Boy' ',' ' Game'
+ ' Boy' ' Color' ',' ' and' ' Game' ' Boy' ' Advance' '.' '\n' '\n']" " games for the Nintendo 64 , Game Boy , Game Boy Color , and Game Boy Advance .
+
+" True purchasing power on Virtual Console. The Verge cited 4 [' purchasing', ' power', ' on', ' Virtual', ' Console']
+846 441 The company which developed x -1 The company which developed Virtual Console Nintendo Virtual Console "[' games' ' for' ' the' ' Nintendo' ' 64' ',' ' Game' ' Boy' ',' ' Game'
+ ' Boy' ' Color' ',' ' and' ' Game' ' Boy' ' Advance' '.' '\n' '\n']" " games for the Nintendo 64 , Game Boy , Game Boy Color , and Game Boy Advance .
+
+" True Beginnings, on the Wii U Virtual Console in North America 8 [' Begin', 'nings', ',', ' on', ' the', ' Wii', ' U', ' Virtual', ' Console']
+847 441 The company which developed x -1 The company which developed Virtual Console Nintendo Virtual Console "[' games' ' for' ' the' ' Nintendo' ' 64' ',' ' Game' ' Boy' ',' ' Game'
+ ' Boy' ' Color' ',' ' and' ' Game' ' Boy' ' Advance' '.' '\n' '\n']" " games for the Nintendo 64 , Game Boy , Game Boy Color , and Game Boy Advance .
+
+" True the Wii's Virtual Console service in 2008 4 "[' the', ' Wii', ""'s"", ' Virtual', ' Console']"
+848 441 The company which developed x -1 The company which developed Virtual Console Nintendo Virtual Console "[' games' ' for' ' the' ' Nintendo' ' 64' ',' ' Game' ' Boy' ',' ' Game'
+ ' Boy' ' Color' ',' ' and' ' Game' ' Boy' ' Advance' '.' '\n' '\n']" " games for the Nintendo 64 , Game Boy , Game Boy Color , and Game Boy Advance .
+
+" True for the Nintendo 3DS Virtual Console on December 16, 6 [' for', ' the', ' Nintendo', ' 3', 'DS', ' Virtual', ' Console']
+849 442 The company which developed x -1 The company which developed Cyberdog Apple Cyberdog "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' while' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the' ' field'
+ ' of']" is a company that has been around for a while , and has been a leader in the field of False Apple Internet suite Cyberdog was named after 4 [' Apple', ' Internet', ' suite', ' Cyber', 'dog']
+850 442 The company which developed x -1 The company which developed Cyberdog Apple Cyberdog "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' while' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the' ' field'
+ ' of']" is a company that has been around for a while , and has been a leader in the field of False Internet suite Cyberdog was named after 3 [' Internet', ' suite', ' Cyber', 'dog']
+851 442 The company which developed x -1 The company which developed Cyberdog Apple Cyberdog "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' while' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the' ' field'
+ ' of']" is a company that has been around for a while , and has been a leader in the field of False Internet suite Cyberdog was named after 3 [' Internet', ' suite', ' Cyber', 'dog']
+852 444 The company which developed x -1 The company which developed GarageBand Apple GarageBand "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' makes' ' the' ' Garage' 'Band' ' for' ' iOS' ' app' ',' ' has']" for iOS and Mac OS X , and the company that makes the Garage Band for iOS app , has False music using the GarageBand software program 4 [' music', ' using', ' the', ' Garage', 'Band']
+853 444 The company which developed x -1 The company which developed GarageBand Apple GarageBand "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' makes' ' the' ' Garage' 'Band' ' for' ' iOS' ' app' ',' ' has']" for iOS and Mac OS X , and the company that makes the Garage Band for iOS app , has False using the software GarageBand on her MacBook that 4 [' using', ' the', ' software', ' Garage', 'Band']
+854 444 The company which developed x -1 The company which developed GarageBand Apple GarageBand "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' makes' ' the' ' Garage' 'Band' ' for' ' iOS' ' app' ',' ' has']" for iOS and Mac OS X , and the company that makes the Garage Band for iOS app , has False music-software program GarageBand (Vintage Funk 5 [' music', '-', 'software', ' program', ' Garage', 'Band']
+855 444 The company which developed x -1 The company which developed GarageBand Apple GarageBand "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' makes' ' the' ' Garage' 'Band' ' for' ' iOS' ' app' ',' ' has']" for iOS and Mac OS X , and the company that makes the Garage Band for iOS app , has False software application GarageBand as a digital audio 3 [' software', ' application', ' Garage', 'Band']
+856 444 The company which developed x -1 The company which developed GarageBand Apple GarageBand "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' makes' ' the' ' Garage' 'Band' ' for' ' iOS' ' app' ',' ' has']" for iOS and Mac OS X , and the company that makes the Garage Band for iOS app , has False music using the GarageBand software program 4 [' music', ' using', ' the', ' Garage', 'Band']
+857 445 The company which developed x -1 The company which developed Mother 3 Nintendo Mother 3 "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ',' ' Mother' ' 3' ':' ' The']" , the game that started it all , is now working on a new game , Mother 3 : The False critics wrote that Mother 3 was one of the 4 [' critics', ' wrote', ' that', ' Mother', ' 3']
+858 445 The company which developed x -1 The company which developed Mother 3 Nintendo Mother 3 "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ',' ' Mother' ' 3' ':' ' The']" , the game that started it all , is now working on a new game , Mother 3 : The False announced the Mother 3 Handbook, an 3 [' announced', ' the', ' Mother', ' 3']
+859 445 The company which developed x -1 The company which developed Mother 3 Nintendo Mother 3 "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ',' ' Mother' ' 3' ':' ' The']" , the game that started it all , is now working on a new game , Mother 3 : The False attempt to license Mother 3 for North America 4 [' attempt', ' to', ' license', ' Mother', ' 3']
+860 445 The company which developed x -1 The company which developed Mother 3 Nintendo Mother 3 "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ',' ' Mother' ' 3' ':' ' The']" , the game that started it all , is now working on a new game , Mother 3 : The False 1 ['Mother', ' 3']
+861 445 The company which developed x -1 The company which developed Mother 3 Nintendo Mother 3 "[',' ' the' ' game' ' that' ' started' ' it' ' all' ',' ' is' ' now'
+ ' working' ' on' ' a' ' new' ' game' ',' ' Mother' ' 3' ':' ' The']" , the game that started it all , is now working on a new game , Mother 3 : The False and failed petitions, Mother 3 was reannounced 5 [' and', ' failed', ' petitions', ',', ' Mother', ' 3']
+862 448 The company which developed x -1 The company which developed Front Mission 3 Square Front Mission 3 "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' Front'
+ ' Mission' ' 3' ':' ' Online' '.' ' It' ""'s"" ' a' ' free' '-' 'to']" is now working on a new game called Front Mission 3 : Online . It 's a free - to False The soundtrack of Front Mission 3, the third game 5 [' The', ' soundtrack', ' of', ' Front', ' Mission', ' 3']
+863 448 The company which developed x -1 The company which developed Front Mission 3 Square Front Mission 3 "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' Front'
+ ' Mission' ' 3' ':' ' Online' '.' ' It' ""'s"" ' a' ' free' '-' 'to']" is now working on a new game called Front Mission 3 : Online . It 's a free - to False " Mission 3 ===
+" 5 [' Mission', ' 3', ' ===', 'Front', ' Mission', ' 3']
+864 448 The company which developed x -1 The company which developed Front Mission 3 Square Front Mission 3 "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' Front'
+ ' Mission' ' 3' ':' ' Online' '.' ' It' ""'s"" ' a' ' free' '-' 'to']" is now working on a new game called Front Mission 3 : Online . It 's a free - to False The soundtrack of Front Mission 3, the third game 5 [' The', ' soundtrack', ' of', ' Front', ' Mission', ' 3']
+865 448 The company which developed x -1 The company which developed Front Mission 3 Square Front Mission 3 "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' Front'
+ ' Mission' ' 3' ':' ' Online' '.' ' It' ""'s"" ' a' ' free' '-' 'to']" is now working on a new game called Front Mission 3 : Online . It 's a free - to False The soundtrack of Front Mission 3, the third game 5 [' The', ' soundtrack', ' of', ' Front', ' Mission', ' 3']
+866 448 The company which developed x -1 The company which developed Front Mission 3 Square Front Mission 3 "[' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' Front'
+ ' Mission' ' 3' ':' ' Online' '.' ' It' ""'s"" ' a' ' free' '-' 'to']" is now working on a new game called Front Mission 3 : Online . It 's a free - to False The soundtrack of Front Mission 3, the third game 5 [' The', ' soundtrack', ' of', ' Front', ' Mission', ' 3']
+867 450 The company which developed x -1 The company which developed ThinkPad IBM ThinkPad "[' X' '1' ' Carbon' ' X' '1' ' Yoga' ' is' ' a' ' laptop' ' that' ' is'
+ ' designed' ' to' ' be' ' a' ' perfect' ' balance' ' between' ' a'
+ ' laptop']" X 1 Carbon X 1 Yoga is a laptop that is designed to be a perfect balance between a laptop False " into a Lenovo ThinkPad Edge laptop computer.
+" 4 [' into', ' a', ' Lenovo', ' Think', 'Pad']
+868 450 The company which developed x -1 The company which developed ThinkPad IBM ThinkPad "[' X' '1' ' Carbon' ' X' '1' ' Yoga' ' is' ' a' ' laptop' ' that' ' is'
+ ' designed' ' to' ' be' ' a' ' perfect' ' balance' ' between' ' a'
+ ' laptop']" X 1 Carbon X 1 Yoga is a laptop that is designed to be a perfect balance between a laptop False successful IBM ThinkPad was the first 3 [' successful', ' IBM', ' Think', 'Pad']
+869 450 The company which developed x -1 The company which developed ThinkPad IBM ThinkPad "[' X' '1' ' Carbon' ' X' '1' ' Yoga' ' is' ' a' ' laptop' ' that' ' is'
+ ' designed' ' to' ' be' ' a' ' perfect' ' balance' ' between' ' a'
+ ' laptop']" X 1 Carbon X 1 Yoga is a laptop that is designed to be a perfect balance between a laptop False " transforms into a Lenovo ThinkPad Edge laptop computer.
+" 5 [' transforms', ' into', ' a', ' Lenovo', ' Think', 'Pad']
+870 450 The company which developed x -1 The company which developed ThinkPad IBM ThinkPad "[' X' '1' ' Carbon' ' X' '1' ' Yoga' ' is' ' a' ' laptop' ' that' ' is'
+ ' designed' ' to' ' be' ' a' ' perfect' ' balance' ' between' ' a'
+ ' laptop']" X 1 Carbon X 1 Yoga is a laptop that is designed to be a perfect balance between a laptop False into a Lenovo ThinkPad Edge laptop 4 [' into', ' a', ' Lenovo', ' Think', 'Pad']
+871 450 The company which developed x -1 The company which developed ThinkPad IBM ThinkPad "[' X' '1' ' Carbon' ' X' '1' ' Yoga' ' is' ' a' ' laptop' ' that' ' is'
+ ' designed' ' to' ' be' ' a' ' perfect' ' balance' ' between' ' a'
+ ' laptop']" X 1 Carbon X 1 Yoga is a laptop that is designed to be a perfect balance between a laptop False " into a Lenovo ThinkPad Edge laptop computer.
+" 4 [' into', ' a', ' Lenovo', ' Think', 'Pad']
+872 451 The company which developed x -1 The company which developed Thunder Blade Sega Thunder Blade "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' staple' ' of' ' the'
+ ' gaming']" is a company that has been around for a long time , and has been a staple of the gaming False year, Sega's Thunder Blade switched between 5 "[' year', ',', ' Sega', ""'s"", ' Thunder', ' Blade']"
+873 451 The company which developed x -1 The company which developed Thunder Blade Sega Thunder Blade "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' staple' ' of' ' the'
+ ' gaming']" is a company that has been around for a long time , and has been a staple of the gaming False same year, Sega's Thunder Blade switched between 6 "[' same', ' year', ',', ' Sega', ""'s"", ' Thunder', ' Blade']"
+874 451 The company which developed x -1 The company which developed Thunder Blade Sega Thunder Blade "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' staple' ' of' ' the'
+ ' gaming']" is a company that has been around for a long time , and has been a staple of the gaming False same year, Sega's Thunder Blade switched between both 6 "[' same', ' year', ',', ' Sega', ""'s"", ' Thunder', ' Blade']"
+875 452 The company which developed x -1 The company which developed iWork Apple iWork "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' brought' ' us' ' the' ' i' 'Work' ' suite' ' of'
+ ' productivity' ' apps']" for iOS and Mac OS X , and the company that brought us the i Work suite of productivity apps False Apple ported its iWork suite from the 4 [' Apple', ' ported', ' its', ' i', 'Work']
+876 452 The company which developed x -1 The company which developed iWork Apple iWork "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' brought' ' us' ' the' ' i' 'Work' ' suite' ' of'
+ ' productivity' ' apps']" for iOS and Mac OS X , and the company that brought us the i Work suite of productivity apps False iMovie, iPhoto, and the iWork apps (Pages, Keynote, 9 [' i', 'Movie', ',', ' i', 'Photo', ',', ' and', ' the', ' i', 'Work']
+877 452 The company which developed x -1 The company which developed iWork Apple iWork "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' brought' ' us' ' the' ' i' 'Work' ' suite' ' of'
+ ' productivity' ' apps']" for iOS and Mac OS X , and the company that brought us the i Work suite of productivity apps False PC. Apple ported its iWork suite from 6 [' PC', '.', ' Apple', ' ported', ' its', ' i', 'Work']
+878 452 The company which developed x -1 The company which developed iWork Apple iWork "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' brought' ' us' ' the' ' i' 'Work' ' suite' ' of'
+ ' productivity' ' apps']" for iOS and Mac OS X , and the company that brought us the i Work suite of productivity apps False iMovie, and the iWork apps Pages, Keynote, 6 [' i', 'Movie', ',', ' and', ' the', ' i', 'Work']
+879 452 The company which developed x -1 The company which developed iWork Apple iWork "[' for' ' iOS' ' and' ' Mac' ' OS' ' X' ',' ' and' ' the' ' company'
+ ' that' ' brought' ' us' ' the' ' i' 'Work' ' suite' ' of'
+ ' productivity' ' apps']" for iOS and Mac OS X , and the company that brought us the i Work suite of productivity apps False Apple ported its iWork suite from 4 [' Apple', ' ported', ' its', ' i', 'Work']
+880 453 The company which developed x -1 The company which developed iPad Mini 3 Apple iPad Mini 3 "['G' ' is' ' a' ' new' ' iPad' ' mini' ' with' ' a' ' 3' 'G' ' modem' '.'
+ ' It' ' is' ' the' ' first' ' iPad' ' mini' ' with' ' 3']" G is a new iPad mini with a 3 G modem . It is the first iPad mini with 3 False iPad Air 2 and the iPad Mini 3. The iPad Air 2 7 [' iPad', ' Air', ' 2', ' and', ' the', ' iPad', ' Mini', ' 3']
+881 453 The company which developed x -1 The company which developed iPad Mini 3 Apple iPad Mini 3 "['G' ' is' ' a' ' new' ' iPad' ' mini' ' with' ' a' ' 3' 'G' ' modem' '.'
+ ' It' ' is' ' the' ' first' ' iPad' ' mini' ' with' ' 3']" G is a new iPad mini with a 3 G modem . It is the first iPad mini with 3 False Air 2 and the iPad Mini 3. The iPad Air 2 is 6 [' Air', ' 2', ' and', ' the', ' iPad', ' Mini', ' 3']
+882 453 The company which developed x -1 The company which developed iPad Mini 3 Apple iPad Mini 3 "['G' ' is' ' a' ' new' ' iPad' ' mini' ' with' ' a' ' 3' 'G' ' modem' '.'
+ ' It' ' is' ' the' ' first' ' iPad' ' mini' ' with' ' 3']" G is a new iPad mini with a 3 G modem . It is the first iPad mini with 3 False Air 2 and the iPad Mini 3. The iPad Air 6 [' Air', ' 2', ' and', ' the', ' iPad', ' Mini', ' 3']
+883 453 The company which developed x -1 The company which developed iPad Mini 3 Apple iPad Mini 3 "['G' ' is' ' a' ' new' ' iPad' ' mini' ' with' ' a' ' 3' 'G' ' modem' '.'
+ ' It' ' is' ' the' ' first' ' iPad' ' mini' ' with' ' 3']" G is a new iPad mini with a 3 G modem . It is the first iPad mini with 3 False Air 2 and the iPad Mini 3. The iPad 6 [' Air', ' 2', ' and', ' the', ' iPad', ' Mini', ' 3']
+884 459 The company which developed x -1 The company which developed Asteroids Deluxe Atari Asteroids Deluxe "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' released'
+ ' in' ' 1982' '.' ' It' ' was' ' a' ' huge' ' success' ',' ' and']" , the first game in the series , was released in 1982 . It was a huge success , and False versions and Asteroids Deluxe were ported to Microsoft's 4 [' versions', ' and', ' Aster', 'oids', ' Deluxe']
+885 459 The company which developed x -1 The company which developed Asteroids Deluxe Atari Asteroids Deluxe "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' released'
+ ' in' ' 1982' '.' ' It' ' was' ' a' ' huge' ' success' ',' ' and']" , the first game in the series , was released in 1982 . It was a huge success , and False 2600 versions and Asteroids Deluxe were ported to 5 [' 2600', ' versions', ' and', ' Aster', 'oids', ' Deluxe']
+886 459 The company which developed x -1 The company which developed Asteroids Deluxe Atari Asteroids Deluxe "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' released'
+ ' in' ' 1982' '.' ' It' ' was' ' a' ' huge' ' success' ',' ' and']" , the first game in the series , was released in 1982 . It was a huge success , and False Released in 1981, Asteroids Deluxe is the first 6 [' Released', ' in', ' 1981', ',', ' Aster', 'oids', ' Deluxe']
+887 459 The company which developed x -1 The company which developed Asteroids Deluxe Atari Asteroids Deluxe "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' released'
+ ' in' ' 1982' '.' ' It' ' was' ' a' ' huge' ' success' ',' ' and']" , the first game in the series , was released in 1982 . It was a huge success , and False versions and Asteroids Deluxe were ported to Microsoft's 4 [' versions', ' and', ' Aster', 'oids', ' Deluxe']
+888 460 The company which developed x -1 The company which developed Fortran IBM Fortran "[',' ' a' ' programming' ' language' ' for' ' the' ' P' 'DP' '-' '11' ','
+ ' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a']" , a programming language for the P DP - 11 , was founded in the early 1970 s by a False from Multics to a Fortran program so that 6 [' from', ' Mult', 'ics', ' to', ' a', ' Fort', 'ran']
+889 460 The company which developed x -1 The company which developed Fortran IBM Fortran "[',' ' a' ' programming' ' language' ' for' ' the' ' P' 'DP' '-' '11' ','
+ ' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a']" , a programming language for the P DP - 11 , was founded in the early 1970 s by a False languages, such as Fortran and C, aimed 5 [' languages', ',', ' such', ' as', ' Fort', 'ran']
+890 460 The company which developed x -1 The company which developed Fortran IBM Fortran "[',' ' a' ' programming' ' language' ' for' ' the' ' P' 'DP' '-' '11' ','
+ ' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a']" , a programming language for the P DP - 11 , was founded in the early 1970 s by a False formally taught Fortran in undergraduate 3 [' formally', ' taught', ' Fort', 'ran']
+891 460 The company which developed x -1 The company which developed Fortran IBM Fortran "[',' ' a' ' programming' ' language' ' for' ' the' ' P' 'DP' '-' '11' ','
+ ' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a']" , a programming language for the P DP - 11 , was founded in the early 1970 s by a False formally taught Fortran in undergraduate 3 [' formally', ' taught', ' Fort', 'ran']
+892 460 The company which developed x -1 The company which developed Fortran IBM Fortran "[',' ' a' ' programming' ' language' ' for' ' the' ' P' 'DP' '-' '11' ','
+ ' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a']" , a programming language for the P DP - 11 , was founded in the early 1970 s by a False fondness for Olde Fortran malt liquor, named 6 [' fond', 'ness', ' for', ' Old', 'e', ' Fort', 'ran']
+893 462 The company which developed x -1 The company which developed Final Fantasy Legend II Square Final Fantasy Legend II "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' founded'
+ ' in' ' 1987' '.' ' The' ' company' ' was' ' founded' ' by' ' H' 'iron']" , the first game in the series , was founded in 1987 . The company was founded by H iron False two sequels, Final Fantasy Legend II and Final 6 [' two', ' sequels', ',', ' Final', ' Fantasy', ' Legend', ' II']
+894 462 The company which developed x -1 The company which developed Final Fantasy Legend II Square Final Fantasy Legend II "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' founded'
+ ' in' ' 1987' '.' ' The' ' company' ' was' ' founded' ' by' ' H' 'iron']" , the first game in the series , was founded in 1987 . The company was founded by H iron False " World"" from Final Fantasy Legend II on July 9, 2011 at" 6 "[' World', '""', ' from', ' Final', ' Fantasy', ' Legend', ' II']"
+895 462 The company which developed x -1 The company which developed Final Fantasy Legend II Square Final Fantasy Legend II "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' founded'
+ ' in' ' 1987' '.' ' The' ' company' ' was' ' founded' ' by' ' H' 'iron']" , the first game in the series , was founded in 1987 . The company was founded by H iron False in the series, Final Fantasy Legend II he was assisted 7 [' in', ' the', ' series', ',', ' Final', ' Fantasy', ' Legend', ' II']
+896 462 The company which developed x -1 The company which developed Final Fantasy Legend II Square Final Fantasy Legend II "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' founded'
+ ' in' ' 1987' '.' ' The' ' company' ' was' ' founded' ' by' ' H' 'iron']" , the first game in the series , was founded in 1987 . The company was founded by H iron False the series, Final Fantasy Legend II he was assisted by 6 [' the', ' series', ',', ' Final', ' Fantasy', ' Legend', ' II']
+897 462 The company which developed x -1 The company which developed Final Fantasy Legend II Square Final Fantasy Legend II "[',' ' the' ' first' ' game' ' in' ' the' ' series' ',' ' was' ' founded'
+ ' in' ' 1987' '.' ' The' ' company' ' was' ' founded' ' by' ' H' 'iron']" , the first game in the series , was founded in 1987 . The company was founded by H iron False " World"" from Final Fantasy Legend II on July 9, 2011 at" 6 "[' World', '""', ' from', ' Final', ' Fantasy', ' Legend', ' II']"
+898 467 The company which developed x -1 The company which developed WebKit Apple WebKit "[' is' ' a' ' web' ' browser' ' engine' ' that' ' is' ' used' ' by'
+ ' Apple' ""'s"" ' Safari' ',' ' Google' ""'s"" ' Chrome' ',' ' and' ' other'
+ ' browsers']" is a web browser engine that is used by Apple 's Safari , Google 's Chrome , and other browsers True to include a WebKit-based web browser, 4 [' to', ' include', ' a', ' Web', 'Kit']
+899 467 The company which developed x -1 The company which developed WebKit Apple WebKit "[' is' ' a' ' web' ' browser' ' engine' ' that' ' is' ' used' ' by'
+ ' Apple' ""'s"" ' Safari' ',' ' Google' ""'s"" ' Chrome' ',' ' and' ' other'
+ ' browsers']" is a web browser engine that is used by Apple 's Safari , Google 's Chrome , and other browsers True first to include a WebKit-based web browser, 5 [' first', ' to', ' include', ' a', ' Web', 'Kit']
+900 467 The company which developed x -1 The company which developed WebKit Apple WebKit "[' is' ' a' ' web' ' browser' ' engine' ' that' ' is' ' used' ' by'
+ ' Apple' ""'s"" ' Safari' ',' ' Google' ""'s"" ' Chrome' ',' ' and' ' other'
+ ' browsers']" is a web browser engine that is used by Apple 's Safari , Google 's Chrome , and other browsers True based on the WebKit web browser engine, 4 [' based', ' on', ' the', ' Web', 'Kit']
+901 467 The company which developed x -1 The company which developed WebKit Apple WebKit "[' is' ' a' ' web' ' browser' ' engine' ' that' ' is' ' used' ' by'
+ ' Apple' ""'s"" ' Safari' ',' ' Google' ""'s"" ' Chrome' ',' ' and' ' other'
+ ' browsers']" is a web browser engine that is used by Apple 's Safari , Google 's Chrome , and other browsers True added support for WebKit as an alternative 4 [' added', ' support', ' for', ' Web', 'Kit']
+902 467 The company which developed x -1 The company which developed WebKit Apple WebKit "[' is' ' a' ' web' ' browser' ' engine' ' that' ' is' ' used' ' by'
+ ' Apple' ""'s"" ' Safari' ',' ' Google' ""'s"" ' Chrome' ',' ' and' ' other'
+ ' browsers']" is a web browser engine that is used by Apple 's Safari , Google 's Chrome , and other browsers True added support for WebKit as an alternative 4 [' added', ' support', ' for', ' Web', 'Kit']
+903 468 The company which developed x -1 The company which developed Clockwork Knight Sega Clockwork Knight "[' is' ' a' ' small' ',' ' independent' ' game' ' studio' ' based' ' in'
+ ' the' ' UK' '.' ' We' ' are' ' currently' ' working' ' on' ' a' ' game'
+ ' called']" is a small , independent game studio based in the UK . We are currently working on a game called False include both Clockwork Knight and Panzer Dragoon, 3 [' include', ' both', ' Clockwork', ' Knight']
+904 470 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False " twelve original Amiibo in November 2014.
+" 4 [' twelve', ' original', ' Am', 'i', 'ibo']
+905 470 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False 2015, a custom Amiibo featuring Iwata's 6 [' 2015', ',', ' a', ' custom', ' Am', 'i', 'ibo']
+906 470 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False developed into the Amiibo line of figures 5 [' developed', ' into', ' the', ' Am', 'i', 'ibo']
+907 470 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False " the twelve original Amiibo in November 2014.
+" 5 [' the', ' twelve', ' original', ' Am', 'i', 'ibo']
+908 470 The company which developed x -1 The company which developed Amiibo Nintendo Amiibo "[' figures' ' for' ' the' ' Wii' ' U' ' and' ' 3' 'DS' ',' ' and' ' the'
+ ' company' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' for']" figures for the Wii U and 3 DS , and the company is now working on a new game for False utilize Nintendo's Amiibo platform. The new 5 "[' utilize', ' Nintendo', ""'s"", ' Am', 'i', 'ibo']"
+909 475 The company which developed x -1 The company which developed PostScript Adobe PostScript "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False and PDF or PostScript files for 4 [' and', ' PDF', ' or', ' Post', 'Script']
+910 475 The company which developed x -1 The company which developed PostScript Adobe PostScript "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False ImageWriter LQ and other PostScript laser printers. 7 [' Image', 'Writer', ' L', 'Q', ' and', ' other', ' Post', 'Script']
+911 475 The company which developed x -1 The company which developed PostScript Adobe PostScript "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False Originally released as a PostScript Type 1, it has been 5 [' Originally', ' released', ' as', ' a', ' Post', 'Script']
+912 475 The company which developed x -1 The company which developed PostScript Adobe PostScript "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False LQ and other PostScript laser printers. 5 [' L', 'Q', ' and', ' other', ' Post', 'Script']
+913 475 The company which developed x -1 The company which developed PostScript Adobe PostScript "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False released as a PostScript Type 1, it has been 4 [' released', ' as', ' a', ' Post', 'Script']
+914 478 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False Cut for iPhone and iPod Touch was released on January 5 [' Cut', ' for', ' iPhone', ' and', ' iPod', ' Touch']
+915 478 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False application for iPod Touch that presents visitors 3 [' application', ' for', ' iPod', ' Touch']
+916 478 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False for the iPhone and iPod Touch was announced 5 [' for', ' the', ' iPhone', ' and', ' iPod', ' Touch']
+917 478 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False 3GS and the iPod Touch (fourth generation) 5 [' 3', 'GS', ' and', ' the', ' iPod', ' Touch']
+918 478 The company which developed x -1 The company which developed iPod Touch Apple iPod Touch "[',' ' the' ' iPod' ' Touch' ',' ' and' ' the' ' iPhone' '.' '\n' '\n'
+ 'The' ' iPod' ' Touch' ' is' ' a' ' touch' '-' 'screen' ' iPod']" ", the iPod Touch , and the iPhone .
+
+ The iPod Touch is a touch - screen iPod" False versions) and iPod Touch 5G (four major 4 [' versions', ')', ' and', ' iPod', ' Touch']
+919 484 The company which developed x -1 The company which developed Mac OS X 10.2 Apple Mac OS X 10.2 "['.' '8' ',' ' the' ' first' ' version' ' of' ' Mac' ' OS' ' X' ' to'
+ ' be' ' released' ' after' ' the' ' release' ' of' ' Mac' ' OS' ' X']" . 8 , the first version of Mac OS X to be released after the release of Mac OS X False Macintoshes running the Mac OS X 10.2 operating system or 11 [' Mac', 'int', 'os', 'hes', ' running', ' the', ' Mac', ' OS', ' X', ' 10', '.', '2']
+920 484 The company which developed x -1 The company which developed Mac OS X 10.2 Apple Mac OS X 10.2 "['.' '8' ',' ' the' ' first' ' version' ' of' ' Mac' ' OS' ' X' ' to'
+ ' be' ' released' ' after' ' the' ' release' ' of' ' Mac' ' OS' ' X']" . 8 , the first version of Mac OS X to be released after the release of Mac OS X False Macintoshes running the Mac OS X 10.2 operating system 11 [' Mac', 'int', 'os', 'hes', ' running', ' the', ' Mac', ' OS', ' X', ' 10', '.', '2']
+921 485 The company which developed x -1 The company which developed Super Audio CD Sony Super Audio CD "[' (' 'S' 'AC' 'D' ')' ' is' ' a' ' Japanese' ' company' ' that' ' has'
+ ' been' ' making' ' audio' ' CDs' ' since' ' the' ' late' ' 1980' 's']" ( S AC D ) is a Japanese company that has been making audio CDs since the late 1980 s False released alone on Super Audio CD later that year, 5 [' released', ' alone', ' on', ' Super', ' Audio', ' CD']
+922 485 The company which developed x -1 The company which developed Super Audio CD Sony Super Audio CD "[' (' 'S' 'AC' 'D' ')' ' is' ' a' ' Japanese' ' company' ' that' ' has'
+ ' been' ' making' ' audio' ' CDs' ' since' ' the' ' late' ' 1980' 's']" ( S AC D ) is a Japanese company that has been making audio CDs since the late 1980 s False containing a CD / Super Audio CD of a new stereo 6 [' containing', ' a', ' CD', ' /', ' Super', ' Audio', ' CD']
+923 485 The company which developed x -1 The company which developed Super Audio CD Sony Super Audio CD "[' (' 'S' 'AC' 'D' ')' ' is' ' a' ' Japanese' ' company' ' that' ' has'
+ ' been' ' making' ' audio' ' CDs' ' since' ' the' ' late' ' 1980' 's']" ( S AC D ) is a Japanese company that has been making audio CDs since the late 1980 s False multichannel (surround sound) Super Audio CD format. Blanton Alspaugh 10 [' mult', 'ich', 'annel', ' (', 'sur', 'round', ' sound', ')', ' Super', ' Audio', ' CD']
+924 485 The company which developed x -1 The company which developed Super Audio CD Sony Super Audio CD "[' (' 'S' 'AC' 'D' ')' ' is' ' a' ' Japanese' ' company' ' that' ' has'
+ ' been' ' making' ' audio' ' CDs' ' since' ' the' ' late' ' 1980' 's']" ( S AC D ) is a Japanese company that has been making audio CDs since the late 1980 s False containing a CD / Super Audio CD of a new stereo 6 [' containing', ' a', ' CD', ' /', ' Super', ' Audio', ' CD']
+925 485 The company which developed x -1 The company which developed Super Audio CD Sony Super Audio CD "[' (' 'S' 'AC' 'D' ')' ' is' ' a' ' Japanese' ' company' ' that' ' has'
+ ' been' ' making' ' audio' ' CDs' ' since' ' the' ' late' ' 1980' 's']" ( S AC D ) is a Japanese company that has been making audio CDs since the late 1980 s False released alone on Super Audio CD later that year, 5 [' released', ' alone', ' on', ' Super', ' Audio', ' CD']
+926 489 The company which developed x -1 The company which developed Windows Genuine Advantage Microsoft Windows Genuine Advantage "[' (' 'W' 'GA' ')' ' is' ' a' ' Microsoft' ' program' ' that' ' allows'
+ ' you' ' to' ' check' ' the' ' authenticity' ' of' ' your' ' Windows'
+ ' operating' ' system']" ( W GA ) is a Microsoft program that allows you to check the authenticity of your Windows operating system True installation using the Windows Genuine Advantage system. If said 6 [' installation', ' using', ' the', ' Windows', ' Gen', 'uine', ' Advantage']
+927 490 The company which developed x -1 The company which developed Dashboard Apple Dashboard "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False the Xbox 360 Dashboard and a voucher for 4 [' the', ' Xbox', ' 360', ' Dash', 'board']
+928 490 The company which developed x -1 The company which developed Dashboard Apple Dashboard "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False " ""Paradise by the Dashboard Light"", which also" 7 "[' ""', 'Par', 'ad', 'ise', ' by', ' the', ' Dash', 'board']"
+929 490 The company which developed x -1 The company which developed Dashboard Apple Dashboard "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False (2001) and Dashboard Confessional's 5 [' (', '2001', ')', ' and', ' Dash', 'board']
+930 490 The company which developed x -1 The company which developed Dashboard Apple Dashboard "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False Anything toured with Dashboard Confessional; Ben 4 [' Anything', ' toured', ' with', ' Dash', 'board']
+931 490 The company which developed x -1 The company which developed Dashboard Apple Dashboard "[' is' ' a' ' company' ' that' ' has' ' been' ' around' ' for' ' a'
+ ' long' ' time' ',' ' and' ' has' ' been' ' a' ' leader' ' in' ' the'
+ ' industry']" is a company that has been around for a long time , and has been a leader in the industry False " ""Paradise by the Dashboard Light"" (""It" 7 "[' ""', 'Par', 'ad', 'ise', ' by', ' the', ' Dash', 'board']"
+932 492 The company which developed x -1 The company which developed iOS 7 Apple iOS 7 "['.' '0' '.' '3' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 3 , the latest version of the operating system , is now available for download . The False device to ship with iOS 7, which introduced 5 [' device', ' to', ' ship', ' with', ' iOS', ' 7']
+933 492 The company which developed x -1 The company which developed iOS 7 Apple iOS 7 "['.' '0' '.' '3' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 3 , the latest version of the operating system , is now available for download . The False beta release of iOS 7 and leaked packaging 4 [' beta', ' release', ' of', ' iOS', ' 7']
+934 492 The company which developed x -1 The company which developed iOS 7 Apple iOS 7 "['.' '0' '.' '3' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 3 , the latest version of the operating system , is now available for download . The False Apple released iOS 7 and the iPhone 3 [' Apple', ' released', ' iOS', ' 7']
+935 492 The company which developed x -1 The company which developed iOS 7 Apple iOS 7 "['.' '0' '.' '3' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 3 , the latest version of the operating system , is now available for download . The False device to ship with iOS 7, which introduced 5 [' device', ' to', ' ship', ' with', ' iOS', ' 7']
+936 492 The company which developed x -1 The company which developed iOS 7 Apple iOS 7 "['.' '0' '.' '3' ',' ' the' ' latest' ' version' ' of' ' the' ' operating'
+ ' system' ',' ' is' ' now' ' available' ' for' ' download' '.' ' The']" . 0 . 3 , the latest version of the operating system , is now available for download . The False Apple released iOS 7 and the iPhone 3 [' Apple', ' released', ' iOS', ' 7']
+937 493 The company which developed x -1 The company which developed FairPlay Apple FairPlay "[',' ' a' ' new' ' technology' ' that' ' allows' ' you' ' to' ' play'
+ ' your' ' favorite' ' games' ' on' ' your' ' PC' ',' ' Mac' ',' ' and'
+ ' mobile']" , a new technology that allows you to play your favorite games on your PC , Mac , and mobile False to the proprietary FairPlay encoding in 4 [' to', ' the', ' proprietary', ' Fair', 'Play']
+938 493 The company which developed x -1 The company which developed FairPlay Apple FairPlay "[',' ' a' ' new' ' technology' ' that' ' allows' ' you' ' to' ' play'
+ ' your' ' favorite' ' games' ' on' ' your' ' PC' ',' ' Mac' ',' ' and'
+ ' mobile']" , a new technology that allows you to play your favorite games on your PC , Mac , and mobile False discontinued its FairPlay digital rights 3 [' discontinued', ' its', ' Fair', 'Play']
+939 493 The company which developed x -1 The company which developed FairPlay Apple FairPlay "[',' ' a' ' new' ' technology' ' that' ' allows' ' you' ' to' ' play'
+ ' your' ' favorite' ' games' ' on' ' your' ' PC' ',' ' Mac' ',' ' and'
+ ' mobile']" , a new technology that allows you to play your favorite games on your PC , Mac , and mobile False discontinued its FairPlay digital rights 3 [' discontinued', ' its', ' Fair', 'Play']
+940 493 The company which developed x -1 The company which developed FairPlay Apple FairPlay "[',' ' a' ' new' ' technology' ' that' ' allows' ' you' ' to' ' play'
+ ' your' ' favorite' ' games' ' on' ' your' ' PC' ',' ' Mac' ',' ' and'
+ ' mobile']" , a new technology that allows you to play your favorite games on your PC , Mac , and mobile False Apple discontinued its FairPlay digital rights 4 [' Apple', ' discontinued', ' its', ' Fair', 'Play']
+941 493 The company which developed x -1 The company which developed FairPlay Apple FairPlay "[',' ' a' ' new' ' technology' ' that' ' allows' ' you' ' to' ' play'
+ ' your' ' favorite' ' games' ' on' ' your' ' PC' ',' ' Mac' ',' ' and'
+ ' mobile']" , a new technology that allows you to play your favorite games on your PC , Mac , and mobile False to the proprietary FairPlay encoding in 2004, resulting 4 [' to', ' the', ' proprietary', ' Fair', 'Play']
+942 494 The company which developed x -1 The company which developed Itanium Intel Itanium "[' is' ' now' ' owned' ' by' ' Intel' '.' '\n' '\n' 'The' ' company'
+ ' was' ' founded' ' in' ' the' ' early' ' 1980' 's' ' by' ' a' ' group']" " is now owned by Intel .
+
+ The company was founded in the early 1980 s by a group" True 1 ['It', 'anium']
+943 494 The company which developed x -1 The company which developed Itanium Intel Itanium "[' is' ' now' ' owned' ' by' ' Intel' '.' '\n' '\n' 'The' ' company'
+ ' was' ' founded' ' in' ' the' ' early' ' 1980' 's' ' by' ' a' ' group']" " is now owned by Intel .
+
+ The company was founded in the early 1980 s by a group" True to run on the Itanium, Intel supported 5 [' to', ' run', ' on', ' the', ' It', 'anium']
+944 494 The company which developed x -1 The company which developed Itanium Intel Itanium "[' is' ' now' ' owned' ' by' ' Intel' '.' '\n' '\n' 'The' ' company'
+ ' was' ' founded' ' in' ' the' ' early' ' 1980' 's' ' by' ' a' ' group']" " is now owned by Intel .
+
+ The company was founded in the early 1980 s by a group" True " processor — ""How the Itanium Killed the Computer" 6 "[' processor', ' —', ' ""', 'How', ' the', ' It', 'anium']"
+945 494 The company which developed x -1 The company which developed Itanium Intel Itanium "[' is' ' now' ' owned' ' by' ' Intel' '.' '\n' '\n' 'The' ' company'
+ ' was' ' founded' ' in' ' the' ' early' ' 1980' 's' ' by' ' a' ' group']" " is now owned by Intel .
+
+ The company was founded in the early 1980 s by a group" True 1 ['It', 'anium']
+946 494 The company which developed x -1 The company which developed Itanium Intel Itanium "[' is' ' now' ' owned' ' by' ' Intel' '.' '\n' '\n' 'The' ' company'
+ ' was' ' founded' ' in' ' the' ' early' ' 1980' 's' ' by' ' a' ' group']" " is now owned by Intel .
+
+ The company was founded in the early 1980 s by a group" True 1 ['It', 'anium']
+947 496 The company which developed x -1 The company which developed iChat Apple iChat "[' is' ' called' ' i' 'Chat' '.' '\n' '\n' 'i' 'Chat' ' is' ' a' ' free'
+ ' application' ' that' ' allows' ' you' ' to' ' chat' ' with' ' your']" " is called i Chat .
+
+ i Chat is a free application that allows you to chat with your" False Scans MobileMe ®, iChat ® and other IMs 7 [' Sc', 'ans', ' Mobile', 'Me', ' ®', ',', ' i', 'Chat']
+948 496 The company which developed x -1 The company which developed iChat Apple iChat "[' is' ' called' ' i' 'Chat' '.' '\n' '\n' 'i' 'Chat' ' is' ' a' ' free'
+ ' application' ' that' ' allows' ' you' ' to' ' chat' ' with' ' your']" " is called i Chat .
+
+ i Chat is a free application that allows you to chat with your" False MobileMe ®, iChat ® and other IMs 5 [' Mobile', 'Me', ' ®', ',', ' i', 'Chat']
+949 497 The company which developed x -1 The company which developed Xbox Live Microsoft Xbox Live "[' Arcade' ' game' ',' ' and' ' the' ' first' ' game' ' to' ' be'
+ ' released' ' on' ' the' ' Xbox' ' 360' '.' '\n' '\n' 'The' ' game' ' is']" " Arcade game , and the first game to be released on the Xbox 360 .
+
+ The game is" False download via the Xbox Live Marketplace. Three 4 [' download', ' via', ' the', ' Xbox', ' Live']
+950 497 The company which developed x -1 The company which developed Xbox Live Microsoft Xbox Live "[' Arcade' ' game' ',' ' and' ' the' ' first' ' game' ' to' ' be'
+ ' released' ' on' ' the' ' Xbox' ' 360' '.' '\n' '\n' 'The' ' game' ' is']" " Arcade game , and the first game to be released on the Xbox 360 .
+
+ The game is" False 360 also provided Xbox Live support for 4 [' 360', ' also', ' provided', ' Xbox', ' Live']
+951 497 The company which developed x -1 The company which developed Xbox Live Microsoft Xbox Live "[' Arcade' ' game' ',' ' and' ' the' ' first' ' game' ' to' ' be'
+ ' released' ' on' ' the' ' Xbox' ' 360' '.' '\n' '\n' 'The' ' game' ' is']" " Arcade game , and the first game to be released on the Xbox 360 .
+
+ The game is" False PlayStation Network and Xbox Live Arcade, allowing 4 [' PlayStation', ' Network', ' and', ' Xbox', ' Live']
+952 497 The company which developed x -1 The company which developed Xbox Live Microsoft Xbox Live "[' Arcade' ' game' ',' ' and' ' the' ' first' ' game' ' to' ' be'
+ ' released' ' on' ' the' ' Xbox' ' 360' '.' '\n' '\n' 'The' ' game' ' is']" " Arcade game , and the first game to be released on the Xbox 360 .
+
+ The game is" False release the game on Xbox Live in mid-2007, with 5 [' release', ' the', ' game', ' on', ' Xbox', ' Live']
+953 497 The company which developed x -1 The company which developed Xbox Live Microsoft Xbox Live "[' Arcade' ' game' ',' ' and' ' the' ' first' ' game' ' to' ' be'
+ ' released' ' on' ' the' ' Xbox' ' 360' '.' '\n' '\n' 'The' ' game' ' is']" " Arcade game , and the first game to be released on the Xbox 360 .
+
+ The game is" False which were for the Xbox Live Arcade. It sold 5 [' which', ' were', ' for', ' the', ' Xbox', ' Live']
+954 498 The company which developed x -1 The company which developed System 7 Apple System 7 "['.' '5' '.' '0' '.' '0' ' is' ' a' ' software' ' company' ' that'
+ ' develops' ' and' ' sells' ' a' ' range' ' of' ' software' ' products'
+ ' for']" . 5 . 0 . 0 is a software company that develops and sells a range of software products for False Mover utility. System 7 also fixed this. 5 [' M', 'over', ' utility', '.', ' System', ' 7']
+955 498 The company which developed x -1 The company which developed System 7 Apple System 7 "['.' '5' '.' '0' '.' '0' ' is' ' a' ' software' ' company' ' that'
+ ' develops' ' and' ' sells' ' a' ' range' ' of' ' software' ' products'
+ ' for']" . 5 . 0 . 0 is a software company that develops and sells a range of software products for False and Macintosh System 7 operating systems 3 [' and', ' Macintosh', ' System', ' 7']
+956 498 The company which developed x -1 The company which developed System 7 Apple System 7 "['.' '5' '.' '0' '.' '0' ' is' ' a' ' software' ' company' ' that'
+ ' develops' ' and' ' sells' ' a' ' range' ' of' ' software' ' products'
+ ' for']" . 5 . 0 . 0 is a software company that develops and sells a range of software products for False or Macintosh System 7 operating systems; 3 [' or', ' Macintosh', ' System', ' 7']
+957 498 The company which developed x -1 The company which developed System 7 Apple System 7 "['.' '5' '.' '0' '.' '0' ' is' ' a' ' software' ' company' ' that'
+ ' develops' ' and' ' sells' ' a' ' range' ' of' ' software' ' products'
+ ' for']" . 5 . 0 . 0 is a software company that develops and sells a range of software products for False and Macintosh System 7 operating systems 3 [' and', ' Macintosh', ' System', ' 7']
+958 498 The company which developed x -1 The company which developed System 7 Apple System 7 "['.' '5' '.' '0' '.' '0' ' is' ' a' ' software' ' company' ' that'
+ ' develops' ' and' ' sells' ' a' ' range' ' of' ' software' ' products'
+ ' for']" . 5 . 0 . 0 is a software company that develops and sells a range of software products for False compensate, the System 7 operating system 4 [' compensate', ',', ' the', ' System', ' 7']
+959 499 The company which developed x -1 The company which developed MS-DOS Microsoft MS-DOS "[' 6' '.' '0' ',' ' the' ' first' ' version' ' of' ' the' ' operating'
+ ' system' ' that' ' was' ' released' ' in' ' 1985' ',' ' is'
+ ' celebrating' ' its']" 6 . 0 , the first version of the operating system that was released in 1985 , is celebrating its False II, Macintosh and MS-DOS versions were also 6 [' II', ',', ' Macintosh', ' and', ' MS', '-', 'DOS']
+960 499 The company which developed x -1 The company which developed MS-DOS Microsoft MS-DOS "[' 6' '.' '0' ',' ' the' ' first' ' version' ' of' ' the' ' operating'
+ ' system' ' that' ' was' ' released' ' in' ' 1985' ',' ' is'
+ ' celebrating' ' its']" 6 . 0 , the first version of the operating system that was released in 1985 , is celebrating its False the PC DOS / MS-DOS came in a close 6 [' the', ' PC', ' DOS', ' /', ' MS', '-', 'DOS']
+961 499 The company which developed x -1 The company which developed MS-DOS Microsoft MS-DOS "[' 6' '.' '0' ',' ' the' ' first' ' version' ' of' ' the' ' operating'
+ ' system' ' that' ' was' ' released' ' in' ' 1985' ',' ' is'
+ ' celebrating' ' its']" 6 . 0 , the first version of the operating system that was released in 1985 , is celebrating its False Windows 3.1 and MS-DOS, then for 7 [' Windows', ' 3', '.', '1', ' and', ' MS', '-', 'DOS']
+962 499 The company which developed x -1 The company which developed MS-DOS Microsoft MS-DOS "[' 6' '.' '0' ',' ' the' ' first' ' version' ' of' ' the' ' operating'
+ ' system' ' that' ' was' ' released' ' in' ' 1985' ',' ' is'
+ ' celebrating' ' its']" 6 . 0 , the first version of the operating system that was released in 1985 , is celebrating its False Marketing included an MS-DOS computer game, part 5 [' Marketing', ' included', ' an', ' MS', '-', 'DOS']
+963 499 The company which developed x -1 The company which developed MS-DOS Microsoft MS-DOS "[' 6' '.' '0' ',' ' the' ' first' ' version' ' of' ' the' ' operating'
+ ' system' ' that' ' was' ' released' ' in' ' 1985' ',' ' is'
+ ' celebrating' ' its']" 6 . 0 , the first version of the operating system that was released in 1985 , is celebrating its False Marketing included an MS-DOS computer game, part 5 [' Marketing', ' included', ' an', ' MS', '-', 'DOS']
+964 501 The company which developed x -1 The company which developed Yoshi's Universal Gravitation Nintendo Yoshi's Universal Gravitation "[',' ' a' ' game' ' that' ' was' ' released' ' in' ' Japan' ' in' ' 1995'
+ ',' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' Yoshi']" , a game that was released in Japan in 1995 , is now working on a new game called Yoshi False other Yoshi game — Yoshi's Universal Gravitation — for the GBA. Universal 8 "[' other', ' Yoshi', ' game', ' —', ' Yoshi', ""'s"", ' Universal', ' Grav', 'itation']"
+965 501 The company which developed x -1 The company which developed Yoshi's Universal Gravitation Nintendo Yoshi's Universal Gravitation "[',' ' a' ' game' ' that' ' was' ' released' ' in' ' Japan' ' in' ' 1995'
+ ',' ' is' ' now' ' working' ' on' ' a' ' new' ' game' ' called' ' Yoshi']" , a game that was released in Japan in 1995 , is now working on a new game called Yoshi False other Yoshi game — Yoshi's Universal Gravitation — for the GBA. Universal 8 "[' other', ' Yoshi', ' game', ' —', ' Yoshi', ""'s"", ' Universal', ' Grav', 'itation']"
+966 503 The company which developed x -1 The company which developed Applesoft BASIC Microsoft Applesoft BASIC "[' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a' ' group'
+ ' of' ' students' ' at' ' the' ' University' ' of' ' Cambridge' '.'
+ ' The' ' company']" was founded in the early 1970 s by a group of students at the University of Cambridge . The company False and learned Applesoft BASIC and 6502 Assembly on 6 [' and', ' learned', ' App', 'les', 'oft', ' BAS', 'IC']
+967 503 The company which developed x -1 The company which developed Applesoft BASIC Microsoft Applesoft BASIC "[' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a' ' group'
+ ' of' ' students' ' at' ' the' ' University' ' of' ' Cambridge' '.'
+ ' The' ' company']" was founded in the early 1970 s by a group of students at the University of Cambridge . The company False school, and learned Applesoft BASIC and 6502 Assembly 8 [' school', ',', ' and', ' learned', ' App', 'les', 'oft', ' BAS', 'IC']
+968 503 The company which developed x -1 The company which developed Applesoft BASIC Microsoft Applesoft BASIC "[' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a' ' group'
+ ' of' ' students' ' at' ' the' ' University' ' of' ' Cambridge' '.'
+ ' The' ' company']" was founded in the early 1970 s by a group of students at the University of Cambridge . The company False school, and learned Applesoft BASIC and 6502 Assembly 8 [' school', ',', ' and', ' learned', ' App', 'les', 'oft', ' BAS', 'IC']
+969 503 The company which developed x -1 The company which developed Applesoft BASIC Microsoft Applesoft BASIC "[' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a' ' group'
+ ' of' ' students' ' at' ' the' ' University' ' of' ' Cambridge' '.'
+ ' The' ' company']" was founded in the early 1970 s by a group of students at the University of Cambridge . The company False school, and learned Applesoft BASIC and 6502 Assembly 8 [' school', ',', ' and', ' learned', ' App', 'les', 'oft', ' BAS', 'IC']
+970 503 The company which developed x -1 The company which developed Applesoft BASIC Microsoft Applesoft BASIC "[' was' ' founded' ' in' ' the' ' early' ' 1970' 's' ' by' ' a' ' group'
+ ' of' ' students' ' at' ' the' ' University' ' of' ' Cambridge' '.'
+ ' The' ' company']" was founded in the early 1970 s by a group of students at the University of Cambridge . The company False school, and learned Applesoft BASIC and 6502 Assembly 8 [' school', ',', ' and', ' learned', ' App', 'les', 'oft', ' BAS', 'IC']
+971 504 The company which developed x -1 The company which developed Final Fantasy X Square Final Fantasy X "['-' '2' ',' ' the' ' sequel' ' to' ' the' ' popular' ' Final' ' Fantasy'
+ ' X' ',' ' is' ' now' ' available' ' for' ' the' ' PlayStation' ' 2' '.']" - 2 , the sequel to the popular Final Fantasy X , is now available for the PlayStation 2 . False 2 ['Final', ' Fantasy', ' X']
+972 504 The company which developed x -1 The company which developed Final Fantasy X Square Final Fantasy X "['-' '2' ',' ' the' ' sequel' ' to' ' the' ' popular' ' Final' ' Fantasy'
+ ' X' ',' ' is' ' now' ' available' ' for' ' the' ' PlayStation' ' 2' '.']" - 2 , the sequel to the popular Final Fantasy X , is now available for the PlayStation 2 . False Enix announced that Final Fantasy X would be re-released 5 [' Enix', ' announced', ' that', ' Final', ' Fantasy', ' X']
+973 504 The company which developed x -1 The company which developed Final Fantasy X Square Final Fantasy X "['-' '2' ',' ' the' ' sequel' ' to' ' the' ' popular' ' Final' ' Fantasy'
+ ' X' ',' ' is' ' now' ' available' ' for' ' the' ' PlayStation' ' 2' '.']" - 2 , the sequel to the popular Final Fantasy X , is now available for the PlayStation 2 . False 2 ['Final', ' Fantasy', ' X']
+974 504 The company which developed x -1 The company which developed Final Fantasy X Square Final Fantasy X "['-' '2' ',' ' the' ' sequel' ' to' ' the' ' popular' ' Final' ' Fantasy'
+ ' X' ',' ' is' ' now' ' available' ' for' ' the' ' PlayStation' ' 2' '.']" - 2 , the sequel to the popular Final Fantasy X , is now available for the PlayStation 2 . False positive response. Final Fantasy X received praise, 5 [' positive', ' response', '.', ' Final', ' Fantasy', ' X']
+975 504 The company which developed x -1 The company which developed Final Fantasy X Square Final Fantasy X "['-' '2' ',' ' the' ' sequel' ' to' ' the' ' popular' ' Final' ' Fantasy'
+ ' X' ',' ' is' ' now' ' available' ' for' ' the' ' PlayStation' ' 2' '.']" - 2 , the sequel to the popular Final Fantasy X , is now available for the PlayStation 2 . False 2 ['Final', ' Fantasy', ' X']
+976 505 The company which developed x -1 The company which developed Zaxxon Sega Zaxxon "[',' ' a' ' new' ' kind' ' of' ' game' ' that' ' is' ' a' ' cross'
+ ' between' ' a' ' puzzle' ' game' ' and' ' a' ' strategy' ' game' ','
+ ' is']" , a new kind of game that is a cross between a puzzle game and a strategy game , is False earlier examples include Zaxxon (1982), Q * bert (1982), 5 [' earlier', ' examples', ' include', ' Z', 'ax', 'xon']
+977 505 The company which developed x -1 The company which developed Zaxxon Sega Zaxxon "[',' ' a' ' new' ' kind' ' of' ' game' ' that' ' is' ' a' ' cross'
+ ' between' ' a' ' puzzle' ' game' ' and' ' a' ' strategy' ' game' ','
+ ' is']" , a new kind of game that is a cross between a puzzle game and a strategy game , is False examples include Zaxxon (1982), Q * bert (1982), 4 [' examples', ' include', ' Z', 'ax', 'xon']
+978 505 The company which developed x -1 The company which developed Zaxxon Sega Zaxxon "[',' ' a' ' new' ' kind' ' of' ' game' ' that' ' is' ' a' ' cross'
+ ' between' ' a' ' puzzle' ' game' ' and' ' a' ' strategy' ' game' ','
+ ' is']" , a new kind of game that is a cross between a puzzle game and a strategy game , is False earlier examples include Zaxxon (1982), Q * bert (1982), 5 [' earlier', ' examples', ' include', ' Z', 'ax', 'xon']
+979 505 The company which developed x -1 The company which developed Zaxxon Sega Zaxxon "[',' ' a' ' new' ' kind' ' of' ' game' ' that' ' is' ' a' ' cross'
+ ' between' ' a' ' puzzle' ' game' ' and' ' a' ' strategy' ' game' ','
+ ' is']" , a new kind of game that is a cross between a puzzle game and a strategy game , is False examples include Zaxxon (1982), Q * bert 4 [' examples', ' include', ' Z', 'ax', 'xon']
+980 506 The company which developed x -1 The company which developed TextEdit Apple TextEdit "[' is' ' now' ' owned' ' by' ' Apple' '.' '\n' '\n' '~~' '~' '\n' '\n'
+ '~~' '~' '\n' '\n' '~~' '~' '\n' '\n']" " is now owned by Apple .
+
+ ~~ ~
+
+ ~~ ~
+
+ ~~ ~
+
+" True published a TextEdit file as an open letter 3 [' published', ' a', ' Text', 'Edit']
+981 507 The company which developed x -1 The company which developed App Store Apple App Store "[' for' ' Android' ' is' ' now' ' available' ' for' ' download' ' on'
+ ' Google' ' Play' '.' '\n' '\n' 'The' ' app' ' is' ' free' ' to'
+ ' download' ' and']" " for Android is now available for download on Google Play .
+
+ The app is free to download and" False pulled from the iOS App Store in February 2015 5 [' pulled', ' from', ' the', ' iOS', ' App', ' Store']
+982 507 The company which developed x -1 The company which developed App Store Apple App Store "[' for' ' Android' ' is' ' now' ' available' ' for' ' download' ' on'
+ ' Google' ' Play' '.' '\n' '\n' 'The' ' app' ' is' ' free' ' to'
+ ' download' ' and']" " for Android is now available for download on Google Play .
+
+ The app is free to download and" False via the Mac App Store on April 26, 2012 4 [' via', ' the', ' Mac', ' App', ' Store']
+983 507 The company which developed x -1 The company which developed App Store Apple App Store "[' for' ' Android' ' is' ' now' ' available' ' for' ' download' ' on'
+ ' Google' ' Play' '.' '\n' '\n' 'The' ' app' ' is' ' free' ' to'
+ ' download' ' and']" " for Android is now available for download on Google Play .
+
+ The app is free to download and" False the top of the App Store paid apps chart, 5 [' the', ' top', ' of', ' the', ' App', ' Store']
+984 507 The company which developed x -1 The company which developed App Store Apple App Store "[' for' ' Android' ' is' ' now' ' available' ' for' ' download' ' on'
+ ' Google' ' Play' '.' '\n' '\n' 'The' ' app' ' is' ' free' ' to'
+ ' download' ' and']" " for Android is now available for download on Google Play .
+
+ The app is free to download and" False listed in Apple's App Store lists. An Android version 5 "[' listed', ' in', ' Apple', ""'s"", ' App', ' Store']"
+985 507 The company which developed x -1 The company which developed App Store Apple App Store "[' for' ' Android' ' is' ' now' ' available' ' for' ' download' ' on'
+ ' Google' ' Play' '.' '\n' '\n' 'The' ' app' ' is' ' free' ' to'
+ ' download' ' and']" " for Android is now available for download on Google Play .
+
+ The app is free to download and" False to the top of the App Store paid apps chart, 6 [' to', ' the', ' top', ' of', ' the', ' App', ' Store']
+986 508 The company which developed x -1 The company which developed Blogger Google Blogger "[' is' ' a' ' great' ' tool' ' for' ' bloggers' ' to' ' share' ' their'
+ ' thoughts' ' and' ' ideas' ' with' ' the' ' world' '.' '\n' '\n' 'I'
+ ' am']" " is a great tool for bloggers to share their thoughts and ideas with the world .
+
+ I am" False personality. Blogger and Blogcritics 3 [' personality', '.', ' Blog', 'ger']
+987 508 The company which developed x -1 The company which developed Blogger Google Blogger "[' is' ' a' ' great' ' tool' ' for' ' bloggers' ' to' ' share' ' their'
+ ' thoughts' ' and' ' ideas' ' with' ' the' ' world' '.' '\n' '\n' 'I'
+ ' am']" " is a great tool for bloggers to share their thoughts and ideas with the world .
+
+ I am" False " [...] all the more"". Blogger Molly Lambert," 6 "[' [...]', ' all', ' the', ' more', '"".', ' Blog', 'ger']"
+988 508 The company which developed x -1 The company which developed Blogger Google Blogger "[' is' ' a' ' great' ' tool' ' for' ' bloggers' ' to' ' share' ' their'
+ ' thoughts' ' and' ' ideas' ' with' ' the' ' world' '.' '\n' '\n' 'I'
+ ' am']" " is a great tool for bloggers to share their thoughts and ideas with the world .
+
+ I am" False Fiercest Climate Blogger Takes on the 5 [' F', 'ier', 'cest', ' Climate', ' Blog', 'ger']
+989 508 The company which developed x -1 The company which developed Blogger Google Blogger "[' is' ' a' ' great' ' tool' ' for' ' bloggers' ' to' ' share' ' their'
+ ' thoughts' ' and' ' ideas' ' with' ' the' ' world' '.' '\n' '\n' 'I'
+ ' am']" " is a great tool for bloggers to share their thoughts and ideas with the world .
+
+ I am" False bright personality. Blogger and Blogcritics 4 [' bright', ' personality', '.', ' Blog', 'ger']
+990 508 The company which developed x -1 The company which developed Blogger Google Blogger "[' is' ' a' ' great' ' tool' ' for' ' bloggers' ' to' ' share' ' their'
+ ' thoughts' ' and' ' ideas' ' with' ' the' ' world' '.' '\n' '\n' 'I'
+ ' am']" " is a great tool for bloggers to share their thoughts and ideas with the world .
+
+ I am" False Tomorrow.sg, a Star Blogger for the STOMP 7 [' Tomorrow', '.', 'sg', ',', ' a', ' Star', ' Blog', 'ger']
+991 514 The company which developed x -1 The company which developed Internet Explorer Mobile Microsoft Internet Explorer Mobile "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Internet' ' Explorer' ' Mobile' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " is Microsoft .
+
+ The company which developed Internet Explorer Mobile is Microsoft .
+
+ The company which" True " soft keys. "" Internet Explorer Mobile 6 has also received" 6 "[' soft', ' keys', '.', ' ""', ' Internet', ' Explorer', ' Mobile']"
+992 514 The company which developed x -1 The company which developed Internet Explorer Mobile Microsoft Internet Explorer Mobile "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Internet' ' Explorer' ' Mobile' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " is Microsoft .
+
+ The company which developed Internet Explorer Mobile is Microsoft .
+
+ The company which" True were added to Internet Explorer Mobile along with 5 [' were', ' added', ' to', ' Internet', ' Explorer', ' Mobile']
+993 514 The company which developed x -1 The company which developed Internet Explorer Mobile Microsoft Internet Explorer Mobile "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Internet' ' Explorer' ' Mobile' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " is Microsoft .
+
+ The company which developed Internet Explorer Mobile is Microsoft .
+
+ The company which" True " replaced soft keys. "" Internet Explorer Mobile 6 has also received" 7 "[' replaced', ' soft', ' keys', '.', ' ""', ' Internet', ' Explorer', ' Mobile']"
+994 514 The company which developed x -1 The company which developed Internet Explorer Mobile Microsoft Internet Explorer Mobile "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Internet' ' Explorer' ' Mobile' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " is Microsoft .
+
+ The company which developed Internet Explorer Mobile is Microsoft .
+
+ The company which" True " soft keys. "" Internet Explorer Mobile 6 has also received" 6 "[' soft', ' keys', '.', ' ""', ' Internet', ' Explorer', ' Mobile']"
+995 514 The company which developed x -1 The company which developed Internet Explorer Mobile Microsoft Internet Explorer Mobile "[' is' ' Microsoft' '.' '\n' '\n' 'The' ' company' ' which' ' developed'
+ ' Internet' ' Explorer' ' Mobile' ' is' ' Microsoft' '.' '\n' '\n' 'The'
+ ' company' ' which']" " is Microsoft .
+
+ The company which developed Internet Explorer Mobile is Microsoft .
+
+ The company which" True support were added to Internet Explorer Mobile along with improved 6 [' support', ' were', ' added', ' to', ' Internet', ' Explorer', ' Mobile']
+996 515 The company which developed x -1 The company which developed V8 Google V8 "[' is' ' now' ' working' ' on' ' a' ' new' ' version' ' of' ' the'
+ ' engine' ',' ' called' ' V' '8' '.' '3' '.' '0' '.' ' The']" is now working on a new version of the engine , called V 8 . 3 . 0 . The False dealerships, looking for a V8 engine that might 7 [' dealers', 'hips', ',', ' looking', ' for', ' a', ' V', '8']
+997 515 The company which developed x -1 The company which developed V8 Google V8 "[' is' ' now' ' working' ' on' ' a' ' new' ' version' ' of' ' the'
+ ' engine' ',' ' called' ' V' '8' '.' '3' '.' '0' '.' ' The']" is now working on a new version of the engine , called V 8 . 3 . 0 . The False Hermética and V8 songs in his 5 [' Herm', 'ét', 'ica', ' and', ' V', '8']
+998 515 The company which developed x -1 The company which developed V8 Google V8 "[' is' ' now' ' working' ' on' ' a' ' new' ' version' ' of' ' the'
+ ' engine' ',' ' called' ' V' '8' '.' '3' '.' '0' '.' ' The']" is now working on a new version of the engine , called V 8 . 3 . 0 . The False centered around the V8 engine and the 4 [' centered', ' around', ' the', ' V', '8']
+999 515 The company which developed x -1 The company which developed V8 Google V8 "[' is' ' now' ' working' ' on' ' a' ' new' ' version' ' of' ' the'
+ ' engine' ',' ' called' ' V' '8' '.' '3' '.' '0' '.' ' The']" is now working on a new version of the engine , called V 8 . 3 . 0 . The False powerful 6.0-litre V8 — have been 8 [' powerful', ' 6', '.', '0', '-', 'lit', 're', ' V', '8']
+1000 515 The company which developed x -1 The company which developed V8 Google V8 "[' is' ' now' ' working' ' on' ' a' ' new' ' version' ' of' ' the'
+ ' engine' ',' ' called' ' V' '8' '.' '3' '.' '0' '.' ' The']" is now working on a new version of the engine , called V 8 . 3 . 0 . The False including the Aston Martin V8 Vantage, during 5 [' including', ' the', ' Aston', ' Martin', ' V', '8']
+1001 517 The company which developed x -1 The company which developed Windows Embedded Microsoft Windows Embedded "[' CE' ' 6' '.' '0' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' embedded' ' operating'
+ ' system' ',']" CE 6 . 0 has announced that it will be releasing a new version of its embedded operating system , False Microsoft announced Windows Embedded Handheld 6.5. The 4 [' Microsoft', ' announced', ' Windows', ' Emb', 'edded']
+1002 517 The company which developed x -1 The company which developed Windows Embedded Microsoft Windows Embedded "[' CE' ' 6' '.' '0' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' embedded' ' operating'
+ ' system' ',']" CE 6 . 0 has announced that it will be releasing a new version of its embedded operating system , False announced Windows Embedded Handheld 6.5. The 3 [' announced', ' Windows', ' Emb', 'edded']
+1003 517 The company which developed x -1 The company which developed Windows Embedded Microsoft Windows Embedded "[' CE' ' 6' '.' '0' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' embedded' ' operating'
+ ' system' ',']" CE 6 . 0 has announced that it will be releasing a new version of its embedded operating system , False Handheld 6.5 6 [' Hand', 'held', ' 6', '.', 'Windows', ' Emb', 'edded']
+1004 517 The company which developed x -1 The company which developed Windows Embedded Microsoft Windows Embedded "[' CE' ' 6' '.' '0' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' embedded' ' operating'
+ ' system' ',']" CE 6 . 0 has announced that it will be releasing a new version of its embedded operating system , False Microsoft announced Windows Embedded Handheld 6.5. The operating 4 [' Microsoft', ' announced', ' Windows', ' Emb', 'edded']
+1005 517 The company which developed x -1 The company which developed Windows Embedded Microsoft Windows Embedded "[' CE' ' 6' '.' '0' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' embedded' ' operating'
+ ' system' ',']" CE 6 . 0 has announced that it will be releasing a new version of its embedded operating system , False Microsoft announced Windows Embedded Handheld 6.5. The operating 4 [' Microsoft', ' announced', ' Windows', ' Emb', 'edded']
+1006 518 The company which developed x -1 The company which developed Portable Document Format Adobe Portable Document Format "[' (' 'PDF' ')' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' PDF' ' reader' ','
+ ' PDF' '-']" ( PDF ) has announced that it will be releasing a new version of its PDF reader , PDF - False Series Two scripts in Portable Document Format (PDF) and a PDF 6 [' Series', ' Two', ' scripts', ' in', ' Portable', ' Document', ' Format']
+1007 518 The company which developed x -1 The company which developed Portable Document Format Adobe Portable Document Format "[' (' 'PDF' ')' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' PDF' ' reader' ','
+ ' PDF' '-']" ( PDF ) has announced that it will be releasing a new version of its PDF reader , PDF - False Format (GIF) and Portable Document Format (PDF) for older 8 [' Format', ' (', 'G', 'IF', ')', ' and', ' Portable', ' Document', ' Format']
+1008 518 The company which developed x -1 The company which developed Portable Document Format Adobe Portable Document Format "[' (' 'PDF' ')' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' PDF' ' reader' ','
+ ' PDF' '-']" ( PDF ) has announced that it will be releasing a new version of its PDF reader , PDF - False (GIF) and Portable Document Format (PDF) for older 7 [' (', 'G', 'IF', ')', ' and', ' Portable', ' Document', ' Format']
+1009 518 The company which developed x -1 The company which developed Portable Document Format Adobe Portable Document Format "[' (' 'PDF' ')' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' PDF' ' reader' ','
+ ' PDF' '-']" ( PDF ) has announced that it will be releasing a new version of its PDF reader , PDF - False Format (GIF) and Portable Document Format (PDF) for older articles. 8 [' Format', ' (', 'G', 'IF', ')', ' and', ' Portable', ' Document', ' Format']
+1010 518 The company which developed x -1 The company which developed Portable Document Format Adobe Portable Document Format "[' (' 'PDF' ')' ' has' ' announced' ' that' ' it' ' will' ' be'
+ ' releasing' ' a' ' new' ' version' ' of' ' its' ' PDF' ' reader' ','
+ ' PDF' '-']" ( PDF ) has announced that it will be releasing a new version of its PDF reader , PDF - False Series Two scripts in Portable Document Format (PDF) and 6 [' Series', ' Two', ' scripts', ' in', ' Portable', ' Document', ' Format']
+1011 521 The company which developed x -1 The company which developed Siri Apple Siri "[',' ' the' ' voice' '-' 'activated' ' personal' ' assistant' ',' ' has'
+ ' been' ' acquired' ' by' ' Apple' ' for' ' $' '3' '.' '2' ' billion' '.']" , the voice - activated personal assistant , has been acquired by Apple for $ 3 . 2 billion . True 1 ['S', 'iri']
+1012 521 The company which developed x -1 The company which developed Siri Apple Siri "[',' ' the' ' voice' '-' 'activated' ' personal' ' assistant' ',' ' has'
+ ' been' ' acquired' ' by' ' Apple' ' for' ' $' '3' '.' '2' ' billion' '.']" , the voice - activated personal assistant , has been acquired by Apple for $ 3 . 2 billion . True generations, eulogizing new Siri features, Control 7 [' generations', ',', ' e', 'ul', 'og', 'izing', ' new', ' Siri']
+1013 521 The company which developed x -1 The company which developed Siri Apple Siri "[',' ' the' ' voice' '-' 'activated' ' personal' ' assistant' ',' ' has'
+ ' been' ' acquired' ' by' ' Apple' ' for' ' $' '3' '.' '2' ' billion' '.']" , the voice - activated personal assistant , has been acquired by Apple for $ 3 . 2 billion . True spin-off company, Siri Inc., which 5 [' spin', '-', 'off', ' company', ',', ' Siri']
+1014 521 The company which developed x -1 The company which developed Siri Apple Siri "[',' ' the' ' voice' '-' 'activated' ' personal' ' assistant' ',' ' has'
+ ' been' ' acquired' ' by' ' Apple' ' for' ' $' '3' '.' '2' ' billion' '.']" , the voice - activated personal assistant , has been acquired by Apple for $ 3 . 2 billion . True Rai Pithora and Siri Fort. Neither 5 [' Rai', ' P', 'ith', 'ora', ' and', ' Siri']
+1015 521 The company which developed x -1 The company which developed Siri Apple Siri "[',' ' the' ' voice' '-' 'activated' ' personal' ' assistant' ',' ' has'
+ ' been' ' acquired' ' by' ' Apple' ' for' ' $' '3' '.' '2' ' billion' '.']" , the voice - activated personal assistant , has been acquired by Apple for $ 3 . 2 billion . True Krishna District, and Siri Media obtained 4 [' Krishna', ' District', ',', ' and', ' Siri']
diff --git a/patchscopes/code/preprocessed_data/factual/star_constellation.tsv b/patchscopes/code/preprocessed_data/factual/star_constellation.tsv
new file mode 100644
index 00000000..8d540bb8
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/star_constellation.tsv
@@ -0,0 +1,542 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The name of the constellation including x -1 The name of the constellation including Sirius Canis Major Sirius [',', ' the', ' brightest', ' star', ' in', ' the', ' sky', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n', '\n', 'The', ' name'] ", the brightest star in the sky , is the brightest star in the sky .
+
+ The name" False Football Network and on Sirius XM Satellite 4 [' Football', ' Network', ' and', ' on', ' Sirius']
+1 0 The name of the constellation including x -1 The name of the constellation including Sirius Canis Major Sirius [',', ' the', ' brightest', ' star', ' in', ' the', ' sky', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n', '\n', 'The', ' name'] ", the brightest star in the sky , is the brightest star in the sky .
+
+ The name" False Harry has a vision of Sirius being tortured 5 [' Harry', ' has', ' a', ' vision', ' of', ' Sirius']
+2 0 The name of the constellation including x -1 The name of the constellation including Sirius Canis Major Sirius [',', ' the', ' brightest', ' star', ' in', ' the', ' sky', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n', '\n', 'The', ' name'] ", the brightest star in the sky , is the brightest star in the sky .
+
+ The name" False Gary Oldman as Sirius Black, Harry's late 4 [' Gary', ' Old', 'man', ' as', ' Sirius']
+3 0 The name of the constellation including x -1 The name of the constellation including Sirius Canis Major Sirius [',', ' the', ' brightest', ' star', ' in', ' the', ' sky', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n', '\n', 'The', ' name'] ", the brightest star in the sky , is the brightest star in the sky .
+
+ The name" False targeted by Sirius Black, his father's 2 [' targeted', ' by', ' Sirius']
+4 0 The name of the constellation including x -1 The name of the constellation including Sirius Canis Major Sirius [',', ' the', ' brightest', ' star', ' in', ' the', ' sky', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n', '\n', 'The', ' name'] ", the brightest star in the sky , is the brightest star in the sky .
+
+ The name" False LGBT-geared station on Sirius XM. Bass is also 6 [' LGBT', '-', 'ge', 'ared', ' station', ' on', ' Sirius']
+5 1 The name of the constellation including x -1 The name of the constellation including Vega Lyra Vega [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Ly', 'ra', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Ly'] ", the brightest star in the constellation of Ly ra .
+
+ The name of the constellation of Ly" True 1 ['V', 'ega']
+6 1 The name of the constellation including x -1 The name of the constellation including Vega Lyra Vega [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Ly', 'ra', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Ly'] ", the brightest star in the constellation of Ly ra .
+
+ The name of the constellation of Ly" True " Compania"" led by Savio Vega (Juan Rivera)." 7 "[' Comp', 'ania', '""', ' led', ' by', ' Sav', 'io', ' Vega']"
+7 1 The name of the constellation including x -1 The name of the constellation including Vega Lyra Vega [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Ly', 'ra', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Ly'] ", the brightest star in the constellation of Ly ra .
+
+ The name of the constellation of Ly" True radiation pressure from Vega pushes debris 3 [' radiation', ' pressure', ' from', ' Vega']
+8 1 The name of the constellation including x -1 The name of the constellation including Vega Lyra Vega [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Ly', 'ra', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Ly'] ", the brightest star in the constellation of Ly ra .
+
+ The name of the constellation of Ly" True artist Suzanne Vega at her Bucharest 2 [' artist', ' Suzanne', ' Vega']
+9 1 The name of the constellation including x -1 The name of the constellation including Vega Lyra Vega [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Ly', 'ra', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Ly'] ", the brightest star in the constellation of Ly ra .
+
+ The name of the constellation of Ly" True atmosphere. The Soviet Vega program in 4 [' atmosphere', '.', ' The', ' Soviet', ' Vega']
+10 2 The name of the constellation including x -1 The name of the constellation including Arcturus Boötes Arcturus [',', ' the', ' Bear', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' northern', ' sky', '.', '\n', '\n', 'The', ' constellation', ' of', ' the', ' Bear'] ", the Bear , is the brightest star in the northern sky .
+
+ The constellation of the Bear" False locket is Regulus Arcturus Black, the 7 [' l', 'ocket', ' is', ' Reg', 'ulus', ' Ar', 'ct', 'urus']
+11 2 The name of the constellation including x -1 The name of the constellation including Arcturus Boötes Arcturus [',', ' the', ' Bear', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' northern', ' sky', '.', '\n', '\n', 'The', ' constellation', ' of', ' the', ' Bear'] ", the Bear , is the brightest star in the northern sky .
+
+ The constellation of the Bear" False voices of Jim Raynor and Arcturus Mengsk. Notable 8 [' voices', ' of', ' Jim', ' Ray', 'nor', ' and', ' Ar', 'ct', 'urus']
+12 2 The name of the constellation including x -1 The name of the constellation including Arcturus Boötes Arcturus [',', ' the', ' Bear', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' northern', ' sky', '.', '\n', '\n', 'The', ' constellation', ' of', ' the', ' Bear'] ", the Bear , is the brightest star in the northern sky .
+
+ The constellation of the Bear" False the player defeats Arcturus Mengsk's armies, 5 [' the', ' player', ' defeats', ' Ar', 'ct', 'urus']
+13 2 The name of the constellation including x -1 The name of the constellation including Arcturus Boötes Arcturus [',', ' the', ' Bear', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' northern', ' sky', '.', '\n', '\n', 'The', ' constellation', ' of', ' the', ' Bear'] ", the Bear , is the brightest star in the northern sky .
+
+ The constellation of the Bear" False raises the suspicions of Arcturus Mengsk, who 6 [' raises', ' the', ' suspicions', ' of', ' Ar', 'ct', 'urus']
+14 2 The name of the constellation including x -1 The name of the constellation including Arcturus Boötes Arcturus [',', ' the', ' Bear', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' northern', ' sky', '.', '\n', '\n', 'The', ' constellation', ' of', ' the', ' Bear'] ", the Bear , is the brightest star in the northern sky .
+
+ The constellation of the Bear" False forces pursuing Arcturus Mengsk, upon 4 [' forces', ' pursuing', ' Ar', 'ct', 'urus']
+15 3 The name of the constellation including x -1 The name of the constellation including Capella Auriga Capella [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Vir', 'go', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Vir'] ", the brightest star in the constellation of Vir go .
+
+ The name of the constellation of Vir" False " the name of the star Capella (Latin for ""nanny" 6 [' the', ' name', ' of', ' the', ' star', ' Cape', 'lla']
+16 3 The name of the constellation including x -1 The name of the constellation including Capella Auriga Capella [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Vir', 'go', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Vir'] ", the brightest star in the constellation of Vir go .
+
+ The name of the constellation of Vir" False complete with a few A Capella chants and Italian 6 [' complete', ' with', ' a', ' few', ' A', ' Cape', 'lla']
+17 3 The name of the constellation including x -1 The name of the constellation including Capella Auriga Capella [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Vir', 'go', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Vir'] ", the brightest star in the constellation of Vir go .
+
+ The name of the constellation of Vir" False 2 ['C', 'ape', 'lla']
+18 3 The name of the constellation including x -1 The name of the constellation including Capella Auriga Capella [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Vir', 'go', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Vir'] ", the brightest star in the constellation of Vir go .
+
+ The name of the constellation of Vir" False complete with a few A Capella chants and Italian 6 [' complete', ' with', ' a', ' few', ' A', ' Cape', 'lla']
+19 3 The name of the constellation including x -1 The name of the constellation including Capella Auriga Capella [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Vir', 'go', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Vir'] ", the brightest star in the constellation of Vir go .
+
+ The name of the constellation of Vir" False " of the star Capella (Latin for ""nanny" 4 [' of', ' the', ' star', ' Cape', 'lla']
+20 4 The name of the constellation including x -1 The name of the constellation including Procyon Canis Minor Procyon [',', ' the', ' Dog', ' Star', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Can', 'is', ' Major', '.', '\n', '\n', 'The'] ", the Dog Star , is the brightest star in the constellation of Can is Major .
+
+ The" False constellation, but they saw Procyon as significant 7 [' constellation', ',', ' but', ' they', ' saw', ' Pro', 'cy', 'on']
+21 4 The name of the constellation including x -1 The name of the constellation including Procyon Canis Minor Procyon [',', ' the', ' Dog', ' Star', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Can', 'is', ' Major', '.', '\n', '\n', 'The'] ", the Dog Star , is the brightest star in the constellation of Can is Major .
+
+ The" False Arabic names for both Procyon and Gomeisa alluded 6 [' Arabic', ' names', ' for', ' both', ' Pro', 'cy', 'on']
+22 4 The name of the constellation including x -1 The name of the constellation including Procyon Canis Minor Procyon [',', ' the', ' Dog', ' Star', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Can', 'is', ' Major', '.', '\n', '\n', 'The'] ", the Dog Star , is the brightest star in the constellation of Can is Major .
+
+ The" False southern wingtip and Procyon the northern wingtip, 6 [' southern', ' wing', 'tip', ' and', ' Pro', 'cy', 'on']
+23 4 The name of the constellation including x -1 The name of the constellation including Procyon Canis Minor Procyon [',', ' the', ' Dog', ' Star', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Can', 'is', ' Major', '.', '\n', '\n', 'The'] ", the Dog Star , is the brightest star in the constellation of Can is Major .
+
+ The" False the family Procyonidae lived in 4 [' the', ' family', ' Pro', 'cy', 'on']
+24 4 The name of the constellation including x -1 The name of the constellation including Procyon Canis Minor Procyon [',', ' the', ' Dog', ' Star', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Can', 'is', ' Major', '.', '\n', '\n', 'The'] ", the Dog Star , is the brightest star in the constellation of Can is Major .
+
+ The" False Planet: Battle at Procyon for the PC in October 6 [' Planet', ':', ' Battle', ' at', ' Pro', 'cy', 'on']
+25 5 The name of the constellation including x -1 The name of the constellation including Altair Aquila Altair [',', ' Vega', ',', ' and', ' Vega', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' including', ' Alt', 'air', ',', ' Vega', ',', ' and'] ", Vega , and Vega .
+
+ The name of the constellation including Alt air , Vega , and" False Bond uses — The Altair — was the 5 [' Bond', ' uses', ' —', ' The', ' Alt', 'air']
+26 5 The name of the constellation including x -1 The name of the constellation including Altair Aquila Altair [',', ' Vega', ',', ' and', ' Vega', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' including', ' Alt', 'air', ',', ' Vega', ',', ' and'] ", Vega , and Vega .
+
+ The name of the constellation including Alt air , Vega , and" False 1 ['Alt', 'air']
+27 5 The name of the constellation including x -1 The name of the constellation including Altair Aquila Altair [',', ' Vega', ',', ' and', ' Vega', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' including', ' Alt', 'air', ',', ' Vega', ',', ' and'] ", Vega , and Vega .
+
+ The name of the constellation including Alt air , Vega , and" False November 1975, the Altair 680. The machines 5 [' November', ' 1975', ',', ' the', ' Alt', 'air']
+28 5 The name of the constellation including x -1 The name of the constellation including Altair Aquila Altair [',', ' Vega', ',', ' and', ' Vega', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' including', ' Alt', 'air', ',', ' Vega', ',', ' and'] ", Vega , and Vega .
+
+ The name of the constellation including Alt air , Vega , and" False uses — The Altair — was the name 4 [' uses', ' —', ' The', ' Alt', 'air']
+29 5 The name of the constellation including x -1 The name of the constellation including Altair Aquila Altair [',', ' Vega', ',', ' and', ' Vega', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' including', ' Alt', 'air', ',', ' Vega', ',', ' and'] ", Vega , and Vega .
+
+ The name of the constellation including Alt air , Vega , and" False (and only) World Altair Computer Convention 6 [' (', 'and', ' only', ')', ' World', ' Alt', 'air']
+30 6 The name of the constellation including x -1 The name of the constellation including Aldebaran Taurus Aldebaran [',', ' the', ' brightest', ' star', ' in', ' the', ' T', 'aurus', ' constellation', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from'] ", the brightest star in the T aurus constellation .
+
+ The name of the constellation is derived from" True constellation is Aldebaran, an orange-hued, 4 [' constellation', ' is', ' Ald', 'eb', 'aran']
+31 6 The name of the constellation including x -1 The name of the constellation including Aldebaran Taurus Aldebaran [',', ' the', ' brightest', ' star', ' in', ' the', ' T', 'aurus', ' constellation', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from'] ", the brightest star in the T aurus constellation .
+
+ The name of the constellation is derived from" True polar bear. Aldebaran represents the bear, 5 [' polar', ' bear', '.', ' Ald', 'eb', 'aran']
+32 6 The name of the constellation including x -1 The name of the constellation including Aldebaran Taurus Aldebaran [',', ' the', ' brightest', ' star', ' in', ' the', ' T', 'aurus', ' constellation', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from'] ", the brightest star in the T aurus constellation .
+
+ The name of the constellation is derived from" True much debated; Aldebaran, Betelgeuse and the 5 [' much', ' debated', ';', ' Ald', 'eb', 'aran']
+33 6 The name of the constellation including x -1 The name of the constellation including Aldebaran Taurus Aldebaran [',', ' the', ' brightest', ' star', ' in', ' the', ' T', 'aurus', ' constellation', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from'] ", the brightest star in the T aurus constellation .
+
+ The name of the constellation is derived from" True the polar bear. Aldebaran represents the bear, 6 [' the', ' polar', ' bear', '.', ' Ald', 'eb', 'aran']
+34 6 The name of the constellation including x -1 The name of the constellation including Aldebaran Taurus Aldebaran [',', ' the', ' brightest', ' star', ' in', ' the', ' T', 'aurus', ' constellation', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from'] ", the brightest star in the T aurus constellation .
+
+ The name of the constellation is derived from" True constellation is Aldebaran, an orange-hued, spectral 4 [' constellation', ' is', ' Ald', 'eb', 'aran']
+35 7 The name of the constellation including x -1 The name of the constellation including Antares Scorpius Antares [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Scorp', 'ius', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Scorp'] ", the brightest star in the constellation of Scorp ius .
+
+ The name of the constellation of Scorp" True Madonna preferred the Antares Auto-Tune plug 4 [' Madonna', ' preferred', ' the', ' Ant', 'ares']
+36 7 The name of the constellation including x -1 The name of the constellation including Antares Scorpius Antares [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Scorp', 'ius', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Scorp'] ", the brightest star in the constellation of Scorp ius .
+
+ The name of the constellation of Scorp" True " 26°east and Antares 44°southeast.
+" 5 [' 26', '°', 'east', ' and', ' Ant', 'ares']
+37 7 The name of the constellation including x -1 The name of the constellation including Antares Scorpius Antares [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Scorp', 'ius', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Scorp'] ", the brightest star in the constellation of Scorp ius .
+
+ The name of the constellation of Scorp" True Madonna preferred the Antares Auto-Tune plug in, 4 [' Madonna', ' preferred', ' the', ' Ant', 'ares']
+38 7 The name of the constellation including x -1 The name of the constellation including Antares Scorpius Antares [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Scorp', 'ius', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Scorp'] ", the brightest star in the constellation of Scorp ius .
+
+ The name of the constellation of Scorp" True mangling. He used the Antares Auto-Tune plug-in 7 [' man', 'gling', '.', ' He', ' used', ' the', ' Ant', 'ares']
+39 7 The name of the constellation including x -1 The name of the constellation including Antares Scorpius Antares [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Scorp', 'ius', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Scorp'] ", the brightest star in the constellation of Scorp ius .
+
+ The name of the constellation of Scorp" True " 26°east and Antares 44°southeast.
+" 5 [' 26', '°', 'east', ' and', ' Ant', 'ares']
+40 8 The name of the constellation including x -1 The name of the constellation including Fomalhaut Piscis Austrinus Fomalhaut [',', ' which', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' the', ' fal', 'con', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", which is the brightest star in the constellation of the fal con .
+
+ The name of the" False optical images of Fomalhaut b could be due to 6 [' optical', ' images', ' of', ' F', 'omal', 'h', 'aut']
+41 8 The name of the constellation including x -1 The name of the constellation including Fomalhaut Piscis Austrinus Fomalhaut [',', ' which', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' the', ' fal', 'con', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", which is the brightest star in the constellation of the fal con .
+
+ The name of the" False " debris disks around Fomalhaut and AU Microscopii.
+" 6 [' debris', ' disks', ' around', ' F', 'omal', 'h', 'aut']
+42 8 The name of the constellation including x -1 The name of the constellation including Fomalhaut Piscis Austrinus Fomalhaut [',', ' which', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' the', ' fal', 'con', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", which is the brightest star in the constellation of the fal con .
+
+ The name of the" False Along with Vega, Fomalhaut and Epsilon Eridani, 7 [' Along', ' with', ' Vega', ',', ' F', 'omal', 'h', 'aut']
+43 8 The name of the constellation including x -1 The name of the constellation including Fomalhaut Piscis Austrinus Fomalhaut [',', ' which', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' the', ' fal', 'con', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", which is the brightest star in the constellation of the fal con .
+
+ The name of the" False " disks around Fomalhaut and AU Microscopii.
+" 5 [' disks', ' around', ' F', 'omal', 'h', 'aut']
+44 8 The name of the constellation including x -1 The name of the constellation including Fomalhaut Piscis Austrinus Fomalhaut [',', ' which', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' the', ' fal', 'con', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", which is the brightest star in the constellation of the fal con .
+
+ The name of the" False of Pegasus through Fomalhaut will lead to 6 [' of', ' Pegasus', ' through', ' F', 'omal', 'h', 'aut']
+45 9 The name of the constellation including x -1 The name of the constellation including Deneb Cygnus Deneb [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Cy', 'gn', 'us', ',', ' the', ' Swan', '.', '\n', '\n', 'The', ' name', ' of'] ", the brightest star in the constellation of Cy gn us , the Swan .
+
+ The name of" True way to planet Deneb 5 when the 5 [' way', ' to', ' planet', ' D', 'ene', 'b']
+46 9 The name of the constellation including x -1 The name of the constellation including Deneb Cygnus Deneb [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Cy', 'gn', 'us', ',', ' the', ' Swan', '.', '\n', '\n', 'The', ' name', ' of'] ", the brightest star in the constellation of Cy gn us , the Swan .
+
+ The name of" True will fall behind Deneb and Mimosa as 20th 5 [' will', ' fall', ' behind', ' D', 'ene', 'b']
+47 9 The name of the constellation including x -1 The name of the constellation including Deneb Cygnus Deneb [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Cy', 'gn', 'us', ',', ' the', ' Swan', '.', '\n', '\n', 'The', ' name', ' of'] ", the brightest star in the constellation of Cy gn us , the Swan .
+
+ The name of" True " Vasudan systems of Deneb and Alpha Centauri.
+" 7 [' Vas', 'ud', 'an', ' systems', ' of', ' D', 'ene', 'b']
+48 9 The name of the constellation including x -1 The name of the constellation including Deneb Cygnus Deneb [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Cy', 'gn', 'us', ',', ' the', ' Swan', '.', '\n', '\n', 'The', ' name', ' of'] ", the brightest star in the constellation of Cy gn us , the Swan .
+
+ The name of" True first magnitude star Deneb in Cygnus. This 5 [' first', ' magnitude', ' star', ' D', 'ene', 'b']
+49 9 The name of the constellation including x -1 The name of the constellation including Deneb Cygnus Deneb [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Cy', 'gn', 'us', ',', ' the', ' Swan', '.', '\n', '\n', 'The', ' name', ' of'] ", the brightest star in the constellation of Cy gn us , the Swan .
+
+ The name of" True Betelgeuse will fall behind Deneb and Mimosa as 20th 9 [' Bet', 'el', 'ge', 'use', ' will', ' fall', ' behind', ' D', 'ene', 'b']
+50 10 The name of the constellation including x -1 The name of the constellation including Castor Gemini Castor [' and', ' Poll', 'ux', ',', ' the', ' Gemini', ' twins', ',', ' is', ' a', ' reference', ' to', ' the', ' constellation', ' of', ' Gemini', ',', ' the', ' twins', '.'] and Poll ux , the Gemini twins , is a reference to the constellation of Gemini , the twins . True " Loose"", The Jimmy Castor Bunch's ""It's" 6 "[' Lo', 'ose', '"",', ' The', ' Jimmy', ' Cast', 'or']"
+51 10 The name of the constellation including x -1 The name of the constellation including Castor Gemini Castor [' and', ' Poll', 'ux', ',', ' the', ' Gemini', ' twins', ',', ' is', ' a', ' reference', ' to', ' the', ' constellation', ' of', ' Gemini', ',', ' the', ' twins', '.'] and Poll ux , the Gemini twins , is a reference to the constellation of Gemini , the twins . True until 8: 05 when Castor spotted smoke bearing 6 [' until', ' 8', ':', ' 05', ' when', ' Cast', 'or']
+52 10 The name of the constellation including x -1 The name of the constellation including Castor Gemini Castor [' and', ' Poll', 'ux', ',', ' the', ' Gemini', ' twins', ',', ' is', ' a', ' reference', ' to', ' the', ' constellation', ' of', ' Gemini', ',', ' the', ' twins', '.'] and Poll ux , the Gemini twins , is a reference to the constellation of Gemini , the twins . True East Coast. Renamed Castor on 15 June 1869, 6 [' East', ' Coast', '.', ' Ren', 'amed', ' Cast', 'or']
+53 10 The name of the constellation including x -1 The name of the constellation including Castor Gemini Castor [' and', ' Poll', 'ux', ',', ' the', ' Gemini', ' twins', ',', ' is', ' a', ' reference', ' to', ' the', ' constellation', ' of', ' Gemini', ',', ' the', ' twins', '.'] and Poll ux , the Gemini twins , is a reference to the constellation of Gemini , the twins . True which host Nash Castor and commentator Adriatica 4 [' which', ' host', ' Nash', ' Cast', 'or']
+54 10 The name of the constellation including x -1 The name of the constellation including Castor Gemini Castor [' and', ' Poll', 'ux', ',', ' the', ' Gemini', ' twins', ',', ' is', ' a', ' reference', ' to', ' the', ' constellation', ' of', ' Gemini', ',', ' the', ' twins', '.'] and Poll ux , the Gemini twins , is a reference to the constellation of Gemini , the twins . True been shards of Castor ware. Fragments 4 [' been', ' shards', ' of', ' Cast', 'or']
+55 11 The name of the constellation including x -1 The name of the constellation including Mizar Ursa Major Mizar [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the star of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False Septentrione — Mizar (ミザール, Mizaaru). 6 [' Sept', 'ent', 'r', 'ione', ' —', ' Miz', 'ar']
+56 11 The name of the constellation including x -1 The name of the constellation including Mizar Ursa Major Mizar [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the star of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False themselves reunited at Mizar's Palace and 4 [' themselves', ' reunited', ' at', ' Miz', 'ar']
+57 11 The name of the constellation including x -1 The name of the constellation including Mizar Ursa Major Mizar [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the star of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False of the star Mizar in a 104-day period. 4 [' of', ' the', ' star', ' Miz', 'ar']
+58 11 The name of the constellation including x -1 The name of the constellation including Mizar Ursa Major Mizar [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the star of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False " observes double star Mizar in Ursa Major
+" 4 [' observes', ' double', ' star', ' Miz', 'ar']
+59 11 The name of the constellation including x -1 The name of the constellation including Mizar Ursa Major Mizar [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the star of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False next Septentrione — Mizar (ミザール, Mizaaru). However, 7 [' next', ' Sept', 'ent', 'r', 'ione', ' —', ' Miz', 'ar']
+60 12 The name of the constellation including x -1 The name of the constellation including Dubhe Ursa Major Dubhe [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades'] , the Ple i ades , the Ple i ades , the Ple i ades , the Ple i ades False a creature known as Dubhe (ドゥベ, Duube) appears 5 [' a', ' creature', ' known', ' as', ' Dub', 'he']
+61 12 The name of the constellation including x -1 The name of the constellation including Dubhe Ursa Major Dubhe [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades'] , the Ple i ades , the Ple i ades , the Ple i ades , the Ple i ades False creature known as Dubhe (ドゥベ, Duube) appears 4 [' creature', ' known', ' as', ' Dub', 'he']
+62 13 The name of the constellation including x -1 The name of the constellation including Alcor Ursa Major Alcor [',', ' the', ' name', ' of', ' the', ' constellation', ',', ' and', ' the', ' name', ' of', ' the', ' star', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", the name of the constellation , and the name of the star .
+
+ The name of the" False of victory, Alcor reveals the true 4 [' of', ' victory', ',', ' Al', 'cor']
+63 13 The name of the constellation including x -1 The name of the constellation including Alcor Ursa Major Alcor [',', ' the', ' name', ' of', ' the', ' constellation', ',', ' and', ' the', ' name', ' of', ' the', ' star', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", the name of the constellation , and the name of the star .
+
+ The name of the" False cost of victory, Alcor reveals the true 5 [' cost', ' of', ' victory', ',', ' Al', 'cor']
+64 13 The name of the constellation including x -1 The name of the constellation including Alcor Ursa Major Alcor [',', ' the', ' name', ' of', ' the', ' constellation', ',', ' and', ' the', ' name', ' of', ' the', ' star', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", the name of the constellation , and the name of the star .
+
+ The name of the" False cost of victory, Alcor reveals the true 5 [' cost', ' of', ' victory', ',', ' Al', 'cor']
+65 13 The name of the constellation including x -1 The name of the constellation including Alcor Ursa Major Alcor [',', ' the', ' name', ' of', ' the', ' constellation', ',', ' and', ' the', ' name', ' of', ' the', ' star', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", the name of the constellation , and the name of the star .
+
+ The name of the" False was sold from Alcor Inc. to O & G Construction 4 [' was', ' sold', ' from', ' Al', 'cor']
+66 13 The name of the constellation including x -1 The name of the constellation including Alcor Ursa Major Alcor [',', ' the', ' name', ' of', ' the', ' constellation', ',', ' and', ' the', ' name', ' of', ' the', ' star', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", the name of the constellation , and the name of the star .
+
+ The name of the" False of victory, Alcor reveals the true 4 [' of', ' victory', ',', ' Al', 'cor']
+67 14 The name of the constellation including x -1 The name of the constellation including Alphecca Corona Borealis Alphecca [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False constellation Alphecca (a name later given 3 [' constellation', ' Al', 'phe', 'cca']
+68 14 The name of the constellation including x -1 The name of the constellation including Alphecca Corona Borealis Alphecca [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False the constellation Alphecca (a name later 4 [' the', ' constellation', ' Al', 'phe', 'cca']
+69 14 The name of the constellation including x -1 The name of the constellation including Alphecca Corona Borealis Alphecca [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False the constellation Alphecca (a name later 4 [' the', ' constellation', ' Al', 'phe', 'cca']
+70 14 The name of the constellation including x -1 The name of the constellation including Alphecca Corona Borealis Alphecca [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False to dance on earth. Alphecca signifies the youngest 7 [' to', ' dance', ' on', ' earth', '.', ' Al', 'phe', 'cca']
+71 14 The name of the constellation including x -1 The name of the constellation including Alphecca Corona Borealis Alphecca [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False to dance on earth. Alphecca signifies the youngest 7 [' to', ' dance', ' on', ' earth', '.', ' Al', 'phe', 'cca']
+72 16 The name of the constellation including x -1 The name of the constellation including Gacrux Crux Gacrux [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False bisecting Gacrux and Acrux 6 [' bis', 'ect', 'ing', ' G', 'ac', 'ru', 'x']
+73 16 The name of the constellation including x -1 The name of the constellation including Gacrux Crux Gacrux [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False imaginary line bisecting Gacrux and Acrux points southward 8 [' imaginary', ' line', ' bis', 'ect', 'ing', ' G', 'ac', 'ru', 'x']
+74 16 The name of the constellation including x -1 The name of the constellation including Gacrux Crux Gacrux [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False forming a male, and Gacrux and Delta Crucis forming 8 [' forming', ' a', ' male', ',', ' and', ' G', 'ac', 'ru', 'x']
+75 17 The name of the constellation including x -1 The name of the constellation including Alnair Grus Alnair [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' sun', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' moon', '.'] , the name of the constellation of the sun , and the name of the constellation of the moon . False Located near Alnair is NGC 7213, a face-on 4 [' Located', ' near', ' Al', 'n', 'air']
+76 17 The name of the constellation including x -1 The name of the constellation including Alnair Grus Alnair [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' sun', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' moon', '.'] , the name of the constellation of the sun , and the name of the constellation of the moon . False a triangle with Alnair and Beta, Delta Gruis 5 [' a', ' triangle', ' with', ' Al', 'n', 'air']
+77 17 The name of the constellation including x -1 The name of the constellation including Alnair Grus Alnair [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' sun', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' moon', '.'] , the name of the constellation of the sun , and the name of the constellation of the moon . False Gruis, is also known as Alnair and appears as 9 [' Gru', 'is', ',', ' is', ' also', ' known', ' as', ' Al', 'n', 'air']
+78 17 The name of the constellation including x -1 The name of the constellation including Alnair Grus Alnair [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' sun', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' moon', '.'] , the name of the constellation of the sun , and the name of the constellation of the moon . False also known as Alnair and appears 5 [' also', ' known', ' as', ' Al', 'n', 'air']
+79 17 The name of the constellation including x -1 The name of the constellation including Alnair Grus Alnair [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' sun', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' moon', '.'] , the name of the constellation of the sun , and the name of the constellation of the moon . False Forming a triangle with Alnair and Beta, Delta Gruis 7 [' Form', 'ing', ' a', ' triangle', ' with', ' Al', 'n', 'air']
+80 22 The name of the constellation including x -1 The name of the constellation including Mirfak Perseus Mirfak [',', ' the', ' name', ' of', ' the', ' star', ' that', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Orion', '.', '\n', '\n', 'The'] ", the name of the star that is the brightest star in the constellation of Orion .
+
+ The" False historical name Mirfak (Arabic for elbow) 4 [' historical', ' name', ' Mir', 'f', 'ak']
+81 23 The name of the constellation including x -1 The name of the constellation including Theta2 Orionis Orion Theta2 Orionis [',', ' the', ' second', '-', 'bright', 'est', ' star', ' in', ' the', ' constellation', ' of', ' Orion', '.', '\n', '\n', 'The', 'ta', '2', ' Orion', 'is'] ", the second - bright est star in the constellation of Orion .
+
+ The ta 2 Orion is" True next brightest star, Theta2 Orionis A.) The H II 8 [' next', ' brightest', ' star', ',', ' The', 'ta', '2', ' Orion', 'is']
+82 23 The name of the constellation including x -1 The name of the constellation including Theta2 Orionis Orion Theta2 Orionis [',', ' the', ' second', '-', 'bright', 'est', ' star', ' in', ' the', ' constellation', ' of', ' Orion', '.', '\n', '\n', 'The', 'ta', '2', ' Orion', 'is'] ", the second - bright est star in the constellation of Orion .
+
+ The ta 2 Orion is" True brightest star, Theta2 Orionis A.) The H 7 [' brightest', ' star', ',', ' The', 'ta', '2', ' Orion', 'is']
+83 23 The name of the constellation including x -1 The name of the constellation including Theta2 Orionis Orion Theta2 Orionis [',', ' the', ' second', '-', 'bright', 'est', ' star', ' in', ' the', ' constellation', ' of', ' Orion', '.', '\n', '\n', 'The', 'ta', '2', ' Orion', 'is'] ", the second - bright est star in the constellation of Orion .
+
+ The ta 2 Orion is" True next brightest star, Theta2 Orionis A.) The H II 8 [' next', ' brightest', ' star', ',', ' The', 'ta', '2', ' Orion', 'is']
+84 23 The name of the constellation including x -1 The name of the constellation including Theta2 Orionis Orion Theta2 Orionis [',', ' the', ' second', '-', 'bright', 'est', ' star', ' in', ' the', ' constellation', ' of', ' Orion', '.', '\n', '\n', 'The', 'ta', '2', ' Orion', 'is'] ", the second - bright est star in the constellation of Orion .
+
+ The ta 2 Orion is" True brightest star, Theta2 Orionis A.) The H II 7 [' brightest', ' star', ',', ' The', 'ta', '2', ' Orion', 'is']
+85 23 The name of the constellation including x -1 The name of the constellation including Theta2 Orionis Orion Theta2 Orionis [',', ' the', ' second', '-', 'bright', 'est', ' star', ' in', ' the', ' constellation', ' of', ' Orion', '.', '\n', '\n', 'The', 'ta', '2', ' Orion', 'is'] ", the second - bright est star in the constellation of Orion .
+
+ The ta 2 Orion is" True brightest star, Theta2 Orionis A.) The H 7 [' brightest', ' star', ',', ' The', 'ta', '2', ' Orion', 'is']
+86 26 The name of the constellation including x -1 The name of the constellation including Iota Persei Perseus Iota Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True magnitude 4.05, nearby Iota Persei has been considered 9 [' magnitude', ' 4', '.', '05', ',', ' nearby', ' I', 'ota', ' Perse', 'i']
+87 32 The name of the constellation including x -1 The name of the constellation including Gliese 581 Libra Gliese 581 "['g', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a']" g is a bit of a mouth ful , but it 's a pretty cool name . It 's a False was sent toward the Gliese 581 system, which 8 [' was', ' sent', ' toward', ' the', ' Gl', 'ies', 'e', ' 5', '81']
+88 32 The name of the constellation including x -1 The name of the constellation including Gliese 581 Libra Gliese 581 "['g', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a']" g is a bit of a mouth ful , but it 's a pretty cool name . It 's a False The name Gliese 581 refers to the 6 [' The', ' name', ' Gl', 'ies', 'e', ' 5', '81']
+89 32 The name of the constellation including x -1 The name of the constellation including Gliese 581 Libra Gliese 581 "['g', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a']" g is a bit of a mouth ful , but it 's a pretty cool name . It 's a False 4 ['Gl', 'ies', 'e', ' 5', '81']
+90 32 The name of the constellation including x -1 The name of the constellation including Gliese 581 Libra Gliese 581 "['g', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a']" g is a bit of a mouth ful , but it 's a pretty cool name . It 's a False average, the light that Gliese 581 d receives from 9 [' average', ',', ' the', ' light', ' that', ' Gl', 'ies', 'e', ' 5', '81']
+91 32 The name of the constellation including x -1 The name of the constellation including Gliese 581 Libra Gliese 581 "['g', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a']" g is a bit of a mouth ful , but it 's a pretty cool name . It 's a False detection of Gliese 581 g after such 6 [' detection', ' of', ' Gl', 'ies', 'e', ' 5', '81']
+92 37 The name of the constellation including x -1 The name of the constellation including Wolf 359 Leo Wolf 359 [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the name of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False 1 ['Wolf', ' 359']
+93 37 The name of the constellation including x -1 The name of the constellation including Wolf 359 Leo Wolf 359 [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the name of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False 1 ['Wolf', ' 359']
+94 37 The name of the constellation including x -1 The name of the constellation including Wolf 359 Leo Wolf 359 [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the name of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False rotational velocity of Wolf 359's equator is less 5 [' rot', 'ational', ' velocity', ' of', ' Wolf', ' 359']
+95 37 The name of the constellation including x -1 The name of the constellation including Wolf 359 Leo Wolf 359 [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the name of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False the Battle of Wolf 359 in the The Next 4 [' the', ' Battle', ' of', ' Wolf', ' 359']
+96 37 The name of the constellation including x -1 The name of the constellation including Wolf 359 Leo Wolf 359 [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Big', ' Di', 'pper', '.', '\n', '\n', 'The', ' Big', ' Di', 'pper', ' is', ' a'] ", the name of the constellation of the Big Di pper .
+
+ The Big Di pper is a" False Star (at 5.9 ly), Wolf 359 (7.8 ly), 9 [' Star', ' (', 'at', ' 5', '.', '9', ' ly', '),', ' Wolf', ' 359']
+97 38 The name of the constellation including x -1 The name of the constellation including Rigel Orion Rigel [',', ' Orion', ',', ' and', ' the', ' Ple', 'i', 'ades', '.', '\n', '\n', 'The', ' Ple', 'i', 'ades', ' is', ' a', ' cluster', ' of', ' stars'] ", Orion , and the Ple i ades .
+
+ The Ple i ades is a cluster of stars" True base of Orion, Rigel and Saiph, and another, 5 [' base', ' of', ' Orion', ',', ' Rig', 'el']
+98 38 The name of the constellation including x -1 The name of the constellation including Rigel Orion Rigel [',', ' Orion', ',', ' and', ' the', ' Ple', 'i', 'ades', '.', '\n', '\n', 'The', ' Ple', 'i', 'ades', ' is', ' a', ' cluster', ' of', ' stars'] ", Orion , and the Ple i ades .
+
+ The Ple i ades is a cluster of stars" True get the stars Rigel and Sirius aligned, 4 [' get', ' the', ' stars', ' Rig', 'el']
+99 38 The name of the constellation including x -1 The name of the constellation including Rigel Orion Rigel [',', ' Orion', ',', ' and', ' the', ' Ple', 'i', 'ades', '.', '\n', '\n', 'The', ' Ple', 'i', 'ades', ' is', ' a', ' cluster', ' of', ' stars'] ", Orion , and the Ple i ades .
+
+ The Ple i ades is a cluster of stars" True " Speaking ""Rigellian"", which coincidentally" 3 "[' Speaking', ' ""', 'R', 'igel']"
+100 38 The name of the constellation including x -1 The name of the constellation including Rigel Orion Rigel [',', ' Orion', ',', ' and', ' the', ' Ple', 'i', 'ades', '.', '\n', '\n', 'The', ' Ple', 'i', 'ades', ' is', ' a', ' cluster', ' of', ' stars'] ", Orion , and the Ple i ades .
+
+ The Ple i ades is a cluster of stars" True distant planet Rigel 7, there is 3 [' distant', ' planet', ' Rig', 'el']
+101 38 The name of the constellation including x -1 The name of the constellation including Rigel Orion Rigel [',', ' Orion', ',', ' and', ' the', ' Ple', 'i', 'ades', '.', '\n', '\n', 'The', ' Ple', 'i', 'ades', ' is', ' a', ' cluster', ' of', ' stars'] ", Orion , and the Ple i ades .
+
+ The Ple i ades is a cluster of stars" True offspring: Bellatrix, Rigel, Sirius, and, in 7 [' offspring', ':', ' Bell', 'at', 'rix', ',', ' Rig', 'el']
+102 40 The name of the constellation including x -1 The name of the constellation including Betelgeuse Orion Betelgeuse [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' Orion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' Orion', ' is', ' derived', ' from'] ", the brightest star in the constellation Orion .
+
+ The name of the constellation Orion is derived from" True vessels, the USS Betelgeuse (AKA-11) launched 7 [' vessels', ',', ' the', ' USS', ' Bet', 'el', 'ge', 'use']
+103 40 The name of the constellation including x -1 The name of the constellation including Betelgeuse Orion Betelgeuse [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' Orion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' Orion', ' is', ' derived', ' from'] ", the brightest star in the constellation Orion .
+
+ The name of the constellation Orion is derived from" True " measured parallax of Betelgeuse was π
+" 8 [' measured', ' par', 'all', 'ax', ' of', ' Bet', 'el', 'ge', 'use']
+104 40 The name of the constellation including x -1 The name of the constellation including Betelgeuse Orion Betelgeuse [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' Orion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' Orion', ' is', ' derived', ' from'] ", the brightest star in the constellation Orion .
+
+ The name of the constellation Orion is derived from" True 3 ['Bet', 'el', 'ge', 'use']
+105 40 The name of the constellation including x -1 The name of the constellation including Betelgeuse Orion Betelgeuse [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' Orion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' Orion', ' is', ' derived', ' from'] ", the brightest star in the constellation Orion .
+
+ The name of the constellation Orion is derived from" True solution, showing Betelgeuse with a uniform 6 [' solution', ',', ' showing', ' Bet', 'el', 'ge', 'use']
+106 40 The name of the constellation including x -1 The name of the constellation including Betelgeuse Orion Betelgeuse [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' Orion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' Orion', ' is', ' derived', ' from'] ", the brightest star in the constellation Orion .
+
+ The name of the constellation Orion is derived from" True 3 ['Bet', 'el', 'ge', 'use']
+107 42 The name of the constellation including x -1 The name of the constellation including Algol Perseus Algol [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n'] ", the brightest star in the constellation of Perse us , is the brightest star in the sky .
+" True The association of Algol with a demon-like 5 [' The', ' association', ' of', ' Al', 'g', 'ol']
+108 42 The name of the constellation including x -1 The name of the constellation including Algol Perseus Algol [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n'] ", the brightest star in the constellation of Perse us , is the brightest star in the sky .
+" True service 3 [' servic', 'Al', 'g', 'ol']
+109 42 The name of the constellation including x -1 The name of the constellation including Algol Perseus Algol [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n'] ", the brightest star in the constellation of Perse us , is the brightest star in the sky .
+" True " Algol =
+" 2 [' Al', 'g', 'ol']
+110 42 The name of the constellation including x -1 The name of the constellation including Algol Perseus Algol [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n'] ", the brightest star in the constellation of Perse us , is the brightest star in the sky .
+" True astronomer Ptolemy, Algol is referred to as 8 [' astronomer', ' P', 'to', 'le', 'my', ',', ' Al', 'g', 'ol']
+111 42 The name of the constellation including x -1 The name of the constellation including Algol Perseus Algol [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' sky', '.', '\n'] ", the brightest star in the constellation of Perse us , is the brightest star in the sky .
+" True In Hebrew folklore, Algol was called 6 [' In', ' Hebrew', ' folklore', ',', ' Al', 'g', 'ol']
+112 45 The name of the constellation including x -1 The name of the constellation including Denebola Leo Denebola [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Cy', 'gn', 'us', ',', ' the', ' Swan', '.', '\n', '\n', 'The', ' name', ' of'] ", the brightest star in the constellation of Cy gn us , the Swan .
+
+ The name of" False arctic whales is Denebola brachycephala from 7 [' ar', 'ctic', ' whales', ' is', ' D', 'ene', 'b', 'ola']
+113 45 The name of the constellation including x -1 The name of the constellation including Denebola Leo Denebola [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Cy', 'gn', 'us', ',', ' the', ' Swan', '.', '\n', '\n', 'The', ' name', ' of'] ", the brightest star in the constellation of Cy gn us , the Swan .
+
+ The name of" False arctic whales is Denebola brachycephala from 7 [' ar', 'ctic', ' whales', ' is', ' D', 'ene', 'b', 'ola']
+114 45 The name of the constellation including x -1 The name of the constellation including Denebola Leo Denebola [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Cy', 'gn', 'us', ',', ' the', ' Swan', '.', '\n', '\n', 'The', ' name', ' of'] ", the brightest star in the constellation of Cy gn us , the Swan .
+
+ The name of" False of arctic whales is Denebola brachycephala from 8 [' of', ' ar', 'ctic', ' whales', ' is', ' D', 'ene', 'b', 'ola']
+115 47 The name of the constellation including x -1 The name of the constellation including Alpha Persei Perseus Alpha Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True yellow-white supergiant Alpha Persei (also called 8 [' yellow', '-', 'white', ' super', 'g', 'iant', ' Alpha', ' Perse', 'i']
+116 47 The name of the constellation including x -1 The name of the constellation including Alpha Persei Perseus Alpha Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True supergiant Alpha Persei (also called 5 [' super', 'g', 'iant', ' Alpha', ' Perse', 'i']
+117 47 The name of the constellation including x -1 The name of the constellation including Alpha Persei Perseus Alpha Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True elbow) or Algenib, Alpha Persei is the brightest 9 [' elbow', ')', ' or', ' Al', 'gen', 'ib', ',', ' Alpha', ' Perse', 'i']
+118 47 The name of the constellation including x -1 The name of the constellation including Alpha Persei Perseus Alpha Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True elbow) or Algenib, Alpha Persei is the brightest star 9 [' elbow', ')', ' or', ' Al', 'gen', 'ib', ',', ' Alpha', ' Perse', 'i']
+119 47 The name of the constellation including x -1 The name of the constellation including Alpha Persei Perseus Alpha Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True yellow-white supergiant Alpha Persei (also called 8 [' yellow', '-', 'white', ' super', 'g', 'iant', ' Alpha', ' Perse', 'i']
+120 50 The name of the constellation including x -1 The name of the constellation including Gliese 436 Leo Gliese 436 [',', ' which', ' is', ' a', ' red', ' dwarf', ' star', ',', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', '.', ' It', ' is', ' located', ' about'] , which is a red dwarf star , is a bit of a mouth ful . It is located about False radiation than Gliese 436 b. The side 5 [' radiation', ' than', ' Gl', 'ies', 'e', ' 436']
+121 50 The name of the constellation including x -1 The name of the constellation including Gliese 436 Leo Gliese 436 [',', ' which', ' is', ' a', ' red', ' dwarf', ' star', ',', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', '.', ' It', ' is', ' located', ' about'] , which is a red dwarf star , is a bit of a mouth ful . It is located about False red dwarf star Gliese 436 named Gliese 436 6 [' red', ' dwarf', ' star', ' Gl', 'ies', 'e', ' 436']
+122 50 The name of the constellation including x -1 The name of the constellation including Gliese 436 Leo Gliese 436 [',', ' which', ' is', ' a', ' red', ' dwarf', ' star', ',', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', '.', ' It', ' is', ' located', ' about'] , which is a red dwarf star , is a bit of a mouth ful . It is located about False radiation than Gliese 436 b. The side of the 5 [' radiation', ' than', ' Gl', 'ies', 'e', ' 436']
+123 50 The name of the constellation including x -1 The name of the constellation including Gliese 436 Leo Gliese 436 [',', ' which', ' is', ' a', ' red', ' dwarf', ' star', ',', ' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', '.', ' It', ' is', ' located', ' about'] , which is a red dwarf star , is a bit of a mouth ful . It is located about False radiation than Gliese 436 b. The side of the 5 [' radiation', ' than', ' Gl', 'ies', 'e', ' 436']
+124 54 The name of the constellation including x -1 The name of the constellation including Omicron Leonis Leo Omicron Leonis [',', ' the', ' lion', ',', ' is', ' the', ' first', ' letter', ' of', ' the', ' Greek', ' alphabet', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation'] ", the lion , is the first letter of the Greek alphabet .
+
+ The name of the constellation" False misinterpreted as Omicron Leonis Minoris. More 7 [' misinterpret', 'ed', ' as', ' O', 'mic', 'ron', ' Leon', 'is']
+125 61 The name of the constellation including x -1 The name of the constellation including X Persei Perseus X Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', ',', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse'] , the brightest star in the constellation of Perse us , is the brightest star in the constellation of Perse True to the other two. X Persei is a double system 7 [' to', ' the', ' other', ' two', '.', ' X', ' Perse', 'i']
+126 70 The name of the constellation including x -1 The name of the constellation including Bellatrix Orion Bellatrix [',', ' the', ' goddess', ' of', ' war', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' goddess', ' of', ' love', ',', ' Venus', '.'] , the goddess of war , and the name of the constellation of the goddess of love , Venus . False to the role of Bellatrix Lestrange, although 6 [' to', ' the', ' role', ' of', ' Bell', 'at', 'rix']
+127 70 The name of the constellation including x -1 The name of the constellation including Bellatrix Orion Bellatrix [',', ' the', ' goddess', ' of', ' war', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' goddess', ' of', ' love', ',', ' Venus', '.'] , the goddess of war , and the name of the constellation of the goddess of love , Venus . False into insanity by Bellatrix Lestrange, was 5 [' into', ' insanity', ' by', ' Bell', 'at', 'rix']
+128 70 The name of the constellation including x -1 The name of the constellation including Bellatrix Orion Bellatrix [',', ' the', ' goddess', ' of', ' war', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' goddess', ' of', ' love', ',', ' Venus', '.'] , the goddess of war , and the name of the constellation of the goddess of love , Venus . False Helena Bonham Carter as Bellatrix Lestrange, one 7 [' Helena', ' Bon', 'ham', ' Carter', ' as', ' Bell', 'at', 'rix']
+129 70 The name of the constellation including x -1 The name of the constellation including Bellatrix Orion Bellatrix [',', ' the', ' goddess', ' of', ' war', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' goddess', ' of', ' love', ',', ' Venus', '.'] , the goddess of war , and the name of the constellation of the goddess of love , Venus . False owns and operates Bellatrix Female Warriors, 5 [' owns', ' and', ' operates', ' Bell', 'at', 'rix']
+130 70 The name of the constellation including x -1 The name of the constellation including Bellatrix Orion Bellatrix [',', ' the', ' goddess', ' of', ' war', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' goddess', ' of', ' love', ',', ' Venus', '.'] , the goddess of war , and the name of the constellation of the goddess of love , Venus . False Meanwhile, Bellatrix Lestrange, 4 [' Meanwhile', ',', ' Bell', 'at', 'rix']
+131 82 The name of the constellation including x -1 The name of the constellation including Delta Leonis Leo Delta Leonis [',', ' the', ' lion', ',', ' is', ' the', ' constellation', ' of', ' Leo', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Leo', ' is'] ", the lion , is the constellation of Leo .
+
+ The name of the constellation of Leo is" True stars including Delta Leonis and Alpha Ophiuchi; 4 [' stars', ' including', ' Delta', ' Leon', 'is']
+132 82 The name of the constellation including x -1 The name of the constellation including Delta Leonis Leo Delta Leonis [',', ' the', ' lion', ',', ' is', ' the', ' constellation', ' of', ' Leo', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Leo', ' is'] ", the lion , is the constellation of Leo .
+
+ The name of the constellation of Leo is" True 70 stars including Delta Leonis and Alpha Ophiuchi; 5 [' 70', ' stars', ' including', ' Delta', ' Leon', 'is']
+133 85 The name of the constellation including x -1 The name of the constellation including Regulus Leo Regulus [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Lion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Lion'] ", the star of the constellation of the Lion .
+
+ The name of the constellation of the Lion" False a related Regulus species. A number 3 [' a', ' related', ' Reg', 'ulus']
+134 85 The name of the constellation including x -1 The name of the constellation including Regulus Leo Regulus [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Lion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Lion'] ", the star of the constellation of the Lion .
+
+ The name of the constellation of the Lion" False of extant Regulus species, mostly 3 [' of', ' extant', ' Reg', 'ulus']
+135 85 The name of the constellation including x -1 The name of the constellation including Regulus Leo Regulus [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Lion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Lion'] ", the star of the constellation of the Lion .
+
+ The name of the constellation of the Lion" False common firecrest (Regulus ignicapilla) also 6 [' common', ' fire', 'c', 'rest', ' (', 'Reg', 'ulus']
+136 85 The name of the constellation including x -1 The name of the constellation including Regulus Leo Regulus [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Lion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Lion'] ", the star of the constellation of the Lion .
+
+ The name of the constellation of the Lion" False found on the star Regulus A (α Leonis A). 5 [' found', ' on', ' the', ' star', ' Reg', 'ulus']
+137 85 The name of the constellation including x -1 The name of the constellation including Regulus Leo Regulus [',', ' the', ' star', ' of', ' the', ' constellation', ' of', ' the', ' Lion', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Lion'] ", the star of the constellation of the Lion .
+
+ The name of the constellation of the Lion" False applied to the Regulus species, the 4 [' applied', ' to', ' the', ' Reg', 'ulus']
+138 99 The name of the constellation including x -1 The name of the constellation including Mu Persei Perseus Mu Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True and possibly Mu Persei lay within it. 4 [' and', ' possibly', ' Mu', ' Perse', 'i']
+139 99 The name of the constellation including x -1 The name of the constellation including Mu Persei Perseus Mu Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True Lambda and possibly Mu Persei lay within it. Dàlíng 6 [' Lamb', 'da', ' and', ' possibly', ' Mu', ' Perse', 'i']
+140 99 The name of the constellation including x -1 The name of the constellation including Mu Persei Perseus Mu Persei [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Perse', 'us', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Perse'] ", the brightest star in the constellation of Perse us .
+
+ The name of the constellation of Perse" True Lambda and possibly Mu Persei lay within it. 6 [' Lamb', 'da', ' and', ' possibly', ' Mu', ' Perse', 'i']
+141 108 The name of the constellation including x -1 The name of the constellation including 48 Persei Perseus 48 Persei [',', ' the', ' constellation', ' of', ' the', ' spring', ' equ', 'in', 'ox', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' autumn'] ", the constellation of the spring equ in ox .
+
+ The name of the constellation of the autumn" False Psi (4.3), and 48 Persei (4.0); the Beta 10 [' P', 'si', ' (', '4', '.', '3', '),', ' and', ' 48', ' Perse', 'i']
+142 108 The name of the constellation including x -1 The name of the constellation including 48 Persei Perseus 48 Persei [',', ' the', ' constellation', ' of', ' the', ' spring', ' equ', 'in', 'ox', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' autumn'] ", the constellation of the spring equ in ox .
+
+ The name of the constellation of the autumn" False Psi (4.3), and 48 Persei (4.0); the Beta Cephei 10 [' P', 'si', ' (', '4', '.', '3', '),', ' and', ' 48', ' Perse', 'i']
+143 108 The name of the constellation including x -1 The name of the constellation including 48 Persei Perseus 48 Persei [',', ' the', ' constellation', ' of', ' the', ' spring', ' equ', 'in', 'ox', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' autumn'] ", the constellation of the spring equ in ox .
+
+ The name of the constellation of the autumn" False (4.3), and 48 Persei (4.0); the Beta 8 [' (', '4', '.', '3', '),', ' and', ' 48', ' Perse', 'i']
+144 108 The name of the constellation including x -1 The name of the constellation including 48 Persei Perseus 48 Persei [',', ' the', ' constellation', ' of', ' the', ' spring', ' equ', 'in', 'ox', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' the', ' autumn'] ", the constellation of the spring equ in ox .
+
+ The name of the constellation of the autumn" False (4.3), and 48 Persei (4.0); the 8 [' (', '4', '.', '3', '),', ' and', ' 48', ' Perse', 'i']
+145 110 The name of the constellation including x -1 The name of the constellation including 51 Pegasi Pegasus 51 Pegasi [',', ' the', ' star', ' that', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Pegasus', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", the star that is the brightest star in the constellation of Pegasus .
+
+ The name of the" True detected hot Jupiter 51 Pegasi b. The orbital 5 [' detected', ' hot', ' Jupiter', ' 51', ' Peg', 'asi']
+146 110 The name of the constellation including x -1 The name of the constellation including 51 Pegasi Pegasus 51 Pegasi [',', ' the', ' star', ' that', ' is', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Pegasus', '.', '\n', '\n', 'The', ' name', ' of', ' the'] ", the star that is the brightest star in the constellation of Pegasus .
+
+ The name of the" True detected hot Jupiter 51 Pegasi b. The orbital 5 [' detected', ' hot', ' Jupiter', ' 51', ' Peg', 'asi']
+147 114 The name of the constellation including x -1 The name of the constellation including 70 Virginis Virgo 70 Virginis [',', ' the', ' Virgin', ' Mary', ',', ' is', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Virgin', ' Mary', '.', '\n', '\n', 'The', ' name'] ", the Virgin Mary , is the name of the constellation of the Virgin Mary .
+
+ The name" False Thong and 70 Virginis b. The discovery 5 [' Th', 'ong', ' and', ' 70', ' Virgin', 'is']
+148 114 The name of the constellation including x -1 The name of the constellation including 70 Virginis Virgo 70 Virginis [',', ' the', ' Virgin', ' Mary', ',', ' is', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Virgin', ' Mary', '.', '\n', '\n', 'The', ' name'] ", the Virgin Mary , is the name of the constellation of the Virgin Mary .
+
+ The name" False Taphao Thong and 70 Virginis b. The discovery 7 [' Tap', 'hao', ' Th', 'ong', ' and', ' 70', ' Virgin', 'is']
+149 114 The name of the constellation including x -1 The name of the constellation including 70 Virginis Virgo 70 Virginis [',', ' the', ' Virgin', ' Mary', ',', ' is', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Virgin', ' Mary', '.', '\n', '\n', 'The', ' name'] ", the Virgin Mary , is the name of the constellation of the Virgin Mary .
+
+ The name" False Taphao Thong and 70 Virginis b. The discovery of 7 [' Tap', 'hao', ' Th', 'ong', ' and', ' 70', ' Virgin', 'is']
+150 128 The name of the constellation including x -1 The name of the constellation including Epsilon Pegasi Pegasus Epsilon Pegasi [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Pegasus', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Pegasus', ' is'] ", the brightest star in the constellation of Pegasus .
+
+ The name of the constellation of Pegasus is" True the Capella – Epsilon Pegasi node by detonating 7 [' the', ' Cape', 'lla', ' –', ' Eps', 'ilon', ' Peg', 'asi']
+151 129 The name of the constellation including x -1 The name of the constellation including Gumala Sagittarius Gumala [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False Coghill visited Gumala Mirnuwarni, a 5 [' Co', 'gh', 'ill', ' visited', ' Gum', 'ala']
+152 131 The name of the constellation including x -1 The name of the constellation including Messier 87 Virgo Messier 87 [',', ' the', ' Great', ' Bear', ',', ' is', ' a', ' constellation', ' in', ' the', ' northern', ' sky', '.', ' It', ' is', ' a', ' large', ' constellation', ',', ' and'] , the Great Bear , is a constellation in the northern sky . It is a large constellation , and False " Messier 87 =
+" 2 [' Mess', 'ier', ' 87']
+153 131 The name of the constellation including x -1 The name of the constellation including Messier 87 Virgo Messier 87 [',', ' the', ' Great', ' Bear', ',', ' is', ' a', ' constellation', ' in', ' the', ' northern', ' sky', '.', ' It', ' is', ' a', ' large', ' constellation', ',', ' and'] , the Great Bear , is a constellation in the northern sky . It is a large constellation , and False " Messier 87 =
+" 2 [' Mess', 'ier', ' 87']
+154 131 The name of the constellation including x -1 The name of the constellation including Messier 87 Virgo Messier 87 [',', ' the', ' Great', ' Bear', ',', ' is', ' a', ' constellation', ' in', ' the', ' northern', ' sky', '.', ' It', ' is', ' a', ' large', ' constellation', ',', ' and'] , the Great Bear , is a constellation in the northern sky . It is a large constellation , and False 2 ['Mess', 'ier', ' 87']
+155 131 The name of the constellation including x -1 The name of the constellation including Messier 87 Virgo Messier 87 [',', ' the', ' Great', ' Bear', ',', ' is', ' a', ' constellation', ' in', ' the', ' northern', ' sky', '.', ' It', ' is', ' a', ' large', ' constellation', ',', ' and'] , the Great Bear , is a constellation in the northern sky . It is a large constellation , and False " Messier 87 =
+" 2 [' Mess', 'ier', ' 87']
+156 131 The name of the constellation including x -1 The name of the constellation including Messier 87 Virgo Messier 87 [',', ' the', ' Great', ' Bear', ',', ' is', ' a', ' constellation', ' in', ' the', ' northern', ' sky', '.', ' It', ' is', ' a', ' large', ' constellation', ',', ' and'] , the Great Bear , is a constellation in the northern sky . It is a large constellation , and False 2 ['Mess', 'ier', ' 87']
+157 145 The name of the constellation including x -1 The name of the constellation including Gamma Pegasi Pegasus Gamma Pegasi [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Pegasus', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Pegasus', ' is'] ", the brightest star in the constellation of Pegasus .
+
+ The name of the constellation of Pegasus is" True " Andromedae and Gamma Pegasi together made ""Wall""" 7 [' And', 'rom', 'ed', 'ae', ' and', ' Gamma', ' Peg', 'asi']
+158 145 The name of the constellation including x -1 The name of the constellation including Gamma Pegasi Pegasus Gamma Pegasi [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Pegasus', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Pegasus', ' is'] ", the brightest star in the constellation of Pegasus .
+
+ The name of the constellation of Pegasus is" True " Andromedae and Gamma Pegasi together made ""Wall""" 7 [' And', 'rom', 'ed', 'ae', ' and', ' Gamma', ' Peg', 'asi']
+159 151 The name of the constellation including x -1 The name of the constellation including Spica Virgo Spica "[',', ' the', ' star', ' of', ' the', ' constellation', ' Vir', 'go', ',', ' is', ' the', ' Latin', ' word', ' for', ' ""', 'v', 'irgin', '"".', '\n', '\n']" ", the star of the constellation Vir go , is the Latin word for "" v irgin "".
+
+" True Wakata named Twin Spica, because of 5 [' Wak', 'ata', ' named', ' Twin', ' Sp', 'ica']
+160 151 The name of the constellation including x -1 The name of the constellation including Spica Virgo Spica "[',', ' the', ' star', ' of', ' the', ' constellation', ' Vir', 'go', ',', ' is', ' the', ' Latin', ' word', ' for', ' ""', 'v', 'irgin', '"".', '\n', '\n']" ", the star of the constellation Vir go , is the Latin word for "" v irgin "".
+
+" True adaptation of Twin Spica five out of five 4 [' adaptation', ' of', ' Twin', ' Sp', 'ica']
+161 151 The name of the constellation including x -1 The name of the constellation including Spica Virgo Spica "[',', ' the', ' star', ' of', ' the', ' constellation', ' Vir', 'go', ',', ' is', ' the', ' Latin', ' word', ' for', ' ""', 'v', 'irgin', '"".', '\n', '\n']" ", the star of the constellation Vir go , is the Latin word for "" v irgin "".
+
+" True from either Spica or Lulim at first. 3 [' from', ' either', ' Sp', 'ica']
+162 151 The name of the constellation including x -1 The name of the constellation including Spica Virgo Spica "[',', ' the', ' star', ' of', ' the', ' constellation', ' Vir', 'go', ',', ' is', ' the', ' Latin', ' word', ' for', ' ""', 'v', 'irgin', '"".', '\n', '\n']" ", the star of the constellation Vir go , is the Latin word for "" v irgin "".
+
+" True back-story of Twin Spica in 2000 with his debut 6 [' back', '-', 'story', ' of', ' Twin', ' Sp', 'ica']
+163 151 The name of the constellation including x -1 The name of the constellation including Spica Virgo Spica "[',', ' the', ' star', ' of', ' the', ' constellation', ' Vir', 'go', ',', ' is', ' the', ' Latin', ' word', ' for', ' ""', 'v', 'irgin', '"".', '\n', '\n']" ", the star of the constellation Vir go , is the Latin word for "" v irgin "".
+
+" True back-story of Twin Spica in 2000 with his 6 [' back', '-', 'story', ' of', ' Twin', ' Sp', 'ica']
+164 162 The name of the constellation including x -1 The name of the constellation including 3C 273 Virgo 3C 273 [',', ' the', ' most', ' distant', ' qu', 'asar', ' known', '.', '\n', '\n', 'The', ' qu', 'asar', ' is', ' located', ' at', ' a', ' red', 'shift', ' of'] ", the most distant qu asar known .
+
+ The qu asar is located at a red shift of" False after that of 3C 273 in 1963. It has an 5 [' after', ' that', ' of', ' 3', 'C', ' 273']
+165 162 The name of the constellation including x -1 The name of the constellation including 3C 273 Virgo 3C 273 [',', ' the', ' most', ' distant', ' qu', 'asar', ' known', '.', '\n', '\n', 'The', ' qu', 'asar', ' is', ' located', ' at', ' a', ' red', 'shift', ' of'] ", the most distant qu asar known .
+
+ The qu asar is located at a red shift of" False until after that of 3C 273 in 1963. It has 6 [' until', ' after', ' that', ' of', ' 3', 'C', ' 273']
+166 166 The name of the constellation including x -1 The name of the constellation including Ross 154 Sagittarius Ross 154 [',', ' which', ' is', ' a', ' binary', ' star', ' system', ',', ' is', ' a', ' double', ' star', ' system', '.', ' The', ' two', ' stars', ' are', ' separated', ' by'] , which is a binary star system , is a double star system . The two stars are separated by False solitary red dwarf Ross 154 (9.7 ly). The closest 4 [' solitary', ' red', ' dwarf', ' Ross', ' 154']
+167 166 The name of the constellation including x -1 The name of the constellation including Ross 154 Sagittarius Ross 154 [',', ' which', ' is', ' a', ' binary', ' star', ' system', ',', ' is', ' a', ' double', ' star', ' system', '.', ' The', ' two', ' stars', ' are', ' separated', ' by'] , which is a binary star system , is a double star system . The two stars are separated by False solitary red dwarf Ross 154 (9.7 ly). The 4 [' solitary', ' red', ' dwarf', ' Ross', ' 154']
+168 166 The name of the constellation including x -1 The name of the constellation including Ross 154 Sagittarius Ross 154 [',', ' which', ' is', ' a', ' binary', ' star', ' system', ',', ' is', ' a', ' double', ' star', ' system', '.', ' The', ' two', ' stars', ' are', ' separated', ' by'] , which is a binary star system , is a double star system . The two stars are separated by False solitary red dwarf Ross 154 (9.7 ly). The closest 4 [' solitary', ' red', ' dwarf', ' Ross', ' 154']
+169 166 The name of the constellation including x -1 The name of the constellation including Ross 154 Sagittarius Ross 154 [',', ' which', ' is', ' a', ' binary', ' star', ' system', ',', ' is', ' a', ' double', ' star', ' system', '.', ' The', ' two', ' stars', ' are', ' separated', ' by'] , which is a binary star system , is a double star system . The two stars are separated by False solitary red dwarf Ross 154 (9.7 ly). The closest 4 [' solitary', ' red', ' dwarf', ' Ross', ' 154']
+170 166 The name of the constellation including x -1 The name of the constellation including Ross 154 Sagittarius Ross 154 [',', ' which', ' is', ' a', ' binary', ' star', ' system', ',', ' is', ' a', ' double', ' star', ' system', '.', ' The', ' two', ' stars', ' are', ' separated', ' by'] , which is a binary star system , is a double star system . The two stars are separated by False solitary red dwarf Ross 154 (9.7 ly). The closest 4 [' solitary', ' red', ' dwarf', ' Ross', ' 154']
+171 201 The name of the constellation including x -1 The name of the constellation including Pollux Gemini Pollux [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Gemini', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Gemini', ' is'] ", the brightest star in the constellation of Gemini .
+
+ The name of the constellation of Gemini is" True " first-magnitude star, as Pollux is from the Earth.
+" 9 [' first', '-', 'm', 'agn', 'itude', ' star', ',', ' as', ' Poll', 'ux']
+172 201 The name of the constellation including x -1 The name of the constellation including Pollux Gemini Pollux [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Gemini', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Gemini', ' is'] ", the brightest star in the constellation of Gemini .
+
+ The name of the constellation of Gemini is" True conjoined twin Pollux (Brad Grusnick), 4 [' con', 'joined', ' twin', ' Poll', 'ux']
+173 201 The name of the constellation including x -1 The name of the constellation including Pollux Gemini Pollux [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Gemini', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Gemini', ' is'] ", the brightest star in the constellation of Gemini .
+
+ The name of the constellation of Gemini is" True conjoined twin Pollux (Brad Grusnick), 4 [' con', 'joined', ' twin', ' Poll', 'ux']
+174 201 The name of the constellation including x -1 The name of the constellation including Pollux Gemini Pollux [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Gemini', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Gemini', ' is'] ", the brightest star in the constellation of Gemini .
+
+ The name of the constellation of Gemini is" True first-magnitude star, as Pollux is from the 9 [' first', '-', 'm', 'agn', 'itude', ' star', ',', ' as', ' Poll', 'ux']
+175 201 The name of the constellation including x -1 The name of the constellation including Pollux Gemini Pollux [',', ' the', ' brightest', ' star', ' in', ' the', ' constellation', ' of', ' Gemini', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' of', ' Gemini', ' is'] ", the brightest star in the constellation of Gemini .
+
+ The name of the constellation of Gemini is" True (Beta Aurigae), Pollux (Beta Geminorum), 7 [' (', 'Beta', ' Aur', 'ig', 'ae', '),', ' Poll', 'ux']
+176 207 The name of the constellation including x -1 The name of the constellation including 14 Herculis Hercules 14 Herculis [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Seven', ' Sisters', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ','] , the Ple i ades , the Seven Sisters , the Ple i ades , the Ple i ades , False planet-bearing star 14 Herculis has nearly triple 7 [' planet', '-', 'bearing', ' star', ' 14', ' Her', 'cul', 'is']
+177 207 The name of the constellation including x -1 The name of the constellation including 14 Herculis Hercules 14 Herculis [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Seven', ' Sisters', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ','] , the Ple i ades , the Seven Sisters , the Ple i ades , the Ple i ades , False planet-bearing star 14 Herculis has nearly triple 7 [' planet', '-', 'bearing', ' star', ' 14', ' Her', 'cul', 'is']
+178 217 The name of the constellation including x -1 The name of the constellation including Nu Indi Indus Nu Indi [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Indian', ' sub', 'cont', 'inent', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation'] ", the name of the constellation of the Indian sub cont inent .
+
+ The name of the constellation" False Lacaille had labelled Nu Indi turned out to be in 7 [' Lac', 'ail', 'le', ' had', ' labelled', ' Nu', ' Ind', 'i']
+179 277 The name of the constellation including x -1 The name of the constellation including Upsilon Andromedae Andromeda Upsilon Andromedae [',', ' the', ' Great', ' Bear', ',', ' is', ' the', ' only', ' one', ' that', ' is', ' not', ' a', ' single', ' star', '.', '\n', '\n', 'The', ' constellation'] ", the Great Bear , is the only one that is not a single star .
+
+ The constellation" False used to detect Upsilon Andromedae c is that 8 [' used', ' to', ' detect', ' Ups', 'ilon', ' And', 'rom', 'ed', 'ae']
+180 277 The name of the constellation including x -1 The name of the constellation including Upsilon Andromedae Andromeda Upsilon Andromedae [',', ' the', ' Great', ' Bear', ',', ' is', ' the', ' only', ' one', ' that', ' is', ' not', ' a', ' single', ' star', '.', '\n', '\n', 'The', ' constellation'] ", the Great Bear , is the only one that is not a single star .
+
+ The constellation" False " Andromedae d =
+" 12 [' And', 'rom', 'ed', 'ae', ' d', ' =', 'U', 'ps', 'ilon', ' And', 'rom', 'ed', 'ae']
+181 277 The name of the constellation including x -1 The name of the constellation including Upsilon Andromedae Andromeda Upsilon Andromedae [',', ' the', ' Great', ' Bear', ',', ' is', ' the', ' only', ' one', ' that', ' is', ' not', ' a', ' single', ' star', '.', '\n', '\n', 'The', ' constellation'] ", the Great Bear , is the only one that is not a single star .
+
+ The constellation" False 6 ['U', 'ps', 'ilon', ' And', 'rom', 'ed', 'ae']
+182 277 The name of the constellation including x -1 The name of the constellation including Upsilon Andromedae Andromeda Upsilon Andromedae [',', ' the', ' Great', ' Bear', ',', ' is', ' the', ' only', ' one', ' that', ' is', ' not', ' a', ' single', ' star', '.', '\n', '\n', 'The', ' constellation'] ", the Great Bear , is the only one that is not a single star .
+
+ The constellation" False 2008. The orbit of Upsilon Andromedae c gradually 10 [' 2008', '.', ' The', ' orbit', ' of', ' Ups', 'ilon', ' And', 'rom', 'ed', 'ae']
+183 277 The name of the constellation including x -1 The name of the constellation including Upsilon Andromedae Andromeda Upsilon Andromedae [',', ' the', ' Great', ' Bear', ',', ' is', ' the', ' only', ' one', ' that', ' is', ' not', ' a', ' single', ' star', '.', '\n', '\n', 'The', ' constellation'] ", the Great Bear , is the only one that is not a single star .
+
+ The constellation" False innermost planet of the Upsilon Andromedae system was discovered 10 [' inner', 'most', ' planet', ' of', ' the', ' Ups', 'ilon', ' And', 'rom', 'ed', 'ae']
+184 283 The name of the constellation including x -1 The name of the constellation including Milky Way Cassiopeia Milky Way [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way', ',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way'] , the name of the constellation of the Milky Way , the name of the constellation of the Milky Way False be seen in the Milky Way galaxy was SN 1604, 5 [' be', ' seen', ' in', ' the', ' Milky', ' Way']
+185 283 The name of the constellation including x -1 The name of the constellation including Milky Way Cassiopeia Milky Way [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way', ',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way'] , the name of the constellation of the Milky Way , the name of the constellation of the Milky Way False center of the Milky Way, going in the direction 4 [' center', ' of', ' the', ' Milky', ' Way']
+186 283 The name of the constellation including x -1 The name of the constellation including Milky Way Cassiopeia Milky Way [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way', ',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way'] , the name of the constellation of the Milky Way , the name of the constellation of the Milky Way False 2 ['Mil', 'ky', ' Way']
+187 283 The name of the constellation including x -1 The name of the constellation including Milky Way Cassiopeia Milky Way [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way', ',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way'] , the name of the constellation of the Milky Way , the name of the constellation of the Milky Way False the shape of the Milky Way and the position 5 [' the', ' shape', ' of', ' the', ' Milky', ' Way']
+188 283 The name of the constellation including x -1 The name of the constellation including Milky Way Cassiopeia Milky Way [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way', ',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Milky', ' Way'] , the name of the constellation of the Milky Way , the name of the constellation of the Milky Way False stars in the Milky Way is estimated 4 [' stars', ' in', ' the', ' Milky', ' Way']
+189 285 The name of the constellation including x -1 The name of the constellation including Andromeda Andromeda Andromeda [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades'] , the Ple i ades , the Ple i ades , the Ple i ades , the Ple i ades False 1 ['And', 'romeda']
+190 285 The name of the constellation including x -1 The name of the constellation including Andromeda Andromeda Andromeda [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades'] , the Ple i ades , the Ple i ades , the Ple i ades , the Ple i ades False Milky Way, the Andromeda Galaxy has satellite 4 [' Milky', ' Way', ',', ' the', ' Andromeda']
+191 285 The name of the constellation including x -1 The name of the constellation including Andromeda Andromeda Andromeda [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades'] , the Ple i ades , the Ple i ades , the Ple i ades , the Ple i ades False miniseries The Andromeda Strain (2008) and returned 4 [' min', 'iser', 'ies', ' The', ' Andromeda']
+192 285 The name of the constellation including x -1 The name of the constellation including Andromeda Andromeda Andromeda [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades'] , the Ple i ades , the Ple i ades , the Ple i ades , the Ple i ades False Milky Way and the Andromeda Galaxy are a binary 4 [' Milky', ' Way', ' and', ' the', ' Andromeda']
+193 285 The name of the constellation including x -1 The name of the constellation including Andromeda Andromeda Andromeda [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Ple', 'i', 'ades'] , the Ple i ades , the Ple i ades , the Ple i ades , the Ple i ades False and the nearby Andromeda Galaxy are moving toward 3 [' and', ' the', ' nearby', ' Andromeda']
+194 286 The name of the constellation including x -1 The name of the constellation including Gliese 876 Aquarius Gliese 876 "[' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a', ' binary']" is a bit of a mouth ful , but it 's a pretty cool name . It 's a binary False " Gliese 876 d
+" 4 [' Gl', 'ies', 'e', ' 8', '76']
+195 286 The name of the constellation including x -1 The name of the constellation including Gliese 876 Aquarius Gliese 876 "[' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a', ' binary']" is a bit of a mouth ful , but it 's a pretty cool name . It 's a binary False the situation in the Gliese 876 system) can be 8 [' the', ' situation', ' in', ' the', ' Gl', 'ies', 'e', ' 8', '76']
+196 286 The name of the constellation including x -1 The name of the constellation including Gliese 876 Aquarius Gliese 876 "[' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a', ' binary']" is a bit of a mouth ful , but it 's a pretty cool name . It 's a binary False to detect Gliese 876 b is that only a lower 6 [' to', ' detect', ' Gl', 'ies', 'e', ' 8', '76']
+197 286 The name of the constellation including x -1 The name of the constellation including Gliese 876 Aquarius Gliese 876 "[' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a', ' binary']" is a bit of a mouth ful , but it 's a pretty cool name . It 's a binary False 4 ['Gl', 'ies', 'e', ' 8', '76']
+198 286 The name of the constellation including x -1 The name of the constellation including Gliese 876 Aquarius Gliese 876 "[' is', ' a', ' bit', ' of', ' a', ' mouth', 'ful', ',', ' but', ' it', ""'s"", ' a', ' pretty', ' cool', ' name', '.', ' It', ""'s"", ' a', ' binary']" is a bit of a mouth ful , but it 's a pretty cool name . It 's a binary False orbit around Gliese 876 by two independent 6 [' orbit', ' around', ' Gl', 'ies', 'e', ' 8', '76']
+199 291 The name of the constellation including x -1 The name of the constellation including Alpha Andromedae Andromeda Alpha Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False 4 ['Alpha', ' And', 'rom', 'ed', 'ae']
+200 291 The name of the constellation including x -1 The name of the constellation including Alpha Andromedae Andromeda Alpha Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False " Andromedae =
+" 9 [' And', 'rom', 'ed', 'ae', ' =', 'Alpha', ' And', 'rom', 'ed', 'ae']
+201 291 The name of the constellation including x -1 The name of the constellation including Alpha Andromedae Andromeda Alpha Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False " Andromedae =
+" 9 [' And', 'rom', 'ed', 'ae', ' =', 'Alpha', ' And', 'rom', 'ed', 'ae']
+202 291 The name of the constellation including x -1 The name of the constellation including Alpha Andromedae Andromeda Alpha Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False " Andromedae =
+" 9 [' And', 'rom', 'ed', 'ae', ' =', 'Alpha', ' And', 'rom', 'ed', 'ae']
+203 291 The name of the constellation including x -1 The name of the constellation including Alpha Andromedae Andromeda Alpha Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False Tuamotu islands, Alpha Andromedae was called Takurua-e-te-tuki-hanga-ruki, 10 [' Tu', 'am', 'ot', 'u', ' islands', ',', ' Alpha', ' And', 'rom', 'ed', 'ae']
+204 292 The name of the constellation including x -1 The name of the constellation including 14 Andromedae Andromeda 14 Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' the'] , the Ple i ades , the Hy ades , the Ple i ades , the Hy ades , the False one planet, 14 Andromedae b, discovered 7 [' one', ' planet', ',', ' 14', ' And', 'rom', 'ed', 'ae']
+205 293 The name of the constellation including x -1 The name of the constellation including Ross 248 Andromeda Ross 248 [',', ' which', ' is', ' a', ' binary', ' star', ' system', ',', ' is', ' a', ' double', ' star', ' system', '.', ' The', ' two', ' stars', ' are', ' separated', ' by'] , which is a binary star system , is a double star system . The two stars are separated by False 1 ['Ross', ' 248']
+206 293 The name of the constellation including x -1 The name of the constellation including Ross 248 Andromeda Ross 248 [',', ' which', ' is', ' a', ' binary', ' star', ' system', ',', ' is', ' a', ' double', ' star', ' system', '.', ' The', ' two', ' stars', ' are', ' separated', ' by'] , which is a binary star system , is a double star system . The two stars are separated by False 1 ['Ross', ' 248']
+207 293 The name of the constellation including x -1 The name of the constellation including Ross 248 Andromeda Ross 248 [',', ' which', ' is', ' a', ' binary', ' star', ' system', ',', ' is', ' a', ' double', ' star', ' system', '.', ' The', ' two', ' stars', ' are', ' separated', ' by'] , which is a binary star system , is a double star system . The two stars are separated by False 1 ['Ross', ' 248']
+208 298 The name of the constellation including x -1 The name of the constellation including Beta Andromedae Andromeda Beta Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False dimmer than Alpha, Beta Andromedae is a red giant, its 9 [' dim', 'mer', ' than', ' Alpha', ',', ' Beta', ' And', 'rom', 'ed', 'ae']
+209 298 The name of the constellation including x -1 The name of the constellation including Beta Andromedae Andromeda Beta Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False " wearisome toil"", and Beta Andromedae was called Piringa-o-Tautu.
+" 11 "[' wear', 'is', 'ome', ' to', 'il', '"",', ' and', ' Beta', ' And', 'rom', 'ed', 'ae']"
+210 306 The name of the constellation including x -1 The name of the constellation including Beta Cassiopeiae Cassiopeia Beta Cassiopeiae [',', ' the', ' Great', ' Bear', ',', ' is', ' a', ' constellation', ' in', ' the', ' northern', ' sky', '.', '\n', '\n', 'The', ' constellation', ' of', ' the', ' Great'] ", the Great Bear , is a constellation in the northern sky .
+
+ The constellation of the Great" False other two being Beta Cassiopeiae and Gamma 7 [' other', ' two', ' being', ' Beta', ' Cass', 'i', 'ope', 'iae']
+211 312 The name of the constellation including x -1 The name of the constellation including Navi Cassiopeia Navi [' Mumbai', ',', ' Mumbai', ',', ' Maharashtra', ',', ' India', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from', ' the', ' Sanskrit'] " Mumbai , Mumbai , Maharashtra , India .
+
+ The name of the constellation is derived from the Sanskrit" False eventually led to NetNavis. The government 5 [' eventually', ' led', ' to', ' Net', 'N', 'avi']
+212 312 The name of the constellation including x -1 The name of the constellation including Navi Cassiopeia Navi [' Mumbai', ',', ' Mumbai', ',', ' Maharashtra', ',', ' India', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from', ' the', ' Sanskrit'] " Mumbai , Mumbai , Maharashtra , India .
+
+ The name of the constellation is derived from the Sanskrit" False Cassiopeiae became Navi – Ivan (Gus Grissom's 6 [' Cass', 'i', 'ope', 'iae', ' became', ' Nav', 'i']
+213 312 The name of the constellation including x -1 The name of the constellation including Navi Cassiopeia Navi [' Mumbai', ',', ' Mumbai', ',', ' Maharashtra', ',', ' India', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from', ' the', ' Sanskrit'] " Mumbai , Mumbai , Maharashtra , India .
+
+ The name of the constellation is derived from the Sanskrit" False 1 ['N', 'avi']
+214 312 The name of the constellation including x -1 The name of the constellation including Navi Cassiopeia Navi [' Mumbai', ',', ' Mumbai', ',', ' Maharashtra', ',', ' India', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from', ' the', ' Sanskrit'] " Mumbai , Mumbai , Maharashtra , India .
+
+ The name of the constellation is derived from the Sanskrit" False 1 ['N', 'avi']
+215 312 The name of the constellation including x -1 The name of the constellation including Navi Cassiopeia Navi [' Mumbai', ',', ' Mumbai', ',', ' Maharashtra', ',', ' India', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' constellation', ' is', ' derived', ' from', ' the', ' Sanskrit'] " Mumbai , Mumbai , Maharashtra , India .
+
+ The name of the constellation is derived from the Sanskrit" False well as parts of Navi Mumbai, Mira-Bhayandar 5 [' well', ' as', ' parts', ' of', ' Nav', 'i']
+216 325 The name of the constellation including x -1 The name of the constellation including Buna Andromeda Buna [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False the Japanese back to Buna in the north. 5 [' the', ' Japanese', ' back', ' to', ' B', 'una']
+217 325 The name of the constellation including x -1 The name of the constellation including Buna Andromeda Buna [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False for a few hours at Buna on 3 September, 6 [' for', ' a', ' few', ' hours', ' at', ' B', 'una']
+218 325 The name of the constellation including x -1 The name of the constellation including Buna Andromeda Buna [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False withdraw back towards Buna and Gona. Fierce 4 [' withdraw', ' back', ' towards', ' B', 'una']
+219 325 The name of the constellation including x -1 The name of the constellation including Buna Andromeda Buna [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False the Battle of Buna – Gona before the 4 [' the', ' Battle', ' of', ' B', 'una']
+220 325 The name of the constellation including x -1 The name of the constellation including Buna Andromeda Buna [',', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Great', ' Bear', ',', ' and', ' the', ' name', ' of', ' the', ' constellation', ' of', ' the', ' Little'] , the name of the constellation of the Great Bear , and the name of the constellation of the Little False secure the area Buna – Gona upon 4 [' secure', ' the', ' area', ' B', 'una']
+221 335 The name of the constellation including x -1 The name of the constellation including Kappa Andromedae Andromeda Kappa Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False MJ, while Kappa Andromedae b has a mass 7 [' MJ', ',', ' while', ' Kappa', ' And', 'rom', 'ed', 'ae']
+222 335 The name of the constellation including x -1 The name of the constellation including Kappa Andromedae Andromeda Kappa Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False 5 ['K', 'appa', ' And', 'rom', 'ed', 'ae']
+223 335 The name of the constellation including x -1 The name of the constellation including Kappa Andromedae Andromeda Kappa Andromedae [',', ' the', ' Ple', 'i', 'ades', ',', ' the', ' Hy', 'ades', ',', ' and', ' the', ' Hy', 'ades', ' cluster', '.', '\n', '\n', 'The', ' Ple'] ", the Ple i ades , the Hy ades , and the Hy ades cluster .
+
+ The Ple" False 0.69 MJ, while Kappa Andromedae b has a mass 10 [' 0', '.', '69', ' MJ', ',', ' while', ' Kappa', ' And', 'rom', 'ed', 'ae']
diff --git a/patchscopes/code/preprocessed_data/factual/superhero_archnemesis.tsv b/patchscopes/code/preprocessed_data/factual/superhero_archnemesis.tsv
new file mode 100644
index 00000000..4b3ac066
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/superhero_archnemesis.tsv
@@ -0,0 +1,810 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Superman Lex Luthor Superman [',', ' the', ' Man', ' of', ' Steel', ',', ' is', ' Lex', ' Luthor', '.', ' He', ' is', ' the', ' most', ' powerful', ' man', ' in', ' the', ' world', ','] , the Man of Steel , is Lex Luthor . He is the most powerful man in the world , True in Zack Snyder's Superman film Man of Steel, 4 "[' in', ' Zack', ' Snyder', ""'s"", ' Superman']"
+1 0 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Superman Lex Luthor Superman [',', ' the', ' Man', ' of', ' Steel', ',', ' is', ' Lex', ' Luthor', '.', ' He', ' is', ' the', ' most', ' powerful', ' man', ' in', ' the', ' world', ','] , the Man of Steel , is Lex Luthor . He is the most powerful man in the world , True DVD box set titled Superman Ultimate Collector's 4 [' DVD', ' box', ' set', ' titled', ' Superman']
+2 0 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Superman Lex Luthor Superman [',', ' the', ' Man', ' of', ' Steel', ',', ' is', ' Lex', ' Luthor', '.', ' He', ' is', ' the', ' most', ' powerful', ' man', ' in', ' the', ' world', ','] , the Man of Steel , is Lex Luthor . He is the most powerful man in the world , True 1 ['Super', 'man']
+3 0 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Superman Lex Luthor Superman [',', ' the', ' Man', ' of', ' Steel', ',', ' is', ' Lex', ' Luthor', '.', ' He', ' is', ' the', ' most', ' powerful', ' man', ' in', ' the', ' world', ','] , the Man of Steel , is Lex Luthor . He is the most powerful man in the world , True design work for Superman Returns. Spielberg 3 [' design', ' work', ' for', ' Superman']
+4 0 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Superman Lex Luthor Superman [',', ' the', ' Man', ' of', ' Steel', ',', ' is', ' Lex', ' Luthor', '.', ' He', ' is', ' the', ' most', ' powerful', ' man', ' in', ' the', ' world', ','] , the Man of Steel , is Lex Luthor . He is the most powerful man in the world , True from the previous Superman films. John Williams'musical 3 [' from', ' the', ' previous', ' Superman']
+5 1 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batman The Joker Batman "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' comic', ' book', ' series', ',']" , the Joker , is a reference to the Joker 's famous line from the Batman comic book series , False 1986, long before Batman or Dick Tracy or 4 [' 1986', ',', ' long', ' before', ' Batman']
+6 1 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batman The Joker Batman "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' comic', ' book', ' series', ',']" , the Joker , is a reference to the Joker 's famous line from the Batman comic book series , False series, sales of Batman continued to 4 [' series', ',', ' sales', ' of', ' Batman']
+7 1 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batman The Joker Batman "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' comic', ' book', ' series', ',']" , the Joker , is a reference to the Joker 's famous line from the Batman comic book series , False " Scissorhands or Batman Returns.""
+" 5 [' Sc', 'iss', 'or', 'hands', ' or', ' Batman']
+8 1 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batman The Joker Batman "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' comic', ' book', ' series', ',']" , the Joker , is a reference to the Joker 's famous line from the Batman comic book series , False outfits for Batman were revealed in 2 [' outfits', ' for', ' Batman']
+9 1 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batman The Joker Batman "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' comic', ' book', ' series', ',']" , the Joker , is a reference to the Joker 's famous line from the Batman comic book series , False 0 ['Batman']
+10 2 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wonder Woman Cheetah Wonder Woman [' is', ' the', ' Che', 'et', 'ah', '.', ' She', ' is', ' a', ' cat', '-', 'like', ' woman', ' who', ' is', ' a', ' member', ' of', ' the', ' Justice'] is the Che et ah . She is a cat - like woman who is a member of the Justice True Comics production Wonder Woman: Amazonia 3 [' Comics', ' production', ' Wonder', ' Woman']
+11 2 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wonder Woman Cheetah Wonder Woman [' is', ' the', ' Che', 'et', 'ah', '.', ' She', ' is', ' a', ' cat', '-', 'like', ' woman', ' who', ' is', ' a', ' member', ' of', ' the', ' Justice'] is the Che et ah . She is a cat - like woman who is a member of the Justice True unproduced live-action Wonder Woman film, was selected 6 [' un', 'produced', ' live', '-', 'action', ' Wonder', ' Woman']
+12 2 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wonder Woman Cheetah Wonder Woman [' is', ' the', ' Che', 'et', 'ah', '.', ' She', ' is', ' a', ' cat', '-', 'like', ' woman', ' who', ' is', ' a', ' member', ' of', ' the', ' Justice'] is the Che et ah . She is a cat - like woman who is a member of the Justice True Flash, Aquaman, and Wonder Woman alongside 7 [' Flash', ',', ' Aqu', 'aman', ',', ' and', ' Wonder', ' Woman']
+13 2 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wonder Woman Cheetah Wonder Woman [' is', ' the', ' Che', 'et', 'ah', '.', ' She', ' is', ' a', ' cat', '-', 'like', ' woman', ' who', ' is', ' a', ' member', ' of', ' the', ' Justice'] is the Che et ah . She is a cat - like woman who is a member of the Justice True Gal Gadot was cast as Wonder Woman in December 7 [' Gal', ' Gad', 'ot', ' was', ' cast', ' as', ' Wonder', ' Woman']
+14 2 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wonder Woman Cheetah Wonder Woman [' is', ' the', ' Che', 'et', 'ah', '.', ' She', ' is', ' a', ' cat', '-', 'like', ' woman', ' who', ' is', ' a', ' member', ' of', ' the', ' Justice'] is the Che et ah . She is a cat - like woman who is a member of the Justice True with Zatanna and Wonder Woman struggling to 6 [' with', ' Z', 'at', 'anna', ' and', ' Wonder', ' Woman']
+15 3 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spider-Man Green Goblin Spider-Man [',', ' the', ' Green', ' Goblin', ',', ' is', ' Norman', ' Os', 'born', '.', ' He', ' is', ' the', ' son', ' of', ' Harry', ' Os', 'born', ',', ' who'] , the Green Goblin , is Norman Os born . He is the son of Harry Os born , who True Stacy respectively. Spider-Man co-creator, 5 [' Stacy', ' respectively', '.', ' Spider', '-', 'Man']
+16 3 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spider-Man Green Goblin Spider-Man [',', ' the', ' Green', ' Goblin', ',', ' is', ' Norman', ' Os', 'born', '.', ' He', ' is', ' the', ' son', ' of', ' Harry', ' Os', 'born', ',', ' who'] , the Green Goblin , is Norman Os born . He is the son of Harry Os born , who True who is saved by Spider-Man from falling 6 [' who', ' is', ' saved', ' by', ' Spider', '-', 'Man']
+17 3 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spider-Man Green Goblin Spider-Man [',', ' the', ' Green', ' Goblin', ',', ' is', ' Norman', ' Os', 'born', '.', ' He', ' is', ' the', ' son', ' of', ' Harry', ' Os', 'born', ',', ' who'] , the Green Goblin , is Norman Os born . He is the son of Harry Os born , who True " be Black Cat if Spider-Man 4 had been made.
+" 6 [' be', ' Black', ' Cat', ' if', ' Spider', '-', 'Man']
+18 3 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spider-Man Green Goblin Spider-Man [',', ' the', ' Green', ' Goblin', ',', ' is', ' Norman', ' Os', 'born', '.', ' He', ' is', ' the', ' son', ' of', ' Harry', ' Os', 'born', ',', ' who'] , the Green Goblin , is Norman Os born . He is the son of Harry Os born , who True as a journalist in Spider-Man 3 also represent 6 [' as', ' a', ' journalist', ' in', ' Spider', '-', 'Man']
+19 3 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spider-Man Green Goblin Spider-Man [',', ' the', ' Green', ' Goblin', ',', ' is', ' Norman', ' Os', 'born', '.', ' He', ' is', ' the', ' son', ' of', ' Harry', ' Os', 'born', ',', ' who'] , the Green Goblin , is Norman Os born . He is the son of Harry Os born , who True has featured Spider-Man in several comic 4 [' has', ' featured', ' Spider', '-', 'Man']
+20 4 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Flash Reverse-Flash The Flash [',', ' the', ' Reverse', ' Flash', ',', ' is', ' Barry', ' Allen', '.', ' He', ' is', ' the', ' son', ' of', ' Barry', ' Allen', ',', ' the', ' Flash', ','] , the Reverse Flash , is Barry Allen . He is the son of Barry Allen , the Flash , False scored on the season. The Flash finished first 6 [' scored', ' on', ' the', ' season', '.', ' The', ' Flash']
+21 4 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Flash Reverse-Flash The Flash [',', ' the', ' Reverse', ' Flash', ',', ' is', ' Barry', ' Allen', '.', ' He', ' is', ' the', ' son', ' of', ' Barry', ' Allen', ',', ' the', ' Flash', ','] , the Reverse Flash , is Barry Allen . He is the son of Barry Allen , the Flash , False 1 ['The', ' Flash']
+22 4 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Flash Reverse-Flash The Flash [',', ' the', ' Reverse', ' Flash', ',', ' is', ' Barry', ' Allen', '.', ' He', ' is', ' the', ' son', ' of', ' Barry', ' Allen', ',', ' the', ' Flash', ','] , the Reverse Flash , is Barry Allen . He is the son of Barry Allen , the Flash , False 1 ['The', ' Flash']
+23 4 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Flash Reverse-Flash The Flash [',', ' the', ' Reverse', ' Flash', ',', ' is', ' Barry', ' Allen', '.', ' He', ' is', ' the', ' son', ' of', ' Barry', ' Allen', ',', ' the', ' Flash', ','] , the Reverse Flash , is Barry Allen . He is the son of Barry Allen , the Flash , False (since unlike The Flash or Green Lantern, 4 [' (', 'since', ' unlike', ' The', ' Flash']
+24 4 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Flash Reverse-Flash The Flash [',', ' the', ' Reverse', ' Flash', ',', ' is', ' Barry', ' Allen', '.', ' He', ' is', ' the', ' son', ' of', ' Barry', ' Allen', ',', ' the', ' Flash', ','] , the Reverse Flash , is Barry Allen . He is the son of Barry Allen , the Flash , False scored on the season. The Flash finished first during 6 [' scored', ' on', ' the', ' season', '.', ' The', ' Flash']
+25 5 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Lantern Sinestro Green Lantern [',', ' the', ' Green', ' Lantern', ' Corps', ',', ' is', ' the', ' Guardians', ' of', ' the', ' Universe', '.', ' The', ' Guardians', ' are', ' the', ' protect', 'ors', ' of'] , the Green Lantern Corps , is the Guardians of the Universe . The Guardians are the protect ors of False force called the Green Lantern Corps. They 4 [' force', ' called', ' the', ' Green', ' Lantern']
+26 5 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Lantern Sinestro Green Lantern [',', ' the', ' Green', ' Lantern', ' Corps', ',', ' is', ' the', ' Guardians', ' of', ' the', ' Universe', '.', ' The', ' Guardians', ' are', ' the', ' protect', 'ors', ' of'] , the Green Lantern Corps , is the Guardians of the Universe . The Guardians are the protect ors of False centers on the Green Lanterns of Earth — 4 [' centers', ' on', ' the', ' Green', ' Lantern']
+27 5 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Lantern Sinestro Green Lantern [',', ' the', ' Green', ' Lantern', ' Corps', ',', ' is', ' the', ' Guardians', ' of', ' the', ' Universe', '.', ' The', ' Guardians', ' are', ' the', ' protect', 'ors', ' of'] , the Green Lantern Corps , is the Guardians of the Universe . The Guardians are the protect ors of False " Lantern Coaster =
+" 5 [' Lantern', ' Co', 'aster', ' =', 'Green', ' Lantern']
+28 5 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Lantern Sinestro Green Lantern [',', ' the', ' Green', ' Lantern', ' Corps', ',', ' is', ' the', ' Guardians', ' of', ' the', ' Universe', '.', ' The', ' Guardians', ' are', ' the', ' protect', 'ors', ' of'] , the Green Lantern Corps , is the Guardians of the Universe . The Guardians are the protect ors of False Arrow, and Green Lantern meet the alternate 4 [' Arrow', ',', ' and', ' Green', ' Lantern']
+29 5 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Lantern Sinestro Green Lantern [',', ' the', ' Green', ' Lantern', ' Corps', ',', ' is', ' the', ' Guardians', ' of', ' the', ' Universe', '.', ' The', ' Guardians', ' are', ' the', ' protect', 'ors', ' of'] , the Green Lantern Corps , is the Guardians of the Universe . The Guardians are the protect ors of False 1 ['Green', ' Lantern']
+30 6 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain America Red Skull Captain America [',', ' the', ' Red', ' Skull', ',', ' is', ' a', ' reference', ' to', ' the', ' Red', ' Skull', ',', ' a', ' character', ' from', ' the', ' Marvel', ' Comics', ' character'] , the Red Skull , is a reference to the Red Skull , a character from the Marvel Comics character True 1 ['Captain', ' America']
+31 6 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain America Red Skull Captain America [',', ' the', ' Red', ' Skull', ',', ' is', ' a', ' reference', ' to', ' the', ' Red', ' Skull', ',', ' a', ' character', ' from', ' the', ' Marvel', ' Comics', ' character'] , the Red Skull , is a reference to the Red Skull , a character from the Marvel Comics character True 1 ['Captain', ' America']
+32 6 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain America Red Skull Captain America [',', ' the', ' Red', ' Skull', ',', ' is', ' a', ' reference', ' to', ' the', ' Red', ' Skull', ',', ' a', ' character', ' from', ' the', ' Marvel', ' Comics', ' character'] , the Red Skull , is a reference to the Red Skull , a character from the Marvel Comics character True 1 ['Captain', ' America']
+33 6 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain America Red Skull Captain America [',', ' the', ' Red', ' Skull', ',', ' is', ' a', ' reference', ' to', ' the', ' Red', ' Skull', ',', ' a', ' character', ' from', ' the', ' Marvel', ' Comics', ' character'] , the Red Skull , is a reference to the Red Skull , a character from the Marvel Comics character True do everything Captain America can do, but backwards 3 [' do', ' everything', ' Captain', ' America']
+34 6 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain America Red Skull Captain America [',', ' the', ' Red', ' Skull', ',', ' is', ' a', ' reference', ' to', ' the', ' Red', ' Skull', ',', ' a', ' character', ' from', ' the', ' Marvel', ' Comics', ' character'] , the Red Skull , is a reference to the Red Skull , a character from the Marvel Comics character True as Steve Rogers / Captain America via archive footage 5 [' as', ' Steve', ' Rogers', ' /', ' Captain', ' America']
+35 7 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Iron Man The Mandarin Iron Man [',', ' the', ' Hulk', ',', ' is', ' a', ' reference', ' to', ' the', ' Hulk', ' Hogan', ',', ' a', ' professional', ' wrestler', ' who', ' was', ' a', ' member', ' of'] , the Hulk , is a reference to the Hulk Hogan , a professional wrestler who was a member of False the previous two Iron Man films, said participating 4 [' the', ' previous', ' two', ' Iron', ' Man']
+36 7 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Iron Man The Mandarin Iron Man [',', ' the', ' Hulk', ',', ' is', ' a', ' reference', ' to', ' the', ' Hulk', ' Hogan', ',', ' a', ' professional', ' wrestler', ' who', ' was', ' a', ' member', ' of'] , the Hulk , is a reference to the Hulk Hogan , a professional wrestler who was a member of False Thirty-Minute Iron Man match to retain 5 [' Thirty', '-', 'Min', 'ute', ' Iron', ' Man']
+37 7 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Iron Man The Mandarin Iron Man [',', ' the', ' Hulk', ',', ' is', ' a', ' reference', ' to', ' the', ' Hulk', ' Hogan', ',', ' a', ' professional', ' wrestler', ' who', ' was', ' a', ' member', ' of'] , the Hulk , is a reference to the Hulk Hogan , a professional wrestler who was a member of False their sixty-minute Iron Man match, which had ended 5 [' their', ' sixty', '-', 'minute', ' Iron', ' Man']
+38 7 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Iron Man The Mandarin Iron Man [',', ' the', ' Hulk', ',', ' is', ' a', ' reference', ' to', ' the', ' Hulk', ' Hogan', ',', ' a', ' professional', ' wrestler', ' who', ' was', ' a', ' member', ' of'] , the Hulk , is a reference to the Hulk Hogan , a professional wrestler who was a member of False After the release of Iron Man 2 in May 2010, Favreau, 5 [' After', ' the', ' release', ' of', ' Iron', ' Man']
+39 7 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Iron Man The Mandarin Iron Man [',', ' the', ' Hulk', ',', ' is', ' a', ' reference', ' to', ' the', ' Hulk', ' Hogan', ',', ' a', ' professional', ' wrestler', ' who', ' was', ' a', ' member', ' of'] , the Hulk , is a reference to the Hulk Hogan , a professional wrestler who was a member of False Angle lost an Iron Man match to Brock Lesnar 4 [' Angle', ' lost', ' an', ' Iron', ' Man']
+40 8 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Thor Loki Thor [',', ' the', ' God', ' of', ' Thunder', ',', ' is', ' M', 'j', 'oln', 'ir', ',', ' and', ' it', ' is', ' the', ' most', ' powerful', ' weapon', ' in'] , the God of Thunder , is M j oln ir , and it is the most powerful weapon in False Pop music critic Thor Christensen of the 3 [' Pop', ' music', ' critic', ' Thor']
+41 8 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Thor Loki Thor [',', ' the', ' God', ' of', ' Thunder', ',', ' is', ' M', 'j', 'oln', 'ir', ',', ' and', ' it', ' is', ' the', ' most', ' powerful', ' weapon', ' in'] , the God of Thunder , is M j oln ir , and it is the most powerful weapon in False (3 times) – with Thor the Viking 6 [' (', '3', ' times', ')', ' –', ' with', ' Thor']
+42 8 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Thor Loki Thor [',', ' the', ' God', ' of', ' Thunder', ',', ' is', ' M', 'j', 'oln', 'ir', ',', ' and', ' it', ' is', ' the', ' most', ' powerful', ' weapon', ' in'] , the God of Thunder , is M j oln ir , and it is the most powerful weapon in False against Fenrir, while Thor moves at his side, 4 [' against', ' Fenrir', ',', ' while', ' Thor']
+43 8 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Thor Loki Thor [',', ' the', ' God', ' of', ' Thunder', ',', ' is', ' M', 'j', 'oln', 'ir', ',', ' and', ' it', ' is', ' the', ' most', ' powerful', ' weapon', ' in'] , the God of Thunder , is M j oln ir , and it is the most powerful weapon in False " = Thor =
+" 1 [' =', ' Thor']
+44 8 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Thor Loki Thor [',', ' the', ' God', ' of', ' Thunder', ',', ' is', ' M', 'j', 'oln', 'ir', ',', ' and', ' it', ' is', ' the', ' most', ' powerful', ' weapon', ' in'] , the God of Thunder , is M j oln ir , and it is the most powerful weapon in False entity Ægir's hall. Thor does not attend 8 "[' entity', ' �', '�', 'g', 'ir', ""'s"", ' hall', '.', ' Thor']"
+45 9 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Panther Killmonger Black Panther [' is', ' the', ' Black', ' Panther', '.', ' He', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] is the Black Panther . He is a fictional character , a superhero appearing in American comic books published by False including the Black Panther Party, Nation 3 [' including', ' the', ' Black', ' Panther']
+46 9 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Panther Killmonger Black Panther [' is', ' the', ' Black', ' Panther', '.', ' He', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] is the Black Panther . He is a fictional character , a superhero appearing in American comic books published by False " ""kicked ... into the Black Panther Party"" after suffering" 7 "[' ""', 'k', 'icked', '...', ' into', ' the', ' Black', ' Panther']"
+47 9 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Panther Killmonger Black Panther [' is', ' the', ' Black', ' Panther', '.', ' He', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] is the Black Panther . He is a fictional character , a superhero appearing in American comic books published by False " South Korean K2 Black Panther main battle tank.
+" 5 [' South', ' Korean', ' K', '2', ' Black', ' Panther']
+48 9 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Panther Killmonger Black Panther [' is', ' the', ' Black', ' Panther', '.', ' He', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] is the Black Panther . He is a fictional character , a superhero appearing in American comic books published by False " the South Korean K2 Black Panther main battle tank.
+" 6 [' the', ' South', ' Korean', ' K', '2', ' Black', ' Panther']
+49 9 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Panther Killmonger Black Panther [' is', ' the', ' Black', ' Panther', '.', ' He', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] is the Black Panther . He is a fictional character , a superhero appearing in American comic books published by False member of the former Black Panther Party (BPP) and Black 5 [' member', ' of', ' the', ' former', ' Black', ' Panther']
+50 10 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Doctor Strange Baron Mordo Doctor Strange [' is', ' the', ' Ancient', ' One', '.', ' She', ' is', ' a', ' powerful', ' sorce', 'ress', ' who', ' has', ' been', ' around', ' for', ' a', ' long', ' time', '.'] is the Ancient One . She is a powerful sorce ress who has been around for a long time . False Jennifer and Doctor Strange as those who 3 [' Jennifer', ' and', ' Doctor', ' Strange']
+51 10 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Doctor Strange Baron Mordo Doctor Strange [' is', ' the', ' Ancient', ' One', '.', ' She', ' is', ' a', ' powerful', ' sorce', 'ress', ' who', ' has', ' been', ' around', ' for', ' a', ' long', ' time', '.'] is the Ancient One . She is a powerful sorce ress who has been around for a long time . False with films like Doctor Strange and Thor: Ragnarok 4 [' with', ' films', ' like', ' Doctor', ' Strange']
+52 10 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Doctor Strange Baron Mordo Doctor Strange [' is', ' the', ' Ancient', ' One', '.', ' She', ' is', ' a', ' powerful', ' sorce', 'ress', ' who', ' has', ' been', ' around', ' for', ' a', ' long', ' time', '.'] is the Ancient One . She is a powerful sorce ress who has been around for a long time . False " War"", with films like Doctor Strange and Thor: Ragnarok" 6 "[' War', '"",', ' with', ' films', ' like', ' Doctor', ' Strange']"
+53 10 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Doctor Strange Baron Mordo Doctor Strange [' is', ' the', ' Ancient', ' One', '.', ' She', ' is', ' a', ' powerful', ' sorce', 'ress', ' who', ' has', ' been', ' around', ' for', ' a', ' long', ' time', '.'] is the Ancient One . She is a powerful sorce ress who has been around for a long time . False we ’ ve got Doctor Strange in November, two 6 [' we', ' �', '�', ' ve', ' got', ' Doctor', ' Strange']
+54 10 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Doctor Strange Baron Mordo Doctor Strange [' is', ' the', ' Ancient', ' One', '.', ' She', ' is', ' a', ' powerful', ' sorce', 'ress', ' who', ' has', ' been', ' around', ' for', ' a', ' long', ' time', '.'] is the Ancient One . She is a powerful sorce ress who has been around for a long time . False the characters Doctor Strange and Marcus Daniels, 3 [' the', ' characters', ' Doctor', ' Strange']
+55 11 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Aquaman Black Manta Aquaman [' is', ' the', ' Ocean', ' Master', '.', ' He', ' is', ' a', ' villain', ' who', ' is', ' a', ' master', ' of', ' the', ' sea', ' and', ' has', ' the', ' ability'] is the Ocean Master . He is a villain who is a master of the sea and has the ability False work on an Aquaman pilot for The 4 [' work', ' on', ' an', ' Aqu', 'aman']
+56 11 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Aquaman Black Manta Aquaman [' is', ' the', ' Ocean', ' Master', '.', ' He', ' is', ' a', ' villain', ' who', ' is', ' a', ' master', ' of', ' the', ' sea', ' and', ' has', ' the', ' ability'] is the Ocean Master . He is a villain who is a master of the sea and has the ability False family is watching Aquaman on television. 4 [' family', ' is', ' watching', ' Aqu', 'aman']
+57 11 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Aquaman Black Manta Aquaman [' is', ' the', ' Ocean', ' Master', '.', ' He', ' is', ' a', ' villain', ' who', ' is', ' a', ' master', ' of', ' the', ' sea', ' and', ' has', ' the', ' ability'] is the Ocean Master . He is a villain who is a master of the sea and has the ability False direction with the Aquaman role and wish 4 [' direction', ' with', ' the', ' Aqu', 'aman']
+58 11 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Aquaman Black Manta Aquaman [' is', ' the', ' Ocean', ' Master', '.', ' He', ' is', ' a', ' villain', ' who', ' is', ' a', ' master', ' of', ' the', ' sea', ' and', ' has', ' the', ' ability'] is the Ocean Master . He is a villain who is a master of the sea and has the ability False " he discovered that Aquaman had a ""serious ..." 4 [' he', ' discovered', ' that', ' Aqu', 'aman']
+59 11 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Aquaman Black Manta Aquaman [' is', ' the', ' Ocean', ' Master', '.', ' He', ' is', ' a', ' villain', ' who', ' is', ' a', ' master', ' of', ' the', ' sea', ' and', ' has', ' the', ' ability'] is the Ocean Master . He is a villain who is a master of the sea and has the ability False realized that an Aquaman snow globe on his 4 [' realized', ' that', ' an', ' Aqu', 'aman']
+60 12 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkeye Crossfire Hawkeye [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False Henry's departure, Hawkeye Pierce (Alan Alda), 5 "[' Henry', ""'s"", ' departure', ',', ' Haw', 'keye']"
+61 12 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkeye Crossfire Hawkeye [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False and unshaven Hawkeye and Trapper and 5 [' and', ' un', 'sh', 'aven', ' Haw', 'keye']
+62 12 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkeye Crossfire Hawkeye [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False his role as Hawkeye Pierce from the 4 [' his', ' role', ' as', ' Haw', 'keye']
+63 12 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkeye Crossfire Hawkeye [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False of the 4077th, Hawkeye whispers to Henry 7 [' of', ' the', ' 40', '77', 'th', ',', ' Haw', 'keye']
+64 12 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkeye Crossfire Hawkeye [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False for his role as Hawkeye Pierce from the 5 [' for', ' his', ' role', ' as', ' Haw', 'keye']
+65 13 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Daredevil Kingpin Daredevil [',', ' the', ' Pun', 'isher', ',', ' is', ' Frank', ' Castle', '.', ' He', ' is', ' a', ' vigilante', ' who', ' has', ' been', ' killing', ' criminals', ' for', ' years'] , the Pun isher , is Frank Castle . He is a vigilante who has been killing criminals for years False ultra-satisfying take on Daredevil's material and lore. 7 [' ultra', '-', 's', 'atisf', 'ying', ' take', ' on', ' Daredevil']
+66 13 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Daredevil Kingpin Daredevil [',', ' the', ' Pun', 'isher', ',', ' is', ' Frank', ' Castle', '.', ' He', ' is', ' a', ' vigilante', ' who', ' has', ' been', ' killing', ' criminals', ' for', ' years'] , the Pun isher , is Frank Castle . He is a vigilante who has been killing criminals for years False Malaysia banned Daredevil in that country. 2 [' Malaysia', ' banned', ' Daredevil']
+67 13 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Daredevil Kingpin Daredevil [',', ' the', ' Pun', 'isher', ',', ' is', ' Frank', ' Castle', '.', ' He', ' is', ' a', ' vigilante', ' who', ' has', ' been', ' killing', ' criminals', ' for', ' years'] , the Pun isher , is Frank Castle . He is a vigilante who has been killing criminals for years False an artist on Daredevil comics, gave several 3 [' an', ' artist', ' on', ' Daredevil']
+68 13 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Daredevil Kingpin Daredevil [',', ' the', ' Pun', 'isher', ',', ' is', ' Frank', ' Castle', '.', ' He', ' is', ' a', ' vigilante', ' who', ' has', ' been', ' killing', ' criminals', ' for', ' years'] , the Pun isher , is Frank Castle . He is a vigilante who has been killing criminals for years False " Frank Lovece said Daredevil ""makes clear that" 4 [' Frank', ' Love', 'ce', ' said', ' Daredevil']
+69 13 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Daredevil Kingpin Daredevil [',', ' the', ' Pun', 'isher', ',', ' is', ' Frank', ' Castle', '.', ' He', ' is', ' a', ' vigilante', ' who', ' has', ' been', ' killing', ' criminals', ' for', ' years'] , the Pun isher , is Frank Castle . He is a vigilante who has been killing criminals for years False 1 ['D', 'aredevil']
+70 14 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Hulk The Abomination The Hulk [',', ' the', ' Green', ' Goblin', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', ',', ' a', ' DC', ' Comics', ' superhero', '.', '\n', '\n'] ", the Green Goblin , is a reference to the Green Lantern , a DC Comics superhero .
+
+" False Ross to Captain Ahab. The Hulk is Hurt's favorite 7 [' Ross', ' to', ' Captain', ' A', 'hab', '.', ' The', ' Hulk']
+71 14 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Hulk The Abomination The Hulk [',', ' the', ' Green', ' Goblin', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', ',', ' a', ' DC', ' Comics', ' superhero', '.', '\n', '\n'] ", the Green Goblin , is a reference to the Green Lantern , a DC Comics superhero .
+
+" False Ross to Captain Ahab. The Hulk is Hurt's favorite 7 [' Ross', ' to', ' Captain', ' A', 'hab', '.', ' The', ' Hulk']
+72 14 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Hulk The Abomination The Hulk [',', ' the', ' Green', ' Goblin', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', ',', ' a', ' DC', ' Comics', ' superhero', '.', '\n', '\n'] ", the Green Goblin , is a reference to the Green Lantern , a DC Comics superhero .
+
+" False Bruce Banner / The Hulk in the Marvel 4 [' Bruce', ' Banner', ' /', ' The', ' Hulk']
+73 14 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Hulk The Abomination The Hulk [',', ' the', ' Green', ' Goblin', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', ',', ' a', ' DC', ' Comics', ' superhero', '.', '\n', '\n'] ", the Green Goblin , is a reference to the Green Lantern , a DC Comics superhero .
+
+" False transformation into the Hulk. The Hulk seriously injures 6 [' transformation', ' into', ' the', ' Hulk', '.', ' The', ' Hulk']
+74 14 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Hulk The Abomination The Hulk [',', ' the', ' Green', ' Goblin', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', ',', ' a', ' DC', ' Comics', ' superhero', '.', '\n', '\n'] ", the Green Goblin , is a reference to the Green Lantern , a DC Comics superhero .
+
+" False Captain Ahab. The Hulk is Hurt's favorite 5 [' Captain', ' A', 'hab', '.', ' The', ' Hulk']
+75 15 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Cyclops Mister Sinister Cyclops [',', ' the', ' X', '-', 'Men', ',', ' is', ' Jean', ' Grey', '.', ' She', ' is', ' the', ' daughter', ' of', ' Charles', ' Xavier', ',', ' the', ' founder'] , the X - Men , is Jean Grey . She is the daughter of Charles Xavier , the founder False monocantha, Cyclops ladakanus and Mesocyclops 5 [' mon', 'oc', 'antha', ',', ' Cycl', 'ops']
+76 15 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Cyclops Mister Sinister Cyclops [',', ' the', ' X', '-', 'Men', ',', ' is', ' Jean', ' Grey', '.', ' She', ' is', ' the', ' daughter', ' of', ' Charles', ' Xavier', ',', ' the', ' founder'] , the X - Men , is Jean Grey . She is the daughter of Charles Xavier , the founder False conclusion when Cyclops officially closes 3 [' conclusion', ' when', ' Cycl', 'ops']
+77 15 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Cyclops Mister Sinister Cyclops [',', ' the', ' X', '-', 'Men', ',', ' is', ' Jean', ' Grey', '.', ' She', ' is', ' the', ' daughter', ' of', ' Charles', ' Xavier', ',', ' the', ' founder'] , the X - Men , is Jean Grey . She is the daughter of Charles Xavier , the founder False " Cyclops (1871) =
+" 1 [' Cycl', 'ops']
+78 15 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Cyclops Mister Sinister Cyclops [',', ' the', ' X', '-', 'Men', ',', ' is', ' Jean', ' Grey', '.', ' She', ' is', ' the', ' daughter', ' of', ' Charles', ' Xavier', ',', ' the', ' founder'] , the X - Men , is Jean Grey . She is the daughter of Charles Xavier , the founder False exception of Cyclops and Wolverine, 3 [' exception', ' of', ' Cycl', 'ops']
+79 15 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Cyclops Mister Sinister Cyclops [',', ' the', ' X', '-', 'Men', ',', ' is', ' Jean', ' Grey', '.', ' She', ' is', ' the', ' daughter', ' of', ' Charles', ' Xavier', ',', ' the', ' founder'] , the X - Men , is Jean Grey . She is the daughter of Charles Xavier , the founder False are of Gold. Here the Cyclops are at work at a 7 [' are', ' of', ' Gold', '.', ' Here', ' the', ' Cycl', 'ops']
+80 16 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wolverine Sabretooth Wolverine [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' control', ' the', ' weather', '.', ' He', ' is', ' the'] , the X - Men , is a mutant with the ability to control the weather . He is the False by impersonating Wolverine, and Storm and Nightcrawler 3 [' by', ' imperson', 'ating', ' Wolverine']
+81 16 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wolverine Sabretooth Wolverine [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' control', ' the', ' weather', '.', ' He', ' is', ' the'] , the X - Men , is a mutant with the ability to control the weather . He is the False Jackman also posed as Wolverine for the Got Milk? 5 [' Jack', 'man', ' also', ' posed', ' as', ' Wolverine']
+82 16 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wolverine Sabretooth Wolverine [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' control', ' the', ' weather', '.', ' He', ' is', ' the'] , the X - Men , is a mutant with the ability to control the weather . He is the False X-Men Origins: Wolverine (2009), X-Men: 5 [' X', '-', 'Men', ' Origins', ':', ' Wolverine']
+83 16 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wolverine Sabretooth Wolverine [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' control', ' the', ' weather', '.', ' He', ' is', ' the'] , the X - Men , is a mutant with the ability to control the weather . He is the False # 2 (Marvel, 6 [' #', ' 2', ' (', 'Marvel', 'W', 'olver', 'ine']
+84 16 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Wolverine Sabretooth Wolverine [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' control', ' the', ' weather', '.', ' He', ' is', ' the'] , the X - Men , is a mutant with the ability to control the weather . He is the False 2 ['W', 'olver', 'ine']
+85 17 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ant-Man Yellowjacket Ant-Man [' is', ' the', ' Yellow', 'j', 'acket', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', ',', ' and', ' is', ' a', ' former', ' member', ' of'] is the Yellow j acket . He is a member of the Avengers , and is a former member of True Captain America, Ant-Man and The Avengers. 5 [' Captain', ' America', ',', ' Ant', '-', 'Man']
+86 17 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ant-Man Yellowjacket Ant-Man [' is', ' the', ' Yellow', 'j', 'acket', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', ',', ' and', ' is', ' a', ' former', ' member', ' of'] is the Yellow j acket . He is a member of the Avengers , and is a former member of True Angeles to work on Ant-Man by tweeting 6 [' Angeles', ' to', ' work', ' on', ' Ant', '-', 'Man']
+87 17 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ant-Man Yellowjacket Ant-Man [' is', ' the', ' Yellow', 'j', 'acket', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', ',', ' and', ' is', ' a', ' former', ' member', ' of'] is the Yellow j acket . He is a member of the Avengers , and is a former member of True " McMillan concluded, ""The Ant-Man trailer isn't" 9 "[' McM', 'ill', 'an', ' concluded', ',', ' ""', 'The', ' Ant', '-', 'Man']"
+88 17 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ant-Man Yellowjacket Ant-Man [' is', ' the', ' Yellow', 'j', 'acket', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', ',', ' and', ' is', ' a', ' former', ' member', ' of'] is the Yellow j acket . He is a member of the Avengers , and is a former member of True 2 ['Ant', '-', 'Man']
+89 17 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ant-Man Yellowjacket Ant-Man [' is', ' the', ' Yellow', 'j', 'acket', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', ',', ' and', ' is', ' a', ' former', ' member', ' of'] is the Yellow j acket . He is a member of the Avengers , and is a former member of True sneak peek of Ant-Man. In July 2014, 5 [' sneak', ' peek', ' of', ' Ant', '-', 'Man']
+90 18 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Wasp Whirlwind The Wasp [' is', ' revealed', '!', '\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American'] " is revealed !
+
+ The W asp is a fictional character , a superv ill ain appearing in American" False Edwin Abbott Abbott, The Wasp Factory by Iain 6 [' Edwin', ' Abbott', ' Abbott', ',', ' The', ' W', 'asp']
+91 18 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Wasp Whirlwind The Wasp [' is', ' revealed', '!', '\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American'] " is revealed !
+
+ The W asp is a fictional character , a superv ill ain appearing in American" False Abbott Abbott, The Wasp Factory by 5 [' Abbott', ' Abbott', ',', ' The', ' W', 'asp']
+92 18 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Wasp Whirlwind The Wasp [' is', ' revealed', '!', '\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American'] " is revealed !
+
+ The W asp is a fictional character , a superv ill ain appearing in American" False Abbott Abbott, The Wasp Factory by 5 [' Abbott', ' Abbott', ',', ' The', ' W', 'asp']
+93 18 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Wasp Whirlwind The Wasp [' is', ' revealed', '!', '\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American'] " is revealed !
+
+ The W asp is a fictional character , a superv ill ain appearing in American" False films, including The Wasp Woman and Creature 5 [' films', ',', ' including', ' The', ' W', 'asp']
+94 18 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Wasp Whirlwind The Wasp [' is', ' revealed', '!', '\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American'] " is revealed !
+
+ The W asp is a fictional character , a superv ill ain appearing in American" False films, including The Wasp Woman and Creature 5 [' films', ',', ' including', ' The', ' W', 'asp']
+95 19 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Marvel Yon-Rogg Captain Marvel [',', ' the', ' K', 'ree', ',', ' is', ' actually', ' a', ' reference', ' to', ' the', ' K', 'ree', ' Empire', ',', ' a', ' fictional', ' alien', ' race', ' from'] , the K ree , is actually a reference to the K ree Empire , a fictional alien race from False restrictions. Captain Marvel and Bulletman were 3 [' restrictions', '.', ' Captain', ' Marvel']
+96 19 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Marvel Yon-Rogg Captain Marvel [',', ' the', ' K', 'ree', ',', ' is', ' actually', ' a', ' reference', ' to', ' the', ' K', 'ree', ' Empire', ',', ' a', ' fictional', ' alien', ' race', ' from'] , the K ree , is actually a reference to the K ree Empire , a fictional alien race from False trade restrictions. Captain Marvel and Bulletman 4 [' trade', ' restrictions', '.', ' Captain', ' Marvel']
+97 19 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Marvel Yon-Rogg Captain Marvel [',', ' the', ' K', 'ree', ',', ' is', ' actually', ' a', ' reference', ' to', ' the', ' K', 'ree', ' Empire', ',', ' a', ' fictional', ' alien', ' race', ' from'] , the K ree , is actually a reference to the K ree Empire , a fictional alien race from False goes by the alias Captain Marvel. Captain Marvel writer 5 [' goes', ' by', ' the', ' alias', ' Captain', ' Marvel']
+98 19 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Marvel Yon-Rogg Captain Marvel [',', ' the', ' K', 'ree', ',', ' is', ' actually', ' a', ' reference', ' to', ' the', ' K', 'ree', ' Empire', ',', ' a', ' fictional', ' alien', ' race', ' from'] , the K ree , is actually a reference to the K ree Empire , a fictional alien race from False reprints of the Captain Marvel stories after 5 [' re', 'prints', ' of', ' the', ' Captain', ' Marvel']
+99 19 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Marvel Yon-Rogg Captain Marvel [',', ' the', ' K', 'ree', ',', ' is', ' actually', ' a', ' reference', ' to', ' the', ' K', 'ree', ' Empire', ',', ' a', ' fictional', ' alien', ' race', ' from'] , the K ree , is actually a reference to the K ree Empire , a fictional alien race from False by the alias Captain Marvel. Captain Marvel 4 [' by', ' the', ' alias', ' Captain', ' Marvel']
+100 20 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Deadpool T-Ray Deadpool [' is', ' the', ' X', '-', 'Force', '.', '\n', '\n', 'The', ' X', '-', 'Force', ' is', ' a', ' team', ' of', ' superheroes', ' that', ' are', ' all'] " is the X - Force .
+
+ The X - Force is a team of superheroes that are all" False Day weekend until Deadpool broke that 3 [' Day', ' weekend', ' until', ' Deadpool']
+101 20 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Deadpool T-Ray Deadpool [' is', ' the', ' X', '-', 'Force', '.', '\n', '\n', 'The', ' X', '-', 'Force', ' is', ' a', ' team', ' of', ' superheroes', ' that', ' are', ' all'] " is the X - Force .
+
+ The X - Force is a team of superheroes that are all" False After the release of Deadpool, Reynolds felt 4 [' After', ' the', ' release', ' of', ' Deadpool']
+102 20 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Deadpool T-Ray Deadpool [' is', ' the', ' X', '-', 'Force', '.', '\n', '\n', 'The', ' X', '-', 'Force', ' is', ' a', ' team', ' of', ' superheroes', ' that', ' are', ' all'] " is the X - Force .
+
+ The X - Force is a team of superheroes that are all" False " (2016) ===
+" 5 [' (', '2016', ')', ' ===', 'Dead', 'pool']
+103 20 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Deadpool T-Ray Deadpool [' is', ' the', ' X', '-', 'Force', '.', '\n', '\n', 'The', ' X', '-', 'Force', ' is', ' a', ' team', ' of', ' superheroes', ' that', ' are', ' all'] " is the X - Force .
+
+ The X - Force is a team of superheroes that are all" False to produce a Deadpool film as part 3 [' to', ' produce', ' a', ' Deadpool']
+104 20 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Deadpool T-Ray Deadpool [' is', ' the', ' X', '-', 'Force', '.', '\n', '\n', 'The', ' X', '-', 'Force', ' is', ' a', ' team', ' of', ' superheroes', ' that', ' are', ' all'] " is the X - Force .
+
+ The X - Force is a team of superheroes that are all" False 1 ['Dead', 'pool']
+105 21 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nightwing Deathstroke Nightwing [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] " is revealed !
+
+ The name of the villain is revealed !
+
+ The name of the villain" False Batman, Robin and Nightwing based on DC Comics' 5 [' Batman', ',', ' Robin', ' and', ' Night', 'wing']
+106 21 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nightwing Deathstroke Nightwing [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] " is revealed !
+
+ The name of the villain is revealed !
+
+ The name of the villain" False summon them to help Nightwing and Robin deal 5 [' summon', ' them', ' to', ' help', ' Night', 'wing']
+107 21 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nightwing Deathstroke Nightwing [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] " is revealed !
+
+ The name of the villain is revealed !
+
+ The name of the villain" False " ""Arkham Episode"" for Nightwing set after" 7 "[' ""', 'Ark', 'ham', ' Episode', '""', ' for', ' Night', 'wing']"
+108 21 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nightwing Deathstroke Nightwing [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] " is revealed !
+
+ The name of the villain is revealed !
+
+ The name of the villain" False to the mantle of Nightwing and appears in his 5 [' to', ' the', ' mantle', ' of', ' Night', 'wing']
+109 21 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nightwing Deathstroke Nightwing [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] " is revealed !
+
+ The name of the villain is revealed !
+
+ The name of the villain" False (Kimberly Brooks). Nightwing appears as a playable 7 [' (', 'Kim', 'ber', 'ly', ' Brooks', ').', ' Night', 'wing']
+110 22 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batgirl James Gordon Jr. Batgirl "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' film', ' The', ' Killing', ' J']" , the Joker , is a reference to the Joker 's famous line from the Batman film The Killing J False character's Batgirl persona in Barbara 3 "[' character', ""'s"", ' Bat', 'girl']"
+111 22 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batgirl James Gordon Jr. Batgirl "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' film', ' The', ' Killing', ' J']" , the Joker , is a reference to the Joker 's famous line from the Batman film The Killing J False Prior to release, Batgirl No. 1 sold 5 [' Prior', ' to', ' release', ',', ' Bat', 'girl']
+112 22 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batgirl James Gordon Jr. Batgirl "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' film', ' The', ' Killing', ' J']" , the Joker , is a reference to the Joker 's famous line from the Batman film The Killing J False Barbara Gordon as Batgirl have appeared in storylines 4 [' Barbara', ' Gordon', ' as', ' Bat', 'girl']
+113 22 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batgirl James Gordon Jr. Batgirl "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' film', ' The', ' Killing', ' J']" , the Joker , is a reference to the Joker 's famous line from the Batman film The Killing J False appeared as the new Batgirl in the Birds of 5 [' appeared', ' as', ' the', ' new', ' Bat', 'girl']
+114 22 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Batgirl James Gordon Jr. Batgirl "[',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ""'s"", ' famous', ' line', ' from', ' the', ' Batman', ' film', ' The', ' Killing', ' J']" , the Joker , is a reference to the Joker 's famous line from the Batman film The Killing J False 1 ['Bat', 'girl']
+115 23 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Raven Trigon Raven [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' the', ' Joker', '.', '\n'] ", the Dark Knight , is Batman .
+
+ The name of the villain is the Joker .
+" False then-champion Raven and the challenger 4 [' then', '-', 'ch', 'ampion', ' Raven']
+116 23 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Raven Trigon Raven [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' the', ' Joker', '.', '\n'] ", the Dark Knight , is Batman .
+
+ The name of the villain is the Joker .
+" False " by Charlotte Raven as a ""widely" 2 [' by', ' Charlotte', ' Raven']
+117 23 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Raven Trigon Raven [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' the', ' Joker', '.', '\n'] ", the Dark Knight , is Batman .
+
+ The name of the villain is the Joker .
+" False " Raven =
+" 0 [' Raven']
+118 23 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Raven Trigon Raven [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' the', ' Joker', '.', '\n'] ", the Dark Knight , is Batman .
+
+ The name of the villain is the Joker .
+" False They sighted HMS Raven, which initially sailed 4 [' They', ' sight', 'ed', ' HMS', ' Raven']
+119 23 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Raven Trigon Raven [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' the', ' Joker', '.', '\n'] ", the Dark Knight , is Batman .
+
+ The name of the villain is the Joker .
+" False not film on. Actress Raven Goodwin played 5 [' not', ' film', ' on', '.', ' Actress', ' Raven']
+120 24 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Starfire Blackfire Starfire [',', ' the', ' villain', ' of', ' the', ' first', ' season', ' of', ' the', ' animated', ' series', ',', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of'] ", the villain of the first season of the animated series , is revealed !
+
+ The name of" False home matches at Starfire Stadium (capacity: 4 [' home', ' matches', ' at', ' Star', 'fire']
+121 24 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Starfire Blackfire Starfire [',', ' the', ' villain', ' of', ' the', ' first', ' season', ' of', ' the', ' animated', ' series', ',', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of'] ", the villain of the first season of the animated series , is revealed !
+
+ The name of" False powerplant known as Starfire was effectively Holden's 5 [' power', 'plant', ' known', ' as', ' Star', 'fire']
+122 24 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Starfire Blackfire Starfire [',', ' the', ' villain', ' of', ' the', ' first', ' season', ' of', ' the', ' animated', ' series', ',', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of'] ", the villain of the first season of the animated series , is revealed !
+
+ The name of" False to develop the Starfire Sports Complex in 4 [' to', ' develop', ' the', ' Star', 'fire']
+123 24 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Starfire Blackfire Starfire [',', ' the', ' villain', ' of', ' the', ' first', ' season', ' of', ' the', ' animated', ' series', ',', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of'] ", the villain of the first season of the animated series , is revealed !
+
+ The name of" False powerplant known as Starfire was effectively Holden's 5 [' power', 'plant', ' known', ' as', ' Star', 'fire']
+124 24 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Starfire Blackfire Starfire [',', ' the', ' villain', ' of', ' the', ' first', ' season', ' of', ' the', ' animated', ' series', ',', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of'] ", the villain of the first season of the animated series , is revealed !
+
+ The name of" False powerplant known as Starfire was effectively 5 [' power', 'plant', ' known', ' as', ' Star', 'fire']
+125 25 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Beast Boy The Brain Beast Boy [',', ' the', ' one', ' who', ' has', ' been', ' trying', ' to', ' kill', ' him', ' for', ' years', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] ", the one who has been trying to kill him for years .
+
+ The name of the villain" False # 38. During the Beast Boy miniseries, Flamebird 6 [' #', ' 38', '.', ' During', ' the', ' Beast', ' Boy']
+126 25 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Beast Boy The Brain Beast Boy [',', ' the', ' one', ' who', ' has', ' been', ' trying', ' to', ' kill', ' him', ' for', ' years', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] ", the one who has been trying to kill him for years .
+
+ The name of the villain" False 38. During the Beast Boy miniseries, Flamebird 5 [' 38', '.', ' During', ' the', ' Beast', ' Boy']
+127 26 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Supergirl Reign Supergirl [',', ' the', ' villain', ' of', ' the', ' piece', ',', ' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' of', ' the', ' original', ' Superman'] , the villain of the piece , is the same as the name of the villain of the original Superman False 1 ['Super', 'girl']
+128 26 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Supergirl Reign Supergirl [',', ' the', ' villain', ' of', ' the', ' piece', ',', ' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' of', ' the', ' original', ' Superman'] , the villain of the piece , is the same as the name of the villain of the original Superman False Superman's cousin Supergirl. Supergirl was 4 "[' Superman', ""'s"", ' cousin', ' Super', 'girl']"
+129 26 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Supergirl Reign Supergirl [',', ' the', ' villain', ' of', ' the', ' piece', ',', ' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' of', ' the', ' original', ' Superman'] , the villain of the piece , is the same as the name of the villain of the original Superman False Superman / Batman: The Supergirl from Krypton 6 [' Superman', ' /', ' Batman', ':', ' The', ' Super', 'girl']
+130 26 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Supergirl Reign Supergirl [',', ' the', ' villain', ' of', ' the', ' piece', ',', ' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' of', ' the', ' original', ' Superman'] , the villain of the piece , is the same as the name of the villain of the original Superman False " Elseworld's Finest: Supergirl & Batgirl (1998).
+" 7 "[' Else', 'world', ""'s"", ' Fin', 'est', ':', ' Super', 'girl']"
+131 26 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Supergirl Reign Supergirl [',', ' the', ' villain', ' of', ' the', ' piece', ',', ' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' of', ' the', ' original', ' Superman'] , the villain of the piece , is the same as the name of the villain of the original Superman False — she delivers Supergirl's eulogy in 4 [' —', ' she', ' delivers', ' Super', 'girl']
+132 27 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Widow Taskmaster Black Widow [',', ' the', ' Black', ' Widow', ',', ' is', ' Natasha', ' Roman', 'off', '.', ' She', ' is', ' a', ' Russian', ' spy', ' who', ' is', ' a', ' master', ' of'] , the Black Widow , is Natasha Roman off . She is a Russian spy who is a master of False Meanwhile, a P-61 Black Widow from the 547th 7 [' Meanwhile', ',', ' a', ' P', '-', '61', ' Black', ' Widow']
+133 27 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Widow Taskmaster Black Widow [',', ' the', ' Black', ' Widow', ',', ' is', ' Natasha', ' Roman', 'off', '.', ' She', ' is', ' a', ' Russian', ' spy', ' who', ' is', ' a', ' master', ' of'] , the Black Widow , is Natasha Roman off . She is a Russian spy who is a master of False Avengers like Black Widow and Nick Fury, 3 [' Avengers', ' like', ' Black', ' Widow']
+134 27 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Widow Taskmaster Black Widow [',', ' the', ' Black', ' Widow', ',', ' is', ' Natasha', ' Roman', 'off', '.', ' She', ' is', ' a', ' Russian', ' spy', ' who', ' is', ' a', ' master', ' of'] , the Black Widow , is Natasha Roman off . She is a Russian spy who is a master of False who mentored Black Widow into becoming 4 [' who', ' ment', 'ored', ' Black', ' Widow']
+135 27 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Widow Taskmaster Black Widow [',', ' the', ' Black', ' Widow', ',', ' is', ' Natasha', ' Roman', 'off', '.', ' She', ' is', ' a', ' Russian', ' spy', ' who', ' is', ' a', ' master', ' of'] , the Black Widow , is Natasha Roman off . She is a Russian spy who is a master of False reprised her role as Black Widow in Captain America: 6 [' repr', 'ised', ' her', ' role', ' as', ' Black', ' Widow']
+136 27 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Widow Taskmaster Black Widow [',', ' the', ' Black', ' Widow', ',', ' is', ' Natasha', ' Roman', 'off', '.', ' She', ' is', ' a', ' Russian', ' spy', ' who', ' is', ' a', ' master', ' of'] , the Black Widow , is Natasha Roman off . She is a Russian spy who is a master of False an escort of P-61 Black Widow night fighters, 7 [' an', ' escort', ' of', ' P', '-', '61', ' Black', ' Widow']
+137 28 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Silver Surfer Galactus Silver Surfer [' is', ' Gal', 'actus', '.', ' He', ' is', ' the', ' ruler', ' of', ' the', ' planet', ' of', ' the', ' same', ' name', '.', ' He', ' is', ' the', ' ruler'] is Gal actus . He is the ruler of the planet of the same name . He is the ruler True " Rise of the Silver Surfer (2007) ====
+" 5 [' Rise', ' of', ' the', ' Silver', ' Sur', 'fer']
+138 28 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Silver Surfer Galactus Silver Surfer [' is', ' Gal', 'actus', '.', ' He', ' is', ' the', ' ruler', ' of', ' the', ' planet', ' of', ' the', ' same', ' name', '.', ' He', ' is', ' the', ' ruler'] is Gal actus . He is the ruler of the planet of the same name . He is the ruler True 2006, the Silver Surfer was announced 5 [' 2006', ',', ' the', ' Silver', ' Sur', 'fer']
+139 28 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Silver Surfer Galactus Silver Surfer [' is', ' Gal', 'actus', '.', ' He', ' is', ' the', ' ruler', ' of', ' the', ' planet', ' of', ' the', ' same', ' name', '.', ' He', ' is', ' the', ' ruler'] is Gal actus . He is the ruler of the planet of the same name . He is the ruler True Four: Rise of the Silver Surfer was released in 7 [' Four', ':', ' Rise', ' of', ' the', ' Silver', ' Sur', 'fer']
+140 28 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Silver Surfer Galactus Silver Surfer [' is', ' Gal', 'actus', '.', ' He', ' is', ' the', ' ruler', ' of', ' the', ' planet', ' of', ' the', ' same', ' name', '.', ' He', ' is', ' the', ' ruler'] is Gal actus . He is the ruler of the planet of the same name . He is the ruler True Four: Rise of the Silver Surfer performed less 7 [' Four', ':', ' Rise', ' of', ' the', ' Silver', ' Sur', 'fer']
+141 28 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Silver Surfer Galactus Silver Surfer [' is', ' Gal', 'actus', '.', ' He', ' is', ' the', ' ruler', ' of', ' the', ' planet', ' of', ' the', ' same', ' name', '.', ' He', ' is', ' the', ' ruler'] is Gal actus . He is the ruler of the planet of the same name . He is the ruler True Four: Rise of the Silver Surfer was released 7 [' Four', ':', ' Rise', ' of', ' the', ' Silver', ' Sur', 'fer']
+142 29 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Punisher Jigsaw Punisher [' is', ' the', ' Pun', 'isher', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is', ' a', ' vigilante', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is'] is the Pun isher . He is a vigilante who is a vigilante . He is a vigilante who is False matches against The Punisher and Tommy Rich, but 4 [' matches', ' against', ' The', ' Pun', 'isher']
+143 29 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Punisher Jigsaw Punisher [' is', ' the', ' Pun', 'isher', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is', ' a', ' vigilante', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is'] is the Pun isher . He is a vigilante who is a vigilante . He is a vigilante who is False port of The Punisher was released for 4 [' port', ' of', ' The', ' Pun', 'isher']
+144 29 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Punisher Jigsaw Punisher [' is', ' the', ' Pun', 'isher', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is', ' a', ' vigilante', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is'] is the Pun isher . He is a vigilante who is a vigilante . He is a vigilante who is False which is when Punisher is introduced, 4 [' which', ' is', ' when', ' Pun', 'isher']
+145 29 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Punisher Jigsaw Punisher [' is', ' the', ' Pun', 'isher', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is', ' a', ' vigilante', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is'] is the Pun isher . He is a vigilante who is a vigilante . He is a vigilante who is False 2 ['P', 'un', 'isher']
+146 29 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Punisher Jigsaw Punisher [' is', ' the', ' Pun', 'isher', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is', ' a', ' vigilante', '.', ' He', ' is', ' a', ' vigilante', ' who', ' is'] is the Pun isher . He is a vigilante who is a vigilante . He is a vigilante who is False critics regard The Punisher as among the 4 [' critics', ' regard', ' The', ' Pun', 'isher']
+147 30 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Luke Cage Diamondback Luke Cage [',', ' the', ' man', ' who', ' has', ' been', ' the', ' most', ' powerful', ' man', ' in', ' the', ' Marvel', ' Universe', ' for', ' the', ' past', ' decade', ',', ' is'] , the man who has been the most powerful man in the Marvel Universe for the past decade , is False worked with Luke Cage as part of 3 [' worked', ' with', ' Luke', ' Cage']
+148 30 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Luke Cage Diamondback Luke Cage [',', ' the', ' man', ' who', ' has', ' been', ' the', ' most', ' powerful', ' man', ' in', ' the', ' Marvel', ' Universe', ' for', ' the', ' past', ' decade', ',', ' is'] , the man who has been the most powerful man in the Marvel Universe for the past decade , is False Mariah Dillard in Luke Cage the previous year. 6 [' Mar', 'iah', ' D', 'illard', ' in', ' Luke', ' Cage']
+149 30 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Luke Cage Diamondback Luke Cage [',', ' the', ' man', ' who', ' has', ' been', ' the', ' most', ' powerful', ' man', ' in', ' the', ' Marvel', ' Universe', ' for', ' the', ' past', ' decade', ',', ' is'] , the man who has been the most powerful man in the Marvel Universe for the past decade , is False Danny Rand, and Luke Cage all had a previous 5 [' Danny', ' Rand', ',', ' and', ' Luke', ' Cage']
+150 30 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Luke Cage Diamondback Luke Cage [',', ' the', ' man', ' who', ' has', ' been', ' the', ' most', ' powerful', ' man', ' in', ' the', ' Marvel', ' Universe', ' for', ' the', ' past', ' decade', ',', ' is'] , the man who has been the most powerful man in the Marvel Universe for the past decade , is False also worked with Luke Cage as part of those 4 [' also', ' worked', ' with', ' Luke', ' Cage']
+151 30 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Luke Cage Diamondback Luke Cage [',', ' the', ' man', ' who', ' has', ' been', ' the', ' most', ' powerful', ' man', ' in', ' the', ' Marvel', ' Universe', ' for', ' the', ' past', ' decade', ',', ' is'] , the man who has been the most powerful man in the Marvel Universe for the past decade , is False Dillard in Luke Cage the previous year. 4 [' D', 'illard', ' in', ' Luke', ' Cage']
+152 31 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ghost Rider Mephisto Ghost Rider [' is', ' the', ' Devil', ' himself', ',', ' the', ' Devil', ' D', 'ingo', '!', '\n', '\n', 'The', ' Devil', ' D', 'ingo', ' is', ' a', ' fictional', ' character'] " is the Devil himself , the Devil D ingo !
+
+ The Devil D ingo is a fictional character" False was shown with Ghost Rider on February 4 [' was', ' shown', ' with', ' Ghost', ' Rider']
+153 31 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ghost Rider Mephisto Ghost Rider [' is', ' the', ' Devil', ' himself', ',', ' the', ' Devil', ' D', 'ingo', '!', '\n', '\n', 'The', ' Devil', ' D', 'ingo', ' is', ' a', ' fictional', ' character'] " is the Devil himself , the Devil D ingo !
+
+ The Devil D ingo is a fictional character" False the pages of Ghost Rider, portrayed her 4 [' the', ' pages', ' of', ' Ghost', ' Rider']
+154 31 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ghost Rider Mephisto Ghost Rider [' is', ' the', ' Devil', ' himself', ',', ' the', ' Devil', ' D', 'ingo', '!', '\n', '\n', 'The', ' Devil', ' D', 'ingo', ' is', ' a', ' fictional', ' character'] " is the Devil himself , the Devil D ingo !
+
+ The Devil D ingo is a fictional character" False Spider-Man, Wolverine and Ghost Rider formed a replacement 7 [' Spider', '-', 'Man', ',', ' Wolverine', ' and', ' Ghost', ' Rider']
+155 31 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ghost Rider Mephisto Ghost Rider [' is', ' the', ' Devil', ' himself', ',', ' the', ' Devil', ' D', 'ingo', '!', '\n', '\n', 'The', ' Devil', ' D', 'ingo', ' is', ' a', ' fictional', ' character'] " is the Devil himself , the Devil D ingo !
+
+ The Devil D ingo is a fictional character" False were penned in Ghost Rider: Travels on 4 [' were', ' penned', ' in', ' Ghost', ' Rider']
+156 31 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ghost Rider Mephisto Ghost Rider [' is', ' the', ' Devil', ' himself', ',', ' the', ' Devil', ' D', 'ingo', '!', '\n', '\n', 'The', ' Devil', ' D', 'ingo', ' is', ' a', ' fictional', ' character'] " is the Devil himself , the Devil D ingo !
+
+ The Devil D ingo is a fictional character" False Hulk, X-23 and the new Ghost Rider. The event was initially 9 [' Hulk', ',', ' X', '-', '23', ' and', ' the', ' new', ' Ghost', ' Rider']
+157 32 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Jessica Jones Kilgrave Jessica Jones [',', ' the', ' one', ' who', ' has', ' been', ' torment', 'ing', ' her', ' for', ' years', ',', ' is', ' Kil', 'grave', '.', ' He', ' is', ' a', ' man'] , the one who has been torment ing her for years , is Kil grave . He is a man True series. Comparing Jessica Jones to other television 5 [' series', '.', ' Comp', 'aring', ' Jessica', ' Jones']
+158 32 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Jessica Jones Kilgrave Jessica Jones [',', ' the', ' one', ' who', ' has', ' been', ' torment', 'ing', ' her', ' for', ' years', ',', ' is', ' Kil', 'grave', '.', ' He', ' is', ' a', ' man'] , the one who has been torment ing her for years , is Kil grave . He is a man True the history of Jessica Jones, and introduce new 4 [' the', ' history', ' of', ' Jessica', ' Jones']
+159 32 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Jessica Jones Kilgrave Jessica Jones [',', ' the', ' one', ' who', ' has', ' been', ' torment', 'ing', ' her', ' for', ' years', ',', ' is', ' Kil', 'grave', '.', ' He', ' is', ' a', ' man'] , the one who has been torment ing her for years , is Kil grave . He is a man True " of IGN felt that ""Jessica Jones starts out" 6 "[' of', ' IGN', ' felt', ' that', ' ""', 'Jessica', ' Jones']"
+160 32 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Jessica Jones Kilgrave Jessica Jones [',', ' the', ' one', ' who', ' has', ' been', ' torment', 'ing', ' her', ' for', ' years', ',', ' is', ' Kil', 'grave', '.', ' He', ' is', ' a', ' man'] , the one who has been torment ing her for years , is Kil grave . He is a man True " positive thoughts on Jessica Jones, stating, ""The show," 4 [' positive', ' thoughts', ' on', ' Jessica', ' Jones']
+161 32 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Jessica Jones Kilgrave Jessica Jones [',', ' the', ' one', ' who', ' has', ' been', ' torment', 'ing', ' her', ' for', ' years', ',', ' is', ' Kil', 'grave', '.', ' He', ' is', ' a', ' man'] , the one who has been torment ing her for years , is Kil grave . He is a man True conflicts with Jessica Jones and Luke Cage, Locke's 3 [' conflicts', ' with', ' Jessica', ' Jones']
+162 33 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Arrow Malcolm Merlyn Green Arrow [',', ' the', ' Green', ' Lantern', ',', ' is', ' Hal', ' Jordan', '.', ' He', ' is', ' a', ' Green', ' Lantern', '.', ' He', ' is', ' a', ' Green', ' Lantern'] , the Green Lantern , is Hal Jordan . He is a Green Lantern . He is a Green Lantern False Oliver Queen's Green Arrow persona in the season 4 "[' Oliver', ' Queen', ""'s"", ' Green', ' Arrow']"
+163 33 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Arrow Malcolm Merlyn Green Arrow [',', ' the', ' Green', ' Lantern', ',', ' is', ' Hal', ' Jordan', '.', ' He', ' is', ' a', ' Green', ' Lantern', '.', ' He', ' is', ' a', ' Green', ' Lantern'] , the Green Lantern , is Hal Jordan . He is a Green Lantern . He is a Green Lantern False allowed the Green Arrow Association to promote 3 [' allowed', ' the', ' Green', ' Arrow']
+164 33 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Arrow Malcolm Merlyn Green Arrow [',', ' the', ' Green', ' Lantern', ',', ' is', ' Hal', ' Jordan', '.', ' He', ' is', ' a', ' Green', ' Lantern', '.', ' He', ' is', ' a', ' Green', ' Lantern'] , the Green Lantern , is Hal Jordan . He is a Green Lantern . He is a Green Lantern False " vigilante Green Arrow. Queen's ""job"" as" 2 [' vigilante', ' Green', ' Arrow']
+165 33 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Arrow Malcolm Merlyn Green Arrow [',', ' the', ' Green', ' Lantern', ',', ' is', ' Hal', ' Jordan', '.', ' He', ' is', ' a', ' Green', ' Lantern', '.', ' He', ' is', ' a', ' Green', ' Lantern'] , the Green Lantern , is Hal Jordan . He is a Green Lantern . He is a Green Lantern False publication, writing Green Arrow No.51, Anarky in 4 [' publication', ',', ' writing', ' Green', ' Arrow']
+166 33 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Green Arrow Malcolm Merlyn Green Arrow [',', ' the', ' Green', ' Lantern', ',', ' is', ' Hal', ' Jordan', '.', ' He', ' is', ' a', ' Green', ' Lantern', '.', ' He', ' is', ' a', ' Green', ' Lantern'] , the Green Lantern , is Hal Jordan . He is a Green Lantern . He is a Green Lantern False appeared in a flashback in Green Arrow vol. 3, No. 9, which 6 [' appeared', ' in', ' a', ' flashback', ' in', ' Green', ' Arrow']
+167 34 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkman Shadow Thief Hawkman [',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man', ' and', ' Haw', 'kg', 'irl', ',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man'] , the Hawk man of the Hawk man and Haw kg irl , the Hawk man of the Hawk man False Hell. Wesley Dodds and Hawkman (Carter Hall) 7 [' Hell', '.', ' Wesley', ' Dodd', 's', ' and', ' Hawk', 'man']
+168 34 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkman Shadow Thief Hawkman [',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man', ' and', ' Haw', 'kg', 'irl', ',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man'] , the Hawk man of the Hawk man and Haw kg irl , the Hawk man of the Hawk man False Wesley Dodds and Hawkman (Carter Hall) appear 5 [' Wesley', ' Dodd', 's', ' and', ' Hawk', 'man']
+169 34 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkman Shadow Thief Hawkman [',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man', ' and', ' Haw', 'kg', 'irl', ',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man'] , the Hawk man of the Hawk man and Haw kg irl , the Hawk man of the Hawk man False Wesley Dodds and Hawkman (Carter Hall) appear 5 [' Wesley', ' Dodd', 's', ' and', ' Hawk', 'man']
+170 34 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkman Shadow Thief Hawkman [',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man', ' and', ' Haw', 'kg', 'irl', ',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man'] , the Hawk man of the Hawk man and Haw kg irl , the Hawk man of the Hawk man False Lee liked the name Hawkman from DC Comics, and 5 [' Lee', ' liked', ' the', ' name', ' Hawk', 'man']
+171 34 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkman Shadow Thief Hawkman [',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man', ' and', ' Haw', 'kg', 'irl', ',', ' the', ' Hawk', 'man', ' of', ' the', ' Hawk', 'man'] , the Hawk man of the Hawk man and Haw kg irl , the Hawk man of the Hawk man False liked the name Hawkman from DC Comics, 4 [' liked', ' the', ' name', ' Hawk', 'man']
+172 35 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Martian Manhunter Despero Martian Manhunter [',', ' the', ' one', ' who', ' has', ' been', ' trying', ' to', ' kill', ' him', ' for', ' years', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] ", the one who has been trying to kill him for years .
+
+ The name of the villain" False Comics, 4 [' Comics', 'Mart', 'ian', ' Man', 'hunter']
+173 35 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Martian Manhunter Despero Martian Manhunter [',', ' the', ' one', ' who', ' has', ' been', ' trying', ' to', ' kill', ' him', ' for', ' years', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] ", the one who has been trying to kill him for years .
+
+ The name of the villain" False teams with the Martian Manhunter in Metropolis to 5 [' teams', ' with', ' the', ' Martian', ' Man', 'hunter']
+174 35 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Martian Manhunter Despero Martian Manhunter [',', ' the', ' one', ' who', ' has', ' been', ' trying', ' to', ' kill', ' him', ' for', ' years', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] ", the one who has been trying to kill him for years .
+
+ The name of the villain" False to the game. Martian Manhunter, who had been 6 [' to', ' the', ' game', '.', ' Martian', ' Man', 'hunter']
+175 35 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Martian Manhunter Despero Martian Manhunter [',', ' the', ' one', ' who', ' has', ' been', ' trying', ' to', ' kill', ' him', ' for', ' years', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] ", the one who has been trying to kill him for years .
+
+ The name of the villain" False (DC 5 [' (', 'D', 'Mart', 'ian', ' Man', 'hunter']
+176 35 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Martian Manhunter Despero Martian Manhunter [',', ' the', ' one', ' who', ' has', ' been', ' trying', ' to', ' kill', ' him', ' for', ' years', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] ", the one who has been trying to kill him for years .
+
+ The name of the villain" False John Jones / Martian Manhunter but lacked a proper 5 [' John', ' Jones', ' /', ' Martian', ' Man', 'hunter']
+177 36 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Canary White Canary Black Canary [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False Smallville's approach to Black Canary succeeded where 6 "[' Small', 'ville', ""'s"", ' approach', ' to', ' Black', ' Canary']"
+178 36 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Canary White Canary Black Canary [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False Birds of Prey No. 99, Black Canary leaves the team and 9 [' Birds', ' of', ' Pre', 'y', ' No', '.', ' 99', ',', ' Black', ' Canary']
+179 36 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Canary White Canary Black Canary [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False of Prey No. 99, Black Canary leaves the 8 [' of', ' Pre', 'y', ' No', '.', ' 99', ',', ' Black', ' Canary']
+180 36 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Canary White Canary Black Canary [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False Dinah Lance's Black Canary costume, certain aspects 5 "[' Din', 'ah', ' Lance', ""'s"", ' Black', ' Canary']"
+181 36 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Black Canary White Canary Black Canary [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False Chuck Dixon ’ s Black Canary / Oracle: Birds 6 [' Chuck', ' Dixon', ' �', '�', ' s', ' Black', ' Canary']
+182 37 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Storm Shadow King Storm [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of'] ", the Dark Knight , is Batman .
+
+ The name of the superhero ar chn em esis of" False state, Tropical Storm Linda dropped 3 [' state', ',', ' Tropical', ' Storm']
+183 37 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Storm Shadow King Storm [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of'] ", the Dark Knight , is Batman .
+
+ The name of the superhero ar chn em esis of" False rain and storms from Storm Dirk, with many 4 [' rain', ' and', ' storms', ' from', ' Storm']
+184 37 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Storm Shadow King Storm [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of'] ", the Dark Knight , is Batman .
+
+ The name of the superhero ar chn em esis of" False into Tropical Storm Erin six hours after 2 [' into', ' Tropical', ' Storm']
+185 37 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Storm Shadow King Storm [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of'] ", the Dark Knight , is Batman .
+
+ The name of the superhero ar chn em esis of" False early, as Tropical Storm Arlene formed on 4 [' early', ',', ' as', ' Tropical', ' Storm']
+186 37 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Storm Shadow King Storm [',', ' the', ' Dark', ' Knight', ',', ' is', ' Batman', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of'] ", the Dark Knight , is Batman .
+
+ The name of the superhero ar chn em esis of" False became Tropical Storm Isaac early 2 [' became', ' Tropical', ' Storm']
+187 40 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Invisible Woman Doctor Doom Invisible Woman [',', ' the', ' Invisible', ' Woman', '.', '\n', '\n', 'The', ' Invisible', ' Woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in'] ", the Invisible Woman .
+
+ The Invisible Woman is a fictional character , a superhero ine appearing in" False unterwegs (published as Invisible Woman in English), discusses 8 [' un', 'ter', 'we', 'gs', ' (', 'published', ' as', ' Invisible', ' Woman']
+188 40 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Invisible Woman Doctor Doom Invisible Woman [',', ' the', ' Invisible', ' Woman', '.', '\n', '\n', 'The', ' Invisible', ' Woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in'] ", the Invisible Woman .
+
+ The Invisible Woman is a fictional character , a superhero ine appearing in" False (published as Invisible Woman in English), 4 [' (', 'published', ' as', ' Invisible', ' Woman']
+189 40 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Invisible Woman Doctor Doom Invisible Woman [',', ' the', ' Invisible', ' Woman', '.', '\n', '\n', 'The', ' Invisible', ' Woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in'] ", the Invisible Woman .
+
+ The Invisible Woman is a fictional character , a superhero ine appearing in" False unterwegs (published as Invisible Woman in English), 8 [' un', 'ter', 'we', 'gs', ' (', 'published', ' as', ' Invisible', ' Woman']
+190 41 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Thing Doctor Doom The Thing [' is', ' revealed', '!', '\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic'] " is revealed !
+
+ The Thing is a fictional character , a superv ill ain appearing in American comic" False depicted, or not: The Thing from Another World, 6 [' depicted', ',', ' or', ' not', ':', ' The', ' Thing']
+191 41 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Thing Doctor Doom The Thing [' is', ' revealed', '!', '\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic'] " is revealed !
+
+ The Thing is a fictional character , a superv ill ain appearing in American comic" False Howard Hawks film The Thing from Another World 4 [' Howard', ' Hawks', ' film', ' The', ' Thing']
+192 41 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Thing Doctor Doom The Thing [' is', ' revealed', '!', '\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic'] " is revealed !
+
+ The Thing is a fictional character , a superv ill ain appearing in American comic" False Carpenter's version of The Thing was due in part to 5 "[' Carpenter', ""'s"", ' version', ' of', ' The', ' Thing']"
+193 41 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Thing Doctor Doom The Thing [' is', ' revealed', '!', '\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic'] " is revealed !
+
+ The Thing is a fictional character , a superv ill ain appearing in American comic" False Extra-Terrestrial and The Thing were eschewing the 6 [' Extra', '-', 'Ter', 'restrial', ' and', ' The', ' Thing']
+194 41 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Thing Doctor Doom The Thing [' is', ' revealed', '!', '\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic'] " is revealed !
+
+ The Thing is a fictional character , a superv ill ain appearing in American comic" False incarnations: The Thing from Another World 4 [' incarn', 'ations', ':', ' The', ' Thing']
+195 42 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ms. Marvel Mystique Ms. Marvel [',', ' the', ' most', ' powerful', ' woman', ' in', ' the', ' Marvel', ' Universe', ',', ' is', ' Carol', ' Dan', 'vers', '.', ' She', ' is', ' a', ' former', ' Air'] , the most powerful woman in the Marvel Universe , is Carol Dan vers . She is a former Air False " ""A huge aspect of Ms. Marvel is being a' second" 7 "[' ""', 'A', ' huge', ' aspect', ' of', ' Ms', '.', ' Marvel']"
+196 42 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ms. Marvel Mystique Ms. Marvel [',', ' the', ' most', ' powerful', ' woman', ' in', ' the', ' Marvel', ' Universe', ',', ' is', ' Carol', ' Dan', 'vers', '.', ' She', ' is', ' a', ' former', ' Air'] , the most powerful woman in the Marvel Universe , is Carol Dan vers . She is a former Air False the codename Ms. Marvel from her idol Carol 5 [' the', ' cod', 'ename', ' Ms', '.', ' Marvel']
+197 42 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ms. Marvel Mystique Ms. Marvel [',', ' the', ' most', ' powerful', ' woman', ' in', ' the', ' Marvel', ' Universe', ',', ' is', ' Carol', ' Dan', 'vers', '.', ' She', ' is', ' a', ' former', ' Air'] , the most powerful woman in the Marvel Universe , is Carol Dan vers . She is a former Air False nothing not to love about Ms. Marvel # 1: every 7 [' nothing', ' not', ' to', ' love', ' about', ' Ms', '.', ' Marvel']
+198 42 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ms. Marvel Mystique Ms. Marvel [',', ' the', ' most', ' powerful', ' woman', ' in', ' the', ' Marvel', ' Universe', ',', ' is', ' Carol', ' Dan', 'vers', '.', ' She', ' is', ' a', ' former', ' Air'] , the most powerful woman in the Marvel Universe , is Carol Dan vers . She is a former Air False the codename Ms. Marvel from her idol Carol 5 [' the', ' cod', 'ename', ' Ms', '.', ' Marvel']
+199 42 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Ms. Marvel Mystique Ms. Marvel [',', ' the', ' most', ' powerful', ' woman', ' in', ' the', ' Marvel', ' Universe', ',', ' is', ' Carol', ' Dan', 'vers', '.', ' She', ' is', ' a', ' former', ' Air'] , the most powerful woman in the Marvel Universe , is Carol Dan vers . She is a former Air False before taking over the Ms. Marvel comic book series 6 [' before', ' taking', ' over', ' the', ' Ms', '.', ' Marvel']
+200 43 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Shazam Black Adam Shazam ['!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sh', 'az', 'am', '!', '\n', '\n', 'The'] "!
+
+ The name of the superhero ar chn em esis of Sh az am !
+
+ The" False in a film about Shazam as a part 6 [' in', ' a', ' film', ' about', ' Sh', 'az', 'am']
+201 43 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Shazam Black Adam Shazam ['!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sh', 'az', 'am', '!', '\n', '\n', 'The'] "!
+
+ The name of the superhero ar chn em esis of Sh az am !
+
+ The" False in a film about Shazam as a part of the 6 [' in', ' a', ' film', ' about', ' Sh', 'az', 'am']
+202 43 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Shazam Black Adam Shazam ['!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sh', 'az', 'am', '!', '\n', '\n', 'The'] "!
+
+ The name of the superhero ar chn em esis of Sh az am !
+
+ The" False (Grandad). The Shazam Productions and 7 [' (', 'Grand', 'ad', ').', ' The', ' Sh', 'az', 'am']
+203 43 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Shazam Black Adam Shazam ['!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sh', 'az', 'am', '!', '\n', '\n', 'The'] "!
+
+ The name of the superhero ar chn em esis of Sh az am !
+
+ The" False a film about Shazam as a part of the DC 5 [' a', ' film', ' about', ' Sh', 'az', 'am']
+204 43 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Shazam Black Adam Shazam ['!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sh', 'az', 'am', '!', '\n', '\n', 'The'] "!
+
+ The name of the superhero ar chn em esis of Sh az am !
+
+ The" False production company, Shazam Productions. Originally 5 [' production', ' company', ',', ' Sh', 'az', 'am']
+205 44 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blue Beetle Maxwell Lord Blue Beetle [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False revival of the Blue Beetle the previous year; 4 [' revival', ' of', ' the', ' Blue', ' Beetle']
+206 44 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blue Beetle Maxwell Lord Blue Beetle [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False specials and a Blue Beetle tie-in issue 4 [' specials', ' and', ' a', ' Blue', ' Beetle']
+207 44 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blue Beetle Maxwell Lord Blue Beetle [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False specials and a Blue Beetle tie-in issue were 4 [' specials', ' and', ' a', ' Blue', ' Beetle']
+208 44 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blue Beetle Maxwell Lord Blue Beetle [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False printings as well. Blue Beetle # 20 saw much higher 6 [' print', 'ings', ' as', ' well', '.', ' Blue', ' Beetle']
+209 44 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blue Beetle Maxwell Lord Blue Beetle [',', ' the', ' Green', ' Arrow', ',', ' is', ' a', ' reference', ' to', ' the', ' Green', ' Lantern', '.', '\n', '\n', 'The', ' Green', ' Lantern', ' is', ' a'] ", the Green Arrow , is a reference to the Green Lantern .
+
+ The Green Lantern is a" False one tie-in with Blue Beetle # 20. Part One, 6 [' one', ' tie', '-', 'in', ' with', ' Blue', ' Beetle']
+210 45 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Zatanna Felix Faust Zatanna [' is', ' the', ' Spectre', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Justice', ' League', ' of', ' America', '.', ' He', ' is', ' a', ' powerful', ' wizard'] is the Spectre . He is a member of the Justice League of America . He is a powerful wizard False that JLA member Zatanna had altered 6 [' that', ' J', 'LA', ' member', ' Z', 'at', 'anna']
+211 45 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Zatanna Felix Faust Zatanna [' is', ' the', ' Spectre', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Justice', ' League', ' of', ' America', '.', ' He', ' is', ' a', ' powerful', ' wizard'] is the Spectre . He is a member of the Justice League of America . He is a powerful wizard False and deals with Zatanna and Wonder Woman 5 [' and', ' deals', ' with', ' Z', 'at', 'anna']
+212 45 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Zatanna Felix Faust Zatanna [' is', ' the', ' Spectre', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Justice', ' League', ' of', ' America', '.', ' He', ' is', ' a', ' powerful', ' wizard'] is the Spectre . He is a member of the Justice League of America . He is a powerful wizard False reveals that JLA member Zatanna had altered Batman's 7 [' reveals', ' that', ' J', 'LA', ' member', ' Z', 'at', 'anna']
+213 45 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Zatanna Felix Faust Zatanna [' is', ' the', ' Spectre', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Justice', ' League', ' of', ' America', '.', ' He', ' is', ' a', ' powerful', ' wizard'] is the Spectre . He is a member of the Justice League of America . He is a powerful wizard False and deals with Zatanna and Wonder Woman 5 [' and', ' deals', ' with', ' Z', 'at', 'anna']
+214 45 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Zatanna Felix Faust Zatanna [' is', ' the', ' Spectre', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Justice', ' League', ' of', ' America', '.', ' He', ' is', ' a', ' powerful', ' wizard'] is the Spectre . He is a member of the Justice League of America . He is a powerful wizard False that JLA member Zatanna had altered 6 [' that', ' J', 'LA', ' member', ' Z', 'at', 'anna']
+215 46 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Constantine Nergal Constantine [' is', ' the', ' Joker', '.', ' He', ' is', ' the', ' one', ' who', ' has', ' been', ' the', ' most', ' successful', ' in', ' the', ' history', ' of', ' the', ' DC'] is the Joker . He is the one who has been the most successful in the history of the DC False to marry her. Constantine was reluctant to let 4 [' to', ' marry', ' her', '.', ' Constantine']
+216 46 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Constantine Nergal Constantine [' is', ' the', ' Joker', '.', ' He', ' is', ' the', ' one', ' who', ' has', ' been', ' the', ' most', ' successful', ' in', ' the', ' history', ' of', ' the', ' DC'] is the Joker . He is the one who has been the most successful in the history of the DC False " Constantine Bypass ===
+" 0 [' Constantine']
+217 46 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Constantine Nergal Constantine [' is', ' the', ' Joker', '.', ' He', ' is', ' the', ' one', ' who', ' has', ' been', ' the', ' most', ' successful', ' in', ' the', ' history', ' of', ' the', ' DC'] is the Joker . He is the one who has been the most successful in the history of the DC False of the Schools, and Constantine Gongyles as 5 [' of', ' the', ' Schools', ',', ' and', ' Constantine']
+218 46 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Constantine Nergal Constantine [' is', ' the', ' Joker', '.', ' He', ' is', ' the', ' one', ' who', ' has', ' been', ' the', ' most', ' successful', ' in', ' the', ' history', ' of', ' the', ' DC'] is the Joker . He is the one who has been the most successful in the history of the DC False Palatinae established by Constantine the Great (reigned 6 [' Pal', 'at', 'ina', 'e', ' established', ' by', ' Constantine']
+219 46 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Constantine Nergal Constantine [' is', ' the', ' Joker', '.', ' He', ' is', ' the', ' one', ' who', ' has', ' been', ' the', ' most', ' successful', ' in', ' the', ' history', ' of', ' the', ' DC'] is the Joker . He is the one who has been the most successful in the history of the DC False the reign of Constantine the Great and Scupi 3 [' the', ' reign', ' of', ' Constantine']
+220 47 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Swamp Thing Arcane Swamp Thing [' is', ' the', ' Swamp', ' Thing', '.', ' He', ' is', ' a', ' creature', ' of', ' the', ' swamp', ',', ' and', ' he', ' is', ' a', ' creature', ' of', ' the'] is the Swamp Thing . He is a creature of the swamp , and he is a creature of the False in this process, Swamp Thing has his human 5 [' in', ' this', ' process', ',', ' Swamp', ' Thing']
+221 47 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Swamp Thing Arcane Swamp Thing [' is', ' the', ' Swamp', ' Thing', '.', ' He', ' is', ' a', ' creature', ' of', ' the', ' swamp', ',', ' and', ' he', ' is', ' a', ' creature', ' of', ' the'] is the Swamp Thing . He is a creature of the swamp , and he is a creature of the False titles. In Swamp Thing vol. 2 No. 84 4 [' titles', '.', ' In', ' Swamp', ' Thing']
+222 47 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Swamp Thing Arcane Swamp Thing [' is', ' the', ' Swamp', ' Thing', '.', ' He', ' is', ' a', ' creature', ' of', ' the', ' swamp', ',', ' and', ' he', ' is', ' a', ' creature', ' of', ' the'] is the Swamp Thing . He is a creature of the swamp , and he is a creature of the False character in both Swamp Thing and The Books 4 [' character', ' in', ' both', ' Swamp', ' Thing']
+223 47 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Swamp Thing Arcane Swamp Thing [' is', ' the', ' Swamp', ' Thing', '.', ' He', ' is', ' a', ' creature', ' of', ' the', ' swamp', ',', ' and', ' he', ' is', ' a', ' creature', ' of', ' the'] is the Swamp Thing . He is a creature of the swamp , and he is a creature of the False character in both Swamp Thing and The Books 4 [' character', ' in', ' both', ' Swamp', ' Thing']
+224 47 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Swamp Thing Arcane Swamp Thing [' is', ' the', ' Swamp', ' Thing', '.', ' He', ' is', ' a', ' creature', ' of', ' the', ' swamp', ',', ' and', ' he', ' is', ' a', ' creature', ' of', ' the'] is the Swamp Thing . He is a creature of the swamp , and he is a creature of the False character in both Swamp Thing and The Books 4 [' character', ' in', ' both', ' Swamp', ' Thing']
+225 48 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Robin Two-Face Robin [' Williams', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n', '\n', 'The'] " Williams , the Joker , is a reference to the Joker from the Batman comic books .
+
+ The" False Julie Benz as Robin Gallagher, a 3 [' Julie', ' Benz', ' as', ' Robin']
+226 48 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Robin Two-Face Robin [' Williams', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n', '\n', 'The'] " Williams , the Joker , is a reference to the Joker from the Batman comic books .
+
+ The" False 0 ['Robin']
+227 48 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Robin Two-Face Robin [' Williams', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n', '\n', 'The'] " Williams , the Joker , is a reference to the Joker from the Batman comic books .
+
+ The" False performance and duet with Robin Thicke at the 2013 5 [' performance', ' and', ' du', 'et', ' with', ' Robin']
+228 48 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Robin Two-Face Robin [' Williams', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n', '\n', 'The'] " Williams , the Joker , is a reference to the Joker from the Batman comic books .
+
+ The" False Archaeologist Robin Holgate stressed 2 [' Archae', 'ologist', ' Robin']
+229 48 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Robin Two-Face Robin [' Williams', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n', '\n', 'The'] " Williams , the Joker , is a reference to the Joker from the Batman comic books .
+
+ The" False was voiced by Robin Atkin Downes 3 [' was', ' voiced', ' by', ' Robin']
+230 49 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Catwoman Black Mask Catwoman [',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ',', ' a', ' character', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n'] ", the Joker , is a reference to the Joker , a character from the Batman comic books .
+" False sees Batman and Catwoman allied against 4 [' sees', ' Batman', ' and', ' Cat', 'woman']
+231 49 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Catwoman Black Mask Catwoman [',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ',', ' a', ' character', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n'] ", the Joker , is a reference to the Joker , a character from the Batman comic books .
+" False moving forward on a Catwoman spin-off. However, 5 [' moving', ' forward', ' on', ' a', ' Cat', 'woman']
+232 49 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Catwoman Black Mask Catwoman [',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ',', ' a', ' character', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n'] ", the Joker , is a reference to the Joker , a character from the Batman comic books .
+" False " in the night sky as Catwoman watches from afar.
+" 6 [' in', ' the', ' night', ' sky', ' as', ' Cat', 'woman']
+233 49 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Catwoman Black Mask Catwoman [',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ',', ' a', ' character', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n'] ", the Joker , is a reference to the Joker , a character from the Batman comic books .
+" False title role in the film Catwoman, a $ 100 million 6 [' title', ' role', ' in', ' the', ' film', ' Cat', 'woman']
+234 49 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Catwoman Black Mask Catwoman [',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ',', ' a', ' character', ' from', ' the', ' Batman', ' comic', ' books', '.', '\n'] ", the Joker , is a reference to the Joker , a character from the Batman comic books .
+" False 1 ['Cat', 'woman']
+235 50 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rorschach Ozymandias Rorschach [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' R', 'ors', 'ch', 'ach', ' is'] " is revealed !
+
+ The name of the superhero ar chn em esis of R ors ch ach is" False Coller portrayed Rorschach in a costume 6 [' Coll', 'er', ' portrayed', ' R', 'ors', 'ch', 'ach']
+236 50 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rorschach Ozymandias Rorschach [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' R', 'ors', 'ch', 'ach', ' is'] " is revealed !
+
+ The name of the superhero ar chn em esis of R ors ch ach is" False test Haley gave as Rorschach for Watchmen; 7 [' test', ' Haley', ' gave', ' as', ' R', 'ors', 'ch', 'ach']
+237 50 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rorschach Ozymandias Rorschach [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' R', 'ors', 'ch', 'ach', ' is'] " is revealed !
+
+ The name of the superhero ar chn em esis of R ors ch ach is" False Earle Haley as Rorschach in Watchmen. Stone 7 [' Ear', 'le', ' Haley', ' as', ' R', 'ors', 'ch', 'ach']
+238 50 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rorschach Ozymandias Rorschach [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' R', 'ors', 'ch', 'ach', ' is'] " is revealed !
+
+ The name of the superhero ar chn em esis of R ors ch ach is" False " mask"". Moore based Rorschach on Ditko's creation" 7 "[' mask', '"".', ' Moore', ' based', ' R', 'ors', 'ch', 'ach']"
+239 50 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rorschach Ozymandias Rorschach [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' R', 'ors', 'ch', 'ach', ' is'] " is revealed !
+
+ The name of the superhero ar chn em esis of R ors ch ach is" False " They resemble Rorschach ""blots"" against the" 5 [' They', ' resemble', ' R', 'ors', 'ch', 'ach']
+240 52 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkgirl Shadow Thief Hawkgirl [' is', ' Hawk', 'man', '.', ' He', ' is', ' the', ' son', ' of', ' the', ' Hawk', ' God', ',', ' Horus', ',', ' and', ' the', ' Hawk', ' Goddess', ','] is Hawk man . He is the son of the Hawk God , Horus , and the Hawk Goddess , False Green Lantern and Hawkgirl mistake Batman for 5 [' Green', ' Lantern', ' and', ' Haw', 'kg', 'irl']
+241 52 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hawkgirl Shadow Thief Hawkgirl [' is', ' Hawk', 'man', '.', ' He', ' is', ' the', ' son', ' of', ' the', ' Hawk', ' God', ',', ' Horus', ',', ' and', ' the', ' Hawk', ' Goddess', ','] is Hawk man . He is the son of the Hawk God , Horus , and the Hawk Goddess , False Green Lantern and Hawkgirl mistake Batman for 5 [' Green', ' Lantern', ' and', ' Haw', 'kg', 'irl']
+242 53 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of She-Hulk Titania She-Hulk [' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' in', ' the', ' original', ' comic', ' book', '.', '\n', '\n', 'The', ' name', ' of'] " is the same as the name of the villain in the original comic book .
+
+ The name of" False starred for Vivid as She-Hulk in their parody 8 [' starred', ' for', ' V', 'ivid', ' as', ' She', '-', 'H', 'ulk']
+243 53 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of She-Hulk Titania She-Hulk [' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' in', ' the', ' original', ' comic', ' book', '.', '\n', '\n', 'The', ' name', ' of'] " is the same as the name of the villain in the original comic book .
+
+ The name of" False for Vivid as She-Hulk in their parody 7 [' for', ' V', 'ivid', ' as', ' She', '-', 'H', 'ulk']
+244 54 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Red Hood The Joker Red Hood [' and', ' the', ' Out', 'laws', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comics', '.', '\n'] " and the Out laws , the Joker , is a reference to the Joker from the Batman comics .
+" False and dons the Red Hood. The heist 5 [' and', ' d', 'ons', ' the', ' Red', ' Hood']
+245 54 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Red Hood The Joker Red Hood [' and', ' the', ' Out', 'laws', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comics', '.', '\n'] " and the Out laws , the Joker , is a reference to the Joker from the Batman comics .
+" False Quinn and Red Hood story packs towards 3 [' Quinn', ' and', ' Red', ' Hood']
+246 54 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Red Hood The Joker Red Hood [' and', ' the', ' Out', 'laws', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comics', '.', '\n'] " and the Out laws , the Joker , is a reference to the Joker from the Batman comics .
+" False Batman: Under the Red Hood on DVD (Blu-ray 5 [' Batman', ':', ' Under', ' the', ' Red', ' Hood']
+247 54 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Red Hood The Joker Red Hood [' and', ' the', ' Out', 'laws', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comics', '.', '\n'] " and the Out laws , the Joker , is a reference to the Joker from the Batman comics .
+" False the identity of Red Hood to support 4 [' the', ' identity', ' of', ' Red', ' Hood']
+248 54 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Red Hood The Joker Red Hood [' and', ' the', ' Out', 'laws', ',', ' the', ' Joker', ',', ' is', ' a', ' reference', ' to', ' the', ' Joker', ' from', ' the', ' Batman', ' comics', '.', '\n'] " and the Out laws , the Joker , is a reference to the Joker from the Batman comics .
+" False Harley Quinn and Red Hood story packs, both 4 [' Harley', ' Quinn', ' and', ' Red', ' Hood']
+249 55 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Dr. Fate Wotan Dr. Fate [',', ' the', ' Spectre', ',', ' is', ' a', ' reference', ' to', ' the', ' Spectre', ',', ' a', ' DC', ' Comics', ' character', ' who', ' is', ' a', ' member', ' of'] , the Spectre , is a reference to the Spectre , a DC Comics character who is a member of False powers restored by Dr. Fate (Brent Stait) 5 [' powers', ' restored', ' by', ' Dr', '.', ' Fate']
+250 55 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Dr. Fate Wotan Dr. Fate [',', ' the', ' Spectre', ',', ' is', ' a', ' reference', ' to', ' the', ' Spectre', ',', ' a', ' DC', ' Comics', ' character', ' who', ' is', ' a', ' member', ' of'] , the Spectre , is a reference to the Spectre , a DC Comics character who is a member of False powers restored by Dr. Fate (Brent Stait) 5 [' powers', ' restored', ' by', ' Dr', '.', ' Fate']
+251 55 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Dr. Fate Wotan Dr. Fate [',', ' the', ' Spectre', ',', ' is', ' a', ' reference', ' to', ' the', ' Spectre', ',', ' a', ' DC', ' Comics', ' character', ' who', ' is', ' a', ' member', ' of'] , the Spectre , is a reference to the Spectre , a DC Comics character who is a member of False restored by Dr. Fate (Brent Stait) 4 [' restored', ' by', ' Dr', '.', ' Fate']
+252 57 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scarlet Witch Morgan Le Fay Scarlet Witch [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' a', ' reference', ' to', ' the', ' Scarlet', ' Witch', ',', ' a', ' Marvel', ' Comics', ' character', ' who', ' is', ' a'] , the Scarlet Witch , is a reference to the Scarlet Witch , a Marvel Comics character who is a False twin brother of the Scarlet Witch, who can move at 5 [' twin', ' brother', ' of', ' the', ' Scarlet', ' Witch']
+253 57 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scarlet Witch Morgan Le Fay Scarlet Witch [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' a', ' reference', ' to', ' the', ' Scarlet', ' Witch', ',', ' a', ' Marvel', ' Comics', ' character', ' who', ' is', ' a'] , the Scarlet Witch , is a reference to the Scarlet Witch , a Marvel Comics character who is a False 2 ['Scar', 'let', ' Witch']
+254 57 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scarlet Witch Morgan Le Fay Scarlet Witch [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' a', ' reference', ' to', ' the', ' Scarlet', ' Witch', ',', ' a', ' Marvel', ' Comics', ' character', ' who', ' is', ' a'] , the Scarlet Witch , is a reference to the Scarlet Witch , a Marvel Comics character who is a False " someone's mind, with Scarlet Witch able to ""feel and" 6 "[' someone', ""'s"", ' mind', ',', ' with', ' Scarlet', ' Witch']"
+255 57 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scarlet Witch Morgan Le Fay Scarlet Witch [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' a', ' reference', ' to', ' the', ' Scarlet', ' Witch', ',', ' a', ' Marvel', ' Comics', ' character', ' who', ' is', ' a'] , the Scarlet Witch , is a reference to the Scarlet Witch , a Marvel Comics character who is a False brother of the Scarlet Witch, who can move at superhuman 4 [' brother', ' of', ' the', ' Scarlet', ' Witch']
+256 57 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scarlet Witch Morgan Le Fay Scarlet Witch [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' a', ' reference', ' to', ' the', ' Scarlet', ' Witch', ',', ' a', ' Marvel', ' Comics', ' character', ' who', ' is', ' a'] , the Scarlet Witch , is a reference to the Scarlet Witch , a Marvel Comics character who is a False " reused, however, for Scarlet Witch instead.
+" 6 [' reused', ',', ' however', ',', ' for', ' Scarlet', ' Witch']
+257 58 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Vision Ultron Vision [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False 2010 Green Car Vision Award by the Green 3 [' 2010', ' Green', ' Car', ' Vision']
+258 58 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Vision Ultron Vision [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False Panoramic Night Vision Goggles) L-3 night-vision 4 [' Pan', 'or', 'amic', ' Night', ' Vision']
+259 58 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Vision Ultron Vision [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False capture prey. Vision is of little importance 3 [' capture', ' prey', '.', ' Vision']
+260 58 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Vision Ultron Vision [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False Abraham (1992). Vision or Villainy: Origins 4 [' Abraham', ' (', '1992', ').', ' Vision']
+261 58 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Vision Ultron Vision [',', ' the', ' Scarlet', ' Witch', ',', ' is', ' W', 'anda', ' Maxim', 'off', '.', ' She', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' manipulate'] , the Scarlet Witch , is W anda Maxim off . She is a mutant with the ability to manipulate False effects that Red Vision had created. In 3 [' effects', ' that', ' Red', ' Vision']
+262 59 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Power Girl Ultra-Humanite Power Girl [',', ' the', ' villain', ' of', ' the', ' week', ',', ' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' of', ' the', ' week', ' in'] , the villain of the week , is the same as the name of the villain of the week in False before Superman and Power Girl can apprehend her. 4 [' before', ' Superman', ' and', ' Power', ' Girl']
+263 59 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Power Girl Ultra-Humanite Power Girl [',', ' the', ' villain', ' of', ' the', ' week', ',', ' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' of', ' the', ' week', ' in'] , the villain of the week , is the same as the name of the villain of the week in False before Superman and Power Girl can apprehend 4 [' before', ' Superman', ' and', ' Power', ' Girl']
+264 60 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Mystique Bishop Mystique [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' power', ' to', ' manipulate', ' the', ' weather', '.', '\n', '\n', 'The'] ", the X - Men , is a mutant with the power to manipulate the weather .
+
+ The" False is a shapeshifter. Mystique is blue, naked and 6 [' is', ' a', ' sh', 'apeshifter', '.', ' Myst', 'ique']
+265 60 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Mystique Bishop Mystique [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' power', ' to', ' manipulate', ' the', ' weather', '.', '\n', '\n', 'The'] ", the X - Men , is a mutant with the power to manipulate the weather .
+
+ The" False the scene where Mystique drugs Magneto's prison 4 [' the', ' scene', ' where', ' Myst', 'ique']
+266 60 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Mystique Bishop Mystique [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' power', ' to', ' manipulate', ' the', ' weather', '.', '\n', '\n', 'The'] ", the X - Men , is a mutant with the power to manipulate the weather .
+
+ The" False for scenes where Mystique uses his persona 4 [' for', ' scenes', ' where', ' Myst', 'ique']
+267 60 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Mystique Bishop Mystique [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' power', ' to', ' manipulate', ' the', ' weather', '.', '\n', '\n', 'The'] ", the X - Men , is a mutant with the power to manipulate the weather .
+
+ The" False and the villains Mystique and Destiny. 4 [' and', ' the', ' villains', ' Myst', 'ique']
+268 60 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Mystique Bishop Mystique [',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' mutant', ' with', ' the', ' power', ' to', ' manipulate', ' the', ' weather', '.', '\n', '\n', 'The'] ", the X - Men , is a mutant with the power to manipulate the weather .
+
+ The" False soldiers take her away, Mystique arrives with Blob and 6 [' soldiers', ' take', ' her', ' away', ',', ' Myst', 'ique']
+269 61 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Shadowcat White Queen Shadowcat [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain'] " is revealed !
+
+ The name of the villain is revealed !
+
+ The name of the villain" False and was produced by Shadowcat Films, with 5 [' and', ' was', ' produced', ' by', ' Shadow', 'cat']
+270 62 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Colossus Juggernaut Colossus "[',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' reference', ' to', ' the', ' X', '-', 'Men', ""'s"", ' arch', '-', 'n', 'em', 'esis']" , the X - Men , is a reference to the X - Men 's arch - n em esis False Shadow of the Colossus (ワンダと巨像, Wanda to 3 [' Shadow', ' of', ' the', ' Colossus']
+271 62 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Colossus Juggernaut Colossus "[',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' reference', ' to', ' the', ' X', '-', 'Men', ""'s"", ' arch', '-', 'n', 'em', 'esis']" , the X - Men , is a reference to the X - Men 's arch - n em esis False in Shadow of the Colossus is the relationship 4 [' in', ' Shadow', ' of', ' the', ' Colossus']
+272 62 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Colossus Juggernaut Colossus "[',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' reference', ' to', ' the', ' X', '-', 'Men', ""'s"", ' arch', '-', 'n', 'em', 'esis']" , the X - Men , is a reference to the X - Men 's arch - n em esis False used on the park's Colossus wooden roller 5 "[' used', ' on', ' the', ' park', ""'s"", ' Colossus']"
+273 62 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Colossus Juggernaut Colossus "[',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' reference', ' to', ' the', ' X', '-', 'Men', ""'s"", ' arch', '-', 'n', 'em', 'esis']" , the X - Men , is a reference to the X - Men 's arch - n em esis False story of Shadow of the Colossus begins as Wander 5 [' story', ' of', ' Shadow', ' of', ' the', ' Colossus']
+274 62 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Colossus Juggernaut Colossus "[',', ' the', ' X', '-', 'Men', ',', ' is', ' a', ' reference', ' to', ' the', ' X', '-', 'Men', ""'s"", ' arch', '-', 'n', 'em', 'esis']" , the X - Men , is a reference to the X - Men 's arch - n em esis False " Special where Colossus throws Wolverine.
+" 2 [' Special', ' where', ' Colossus']
+275 63 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Magneto Professor X Magneto [',', ' the', ' X', '-', 'Men', ',', ' is', ' Magnet', 'o', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em'] ", the X - Men , is Magnet o .
+
+ The name of the superhero ar chn em" False and close-friend, Magneto now believes 6 [' and', ' close', '-', 'friend', ',', ' Magnet', 'o']
+276 63 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Magneto Professor X Magneto [',', ' the', ' X', '-', 'Men', ',', ' is', ' Magnet', 'o', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em'] ", the X - Men , is Magnet o .
+
+ The name of the superhero ar chn em" False the past Xavier and Magneto into preventing 5 [' the', ' past', ' Xavier', ' and', ' Magnet', 'o']
+277 63 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Magneto Professor X Magneto [',', ' the', ' X', '-', 'Men', ',', ' is', ' Magnet', 'o', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em'] ", the X - Men , is Magnet o .
+
+ The name of the superhero ar chn em" False He later joins Magneto and Mystique. 4 [' He', ' later', ' joins', ' Magnet', 'o']
+278 63 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Magneto Professor X Magneto [',', ' the', ' X', '-', 'Men', ',', ' is', ' Magnet', 'o', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em'] ", the X - Men , is Magnet o .
+
+ The name of the superhero ar chn em" False themes for Magneto and Shaw have similarities 3 [' themes', ' for', ' Magnet', 'o']
+279 63 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Magneto Professor X Magneto [',', ' the', ' X', '-', 'Men', ',', ' is', ' Magnet', 'o', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em'] ", the X - Men , is Magnet o .
+
+ The name of the superhero ar chn em" False overtook a planned Magneto prequel that 5 [' overt', 'ook', ' a', ' planned', ' Magnet', 'o']
+280 64 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Gamora Thanos Gamora [',', ' the', ' villain', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Kr', 'ag', 'lin', '.', ' He', ' is', ' a', ' large', ','] , the villain of the Guardians of the Galaxy , is Kr ag lin . He is a large , False with Drax and Gamora being used for 5 [' with', ' Dra', 'x', ' and', ' Gam', 'ora']
+281 64 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Gamora Thanos Gamora [',', ' the', ' villain', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Kr', 'ag', 'lin', '.', ' He', ' is', ' a', ' large', ','] , the villain of the Guardians of the Galaxy , is Kr ag lin . He is a large , False negotiations to star as Gamora in the film, and 5 [' negotiations', ' to', ' star', ' as', ' Gam', 'ora']
+282 64 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Gamora Thanos Gamora [',', ' the', ' villain', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Kr', 'ag', 'lin', '.', ' He', ' is', ' a', ' large', ','] , the villain of the Guardians of the Galaxy , is Kr ag lin . He is a large , False character, with Drax and Gamora being used for older 7 [' character', ',', ' with', ' Dra', 'x', ' and', ' Gam', 'ora']
+283 64 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Gamora Thanos Gamora [',', ' the', ' villain', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Kr', 'ag', 'lin', '.', ' He', ' is', ' a', ' large', ','] , the villain of the Guardians of the Galaxy , is Kr ag lin . He is a large , False character, with Drax and Gamora being used for 7 [' character', ',', ' with', ' Dra', 'x', ' and', ' Gam', 'ora']
+284 64 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Gamora Thanos Gamora [',', ' the', ' villain', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Kr', 'ag', 'lin', '.', ' He', ' is', ' a', ' large', ','] , the villain of the Guardians of the Galaxy , is Kr ag lin . He is a large , False negotiations to star as Gamora in the film, 5 [' negotiations', ' to', ' star', ' as', ' Gam', 'ora']
+285 65 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Groot Collector Groot [',', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Gro', 'ot', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' of', ' the'] ", the Guardians of the Galaxy , is Gro ot .
+
+ The name of the villain of the" False Raccoon and Groot would be created through 5 [' R', 'acco', 'on', ' and', ' Gro', 'ot']
+286 65 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Groot Collector Groot [',', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Gro', 'ot', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' of', ' the'] ", the Guardians of the Galaxy , is Gro ot .
+
+ The name of the villain of the" False relatives. A Groot Besogne (Grand Commission) 4 [' relatives', '.', ' A', ' Gro', 'ot']
+287 65 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Groot Collector Groot [',', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Gro', 'ot', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' of', ' the'] ", the Guardians of the Galaxy , is Gro ot .
+
+ The name of the villain of the" False portray Rocket and Groot on set, as it provided 4 [' portray', ' Rocket', ' and', ' Gro', 'ot']
+288 65 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Groot Collector Groot [',', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Gro', 'ot', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' of', ' the'] ", the Guardians of the Galaxy , is Gro ot .
+
+ The name of the villain of the" False Forfar. Henry de Groot recorded that 6 [' For', 'far', '.', ' Henry', ' de', ' Gro', 'ot']
+289 65 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Groot Collector Groot [',', ' the', ' Guardians', ' of', ' the', ' Galaxy', ',', ' is', ' Gro', 'ot', '.', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' of', ' the'] ", the Guardians of the Galaxy , is Gro ot .
+
+ The name of the villain of the" False vice versa for when Groot was needed 5 [' vice', ' versa', ' for', ' when', ' Gro', 'ot']
+290 66 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rocket Raccoon Collector Rocket Raccoon [' and', ' Gro', 'ot', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', 'ous', ' ar', 'chn', 'em', 'esis', ' of', ' Rocket'] " and Gro ot is revealed !
+
+ The name of the villain ous ar chn em esis of Rocket" False also said that both Rocket Raccoon and Groot would be 7 [' also', ' said', ' that', ' both', ' Rocket', ' R', 'acco', 'on']
+291 66 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rocket Raccoon Collector Rocket Raccoon [' and', ' Gro', 'ot', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', 'ous', ' ar', 'chn', 'em', 'esis', ' of', ' Rocket'] " and Gro ot is revealed !
+
+ The name of the villain ous ar chn em esis of Rocket" False characters, such as Rocket Raccoon and Nova, were 7 [' characters', ',', ' such', ' as', ' Rocket', ' R', 'acco', 'on']
+292 66 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rocket Raccoon Collector Rocket Raccoon [' and', ' Gro', 'ot', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', 'ous', ' ar', 'chn', 'em', 'esis', ' of', ' Rocket'] " and Gro ot is revealed !
+
+ The name of the villain ous ar chn em esis of Rocket" False said that both Rocket Raccoon and Groot would 6 [' said', ' that', ' both', ' Rocket', ' R', 'acco', 'on']
+293 66 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Rocket Raccoon Collector Rocket Raccoon [' and', ' Gro', 'ot', ' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', 'ous', ' ar', 'chn', 'em', 'esis', ' of', ' Rocket'] " and Gro ot is revealed !
+
+ The name of the villain ous ar chn em esis of Rocket" False also said that both Rocket Raccoon and Groot would 7 [' also', ' said', ' that', ' both', ' Rocket', ' R', 'acco', 'on']
+294 67 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Star-Lord Ronan the Accuser Star-Lord [',', ' the', ' Green', ' Lantern', ',', ' is', ' Kil', 'ow', 'og', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Green', ' Lantern', ' Corps', ','] , the Green Lantern , is Kil ow og . He is a member of the Green Lantern Corps , False older demographics; Star-Lord ’ s obsession 5 [' older', ' demographics', ';', ' Star', '-', 'Lord']
+295 67 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Star-Lord Ronan the Accuser Star-Lord [',', ' the', ' Green', ' Lantern', ',', ' is', ' Kil', 'ow', 'og', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Green', ' Lantern', ' Corps', ','] , the Green Lantern , is Kil ow og . He is a member of the Green Lantern Corps , False Pratt dressed as Star-Lord to entertain 5 [' Pratt', ' dressed', ' as', ' Star', '-', 'Lord']
+296 67 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Star-Lord Ronan the Accuser Star-Lord [',', ' the', ' Green', ' Lantern', ',', ' is', ' Kil', 'ow', 'og', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Green', ' Lantern', ' Corps', ','] , the Green Lantern , is Kil ow og . He is a member of the Green Lantern Corps , False demographics; Star-Lord ’ s obsession 4 [' demographics', ';', ' Star', '-', 'Lord']
+297 67 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Star-Lord Ronan the Accuser Star-Lord [',', ' the', ' Green', ' Lantern', ',', ' is', ' Kil', 'ow', 'og', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Green', ' Lantern', ' Corps', ','] , the Green Lantern , is Kil ow og . He is a member of the Green Lantern Corps , False Peter Quill / Star-Lord in Guardians of the 6 [' Peter', ' Qu', 'ill', ' /', ' Star', '-', 'Lord']
+298 67 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Star-Lord Ronan the Accuser Star-Lord [',', ' the', ' Green', ' Lantern', ',', ' is', ' Kil', 'ow', 'og', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Green', ' Lantern', ' Corps', ','] , the Green Lantern , is Kil ow og . He is a member of the Green Lantern Corps , False demographics; Star-Lord ’ s obsession with 4 [' demographics', ';', ' Star', '-', 'Lord']
+299 68 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nick Fury Baron Strucker Nick Fury [',', ' the', ' man', ' who', ' has', ' been', ' fighting', ' the', ' good', ' fight', ' against', ' the', ' forces', ' of', ' evil', ' for', ' decades', '.', '\n', '\n'] ", the man who has been fighting the good fight against the forces of evil for decades .
+
+" False idea of a Nick Fury cameo to set 4 [' idea', ' of', ' a', ' Nick', ' Fury']
+300 68 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nick Fury Baron Strucker Nick Fury [',', ' the', ' man', ' who', ' has', ' been', ' fighting', ' the', ' good', ' fight', ' against', ' the', ' forces', ' of', ' evil', ' for', ' decades', '.', '\n', '\n'] ", the man who has been fighting the good fight against the forces of evil for decades .
+
+" False that comes out of Nick Fury's mouth is a lie 5 [' that', ' comes', ' out', ' of', ' Nick', ' Fury']
+301 68 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nick Fury Baron Strucker Nick Fury [',', ' the', ' man', ' who', ' has', ' been', ' fighting', ' the', ' good', ' fight', ' against', ' the', ' forces', ' of', ' evil', ' for', ' decades', '.', '\n', '\n'] ", the man who has been fighting the good fight against the forces of evil for decades .
+
+" False headquarters, where Nick Fury reveals that 4 [' headquarters', ',', ' where', ' Nick', ' Fury']
+302 68 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nick Fury Baron Strucker Nick Fury [',', ' the', ' man', ' who', ' has', ' been', ' fighting', ' the', ' good', ' fight', ' against', ' the', ' forces', ' of', ' evil', ' for', ' decades', '.', '\n', '\n'] ", the man who has been fighting the good fight against the forces of evil for decades .
+
+" False S.H.I.E.L.D. agent Nick Fury as the second player's 14 [' S', '.', 'H', '.', 'I', '.', 'E', '.', 'L', '.', 'D', '.', ' agent', ' Nick', ' Fury']
+303 68 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Nick Fury Baron Strucker Nick Fury [',', ' the', ' man', ' who', ' has', ' been', ' fighting', ' the', ' good', ' fight', ' against', ' the', ' forces', ' of', ' evil', ' for', ' decades', '.', '\n', '\n'] ", the man who has been fighting the good fight against the forces of evil for decades .
+
+" False Entertainment to play Nick Fury in Iron Man 2 and 4 [' Entertainment', ' to', ' play', ' Nick', ' Fury']
+304 70 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellboy Grigori Rasputin Hellboy [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Hell', 'boy', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Hell boy is revealed !" False down to direct Hellboy II: The Golden Army. 4 [' down', ' to', ' direct', ' Hell', 'boy']
+305 70 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellboy Grigori Rasputin Hellboy [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Hell', 'boy', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Hell boy is revealed !" False illustrated the Hellboy Christmas Special 3 [' illustrated', ' the', ' Hell', 'boy']
+306 70 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellboy Grigori Rasputin Hellboy [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Hell', 'boy', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Hell boy is revealed !" False Toro's previous film, Hellboy II: The Golden 6 "[' Toro', ""'s"", ' previous', ' film', ',', ' Hell', 'boy']"
+307 70 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellboy Grigori Rasputin Hellboy [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Hell', 'boy', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Hell boy is revealed !" False and Simon Lee, and Hellboy II and The Hobbit 6 [' and', ' Simon', ' Lee', ',', ' and', ' Hell', 'boy']
+308 70 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellboy Grigori Rasputin Hellboy [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Hell', 'boy', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Hell boy is revealed !" False del Toro, surpassing Hellboy II: The Golden 6 [' del', ' Toro', ',', ' surpass', 'ing', ' Hell', 'boy']
+309 71 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Judge Dredd Judge Death Judge Dredd [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Judge', ' D', 'redd', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Judge D redd is revealed" False the 1995 movie Judge Dredd starring Sylvester 5 [' the', ' 1995', ' movie', ' Judge', ' D', 'redd']
+310 71 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Judge Dredd Judge Death Judge Dredd [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Judge', ' D', 'redd', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Judge D redd is revealed" False three years, Judge Dredd continued to be 5 [' three', ' years', ',', ' Judge', ' D', 'redd']
+311 71 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Judge Dredd Judge Death Judge Dredd [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Judge', ' D', 'redd', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Judge D redd is revealed" False early in the Judge Dredd strip (in 1978), the 5 [' early', ' in', ' the', ' Judge', ' D', 'redd']
+312 71 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Judge Dredd Judge Death Judge Dredd [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Judge', ' D', 'redd', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Judge D redd is revealed" False Burns, in Judge Dredd Megazine vol. 5 [' Burns', ',', ' in', ' Judge', ' D', 'redd']
+313 71 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Judge Dredd Judge Death Judge Dredd [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Judge', ' D', 'redd', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Judge D redd is revealed" False executioner. Judge Dredd is tasked by 5 [' execution', 'er', '.', ' Judge', ' D', 'redd']
+314 72 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Teenage Mutant Ninja Turtles Shredder Teenage Mutant Ninja Turtles [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles'] " is revealed !
+
+ The name of the superhero ar chn em esis of Teen age Mutant Ninja Turtles" False and movies, with Teenage Mutant Ninja Turtles a surprise 8 [' and', ' movies', ',', ' with', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles']
+315 72 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Teenage Mutant Ninja Turtles Shredder Teenage Mutant Ninja Turtles [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles'] " is revealed !
+
+ The name of the superhero ar chn em esis of Teen age Mutant Ninja Turtles" False 4 ['Teen', 'age', ' Mutant', ' Ninja', ' Turtles']
+316 72 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Teenage Mutant Ninja Turtles Shredder Teenage Mutant Ninja Turtles [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles'] " is revealed !
+
+ The name of the superhero ar chn em esis of Teen age Mutant Ninja Turtles" False In the present, the Teenage Mutant Ninja Turtles have grown 9 [' In', ' the', ' present', ',', ' the', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles']
+317 72 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Teenage Mutant Ninja Turtles Shredder Teenage Mutant Ninja Turtles [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles'] " is revealed !
+
+ The name of the superhero ar chn em esis of Teen age Mutant Ninja Turtles" False Warriors – just as Teenage Mutant Ninja Turtles was renamed Teenage 8 [' Warriors', ' –', ' just', ' as', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles']
+318 72 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Teenage Mutant Ninja Turtles Shredder Teenage Mutant Ninja Turtles [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles'] " is revealed !
+
+ The name of the superhero ar chn em esis of Teen age Mutant Ninja Turtles" False former editor of Teenage Mutant Ninja Turtles and other Mirage 7 [' former', ' editor', ' of', ' Teen', 'age', ' Mutant', ' Ninja', ' Turtles']
+319 73 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Sailor Moon Queen Beryl Sailor Moon [',', ' Sailor', ' Uran', 'us', ',', ' is', ' a', ' reference', ' to', ' the', ' planet', ' Uran', 'us', ',', ' which', ' is', ' the', ' seventh', ' planet', ' from'] , Sailor Uran us , is a reference to the planet Uran us , which is the seventh planet from False threesome with Raj and a Sailor Moon cosplayer at ComicCon 8 [' th', 're', 'esome', ' with', ' Raj', ' and', ' a', ' Sailor', ' Moon']
+320 73 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Sailor Moon Queen Beryl Sailor Moon [',', ' Sailor', ' Uran', 'us', ',', ' is', ' a', ' reference', ' to', ' the', ' planet', ' Uran', 'us', ',', ' which', ' is', ' the', ' seventh', ' planet', ' from'] , Sailor Uran us , is a reference to the planet Uran us , which is the seventh planet from False compared to Sailor Moon due to both having 3 [' compared', ' to', ' Sailor', ' Moon']
+321 73 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Sailor Moon Queen Beryl Sailor Moon [',', ' Sailor', ' Uran', 'us', ',', ' is', ' a', ' reference', ' to', ' the', ' planet', ' Uran', 'us', ',', ' which', ' is', ' the', ' seventh', ' planet', ' from'] , Sailor Uran us , is a reference to the planet Uran us , which is the seventh planet from False Raj and a Sailor Moon cosplayer at 4 [' Raj', ' and', ' a', ' Sailor', ' Moon']
+322 73 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Sailor Moon Queen Beryl Sailor Moon [',', ' Sailor', ' Uran', 'us', ',', ' is', ' a', ' reference', ' to', ' the', ' planet', ' Uran', 'us', ',', ' which', ' is', ' the', ' seventh', ' planet', ' from'] , Sailor Uran us , is a reference to the planet Uran us , which is the seventh planet from False compared to Sailor Moon due to both having 3 [' compared', ' to', ' Sailor', ' Moon']
+323 73 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Sailor Moon Queen Beryl Sailor Moon [',', ' Sailor', ' Uran', 'us', ',', ' is', ' a', ' reference', ' to', ' the', ' planet', ' Uran', 'us', ',', ' which', ' is', ' the', ' seventh', ' planet', ' from'] , Sailor Uran us , is a reference to the planet Uran us , which is the seventh planet from False has been compared to Sailor Moon due to both having 5 [' has', ' been', ' compared', ' to', ' Sailor', ' Moon']
+324 74 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Astro Boy Dr. Tenma Astro Boy [',', ' the', ' robot', ' boy', ' who', ' was', ' created', ' by', ' the', ' Japanese', ' manga', ' artist', ' Os', 'am', 'u', ' Te', 'z', 'uka', '.', ' He'] , the robot boy who was created by the Japanese manga artist Os am u Te z uka . He False Tezuka's manga Astro Boy in his Mega 6 "[' Te', 'z', 'uka', ""'s"", ' manga', ' Astro', ' Boy']"
+325 74 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Astro Boy Dr. Tenma Astro Boy [',', ' the', ' robot', ' boy', ' who', ' was', ' created', ' by', ' the', ' Japanese', ' manga', ' artist', ' Os', 'am', 'u', ' Te', 'z', 'uka', '.', ' He'] , the robot boy who was created by the Japanese manga artist Os am u Te z uka . He False Tezuka's manga Astro Boy in his Mega 6 "[' Te', 'z', 'uka', ""'s"", ' manga', ' Astro', ' Boy']"
+326 75 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Goku Frieza Goku [',', ' the', ' Dragon', ' Ball', ' Z', ' character', ' who', ' is', ' the', ' main', ' protagonist', ' of', ' the', ' Dragon', ' Ball', ' series', '.', ' He', ' is', ' the'] , the Dragon Ball Z character who is the main protagonist of the Dragon Ball series . He is the False explained the he had Goku grow up as a 4 [' explained', ' the', ' he', ' had', ' Goku']
+327 75 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Goku Frieza Goku [',', ' the', ' Dragon', ' Ball', ' Z', ' character', ' who', ' is', ' the', ' main', ' protagonist', ' of', ' the', ' Dragon', ' Ball', ' series', '.', ' He', ' is', ' the'] , the Dragon Ball Z character who is the main protagonist of the Dragon Ball series . He is the False for. With Goku being Sun Wukong, 3 [' for', '.', ' With', ' Goku']
+328 75 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Goku Frieza Goku [',', ' the', ' Dragon', ' Ball', ' Z', ' character', ' who', ' is', ' the', ' main', ' protagonist', ' of', ' the', ' Dragon', ' Ball', ' series', '.', ' He', ' is', ' the'] , the Dragon Ball Z character who is the main protagonist of the Dragon Ball series . He is the False revenge against Goku. During this time, 2 [' revenge', ' against', ' Goku']
+329 75 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Goku Frieza Goku [',', ' the', ' Dragon', ' Ball', ' Z', ' character', ' who', ' is', ' the', ' main', ' protagonist', ' of', ' the', ' Dragon', ' Ball', ' series', '.', ' He', ' is', ' the'] , the Dragon Ball Z character who is the main protagonist of the Dragon Ball series . He is the False Freeza -- Son Goku's Father ~, renamed 4 [' Free', 'za', ' --', ' Son', ' Goku']
+330 75 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Goku Frieza Goku [',', ' the', ' Dragon', ' Ball', ' Z', ' character', ' who', ' is', ' the', ' main', ' protagonist', ' of', ' the', ' Dragon', ' Ball', ' series', '.', ' He', ' is', ' the'] , the Dragon Ball Z character who is the main protagonist of the Dragon Ball series . He is the False Five years later, Goku is a young adult 4 [' Five', ' years', ' later', ',', ' Goku']
+331 76 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spawn Malebolgia Spawn [',', ' the', ' Dark', ' Horse', ' Comics', ' character', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' demon', '.', '\n'] ", the Dark Horse Comics character , is a reference to the fact that he is a demon .
+" False enormously successful Spawn. Spawn holds the 2 [' enormously', ' successful', ' Spawn']
+332 76 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spawn Malebolgia Spawn [',', ' the', ' Dark', ' Horse', ' Comics', ' character', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' demon', '.', '\n'] ", the Dark Horse Comics character , is a reference to the fact that he is a demon .
+" False Digital Chumps and Spawn Kill confirmed 4 [' Digital', ' Ch', 'umps', ' and', ' Spawn']
+333 76 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spawn Malebolgia Spawn [',', ' the', ' Dark', ' Horse', ' Comics', ' character', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' demon', '.', '\n'] ", the Dark Horse Comics character , is a reference to the fact that he is a demon .
+" False battle against Spawn Again was a rematch 2 [' battle', ' against', ' Spawn']
+334 76 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spawn Malebolgia Spawn [',', ' the', ' Dark', ' Horse', ' Comics', ' character', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' demon', '.', '\n'] ", the Dark Horse Comics character , is a reference to the fact that he is a demon .
+" False " Teenage Riot, from the Spawn soundtrack,"" and that" 6 [' Teen', 'age', ' Riot', ',', ' from', ' the', ' Spawn']
+335 76 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Spawn Malebolgia Spawn [',', ' the', ' Dark', ' Horse', ' Comics', ' character', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' demon', '.', '\n'] ", the Dark Horse Comics character , is a reference to the fact that he is a demon .
+" False " Riot, from the Spawn soundtrack,""" 4 [' Riot', ',', ' from', ' the', ' Spawn']
+336 77 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Venom Carnage Venom [' is', ' the', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', '.', ' He', ' is', ' the', ' ar', 'chn', 'em', 'esis', ' of', ' Spider', '-'] is the Spider - Man of the Marvel Universe . He is the ar chn em esis of Spider - False version of the Venom drug that gives Bane 3 [' version', ' of', ' the', ' Venom']
+337 77 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Venom Carnage Venom [' is', ' the', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', '.', ' He', ' is', ' the', ' ar', 'chn', 'em', 'esis', ' of', ' Spider', '-'] is the Spider - Man of the Marvel Universe . He is the ar chn em esis of Spider - False liquid-mercury virus, Venom is injured trying 7 [' liquid', '-', 'mer', 'c', 'ury', ' virus', ',', ' Venom']
+338 77 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Venom Carnage Venom [' is', ' the', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', '.', ' He', ' is', ' the', ' ar', 'chn', 'em', 'esis', ' of', ' Spider', '-'] is the Spider - Man of the Marvel Universe . He is the ar chn em esis of Spider - False Spider-Man focused on Venom (with Kurtzman 5 [' Spider', '-', 'Man', ' focused', ' on', ' Venom']
+339 77 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Venom Carnage Venom [' is', ' the', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', '.', ' He', ' is', ' the', ' ar', 'chn', 'em', 'esis', ' of', ' Spider', '-'] is the Spider - Man of the Marvel Universe . He is the ar chn em esis of Spider - False NWOBHM band Venom was also an important 4 [' NW', 'OB', 'HM', ' band', ' Venom']
+340 77 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Venom Carnage Venom [' is', ' the', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', '.', ' He', ' is', ' the', ' ar', 'chn', 'em', 'esis', ' of', ' Spider', '-'] is the Spider - Man of the Marvel Universe . He is the ar chn em esis of Spider - False The first Venom title, Venom: Lethal 2 [' The', ' first', ' Venom']
+341 78 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Elektra Bullseye Elektra [' is', ' the', ' Black', ' Cat', '.', '\n', '\n', 'The', ' Black', ' Cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing'] " is the Black Cat .
+
+ The Black Cat is a fictional character , a superv ill ain appearing" False released in 1992 by the Elektra imprint Nonesuch Records. 7 [' released', ' in', ' 1992', ' by', ' the', ' Ele', 'k', 'tra']
+342 78 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Elektra Bullseye Elektra [' is', ' the', ' Black', ' Cat', '.', '\n', '\n', 'The', ' Black', ' Cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing'] " is the Black Cat .
+
+ The Black Cat is a fictional character , a superv ill ain appearing" False videos for singles. Elektra Records' Peter 6 [' videos', ' for', ' singles', '.', ' Ele', 'k', 'tra']
+343 78 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Elektra Bullseye Elektra [' is', ' the', ' Black', ' Cat', '.', '\n', '\n', 'The', ' Black', ' Cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing'] " is the Black Cat .
+
+ The Black Cat is a fictional character , a superv ill ain appearing" False after signing to Elektra Records, recorded 5 [' after', ' signing', ' to', ' Ele', 'k', 'tra']
+344 78 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Elektra Bullseye Elektra [' is', ' the', ' Black', ' Cat', '.', '\n', '\n', 'The', ' Black', ' Cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing'] " is the Black Cat .
+
+ The Black Cat is a fictional character , a superv ill ain appearing" False its release, Elektra Records signed 5 [' its', ' release', ',', ' Ele', 'k', 'tra']
+345 78 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Elektra Bullseye Elektra [' is', ' the', ' Black', ' Cat', '.', '\n', '\n', 'The', ' Black', ' Cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superv', 'ill', 'ain', ' appearing'] " is the Black Cat .
+
+ The Black Cat is a fictional character , a superv ill ain appearing" False daughter, Elektra; Renard previously 4 [' daughter', ',', ' Ele', 'k', 'tra']
+346 79 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Buffy the Vampire Slayer The Master Buffy the Vampire Slayer [',', ' Angel', ',', ' is', ' a', ' reference', ' to', ' the', ' Angel', 'us', ',', ' a', ' demon', ' that', ' is', ' the', ' embodiment', ' of', ' evil', '.'] , Angel , is a reference to the Angel us , a demon that is the embodiment of evil . False " ""Ultimate Power"" to Buffy the Vampire Slayer character Dawn Summers" 8 "[' ""', 'Ultimate', ' Power', '""', ' to', ' Buffy', ' the', ' Vampire', ' Slayer']"
+347 79 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Buffy the Vampire Slayer The Master Buffy the Vampire Slayer [',', ' Angel', ',', ' is', ' a', ' reference', ' to', ' the', ' Angel', 'us', ',', ' a', ' demon', ' that', ' is', ' the', ' embodiment', ' of', ' evil', '.'] , Angel , is a reference to the Angel us , a demon that is the embodiment of evil . False as Spike in Buffy the Vampire Slayer and Angel. Whilst 6 [' as', ' Spike', ' in', ' Buffy', ' the', ' Vampire', ' Slayer']
+348 79 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Buffy the Vampire Slayer The Master Buffy the Vampire Slayer [',', ' Angel', ',', ' is', ' a', ' reference', ' to', ' the', ' Angel', 'us', ',', ' a', ' demon', ' that', ' is', ' the', ' embodiment', ' of', ' evil', '.'] , Angel , is a reference to the Angel us , a demon that is the embodiment of evil . False 4 ['B', 'uffy', ' the', ' Vampire', ' Slayer']
+349 79 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Buffy the Vampire Slayer The Master Buffy the Vampire Slayer [',', ' Angel', ',', ' is', ' a', ' reference', ' to', ' the', ' Angel', 'us', ',', ' a', ' demon', ' that', ' is', ' the', ' embodiment', ' of', ' evil', '.'] , Angel , is a reference to the Angel us , a demon that is the embodiment of evil . False " ""Bomis: The Buffy the Vampire Slayer Ring"", devoted" 9 "[' ""', 'B', 'om', 'is', ':', ' The', ' Buffy', ' the', ' Vampire', ' Slayer']"
+350 79 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Buffy the Vampire Slayer The Master Buffy the Vampire Slayer [',', ' Angel', ',', ' is', ' a', ' reference', ' to', ' the', ' Angel', 'us', ',', ' a', ' demon', ' that', ' is', ' the', ' embodiment', ' of', ' evil', '.'] , Angel , is a reference to the Angel us , a demon that is the embodiment of evil . False Whedonesque covered Buffy the Vampire Slayer and Angel, but 8 [' W', 'hed', 'ones', 'que', ' covered', ' Buffy', ' the', ' Vampire', ' Slayer']
+351 80 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of RoboCop Clarence Boddicker RoboCop [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Robo', 'Cop', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Robo Cop is revealed !" False Gun (1986), RoboCop (1987), Star Wars 5 [' Gun', ' (', '1986', '),', ' Robo', 'Cop']
+352 80 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of RoboCop Clarence Boddicker RoboCop [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Robo', 'Cop', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Robo Cop is revealed !" False Gun (1986), RoboCop (1987), Star 5 [' Gun', ' (', '1986', '),', ' Robo', 'Cop']
+353 80 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of RoboCop Clarence Boddicker RoboCop [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Robo', 'Cop', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Robo Cop is revealed !" False in the film RoboCop (1987). He also 4 [' in', ' the', ' film', ' Robo', 'Cop']
+354 80 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of RoboCop Clarence Boddicker RoboCop [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Robo', 'Cop', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Robo Cop is revealed !" False Blood (1985), RoboCop (1987) and Starship 5 [' Blood', ' (', '1985', '),', ' Robo', 'Cop']
+355 80 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of RoboCop Clarence Boddicker RoboCop [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Robo', 'Cop', ' is', ' revealed', '!'] " is revealed !
+
+ The name of the superhero ar chn em esis of Robo Cop is revealed !" False fiction-action film RoboCop (1987). During 5 [' fiction', '-', 'action', ' film', ' Robo', 'Cop']
+356 81 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of John Constantine First of the Fallen John Constantine [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' who', ' will', ' be', ' the', ' main', ' antagonist', ' of', ' the', ' upcoming', ' season'] " is revealed !
+
+ The name of the villain who will be the main antagonist of the upcoming season" False that follows John Constantine on a trip across 3 [' that', ' follows', ' John', ' Constantine']
+357 81 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of John Constantine First of the Fallen John Constantine [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' who', ' will', ' be', ' the', ' main', ' antagonist', ' of', ' the', ' upcoming', ' season'] " is revealed !
+
+ The name of the villain who will be the main antagonist of the upcoming season" False attempted to return John Constantine to his roots, with 4 [' attempted', ' to', ' return', ' John', ' Constantine']
+358 81 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of John Constantine First of the Fallen John Constantine [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' who', ' will', ' be', ' the', ' main', ' antagonist', ' of', ' the', ' upcoming', ' season'] " is revealed !
+
+ The name of the villain who will be the main antagonist of the upcoming season" False that follows John Constantine on a trip 3 [' that', ' follows', ' John', ' Constantine']
+359 81 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of John Constantine First of the Fallen John Constantine [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' who', ' will', ' be', ' the', ' main', ' antagonist', ' of', ' the', ' upcoming', ' season'] " is revealed !
+
+ The name of the villain who will be the main antagonist of the upcoming season" False consists of the John Constantine character wandering 4 [' consists', ' of', ' the', ' John', ' Constantine']
+360 81 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of John Constantine First of the Fallen John Constantine [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' villain', ' who', ' will', ' be', ' the', ' main', ' antagonist', ' of', ' the', ' upcoming', ' season'] " is revealed !
+
+ The name of the villain who will be the main antagonist of the upcoming season" False unconsciousness of mankind. John Constantine cuts his own wrists, 6 [' unconscious', 'ness', ' of', ' mankind', '.', ' John', ' Constantine']
+361 83 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blade Deacon Frost Blade [' Runner', ' 20', '49', ' is', ' the', ' Repl', 'ic', 'ant', '.', '\n', '\n', 'The', ' Repl', 'ic', 'ant', ' is', ' a', ' human', '-', 'like'] " Runner 20 49 is the Repl ic ant .
+
+ The Repl ic ant is a human - like" False " ""A Single Blade of Grass"" is the first" 3 "[' ""', 'A', ' Single', ' Blade']"
+362 83 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blade Deacon Frost Blade [' Runner', ' 20', '49', ' is', ' the', ' Repl', 'ic', 'ant', '.', '\n', '\n', 'The', ' Repl', 'ic', 'ant', ' is', ' a', ' human', '-', 'like'] " Runner 20 49 is the Repl ic ant .
+
+ The Repl ic ant is a human - like" False (known as Soul Blade outside Japan) 4 [' (', 'known', ' as', ' Soul', ' Blade']
+363 83 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blade Deacon Frost Blade [' Runner', ' 20', '49', ' is', ' the', ' Repl', 'ic', 'ant', '.', '\n', '\n', 'The', ' Repl', 'ic', 'ant', ' is', ' a', ' human', '-', 'like'] " Runner 20 49 is the Repl ic ant .
+
+ The Repl ic ant is a human - like" False destroyed the Picori Blade and petrified 4 [' destroyed', ' the', ' Pic', 'ori', ' Blade']
+364 83 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blade Deacon Frost Blade [' Runner', ' 20', '49', ' is', ' the', ' Repl', 'ic', 'ant', '.', '\n', '\n', 'The', ' Repl', 'ic', 'ant', ' is', ' a', ' human', '-', 'like'] " Runner 20 49 is the Repl ic ant .
+
+ The Repl ic ant is a human - like" False " Runner =
+" 2 [' Runner', ' =', 'Blade']
+365 83 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Blade Deacon Frost Blade [' Runner', ' 20', '49', ' is', ' the', ' Repl', 'ic', 'ant', '.', '\n', '\n', 'The', ' Repl', 'ic', 'ant', ' is', ' a', ' human', '-', 'like'] " Runner 20 49 is the Repl ic ant .
+
+ The Repl ic ant is a human - like" False 0 ['Blade']
+366 84 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kick-Ass Red Mist Kick-Ass [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Kick', '-', 'Ass', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Kick - Ass is revealed" False order to film Kick-Ass 2. The episode is 5 [' order', ' to', ' film', ' Kick', '-', 'Ass']
+367 84 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kick-Ass Red Mist Kick-Ass [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Kick', '-', 'Ass', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Kick - Ass is revealed" False " episodes to film Kick-Ass 2.
+" 5 [' episodes', ' to', ' film', ' Kick', '-', 'Ass']
+368 84 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kick-Ass Red Mist Kick-Ass [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Kick', '-', 'Ass', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Kick - Ass is revealed" False in order to film Kick-Ass 2. The episode is 6 [' in', ' order', ' to', ' film', ' Kick', '-', 'Ass']
+369 84 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kick-Ass Red Mist Kick-Ass [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Kick', '-', 'Ass', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Kick - Ass is revealed" False " Part III and Kick-Ass 2, respectively.
+" 5 [' Part', ' III', ' and', ' Kick', '-', 'Ass']
+370 84 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kick-Ass Red Mist Kick-Ass [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Kick', '-', 'Ass', ' is', ' revealed'] " is revealed !
+
+ The name of the superhero ar chn em esis of Kick - Ass is revealed" False " episodes to film Kick-Ass 2.
+" 5 [' episodes', ' to', ' film', ' Kick', '-', 'Ass']
+371 85 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Xena: Warrior Princess Callisto Xena: Warrior Princess [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' X', 'ena', ':', ' Warrior', ' Princess'] " is revealed !
+
+ The name of the superhero ar chn em esis of X ena : Warrior Princess" False " Arquette in season six of Xena: Warrior Princess (2001);
+" 11 [' Ar', 'qu', 'ette', ' in', ' season', ' six', ' of', ' X', 'ena', ':', ' Warrior', ' Princess']
+372 85 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Xena: Warrior Princess Callisto Xena: Warrior Princess [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' X', 'ena', ':', ' Warrior', ' Princess'] " is revealed !
+
+ The name of the superhero ar chn em esis of X ena : Warrior Princess" False shows such as Xena: Warrior Princess and Stargate 7 [' shows', ' such', ' as', ' X', 'ena', ':', ' Warrior', ' Princess']
+373 85 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Xena: Warrior Princess Callisto Xena: Warrior Princess [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' X', 'ena', ':', ' Warrior', ' Princess'] " is revealed !
+
+ The name of the superhero ar chn em esis of X ena : Warrior Princess" False " Arquette in season six of Xena: Warrior Princess (2001);
+" 11 [' Ar', 'qu', 'ette', ' in', ' season', ' six', ' of', ' X', 'ena', ':', ' Warrior', ' Princess']
+374 85 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Xena: Warrior Princess Callisto Xena: Warrior Princess [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' X', 'ena', ':', ' Warrior', ' Princess'] " is revealed !
+
+ The name of the superhero ar chn em esis of X ena : Warrior Princess" False pencils, as on Xena: Warrior Princess # 4 (January 2000). 9 [' pencil', 's', ',', ' as', ' on', ' X', 'ena', ':', ' Warrior', ' Princess']
+375 85 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Xena: Warrior Princess Callisto Xena: Warrior Princess [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' X', 'ena', ':', ' Warrior', ' Princess'] " is revealed !
+
+ The name of the superhero ar chn em esis of X ena : Warrior Princess" False shows such as Xena: Warrior Princess and Stargate 7 [' shows', ' such', ' as', ' X', 'ena', ':', ' Warrior', ' Princess']
+376 86 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hercules Hades Hercules [',', ' the', ' God', ' of', ' Strength', ',', ' is', ' Hercules', '.', ' He', ' is', ' the', ' son', ' of', ' Zeus', ' and', ' Al', 'cm', 'ene', ','] , the God of Strength , is Hercules . He is the son of Zeus and Al cm ene , False Dutch privateers, the Hercules and Mars. Nine 5 [' Dutch', ' private', 'ers', ',', ' the', ' Hercules']
+377 86 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hercules Hades Hercules [',', ' the', ' God', ' of', ' Strength', ',', ' is', ' Hercules', '.', ' He', ' is', ' the', ' son', ' of', ' Zeus', ' and', ' Al', 'cm', 'ene', ','] , the God of Strength , is Hercules . He is the son of Zeus and Al cm ene , False " was described as Hercules fighting a boy.
+" 3 [' was', ' described', ' as', ' Hercules']
+378 86 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hercules Hades Hercules [',', ' the', ' God', ' of', ' Strength', ',', ' is', ' Hercules', '.', ' He', ' is', ' the', ' son', ' of', ' Zeus', ' and', ' Al', 'cm', 'ene', ','] , the God of Strength , is Hercules . He is the son of Zeus and Al cm ene , False 2 ['H', 'erc', 'ules']
+379 86 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hercules Hades Hercules [',', ' the', ' God', ' of', ' Strength', ',', ' is', ' Hercules', '.', ' He', ' is', ' the', ' son', ' of', ' Zeus', ' and', ' Al', 'cm', 'ene', ','] , the God of Strength , is Hercules . He is the son of Zeus and Al cm ene , False following summer, Hercules, grossed only $ 252 3 [' following', ' summer', ',', ' Hercules']
+380 86 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hercules Hades Hercules [',', ' the', ' God', ' of', ' Strength', ',', ' is', ' Hercules', '.', ' He', ' is', ' the', ' son', ' of', ' Zeus', ' and', ' Al', 'cm', 'ene', ','] , the God of Strength , is Hercules . He is the son of Zeus and Al cm ene , False Lockheed C-130 Hercules in Australian 4 [' Lockheed', ' C', '-', '130', ' Hercules']
+381 87 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Powerpuff Girls Mojo Jojo The Powerpuff Girls "[' is', ' revealed', '!', '\n', '\n', 'The', ' Power', 'puff', ' Girls', ' are', ' back', '!', ' And', ' they', ""'re"", ' back', ' to', ' save', ' the', ' world']" " is revealed !
+
+ The Power puff Girls are back ! And they 're back to save the world" False the Looney Tunes and The Powerpuff Girls editions of Monopoly, 9 [' the', ' Lo', 'oney', ' T', 'unes', ' and', ' The', ' Power', 'puff', ' Girls']
+382 87 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Powerpuff Girls Mojo Jojo The Powerpuff Girls "[' is', ' revealed', '!', '\n', '\n', 'The', ' Power', 'puff', ' Girls', ' are', ' back', '!', ' And', ' they', ""'re"", ' back', ' to', ' save', ' the', ' world']" " is revealed !
+
+ The Power puff Girls are back ! And they 're back to save the world" False Strong as Bubbles from The Powerpuff Girls and Timmy Turner from 8 [' Strong', ' as', ' Bub', 'bles', ' from', ' The', ' Power', 'puff', ' Girls']
+383 87 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Powerpuff Girls Mojo Jojo The Powerpuff Girls "[' is', ' revealed', '!', '\n', '\n', 'The', ' Power', 'puff', ' Girls', ' are', ' back', '!', ' And', ' they', ""'re"", ' back', ' to', ' save', ' the', ' world']" " is revealed !
+
+ The Power puff Girls are back ! And they 're back to save the world" False the Looney Tunes and The Powerpuff Girls editions of Monopoly, 9 [' the', ' Lo', 'oney', ' T', 'unes', ' and', ' The', ' Power', 'puff', ' Girls']
+384 87 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Powerpuff Girls Mojo Jojo The Powerpuff Girls "[' is', ' revealed', '!', '\n', '\n', 'The', ' Power', 'puff', ' Girls', ' are', ' back', '!', ' And', ' they', ""'re"", ' back', ' to', ' save', ' the', ' world']" " is revealed !
+
+ The Power puff Girls are back ! And they 're back to save the world" False theme to promote The Powerpuff Girls Movie. The 6 [' theme', ' to', ' promote', ' The', ' Power', 'puff', ' Girls']
+385 87 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Powerpuff Girls Mojo Jojo The Powerpuff Girls "[' is', ' revealed', '!', '\n', '\n', 'The', ' Power', 'puff', ' Girls', ' are', ' back', '!', ' And', ' they', ""'re"", ' back', ' to', ' save', ' the', ' world']" " is revealed !
+
+ The Power puff Girls are back ! And they 're back to save the world" False the Looney Tunes and The Powerpuff Girls editions of Monopoly, 9 [' the', ' Lo', 'oney', ' T', 'unes', ' and', ' The', ' Power', 'puff', ' Girls']
+386 88 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellcat Mad Dog Hellcat [' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' in', ' the', ' original', ' comic', ' book', '.', '\n', '\n', 'The', ' name', ' of'] " is the same as the name of the villain in the original comic book .
+
+ The name of" False be an F6F Hellcat fighter, though the 6 [' be', ' an', ' F', '6', 'F', ' Hell', 'cat']
+387 88 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellcat Mad Dog Hellcat [' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' in', ' the', ' original', ' comic', ' book', '.', '\n', '\n', 'The', ' name', ' of'] " is the same as the name of the villain in the original comic book .
+
+ The name of" False began with Hellcat fighters strafing 3 [' began', ' with', ' Hell', 'cat']
+388 88 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellcat Mad Dog Hellcat [' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' in', ' the', ' original', ' comic', ' book', '.', '\n', '\n', 'The', ' name', ' of'] " is the same as the name of the villain in the original comic book .
+
+ The name of" False consisted of F6F Hellcat and F4U Corsair fighters, 6 [' consisted', ' of', ' F', '6', 'F', ' Hell', 'cat']
+389 88 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellcat Mad Dog Hellcat [' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' in', ' the', ' original', ' comic', ' book', '.', '\n', '\n', 'The', ' name', ' of'] " is the same as the name of the villain in the original comic book .
+
+ The name of" False by Grumman F6F Hellcat fighters, that had 8 [' by', ' Gr', 'um', 'man', ' F', '6', 'F', ' Hell', 'cat']
+390 88 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Hellcat Mad Dog Hellcat [' is', ' the', ' same', ' as', ' the', ' name', ' of', ' the', ' villain', ' in', ' the', ' original', ' comic', ' book', '.', '\n', '\n', 'The', ' name', ' of'] " is the same as the name of the villain in the original comic book .
+
+ The name of" False operated 20 Grumman F6F Hellcat fighters from 1840 9 [' operated', ' 20', ' Gr', 'um', 'man', ' F', '6', 'F', ' Hell', 'cat']
+391 90 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Planet Captain Pollution Captain Planet [',', ' the', ' evil', ' Dr', '.', ' N', 'ef', 'arious', ',', ' is', ' a', ' reference', ' to', ' the', ' villain', 'ous', ' Dr', '.', ' N', 'ef'] , the evil Dr . N ef arious , is a reference to the villain ous Dr . N ef False " claptrap"" of Captain Planet and the Planeteers," 6 "[' cl', 'apt', 'rap', '""', ' of', ' Captain', ' Planet']"
+392 90 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Planet Captain Pollution Captain Planet [',', ' the', ' evil', ' Dr', '.', ' N', 'ef', 'arious', ',', ' is', ' a', ' reference', ' to', ' the', ' villain', 'ous', ' Dr', '.', ' N', 'ef'] , the evil Dr . N ef arious , is a reference to the villain ous Dr . N ef False " correct claptrap"" of Captain Planet and the Planeteers," 7 "[' correct', ' cl', 'apt', 'rap', '""', ' of', ' Captain', ' Planet']"
+393 90 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Planet Captain Pollution Captain Planet [',', ' the', ' evil', ' Dr', '.', ' N', 'ef', 'arious', ',', ' is', ' a', ' reference', ' to', ' the', ' villain', 'ous', ' Dr', '.', ' N', 'ef'] , the evil Dr . N ef arious , is a reference to the villain ous Dr . N ef False " correct claptrap"" of Captain Planet and the Planeteers," 7 "[' correct', ' cl', 'apt', 'rap', '""', ' of', ' Captain', ' Planet']"
+394 90 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Captain Planet Captain Pollution Captain Planet [',', ' the', ' evil', ' Dr', '.', ' N', 'ef', 'arious', ',', ' is', ' a', ' reference', ' to', ' the', ' villain', 'ous', ' Dr', '.', ' N', 'ef'] , the evil Dr . N ef arious , is a reference to the villain ous Dr . N ef False " correct claptrap"" of Captain Planet and the Planeteers," 7 "[' correct', ' cl', 'apt', 'rap', '""', ' of', ' Captain', ' Planet']"
+395 91 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Phantom The Python The Phantom [',', ' the', ' Phantom', ' Stranger', ',', ' is', ' a', ' reference', ' to', ' the', ' DC', ' Comics', ' character', ',', ' who', ' was', ' created', ' by', ' writer', ' Otto'] , the Phantom Stranger , is a reference to the DC Comics character , who was created by writer Otto False 1 ['The', ' Phantom']
+396 91 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Phantom The Python The Phantom [',', ' the', ' Phantom', ' Stranger', ',', ' is', ' a', ' reference', ' to', ' the', ' DC', ' Comics', ' character', ',', ' who', ' was', ' created', ' by', ' writer', ' Otto'] , the Phantom Stranger , is a reference to the DC Comics character , who was created by writer Otto False in the war. The Phantom has the distinction 5 [' in', ' the', ' war', '.', ' The', ' Phantom']
+397 91 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Phantom The Python The Phantom [',', ' the', ' Phantom', ' Stranger', ',', ' is', ' a', ' reference', ' to', ' the', ' DC', ' Comics', ' character', ',', ' who', ' was', ' created', ' by', ' writer', ' Otto'] , the Phantom Stranger , is a reference to the DC Comics character , who was created by writer Otto False dressed as Red Death. The Phantom brings his own composition, 6 [' dressed', ' as', ' Red', ' Death', '.', ' The', ' Phantom']
+398 91 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Phantom The Python The Phantom [',', ' the', ' Phantom', ' Stranger', ',', ' is', ' a', ' reference', ' to', ' the', ' DC', ' Comics', ' character', ',', ' who', ' was', ' created', ' by', ' writer', ' Otto'] , the Phantom Stranger , is a reference to the DC Comics character , who was created by writer Otto False Wars: Episode I – The Phantom Menace and Titanic 6 [' Wars', ':', ' Episode', ' I', ' –', ' The', ' Phantom']
+399 91 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of The Phantom The Python The Phantom [',', ' the', ' Phantom', ' Stranger', ',', ' is', ' a', ' reference', ' to', ' the', ' DC', ' Comics', ' character', ',', ' who', ' was', ' created', ' by', ' writer', ' Otto'] , the Phantom Stranger , is a reference to the DC Comics character , who was created by writer Otto False Wars: Episode I – The Phantom Menace, which didn't 6 [' Wars', ':', ' Episode', ' I', ' –', ' The', ' Phantom']
+400 93 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kim Possible Dr. Drakken Kim Possible [',', ' the', ' villain', ' of', ' the', ' series', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' villain', ' who', ' is'] , the villain of the series , is a reference to the fact that he is a villain who is False " (character) =
+" 5 [' (', 'character', ')', ' =', 'Kim', ' Possible']
+401 93 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kim Possible Dr. Drakken Kim Possible [',', ' the', ' villain', ' of', ' the', ' series', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' villain', ' who', ' is'] , the villain of the series , is a reference to the fact that he is a villain who is False attraction entitled the Kim Possible World Showcase Adventure 4 [' attraction', ' entitled', ' the', ' Kim', ' Possible']
+402 93 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kim Possible Dr. Drakken Kim Possible [',', ' the', ' villain', ' of', ' the', ' series', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' villain', ' who', ' is'] , the villain of the series , is a reference to the fact that he is a villain who is False Possible (character) 4 [' Possible', ' (', 'character', 'Kim', ' Possible']
+403 93 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kim Possible Dr. Drakken Kim Possible [',', ' the', ' villain', ' of', ' the', ' series', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' villain', ' who', ' is'] , the villain of the series , is a reference to the fact that he is a villain who is False " entitled ""13 Reasons Kim Possible is the Only Role Model" 5 "[' entitled', ' ""', '13', ' Reasons', ' Kim', ' Possible']"
+404 93 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Kim Possible Dr. Drakken Kim Possible [',', ' the', ' villain', ' of', ' the', ' series', ',', ' is', ' a', ' reference', ' to', ' the', ' fact', ' that', ' he', ' is', ' a', ' villain', ' who', ' is'] , the villain of the series , is a reference to the fact that he is a villain who is False " entitled ""13 Reasons Kim Possible is the Only Role Model" 5 "[' entitled', ' ""', '13', ' Reasons', ' Kim', ' Possible']"
+405 95 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scooby-Doo Ghost Clown Scooby-Doo [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sco', 'oby', '-', 'D', 'oo'] " is revealed !
+
+ The name of the superhero ar chn em esis of Sco oby - D oo" False production included Scooby-Doo and outfits like s'mores, 6 [' production', ' included', ' Sco', 'oby', '-', 'D', 'oo']
+406 95 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scooby-Doo Ghost Clown Scooby-Doo [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sco', 'oby', '-', 'D', 'oo'] " is revealed !
+
+ The name of the superhero ar chn em esis of Sco oby - D oo" False " well as the Scooby-Doo sequence as ""a" 7 [' well', ' as', ' the', ' Sco', 'oby', '-', 'D', 'oo']
+407 95 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scooby-Doo Ghost Clown Scooby-Doo [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sco', 'oby', '-', 'D', 'oo'] " is revealed !
+
+ The name of the superhero ar chn em esis of Sco oby - D oo" False Jetsons family and Scooby-Doo and Shaggy 8 [' Jets', 'ons', ' family', ' and', ' Sco', 'oby', '-', 'D', 'oo']
+408 95 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scooby-Doo Ghost Clown Scooby-Doo [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sco', 'oby', '-', 'D', 'oo'] " is revealed !
+
+ The name of the superhero ar chn em esis of Sco oby - D oo" False plots in the mode of Scooby-Doo are framed around 9 [' plots', ' in', ' the', ' mode', ' of', ' Sco', 'oby', '-', 'D', 'oo']
+409 95 The name of the superhero archnemesis of x -1 The name of the superhero archnemesis of Scooby-Doo Ghost Clown Scooby-Doo [' is', ' revealed', '!', '\n', '\n', 'The', ' name', ' of', ' the', ' superhero', ' ar', 'chn', 'em', 'esis', ' of', ' Sco', 'oby', '-', 'D', 'oo'] " is revealed !
+
+ The name of the superhero ar chn em esis of Sco oby - D oo" False teenagers from Scooby-Doo arrive at Camp 6 [' teenagers', ' from', ' Sco', 'oby', '-', 'D', 'oo']
diff --git a/patchscopes/code/preprocessed_data/factual/superhero_person.tsv b/patchscopes/code/preprocessed_data/factual/superhero_person.tsv
new file mode 100644
index 00000000..0f422b99
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual/superhero_person.tsv
@@ -0,0 +1,1340 @@
+ sample_id prompt_target position_target target_baseline object subject generations_baseline generations_baseline_str is_correct_baseline prompt_source position_source source_cropped_toks
+0 0 Secret identity of x -1 Secret identity of Superman Clark Kent Superman ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Superman', ' is', ' Clark', ' Kent', ',', ' a', ' mild', '-', 'man', 'nered', ' reporter', ' for', ' the', ' Daily'] "
+
+ The secret identity of Superman is Clark Kent , a mild - man nered reporter for the Daily" True the traditional Superman costume, and 2 [' the', ' traditional', ' Superman']
+1 0 Secret identity of x -1 Secret identity of Superman Clark Kent Superman ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Superman', ' is', ' Clark', ' Kent', ',', ' a', ' mild', '-', 'man', 'nered', ' reporter', ' for', ' the', ' Daily'] "
+
+ The secret identity of Superman is Clark Kent , a mild - man nered reporter for the Daily" True Vendetta and Singer's Superman Returns. Chris 5 "[' Vend', 'etta', ' and', ' Singer', ""'s"", ' Superman']"
+2 0 Secret identity of x -1 Secret identity of Superman Clark Kent Superman ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Superman', ' is', ' Clark', ' Kent', ',', ' a', ' mild', '-', 'man', 'nered', ' reporter', ' for', ' the', ' Daily'] "
+
+ The secret identity of Superman is Clark Kent , a mild - man nered reporter for the Daily" True Bros. announced that Superman and Batman will unite 4 [' Bros', '.', ' announced', ' that', ' Superman']
+3 0 Secret identity of x -1 Secret identity of Superman Clark Kent Superman ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Superman', ' is', ' Clark', ' Kent', ',', ' a', ' mild', '-', 'man', 'nered', ' reporter', ' for', ' the', ' Daily'] "
+
+ The secret identity of Superman is Clark Kent , a mild - man nered reporter for the Daily" True " been dubbed Superman by the masses.
+" 2 [' been', ' dubbed', ' Superman']
+4 0 Secret identity of x -1 Secret identity of Superman Clark Kent Superman ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Superman', ' is', ' Clark', ' Kent', ',', ' a', ' mild', '-', 'man', 'nered', ' reporter', ' for', ' the', ' Daily'] "
+
+ The secret identity of Superman is Clark Kent , a mild - man nered reporter for the Daily" True DVD box set titled Superman Ultimate Collector's 4 [' DVD', ' box', ' set', ' titled', ' Superman']
+5 1 Secret identity of x -1 Secret identity of Batman Bruce Wayne Batman ['\n', '\n', 'Batman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' created', ' by', ' Bob', ' Kane', ' and', ' Bill', ' Finger', '.', ' He', ' first'] "
+
+ Batman is a fictional character , a superhero created by Bob Kane and Bill Finger . He first" False crime-fighting partner to Batman and ward of Bruce 5 [' crime', '-', 'fighting', ' partner', ' to', ' Batman']
+6 1 Secret identity of x -1 Secret identity of Batman Bruce Wayne Batman ['\n', '\n', 'Batman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' created', ' by', ' Bob', ' Kane', ' and', ' Bill', ' Finger', '.', ' He', ' first'] "
+
+ Batman is a fictional character , a superhero created by Bob Kane and Bill Finger . He first" False Robin (1997) 3 [' Robin', ' (', '1997', 'Batman']
+7 1 Secret identity of x -1 Secret identity of Batman Bruce Wayne Batman ['\n', '\n', 'Batman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' created', ' by', ' Bob', ' Kane', ' and', ' Bill', ' Finger', '.', ' He', ' first'] "
+
+ Batman is a fictional character , a superhero created by Bob Kane and Bill Finger . He first" False with alternate Batman outfits based 2 [' with', ' alternate', ' Batman']
+8 1 Secret identity of x -1 Secret identity of Batman Bruce Wayne Batman ['\n', '\n', 'Batman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' created', ' by', ' Bob', ' Kane', ' and', ' Bill', ' Finger', '.', ' He', ' first'] "
+
+ Batman is a fictional character , a superhero created by Bob Kane and Bill Finger . He first" False limited to the Batman Family, serving a 3 [' limited', ' to', ' the', ' Batman']
+9 1 Secret identity of x -1 Secret identity of Batman Bruce Wayne Batman ['\n', '\n', 'Batman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' created', ' by', ' Bob', ' Kane', ' and', ' Bill', ' Finger', '.', ' He', ' first'] "
+
+ Batman is a fictional character , a superhero created by Bob Kane and Bill Finger . He first" False dumping ground for Batman's psychotic 3 [' dumping', ' ground', ' for', ' Batman']
+10 2 Secret identity of x -1 Secret identity of Wonder Woman Diana Prince Wonder Woman ['\n', '\n', 'Wonder', ' Woman', ' is', ' a', ' fictional', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' Created'] "
+
+ Wonder Woman is a fictional superhero ine appearing in American comic books published by DC Comics . Created" False with Zatanna and Wonder Woman struggling 6 [' with', ' Z', 'at', 'anna', ' and', ' Wonder', ' Woman']
+11 2 Secret identity of x -1 Secret identity of Wonder Woman Diana Prince Wonder Woman ['\n', '\n', 'Wonder', ' Woman', ' is', ' a', ' fictional', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' Created'] "
+
+ Wonder Woman is a fictional superhero ine appearing in American comic books published by DC Comics . Created" False " Beyoncé is seen wearing a Wonder Woman T-shirt.
+" 7 [' Beyon', 'cé', ' is', ' seen', ' wearing', ' a', ' Wonder', ' Woman']
+12 2 Secret identity of x -1 Secret identity of Wonder Woman Diana Prince Wonder Woman ['\n', '\n', 'Wonder', ' Woman', ' is', ' a', ' fictional', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' Created'] "
+
+ Wonder Woman is a fictional superhero ine appearing in American comic books published by DC Comics . Created" False Comics production Wonder Woman: Amazonia 3 [' Comics', ' production', ' Wonder', ' Woman']
+13 2 Secret identity of x -1 Secret identity of Wonder Woman Diana Prince Wonder Woman ['\n', '\n', 'Wonder', ' Woman', ' is', ' a', ' fictional', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' Created'] "
+
+ Wonder Woman is a fictional superhero ine appearing in American comic books published by DC Comics . Created" False Superman, Batman, and Wonder Woman — were still 6 [' Superman', ',', ' Batman', ',', ' and', ' Wonder', ' Woman']
+14 2 Secret identity of x -1 Secret identity of Wonder Woman Diana Prince Wonder Woman ['\n', '\n', 'Wonder', ' Woman', ' is', ' a', ' fictional', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' Created'] "
+
+ Wonder Woman is a fictional superhero ine appearing in American comic books published by DC Comics . Created" False Gadot was cast as Wonder Woman in December 2013. 6 [' Gad', 'ot', ' was', ' cast', ' as', ' Wonder', ' Woman']
+15 3 Secret identity of x -1 Secret identity of Spider-Man Peter Parker Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False Friends, the 1967 Spider-Man cartoon and 6 [' Friends', ',', ' the', ' 1967', ' Spider', '-', 'Man']
+16 3 Secret identity of x -1 Secret identity of Spider-Man Peter Parker Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False to The Amazing Spider-Man — her sole 5 [' to', ' The', ' Amazing', ' Spider', '-', 'Man']
+17 3 Secret identity of x -1 Secret identity of Spider-Man Peter Parker Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False 2 ['Spider', '-', 'Man']
+18 3 Secret identity of x -1 Secret identity of Spider-Man Peter Parker Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False and wind effects. Spider-Man attempts to pull 6 [' and', ' wind', ' effects', '.', ' Spider', '-', 'Man']
+19 3 Secret identity of x -1 Secret identity of Spider-Man Peter Parker Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False 2 ['Spider', '-', 'Man']
+20 4 Secret identity of x -1 Secret identity of Iron Man Tony Stark Iron Man ['\n', '\n', 'The', ' Iron', ' Man', ' suit', ' is', ' a', ' fictional', ' superhero', ' costume', ' worn', ' by', ' Tony', ' Stark', ',', ' a', '.', 'k', '.'] "
+
+ The Iron Man suit is a fictional superhero costume worn by Tony Stark , a . k ." True Hayter, and pitted Iron Man against his 6 [' Hay', 'ter', ',', ' and', ' pitted', ' Iron', ' Man']
+21 4 Secret identity of x -1 Secret identity of Iron Man Tony Stark Iron Man ['\n', '\n', 'The', ' Iron', ' Man', ' suit', ' is', ' a', ' fictional', ' superhero', ' costume', ' worn', ' by', ' Tony', ' Stark', ',', ' a', '.', 'k', '.'] "
+
+ The Iron Man suit is a fictional superhero costume worn by Tony Stark , a . k ." True " Stark build the first Iron Man suit.
+" 5 [' Stark', ' build', ' the', ' first', ' Iron', ' Man']
+22 4 Secret identity of x -1 Secret identity of Iron Man Tony Stark Iron Man ['\n', '\n', 'The', ' Iron', ' Man', ' suit', ' is', ' a', ' fictional', ' superhero', ' costume', ' worn', ' by', ' Tony', ' Stark', ',', ' a', '.', 'k', '.'] "
+
+ The Iron Man suit is a fictional superhero costume worn by Tony Stark , a . k ." True Orton in a 60-minute Iron Man match. Cena successfully 8 [' Or', 'ton', ' in', ' a', ' 60', '-', 'minute', ' Iron', ' Man']
+23 4 Secret identity of x -1 Secret identity of Iron Man Tony Stark Iron Man ['\n', '\n', 'The', ' Iron', ' Man', ' suit', ' is', ' a', ' fictional', ' superhero', ' costume', ' worn', ' by', ' Tony', ' Stark', ',', ' a', '.', 'k', '.'] "
+
+ The Iron Man suit is a fictional superhero costume worn by Tony Stark , a . k ." True premiere of Iron Man 3 on April 24. 3 [' premiere', ' of', ' Iron', ' Man']
+24 4 Secret identity of x -1 Secret identity of Iron Man Tony Stark Iron Man ['\n', '\n', 'The', ' Iron', ' Man', ' suit', ' is', ' a', ' fictional', ' superhero', ' costume', ' worn', ' by', ' Tony', ' Stark', ',', ' a', '.', 'k', '.'] "
+
+ The Iron Man suit is a fictional superhero costume worn by Tony Stark , a . k ." True because Favreau felt Iron Man's archnemesis, the 5 [' because', ' Fav', 'reau', ' felt', ' Iron', ' Man']
+25 5 Secret identity of x -1 Secret identity of Captain America Steve Rogers Captain America [':', ' The', ' Winter', ' Soldier', '\n', '\n', 'The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ": The Winter Soldier
+
+ The Winter Soldier is a fictional character appearing in American comic books published by" False 2, Thor and Captain America: The First 5 [' 2', ',', ' Thor', ' and', ' Captain', ' America']
+26 5 Secret identity of x -1 Secret identity of Captain America Steve Rogers Captain America [':', ' The', ' Winter', ' Soldier', '\n', '\n', 'The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ": The Winter Soldier
+
+ The Winter Soldier is a fictional character appearing in American comic books published by" False " year to watch Captain America save the day.""
+" 4 [' year', ' to', ' watch', ' Captain', ' America']
+27 5 Secret identity of x -1 Secret identity of Captain America Steve Rogers Captain America [':', ' The', ' Winter', ' Soldier', '\n', '\n', 'The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ": The Winter Soldier
+
+ The Winter Soldier is a fictional character appearing in American comic books published by" False 1 ['Captain', ' America']
+28 5 Secret identity of x -1 Secret identity of Captain America Steve Rogers Captain America [':', ' The', ' Winter', ' Soldier', '\n', '\n', 'The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ": The Winter Soldier
+
+ The Winter Soldier is a fictional character appearing in American comic books published by" False Man, Thor, Captain America with neo-noir attitude. 5 [' Man', ',', ' Thor', ',', ' Captain', ' America']
+29 5 Secret identity of x -1 Secret identity of Captain America Steve Rogers Captain America [':', ' The', ' Winter', ' Soldier', '\n', '\n', 'The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ": The Winter Soldier
+
+ The Winter Soldier is a fictional character appearing in American comic books published by" False around the fictional Captain America Adventure Program radio 4 [' around', ' the', ' fictional', ' Captain', ' America']
+30 6 Secret identity of x -1 Secret identity of The Flash Barry Allen The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" True in complexity. The Flash version was much 4 [' in', ' complexity', '.', ' The', ' Flash']
+31 6 Secret identity of x -1 Secret identity of The Flash Barry Allen The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" True the season. The Flash finished first during 4 [' the', ' season', '.', ' The', ' Flash']
+32 6 Secret identity of x -1 Secret identity of The Flash Barry Allen The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" True Comics, 2 [' Comics', 'The', ' Flash']
+33 6 Secret identity of x -1 Secret identity of The Flash Barry Allen The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" True care is needed. The Flash is used for yacht 5 [' care', ' is', ' needed', '.', ' The', ' Flash']
+34 6 Secret identity of x -1 Secret identity of The Flash Barry Allen The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" True Portal, titled Portal: The Flash Version, just prior 6 [' Portal', ',', ' titled', ' Portal', ':', ' The', ' Flash']
+35 7 Secret identity of x -1 Secret identity of Green Lantern Hal Jordan Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False 1 ['Green', ' Lantern']
+36 7 Secret identity of x -1 Secret identity of Green Lantern Hal Jordan Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False 1 ['Green', ' Lantern']
+37 7 Secret identity of x -1 Secret identity of Green Lantern Hal Jordan Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False 1 ['Green', ' Lantern']
+38 7 Secret identity of x -1 Secret identity of Green Lantern Hal Jordan Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False by DC Comics in its Green Lantern and Green Lantern 6 [' by', ' DC', ' Comics', ' in', ' its', ' Green', ' Lantern']
+39 7 Secret identity of x -1 Secret identity of Green Lantern Hal Jordan Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False 1 ['Green', ' Lantern']
+40 8 Secret identity of x -1 Secret identity of Green Arrow Oliver Queen Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False of Oliver Queen's Green Arrow persona in the season 5 "[' of', ' Oliver', ' Queen', ""'s"", ' Green', ' Arrow']"
+41 8 Secret identity of x -1 Secret identity of Green Arrow Oliver Queen Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False begins assisting Green Arrow (Justin Hartley) 3 [' begins', ' assisting', ' Green', ' Arrow']
+42 8 Secret identity of x -1 Secret identity of Green Arrow Oliver Queen Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False creating Oliver's Green Arrow costume, Cranstoun 4 "[' creating', ' Oliver', ""'s"", ' Green', ' Arrow']"
+43 8 Secret identity of x -1 Secret identity of Green Arrow Oliver Queen Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False debut in the Green Arrow television 4 [' debut', ' in', ' the', ' Green', ' Arrow']
+44 8 Secret identity of x -1 Secret identity of Green Arrow Oliver Queen Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False and his alter ego Green Arrow. Oliver was 5 [' and', ' his', ' alter', ' ego', ' Green', ' Arrow']
+45 9 Secret identity of x -1 Secret identity of Aquaman Arthur Curry Aquaman ['\n', '\n', 'A', 'qu', 'aman', ' is', ' the', ' king', ' of', ' the', ' sea', '.', ' He', ' is', ' the', ' ruler', ' of', ' Atlantis', ',', ' a'] "
+
+ A qu aman is the king of the sea . He is the ruler of Atlantis , a" False mentions that Aquaman is pretty 3 [' mentions', ' that', ' Aqu', 'aman']
+46 9 Secret identity of x -1 Secret identity of Aquaman Arthur Curry Aquaman ['\n', '\n', 'A', 'qu', 'aman', ' is', ' the', ' king', ' of', ' the', ' sea', '.', ' He', ' is', ' the', ' ruler', ' of', ' Atlantis', ',', ' a'] "
+
+ A qu aman is the king of the sea . He is the ruler of Atlantis , a" False " to work on the Aquaman pilot.
+" 5 [' to', ' work', ' on', ' the', ' Aqu', 'aman']
+47 9 Secret identity of x -1 Secret identity of Aquaman Arthur Curry Aquaman ['\n', '\n', 'A', 'qu', 'aman', ' is', ' the', ' king', ' of', ' the', ' sea', '.', ' He', ' is', ' the', ' ruler', ' of', ' Atlantis', ',', ' a'] "
+
+ A qu aman is the king of the sea . He is the ruler of Atlantis , a" False The concept of Aquaman stemmed from a fifth 4 [' The', ' concept', ' of', ' Aqu', 'aman']
+48 9 Secret identity of x -1 Secret identity of Aquaman Arthur Curry Aquaman ['\n', '\n', 'A', 'qu', 'aman', ' is', ' the', ' king', ' of', ' the', ' sea', '.', ' He', ' is', ' the', ' ruler', ' of', ' Atlantis', ',', ' a'] "
+
+ A qu aman is the king of the sea . He is the ruler of Atlantis , a" False " he discovered that Aquaman had a ""serious ..." 4 [' he', ' discovered', ' that', ' Aqu', 'aman']
+49 9 Secret identity of x -1 Secret identity of Aquaman Arthur Curry Aquaman ['\n', '\n', 'A', 'qu', 'aman', ' is', ' the', ' king', ' of', ' the', ' sea', '.', ' He', ' is', ' the', ' ruler', ' of', ' Atlantis', ',', ' a'] "
+
+ A qu aman is the king of the sea . He is the ruler of Atlantis , a" False (DC Comics, 5 [' (', 'DC', ' Comics', 'A', 'qu', 'aman']
+50 10 Secret identity of x -1 Secret identity of Daredevil Matt Murdock Daredevil [':', ' The', ' Man', ' Without', ' Fear', '\n', '\n', 'The', ' Man', ' Without', ' Fear', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in'] ": The Man Without Fear
+
+ The Man Without Fear is a fictional character , a superhero appearing in" False (Matt Murdock / Daredevil in Daredevil) 5 [' (', 'Matt', ' Murd', 'ock', ' /', ' Daredevil']
+51 10 Secret identity of x -1 Secret identity of Daredevil Matt Murdock Daredevil [':', ' The', ' Man', ' Without', ' Fear', '\n', '\n', 'The', ' Man', ' Without', ' Fear', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in'] ": The Man Without Fear
+
+ The Man Without Fear is a fictional character , a superhero appearing in" False streaming with Marvel's Daredevil on Netflix in 2015, 4 "[' streaming', ' with', ' Marvel', ""'s"", ' Daredevil']"
+52 10 Secret identity of x -1 Secret identity of Daredevil Matt Murdock Daredevil [':', ' The', ' Man', ' Without', ' Fear', '\n', '\n', 'The', ' Man', ' Without', ' Fear', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in'] ": The Man Without Fear
+
+ The Man Without Fear is a fictional character , a superhero appearing in" False Charles Xavier and Daredevil — are introduced 3 [' Charles', ' Xavier', ' and', ' Daredevil']
+53 10 Secret identity of x -1 Secret identity of Daredevil Matt Murdock Daredevil [':', ' The', ' Man', ' Without', ' Fear', '\n', '\n', 'The', ' Man', ' Without', ' Fear', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in'] ": The Man Without Fear
+
+ The Man Without Fear is a fictional character , a superhero appearing in" False director's cut of Daredevil was released, reincorporating 4 "[' director', ""'s"", ' cut', ' of', ' Daredevil']"
+54 10 Secret identity of x -1 Secret identity of Daredevil Matt Murdock Daredevil [':', ' The', ' Man', ' Without', ' Fear', '\n', '\n', 'The', ' Man', ' Without', ' Fear', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in'] ": The Man Without Fear
+
+ The Man Without Fear is a fictional character , a superhero appearing in" False 2015 TV series Daredevil and the costume 3 [' 2015', ' TV', ' series', ' Daredevil']
+55 11 Secret identity of x -1 Secret identity of Cyclops Scott Summers Cyclops ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False conclusion when Cyclops officially closes 3 [' conclusion', ' when', ' Cycl', 'ops']
+56 11 Secret identity of x -1 Secret identity of Cyclops Scott Summers Cyclops ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False Professor X and Cyclops go to see Magneto. 4 [' Professor', ' X', ' and', ' Cycl', 'ops']
+57 11 Secret identity of x -1 Secret identity of Cyclops Scott Summers Cyclops ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False brainwashed Cyclops while trying to 3 [' brain', 'washed', ' Cycl', 'ops']
+58 11 Secret identity of x -1 Secret identity of Cyclops Scott Summers Cyclops ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False her shock. Meanwhile, Cyclops, still distraught 6 [' her', ' shock', '.', ' Meanwhile', ',', ' Cycl', 'ops']
+59 11 Secret identity of x -1 Secret identity of Cyclops Scott Summers Cyclops ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False shock. Meanwhile, Cyclops, still distraught 5 [' shock', '.', ' Meanwhile', ',', ' Cycl', 'ops']
+60 12 Secret identity of x -1 Secret identity of Storm Ororo Munroe Storm ['y', ' Daniels', '\n', '\n', 'The', ' porn', ' star', ' who', ' claims', ' she', ' had', ' an', ' affair', ' with', ' Donald', ' Trump', ' in', ' 2006', ' has', ' been'] "y Daniels
+
+ The porn star who claims she had an affair with Donald Trump in 2006 has been" False between it and Tropical Storm Humberto to its northeast. 4 [' between', ' it', ' and', ' Tropical', ' Storm']
+61 12 Secret identity of x -1 Secret identity of Storm Ororo Munroe Storm ['y', ' Daniels', '\n', '\n', 'The', ' porn', ' star', ' who', ' claims', ' she', ' had', ' an', ' affair', ' with', ' Donald', ' Trump', ' in', ' 2006', ' has', ' been'] "y Daniels
+
+ The porn star who claims she had an affair with Donald Trump in 2006 has been" False the Gathering Storm by the National 2 [' the', ' Gathering', ' Storm']
+62 12 Secret identity of x -1 Secret identity of Storm Ororo Munroe Storm ['y', ' Daniels', '\n', '\n', 'The', ' porn', ' star', ' who', ' claims', ' she', ' had', ' an', ' affair', ' with', ' Donald', ' Trump', ' in', ' 2006', ' has', ' been'] "y Daniels
+
+ The porn star who claims she had an affair with Donald Trump in 2006 has been" False " === Cyclonic Storm BOB 06 ===
+" 3 [' ===', ' Cycl', 'onic', ' Storm']
+63 12 Secret identity of x -1 Secret identity of Storm Ororo Munroe Storm ['y', ' Daniels', '\n', '\n', 'The', ' porn', ' star', ' who', ' claims', ' she', ' had', ' an', ' affair', ' with', ' Donald', ' Trump', ' in', ' 2006', ' has', ' been'] "y Daniels
+
+ The porn star who claims she had an affair with Donald Trump in 2006 has been" False " into Tropical Storm Estelle.
+" 2 [' into', ' Tropical', ' Storm']
+64 12 Secret identity of x -1 Secret identity of Storm Ororo Munroe Storm ['y', ' Daniels', '\n', '\n', 'The', ' porn', ' star', ' who', ' claims', ' she', ' had', ' an', ' affair', ' with', ' Donald', ' Trump', ' in', ' 2006', ' has', ' been'] "y Daniels
+
+ The porn star who claims she had an affair with Donald Trump in 2006 has been" False intensified into Tropical Storm Kammuri off the 3 [' intensified', ' into', ' Tropical', ' Storm']
+65 14 Secret identity of x -1 Secret identity of Invisible Woman Sue Storm Invisible Woman ['\n', '\n', 'The', ' Invisible', ' Woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The Invisible Woman is a fictional character , a superhero ine appearing in American comic books published by" False (published as Invisible Woman in English), discusses 4 [' (', 'published', ' as', ' Invisible', ' Woman']
+66 14 Secret identity of x -1 Secret identity of Invisible Woman Sue Storm Invisible Woman ['\n', '\n', 'The', ' Invisible', ' Woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The Invisible Woman is a fictional character , a superhero ine appearing in American comic books published by" False (published as Invisible Woman in English), discusses 4 [' (', 'published', ' as', ' Invisible', ' Woman']
+67 14 Secret identity of x -1 Secret identity of Invisible Woman Sue Storm Invisible Woman ['\n', '\n', 'The', ' Invisible', ' Woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The Invisible Woman is a fictional character , a superhero ine appearing in American comic books published by" False (published as Invisible Woman in English), discusses 4 [' (', 'published', ' as', ' Invisible', ' Woman']
+68 14 Secret identity of x -1 Secret identity of Invisible Woman Sue Storm Invisible Woman ['\n', '\n', 'The', ' Invisible', ' Woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The Invisible Woman is a fictional character , a superhero ine appearing in American comic books published by" False unterwegs (published as Invisible Woman in English), discusses 8 [' un', 'ter', 'we', 'gs', ' (', 'published', ' as', ' Invisible', ' Woman']
+69 14 Secret identity of x -1 Secret identity of Invisible Woman Sue Storm Invisible Woman ['\n', '\n', 'The', ' Invisible', ' Woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The Invisible Woman is a fictional character , a superhero ine appearing in American comic books published by" False unterwegs (published as Invisible Woman in English), 8 [' un', 'ter', 'we', 'gs', ' (', 'published', ' as', ' Invisible', ' Woman']
+70 15 Secret identity of x -1 Secret identity of Human Torch Johnny Storm Human Torch [' and', ' the', ' Fantastic', ' Four', '\n', '\n', 'The', ' Fantastic', ' Four', ' are', ' a', ' fictional', ' superhero', ' team', ' appearing', ' in', ' American', ' comic', ' books', ' published'] " and the Fantastic Four
+
+ The Fantastic Four are a fictional superhero team appearing in American comic books published" False with Marvel as the Human Torch in the Fantastic 5 [' with', ' Marvel', ' as', ' the', ' Human', ' Torch']
+71 16 Secret identity of x -1 Secret identity of The Thing Ben Grimm The Thing ['\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' fictional', ' alien', ',', ' and', ' a', ' fictional', ' alien', ' from', ' the', ' Marvel'] "
+
+ The Thing is a fictional character , a fictional alien , and a fictional alien from the Marvel" False from the ending of The Thing from Another World, 5 [' from', ' the', ' ending', ' of', ' The', ' Thing']
+72 16 Secret identity of x -1 Secret identity of The Thing Ben Grimm The Thing ['\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' fictional', ' alien', ',', ' and', ' a', ' fictional', ' alien', ' from', ' the', ' Marvel'] "
+
+ The Thing is a fictional character , a fictional alien , and a fictional alien from the Marvel" False its debt to The Thing, and Juliette 4 [' its', ' debt', ' to', ' The', ' Thing']
+73 16 Secret identity of x -1 Secret identity of The Thing Ben Grimm The Thing ['\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' fictional', ' alien', ',', ' and', ' a', ' fictional', ' alien', ' from', ' the', ' Marvel'] "
+
+ The Thing is a fictional character , a fictional alien , and a fictional alien from the Marvel" False occupant, was filmed as The Thing from Another World 6 [' occupant', ',', ' was', ' filmed', ' as', ' The', ' Thing']
+74 16 Secret identity of x -1 Secret identity of The Thing Ben Grimm The Thing ['\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' fictional', ' alien', ',', ' and', ' a', ' fictional', ' alien', ' from', ' the', ' Marvel'] "
+
+ The Thing is a fictional character , a fictional alien , and a fictional alien from the Marvel" False horror formula. The Thing from 2002 has been 4 [' horror', ' formula', '.', ' The', ' Thing']
+75 16 Secret identity of x -1 Secret identity of The Thing Ben Grimm The Thing ['\n', '\n', 'The', ' Thing', ' is', ' a', ' fictional', ' character', ',', ' a', ' fictional', ' alien', ',', ' and', ' a', ' fictional', ' alien', ' from', ' the', ' Marvel'] "
+
+ The Thing is a fictional character , a fictional alien , and a fictional alien from the Marvel" False Dead (1968) and The Thing (1982) in the plot 6 [' Dead', ' (', '1968', ')', ' and', ' The', ' Thing']
+76 18 Secret identity of x -1 Secret identity of Nightwing Dick Grayson Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" True is restored as Nightwing with Wayne serving 4 [' is', ' restored', ' as', ' Night', 'wing']
+77 18 Secret identity of x -1 Secret identity of Nightwing Dick Grayson Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" True summon them to help Nightwing and Robin deal 5 [' summon', ' them', ' to', ' help', ' Night', 'wing']
+78 18 Secret identity of x -1 Secret identity of Nightwing Dick Grayson Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" True " released. The ""Nightwing"" pack, released" 5 "[' released', '.', ' The', ' ""', 'Night', 'wing']"
+79 18 Secret identity of x -1 Secret identity of Nightwing Dick Grayson Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" True returns to the mantle of Nightwing and appears in his 6 [' returns', ' to', ' the', ' mantle', ' of', ' Night', 'wing']
+80 18 Secret identity of x -1 Secret identity of Nightwing Dick Grayson Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" True Batman's partner Nightwing (Barbara Gordon). 4 "[' Batman', ""'s"", ' partner', ' Night', 'wing']"
+81 19 Secret identity of x -1 Secret identity of Shazam Billy Batson Shazam ['!', '\n', '\n', 'The', ' DC', ' Universe', ' is', ' a', ' vast', ',', ' interconnected', ' world', ' of', ' heroes', ',', ' villains', ',', ' and', ' everything', ' in'] "!
+
+ The DC Universe is a vast , interconnected world of heroes , villains , and everything in" False " ""first listen"" on Shazam and the official" 7 "[' ""', 'first', ' listen', '""', ' on', ' Sh', 'az', 'am']"
+82 19 Secret identity of x -1 Secret identity of Shazam Billy Batson Shazam ['!', '\n', '\n', 'The', ' DC', ' Universe', ' is', ' a', ' vast', ',', ' interconnected', ' world', ' of', ' heroes', ',', ' villains', ',', ' and', ' everything', ' in'] "!
+
+ The DC Universe is a vast , interconnected world of heroes , villains , and everything in" False " listen"" on Shazam and the official" 5 "[' listen', '""', ' on', ' Sh', 'az', 'am']"
+83 19 Secret identity of x -1 Secret identity of Shazam Billy Batson Shazam ['!', '\n', '\n', 'The', ' DC', ' Universe', ' is', ' a', ' vast', ',', ' interconnected', ' world', ' of', ' heroes', ',', ' villains', ',', ' and', ' everything', ' in'] "!
+
+ The DC Universe is a vast , interconnected world of heroes , villains , and everything in" False villain, in a film about Shazam as a part of 8 [' villain', ',', ' in', ' a', ' film', ' about', ' Sh', 'az', 'am']
+84 19 Secret identity of x -1 Secret identity of Shazam Billy Batson Shazam ['!', '\n', '\n', 'The', ' DC', ' Universe', ' is', ' a', ' vast', ',', ' interconnected', ' world', ' of', ' heroes', ',', ' villains', ',', ' and', ' everything', ' in'] "!
+
+ The DC Universe is a vast , interconnected world of heroes , villains , and everything in" False Daniels (Grandad). The Shazam Productions 8 [' Daniels', ' (', 'Grand', 'ad', ').', ' The', ' Sh', 'az', 'am']
+85 19 Secret identity of x -1 Secret identity of Shazam Billy Batson Shazam ['!', '\n', '\n', 'The', ' DC', ' Universe', ' is', ' a', ' vast', ',', ' interconnected', ' world', ' of', ' heroes', ',', ' villains', ',', ' and', ' everything', ' in'] "!
+
+ The DC Universe is a vast , interconnected world of heroes , villains , and everything in" False production company, Shazam Productions. 5 [' production', ' company', ',', ' Sh', 'az', 'am']
+86 20 Secret identity of x -1 Secret identity of Green Lantern Kyle Rayner Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False releasing a series of Green Lantern Movie Prequel 5 [' releasing', ' a', ' series', ' of', ' Green', ' Lantern']
+87 20 Secret identity of x -1 Secret identity of Green Lantern Kyle Rayner Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False one-shot and co-write Green Lantern / Sinestro 8 [' one', '-', 'shot', ' and', ' co', '-', 'write', ' Green', ' Lantern']
+88 20 Secret identity of x -1 Secret identity of Green Lantern Kyle Rayner Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False through the Green Lantern and Green Lantern 3 [' through', ' the', ' Green', ' Lantern']
+89 20 Secret identity of x -1 Secret identity of Green Lantern Kyle Rayner Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False upper body; Green Lantern creates pressure 4 [' upper', ' body', ';', ' Green', ' Lantern']
+90 20 Secret identity of x -1 Secret identity of Green Lantern Kyle Rayner Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False 1 ['Green', ' Lantern']
+91 21 Secret identity of x -1 Secret identity of Red Robin Tim Drake Red Robin ['\n', '\n', 'The', ' Red', ' Robin', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' is', ' the', ' alter', ' ego'] "
+
+ The Red Robin is a fictional character in the DC Comics universe . He is the alter ego" False " ""When the Red, Red Robin (Comes Bob," 6 "[' ""', 'When', ' the', ' Red', ',', ' Red', ' Robin']"
+92 21 Secret identity of x -1 Secret identity of Red Robin Tim Drake Red Robin ['\n', '\n', 'The', ' Red', ' Robin', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' is', ' the', ' alter', ' ego'] "
+
+ The Red Robin is a fictional character in the DC Comics universe . He is the alter ego" False commissioned by Red Robin and released 3 [' commissioned', ' by', ' Red', ' Robin']
+93 21 Secret identity of x -1 Secret identity of Red Robin Tim Drake Red Robin ['\n', '\n', 'The', ' Red', ' Robin', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' is', ' the', ' alter', ' ego'] "
+
+ The Red Robin is a fictional character in the DC Comics universe . He is the alter ego" False costumes, his Red Robin costume, the designs 4 [' costumes', ',', ' his', ' Red', ' Robin']
+94 21 Secret identity of x -1 Secret identity of Red Robin Tim Drake Red Robin ['\n', '\n', 'The', ' Red', ' Robin', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' is', ' the', ' alter', ' ego'] "
+
+ The Red Robin is a fictional character in the DC Comics universe . He is the alter ego" False character in a new Red Robin series. However, 5 [' character', ' in', ' a', ' new', ' Red', ' Robin']
+95 21 Secret identity of x -1 Secret identity of Red Robin Tim Drake Red Robin ['\n', '\n', 'The', ' Red', ' Robin', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' is', ' the', ' alter', ' ego'] "
+
+ The Red Robin is a fictional character in the DC Comics universe . He is the alter ego" False character in a new Red Robin series. However, 5 [' character', ' in', ' a', ' new', ' Red', ' Robin']
+96 22 Secret identity of x -1 Secret identity of Red Hood Jason Todd Red Hood [' and', ' the', ' Out', 'laws', '\n', '\n', 'The', ' Red', ' Hood', ' and', ' the', ' Out', 'laws', ' is', ' a', ' fictional', ' superhero', ' team', ' appearing', ' in'] " and the Out laws
+
+ The Red Hood and the Out laws is a fictional superhero team appearing in" False Robin, Red Robin, Red Hood and Batgirl, 6 [' Robin', ',', ' Red', ' Robin', ',', ' Red', ' Hood']
+97 22 Secret identity of x -1 Secret identity of Red Hood Jason Todd Red Hood [' and', ' the', ' Out', 'laws', '\n', '\n', 'The', ' Red', ' Hood', ' and', ' the', ' Out', 'laws', ' is', ' a', ' fictional', ' superhero', ' team', ' appearing', ' in'] " and the Out laws
+
+ The Red Hood and the Out laws is a fictional superhero team appearing in" False the helmet of the Red Hood is revealed. In addition 5 [' the', ' helmet', ' of', ' the', ' Red', ' Hood']
+98 22 Secret identity of x -1 Secret identity of Red Hood Jason Todd Red Hood [' and', ' the', ' Out', 'laws', '\n', '\n', 'The', ' Red', ' Hood', ' and', ' the', ' Out', 'laws', ' is', ' a', ' fictional', ' superhero', ' team', ' appearing', ' in'] " and the Out laws
+
+ The Red Hood and the Out laws is a fictional superhero team appearing in" False Harley Quinn and Red Hood story packs, 4 [' Harley', ' Quinn', ' and', ' Red', ' Hood']
+99 22 Secret identity of x -1 Secret identity of Red Hood Jason Todd Red Hood [' and', ' the', ' Out', 'laws', '\n', '\n', 'The', ' Red', ' Hood', ' and', ' the', ' Out', 'laws', ' is', ' a', ' fictional', ' superhero', ' team', ' appearing', ' in'] " and the Out laws
+
+ The Red Hood and the Out laws is a fictional superhero team appearing in" False Jason Todd as Red Hood is also a playable 4 [' Jason', ' Todd', ' as', ' Red', ' Hood']
+100 22 Secret identity of x -1 Secret identity of Red Hood Jason Todd Red Hood [' and', ' the', ' Out', 'laws', '\n', '\n', 'The', ' Red', ' Hood', ' and', ' the', ' Out', 'laws', ' is', ' a', ' fictional', ' superhero', ' team', ' appearing', ' in'] " and the Out laws
+
+ The Red Hood and the Out laws is a fictional superhero team appearing in" False the identity of Red Hood to support 4 [' the', ' identity', ' of', ' Red', ' Hood']
+101 23 Secret identity of x -1 Secret identity of Catwoman Selina Kyle Catwoman ['\n', '\n', 'Cat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Cat woman is a fictional character , a superhero ine appearing in American comic books published by DC" False " Michelle Pfeiffer as Catwoman / Selina Kyle
+" 7 [' Michelle', ' P', 'fe', 'iff', 'er', ' as', ' Cat', 'woman']
+102 23 Secret identity of x -1 Secret identity of Catwoman Selina Kyle Catwoman ['\n', '\n', 'Cat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Cat woman is a fictional character , a superhero ine appearing in American comic books published by DC" False Aguilera as Catwoman and performing 5 [' Agu', 'iler', 'a', ' as', ' Cat', 'woman']
+103 23 Secret identity of x -1 Secret identity of Catwoman Selina Kyle Catwoman ['\n', '\n', 'Cat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Cat woman is a fictional character , a superhero ine appearing in American comic books published by DC" False 1 ['Cat', 'woman']
+104 23 Secret identity of x -1 Secret identity of Catwoman Selina Kyle Catwoman ['\n', '\n', 'Cat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Cat woman is a fictional character , a superhero ine appearing in American comic books published by DC" False the feature in the Catwoman soundtrack. 5 [' the', ' feature', ' in', ' the', ' Cat', 'woman']
+105 23 Secret identity of x -1 Secret identity of Catwoman Selina Kyle Catwoman ['\n', '\n', 'Cat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Cat woman is a fictional character , a superhero ine appearing in American comic books published by DC" False ones such as Catwoman. In Catwoman 4 [' ones', ' such', ' as', ' Cat', 'woman']
+106 24 Secret identity of x -1 Secret identity of Phoenix Jean Grey Phoenix [' Wright', ':', ' Ace', ' Attorney', '\n', '\n', 'The', ' game', ' is', ' a', ' spin', '-', 'off', ' of', ' the', ' Ace', ' Attorney', ' series', ',', ' which'] " Wright : Ace Attorney
+
+ The game is a spin - off of the Ace Attorney series , which" False Hensley died in Phoenix on June 21, 2000. 5 [' H', 'ens', 'ley', ' died', ' in', ' Phoenix']
+107 24 Secret identity of x -1 Secret identity of Phoenix Jean Grey Phoenix [' Wright', ':', ' Ace', ' Attorney', '\n', '\n', 'The', ' game', ' is', ' a', ' spin', '-', 'off', ' of', ' the', ' Ace', ' Attorney', ' series', ',', ' which'] " Wright : Ace Attorney
+
+ The game is a spin - off of the Ace Attorney series , which" False British destroyer Phoenix at 09: 18 2 [' British', ' destroyer', ' Phoenix']
+108 24 Secret identity of x -1 Secret identity of Phoenix Jean Grey Phoenix [' Wright', ':', ' Ace', ' Attorney', '\n', '\n', 'The', ' game', ' is', ' a', ' spin', '-', 'off', ' of', ' the', ' Ace', ' Attorney', ' series', ',', ' which'] " Wright : Ace Attorney
+
+ The game is a spin - off of the Ace Attorney series , which" False society. Both Joaquin Phoenix and John Cusack 5 [' society', '.', ' Both', ' Jo', 'aquin', ' Phoenix']
+109 24 Secret identity of x -1 Secret identity of Phoenix Jean Grey Phoenix [' Wright', ':', ' Ace', ' Attorney', '\n', '\n', 'The', ' game', ' is', ' a', ' spin', '-', 'off', ' of', ' the', ' Ace', ' Attorney', ' series', ',', ' which'] " Wright : Ace Attorney
+
+ The game is a spin - off of the Ace Attorney series , which" False Qwark on the Phoenix, Ratchet goes to 4 [' Q', 'wark', ' on', ' the', ' Phoenix']
+110 24 Secret identity of x -1 Secret identity of Phoenix Jean Grey Phoenix [' Wright', ':', ' Ace', ' Attorney', '\n', '\n', 'The', ' game', ' is', ' a', ' spin', '-', 'off', ' of', ' the', ' Ace', ' Attorney', ' series', ',', ' which'] " Wright : Ace Attorney
+
+ The game is a spin - off of the Ace Attorney series , which" False experienced in Phoenix and the surrounding 2 [' experienced', ' in', ' Phoenix']
+111 25 Secret identity of x -1 Secret identity of Wolverine Logan Howlett Wolverine ['\n', '\n', 'The', ' Wolverine', ' is', ' a', ' fictional', ' character', ',', ' a', ' mutant', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wolverine is a fictional character , a mutant superhero appearing in American comic books published by Marvel" False following the Death of Wolverine comic book, 4 [' following', ' the', ' Death', ' of', ' Wolverine']
+112 25 Secret identity of x -1 Secret identity of Wolverine Logan Howlett Wolverine ['\n', '\n', 'The', ' Wolverine', ' is', ' a', ' fictional', ' character', ',', ' a', ' mutant', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wolverine is a fictional character , a mutant superhero appearing in American comic books published by Marvel" False the depiction of Wolverine as an Army veteran 3 [' the', ' depiction', ' of', ' Wolverine']
+113 25 Secret identity of x -1 Secret identity of Wolverine Logan Howlett Wolverine ['\n', '\n', 'The', ' Wolverine', ' is', ' a', ' fictional', ' character', ',', ' a', ' mutant', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wolverine is a fictional character , a mutant superhero appearing in American comic books published by Marvel" False " way: ""[The] Wolverine halfback skidded" 5 "[' way', ':', ' ""[', 'The', ']', ' Wolverine']"
+114 25 Secret identity of x -1 Secret identity of Wolverine Logan Howlett Wolverine ['\n', '\n', 'The', ' Wolverine', ' is', ' a', ' fictional', ' character', ',', ' a', ' mutant', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wolverine is a fictional character , a mutant superhero appearing in American comic books published by Marvel" False " release of a third Wolverine film in 2017.
+" 4 [' release', ' of', ' a', ' third', ' Wolverine']
+115 25 Secret identity of x -1 Secret identity of Wolverine Logan Howlett Wolverine ['\n', '\n', 'The', ' Wolverine', ' is', ' a', ' fictional', ' character', ',', ' a', ' mutant', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wolverine is a fictional character , a mutant superhero appearing in American comic books published by Marvel" False gruff loner, Wolverine is also a 5 [' gru', 'ff', ' l', 'oner', ',', ' Wolverine']
+116 26 Secret identity of x -1 Secret identity of Supergirl Kara Zor-El Supergirl ['\n', '\n', 'Super', 'girl', ' is', ' Kara', ' Z', 'or', '-', 'El', ',', ' a', ' Krypt', 'onian', ' from', ' the', ' planet', ' Krypt', 'on', '.'] "
+
+ Super girl is Kara Z or - El , a Krypt onian from the planet Krypt on ." True starred alongside Supergirl and Wonder 3 [' starred', ' alongside', ' Super', 'girl']
+117 26 Secret identity of x -1 Secret identity of Supergirl Kara Zor-El Supergirl ['\n', '\n', 'Super', 'girl', ' is', ' Kara', ' Z', 'or', '-', 'El', ',', ' a', ' Krypt', 'onian', ' from', ' the', ' planet', ' Krypt', 'on', '.'] "
+
+ Super girl is Kara Z or - El , a Krypt onian from the planet Krypt on ." True 1984 spin-off Supergirl before selling 5 [' 1984', ' spin', '-', 'off', ' Super', 'girl']
+118 26 Secret identity of x -1 Secret identity of Supergirl Kara Zor-El Supergirl ['\n', '\n', 'Super', 'girl', ' is', ' Kara', ' Z', 'or', '-', 'El', ',', ' a', ' Krypt', 'onian', ' from', ' the', ' planet', ' Krypt', 'on', '.'] "
+
+ Super girl is Kara Z or - El , a Krypt onian from the planet Krypt on ." True Superman / Batman: The Supergirl from Krypton with 6 [' Superman', ' /', ' Batman', ':', ' The', ' Super', 'girl']
+119 26 Secret identity of x -1 Secret identity of Supergirl Kara Zor-El Supergirl ['\n', '\n', 'Super', 'girl', ' is', ' Kara', ' Z', 'or', '-', 'El', ',', ' a', ' Krypt', 'onian', ' from', ' the', ' planet', ' Krypt', 'on', '.'] "
+
+ Super girl is Kara Z or - El , a Krypt onian from the planet Krypt on ." True " Supergirl (1984) ===
+" 1 [' Super', 'girl']
+120 26 Secret identity of x -1 Secret identity of Supergirl Kara Zor-El Supergirl ['\n', '\n', 'Super', 'girl', ' is', ' Kara', ' Z', 'or', '-', 'El', ',', ' a', ' Krypt', 'onian', ' from', ' the', ' planet', ' Krypt', 'on', '.'] "
+
+ Super girl is Kara Z or - El , a Krypt onian from the planet Krypt on ." True " Elseworld's Finest: Supergirl & Batgirl (1998).
+" 7 "[' Else', 'world', ""'s"", ' Fin', 'est', ':', ' Super', 'girl']"
+121 27 Secret identity of x -1 Secret identity of Black Widow Natasha Romanoff Black Widow [':', '\n', '\n', 'Black', ' Widow', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The'] ":
+
+ Black Widow is a fictional character appearing in American comic books published by Marvel Comics . The" False Avengers like Black Widow and Nick Fury, because 3 [' Avengers', ' like', ' Black', ' Widow']
+122 27 Secret identity of x -1 Secret identity of Black Widow Natasha Romanoff Black Widow [':', '\n', '\n', 'Black', ' Widow', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The'] ":
+
+ Black Widow is a fictional character appearing in American comic books published by Marvel Comics . The" False featuring Gail “ Black Widow ” McIntyre and her 6 [' featuring', ' G', 'ail', ' �', '�', ' Black', ' Widow']
+123 27 Secret identity of x -1 Secret identity of Black Widow Natasha Romanoff Black Widow [':', '\n', '\n', 'Black', ' Widow', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The'] ":
+
+ Black Widow is a fictional character appearing in American comic books published by Marvel Comics . The" False 1 ['Black', ' Widow']
+124 27 Secret identity of x -1 Secret identity of Black Widow Natasha Romanoff Black Widow [':', '\n', '\n', 'Black', ' Widow', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The'] ":
+
+ Black Widow is a fictional character appearing in American comic books published by Marvel Comics . The" False 1 ['Black', ' Widow']
+125 27 Secret identity of x -1 Secret identity of Black Widow Natasha Romanoff Black Widow [':', '\n', '\n', 'Black', ' Widow', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The'] ":
+
+ Black Widow is a fictional character appearing in American comic books published by Marvel Comics . The" False 1 ['Black', ' Widow']
+126 28 Secret identity of x -1 Secret identity of Hawkeye Clint Barton Hawkeye [':', ' The', ' identity', ' of', ' Haw', 'keye', ' is', ' a', ' secret', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n'] ": The identity of Haw keye is a secret . He is a member of the Avengers .
+
+" False the signal to E2C Hawkeye aircraft orbiting 7 [' the', ' signal', ' to', ' E', '2', 'C', ' Haw', 'keye']
+127 28 Secret identity of x -1 Secret identity of Hawkeye Clint Barton Hawkeye [':', ' The', ' identity', ' of', ' Haw', 'keye', ' is', ' a', ' secret', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n'] ": The identity of Haw keye is a secret . He is a member of the Avengers .
+
+" False Whedon said that Hawkeye interacts more 4 [' Whedon', ' said', ' that', ' Haw', 'keye']
+128 28 Secret identity of x -1 Secret identity of Hawkeye Clint Barton Hawkeye [':', ' The', ' identity', ' of', ' Haw', 'keye', ' is', ' a', ' secret', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n'] ": The identity of Haw keye is a secret . He is a member of the Avengers .
+
+" False uniform and unshaven Hawkeye and Trapper 6 [' uniform', ' and', ' un', 'sh', 'aven', ' Haw', 'keye']
+129 28 Secret identity of x -1 Secret identity of Hawkeye Clint Barton Hawkeye [':', ' The', ' identity', ' of', ' Haw', 'keye', ' is', ' a', ' secret', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n'] ": The identity of Haw keye is a secret . He is a member of the Avengers .
+
+" False known for his role as Hawkeye Pierce from 6 [' known', ' for', ' his', ' role', ' as', ' Haw', 'keye']
+130 28 Secret identity of x -1 Secret identity of Hawkeye Clint Barton Hawkeye [':', ' The', ' identity', ' of', ' Haw', 'keye', ' is', ' a', ' secret', '.', ' He', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n'] ": The identity of Haw keye is a secret . He is a member of the Avengers .
+
+" False about this Hawkeye character, 3 [' about', ' this', ' Haw', 'keye']
+131 29 Secret identity of x -1 Secret identity of Captain Marvel Carol Danvers Captain Marvel ['\n', '\n', 'The', ' Marvel', ' Cinem', 'atic', ' Universe', ' has', ' been', ' a', ' huge', ' success', ' for', ' the', ' company', ',', ' and', ' it', '�', '�'] "
+
+ The Marvel Cinem atic Universe has been a huge success for the company , and it � �" False reprints of the Captain Marvel stories after 5 [' re', 'prints', ' of', ' the', ' Captain', ' Marvel']
+132 29 Secret identity of x -1 Secret identity of Captain Marvel Carol Danvers Captain Marvel ['\n', '\n', 'The', ' Marvel', ' Cinem', 'atic', ' Universe', ' has', ' been', ' a', ' huge', ' success', ' for', ' the', ' company', ',', ' and', ' it', '�', '�'] "
+
+ The Marvel Cinem atic Universe has been a huge success for the company , and it � �" False wanted to include Captain Marvel and Spider-Man 4 [' wanted', ' to', ' include', ' Captain', ' Marvel']
+133 29 Secret identity of x -1 Secret identity of Captain Marvel Carol Danvers Captain Marvel ['\n', '\n', 'The', ' Marvel', ' Cinem', 'atic', ' Universe', ' has', ' been', ' a', ' huge', ' success', ' for', ' the', ' company', ',', ' and', ' it', '�', '�'] "
+
+ The Marvel Cinem atic Universe has been a huge success for the company , and it � �" False revealed that Captain Marvel, who is scheduled 3 [' revealed', ' that', ' Captain', ' Marvel']
+134 29 Secret identity of x -1 Secret identity of Captain Marvel Carol Danvers Captain Marvel ['\n', '\n', 'The', ' Marvel', ' Cinem', 'atic', ' Universe', ' has', ' been', ' a', ' huge', ' success', ' for', ' the', ' company', ',', ' and', ' it', '�', '�'] "
+
+ The Marvel Cinem atic Universe has been a huge success for the company , and it � �" False restrictions. Captain Marvel and Bulletman 3 [' restrictions', '.', ' Captain', ' Marvel']
+135 29 Secret identity of x -1 Secret identity of Captain Marvel Carol Danvers Captain Marvel ['\n', '\n', 'The', ' Marvel', ' Cinem', 'atic', ' Universe', ' has', ' been', ' a', ' huge', ' success', ' for', ' the', ' company', ',', ' and', ' it', '�', '�'] "
+
+ The Marvel Cinem atic Universe has been a huge success for the company , and it � �" False wanted to include Captain Marvel and Spider-Man at 4 [' wanted', ' to', ' include', ' Captain', ' Marvel']
+136 31 Secret identity of x -1 Secret identity of Spider-Woman Jessica Drew Spider-Woman [' (', 'Jessica', ' Drew', ')', '\n', '\n', 'Jessica', ' Drew', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] " ( Jessica Drew )
+
+ Jessica Drew is a fictional character appearing in American comic books published by Marvel" True mother had voiced Spider-Woman in the original 5 [' mother', ' had', ' voiced', ' Spider', '-', 'Woman']
+137 31 Secret identity of x -1 Secret identity of Spider-Woman Jessica Drew Spider-Woman [' (', 'Jessica', ' Drew', ')', '\n', '\n', 'Jessica', ' Drew', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] " ( Jessica Drew )
+
+ Jessica Drew is a fictional character appearing in American comic books published by Marvel" True 2 ['Spider', '-', 'Woman']
+138 31 Secret identity of x -1 Secret identity of Spider-Woman Jessica Drew Spider-Woman [' (', 'Jessica', ' Drew', ')', '\n', '\n', 'Jessica', ' Drew', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] " ( Jessica Drew )
+
+ Jessica Drew is a fictional character appearing in American comic books published by Marvel" True Ganke and Spider-Woman convinces him 5 [' Gan', 'ke', ' and', ' Spider', '-', 'Woman']
+139 31 Secret identity of x -1 Secret identity of Spider-Woman Jessica Drew Spider-Woman [' (', 'Jessica', ' Drew', ')', '\n', '\n', 'Jessica', ' Drew', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] " ( Jessica Drew )
+
+ Jessica Drew is a fictional character appearing in American comic books published by Marvel" True 2 ['Spider', '-', 'Woman']
+140 31 Secret identity of x -1 Secret identity of Spider-Woman Jessica Drew Spider-Woman [' (', 'Jessica', ' Drew', ')', '\n', '\n', 'Jessica', ' Drew', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] " ( Jessica Drew )
+
+ Jessica Drew is a fictional character appearing in American comic books published by Marvel" True mother had voiced Spider-Woman in the original 5 [' mother', ' had', ' voiced', ' Spider', '-', 'Woman']
+141 32 Secret identity of x -1 Secret identity of Spider-Man Miles Morales Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False 2 ['Spider', '-', 'Man']
+142 32 Secret identity of x -1 Secret identity of Spider-Man Miles Morales Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False chase scene in Spider-Man 3 (2007), though 5 [' chase', ' scene', ' in', ' Spider', '-', 'Man']
+143 32 Secret identity of x -1 Secret identity of Spider-Man Miles Morales Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False herself on Spider-Man, admitting she 4 [' herself', ' on', ' Spider', '-', 'Man']
+144 32 Secret identity of x -1 Secret identity of Spider-Man Miles Morales Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False The Amazing Spider-Man # 573, October 4 [' The', ' Amazing', ' Spider', '-', 'Man']
+145 32 Secret identity of x -1 Secret identity of Spider-Man Miles Morales Spider-Man ['\n', '\n', 'The', ' Spider', '-', 'Man', ' of', ' the', ' Marvel', ' Universe', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American'] "
+
+ The Spider - Man of the Marvel Universe is a fictional character , a superhero appearing in American" False 2 ['Spider', '-', 'Man']
+146 33 Secret identity of x -1 Secret identity of Ms. Marvel Kamala Khan Ms. Marvel [':', ' She', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n', 'Ms', '.', ' Marvel', ' is', ' a', ' member', ' of', ' the', ' Avengers'] ": She is a member of the Avengers .
+
+ Ms . Marvel is a member of the Avengers" False takes the name Ms. Marvel from Carol Danvers, 5 [' takes', ' the', ' name', ' Ms', '.', ' Marvel']
+147 33 Secret identity of x -1 Secret identity of Ms. Marvel Kamala Khan Ms. Marvel [':', ' She', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n', 'Ms', '.', ' Marvel', ' is', ' a', ' member', ' of', ' the', ' Avengers'] ": She is a member of the Avengers .
+
+ Ms . Marvel is a member of the Avengers" False Khan) 3 [' Khan', 'Ms', '.', ' Marvel']
+148 33 Secret identity of x -1 Secret identity of Ms. Marvel Kamala Khan Ms. Marvel [':', ' She', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n', 'Ms', '.', ' Marvel', ' is', ' a', ' member', ' of', ' the', ' Avengers'] ": She is a member of the Avengers .
+
+ Ms . Marvel is a member of the Avengers" False " in June 2015, Ms. Marvel tied into the ""Secret" 6 [' in', ' June', ' 2015', ',', ' Ms', '.', ' Marvel']
+149 33 Secret identity of x -1 Secret identity of Ms. Marvel Kamala Khan Ms. Marvel [':', ' She', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n', 'Ms', '.', ' Marvel', ' is', ' a', ' member', ' of', ' the', ' Avengers'] ": She is a member of the Avengers .
+
+ Ms . Marvel is a member of the Avengers" False takes the name Ms. Marvel from Carol Danvers, 5 [' takes', ' the', ' name', ' Ms', '.', ' Marvel']
+150 33 Secret identity of x -1 Secret identity of Ms. Marvel Kamala Khan Ms. Marvel [':', ' She', ' is', ' a', ' member', ' of', ' the', ' Avengers', '.', '\n', '\n', 'Ms', '.', ' Marvel', ' is', ' a', ' member', ' of', ' the', ' Avengers'] ": She is a member of the Avengers .
+
+ Ms . Marvel is a member of the Avengers" False announced that Ms. Marvel would tie into the 4 [' announced', ' that', ' Ms', '.', ' Marvel']
+151 34 Secret identity of x -1 Secret identity of Falcon Sam Wilson Falcon "['er', '\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' gamer', ',', ' a', ' geek', 'y']" "er
+
+ I 'm a writer , a reader , a geek , a gamer , a geek y" False The Maltese Falcon (1941). Lawrence 3 [' The', ' Malt', 'ese', ' Falcon']
+152 34 Secret identity of x -1 Secret identity of Falcon Sam Wilson Falcon "['er', '\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' gamer', ',', ' a', ' geek', 'y']" "er
+
+ I 'm a writer , a reader , a geek , a gamer , a geek y" False F-16 Fighting Falcon and F / A-18; 4 [' F', '-', '16', ' Fighting', ' Falcon']
+153 34 Secret identity of x -1 Secret identity of Falcon Sam Wilson Falcon "['er', '\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' gamer', ',', ' a', ' geek', 'y']" "er
+
+ I 'm a writer , a reader , a geek , a gamer , a geek y" False journeys of Robert Falcon Scott and Ernest 3 [' journeys', ' of', ' Robert', ' Falcon']
+154 34 Secret identity of x -1 Secret identity of Falcon Sam Wilson Falcon "['er', '\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' gamer', ',', ' a', ' geek', 'y']" "er
+
+ I 'm a writer , a reader , a geek , a gamer , a geek y" False simulators, such as Falcon 3.0, in favor 5 [' sim', 'ulators', ',', ' such', ' as', ' Falcon']
+155 34 Secret identity of x -1 Secret identity of Falcon Sam Wilson Falcon "['er', '\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' gamer', ',', ' a', ' geek', 'y']" "er
+
+ I 'm a writer , a reader , a geek , a gamer , a geek y" False equip subsequent Falcon 9 first-stages 2 [' equip', ' subsequent', ' Falcon']
+156 35 Secret identity of x -1 Secret identity of Black Panther T'Challa Black Panther [':', ' The', ' Panther', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.'] : The Panther is a fictional character , a superhero appearing in American comic books published by Marvel Comics . False for members of the Black Panther Party. According to 5 [' for', ' members', ' of', ' the', ' Black', ' Panther']
+157 35 Secret identity of x -1 Secret identity of Black Panther T'Challa Black Panther [':', ' The', ' Panther', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.'] : The Panther is a fictional character , a superhero appearing in American comic books published by Marvel Comics . False Brixton branch of the Black Panther Party to join Labour. 6 [' Bri', 'xton', ' branch', ' of', ' the', ' Black', ' Panther']
+158 35 Secret identity of x -1 Secret identity of Black Panther T'Challa Black Panther [':', ' The', ' Panther', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.'] : The Panther is a fictional character , a superhero appearing in American comic books published by Marvel Comics . False of the former Black Panther Party (BPP) and Black 4 [' of', ' the', ' former', ' Black', ' Panther']
+159 35 Secret identity of x -1 Secret identity of Black Panther T'Challa Black Panther [':', ' The', ' Panther', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.'] : The Panther is a fictional character , a superhero appearing in American comic books published by Marvel Comics . False movements such as the Black Panther Party, the 5 [' movements', ' such', ' as', ' the', ' Black', ' Panther']
+160 35 Secret identity of x -1 Secret identity of Black Panther T'Challa Black Panther [':', ' The', ' Panther', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.'] : The Panther is a fictional character , a superhero appearing in American comic books published by Marvel Comics . False involved in the Black Panther Party (BPP), and 4 [' involved', ' in', ' the', ' Black', ' Panther']
+161 36 Secret identity of x -1 Secret identity of Doctor Strange Stephen Strange Doctor Strange [':', ' The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The'] ": The man who can do anything .
+
+ The man who can do anything .
+
+ The" False films like Doctor Strange and Thor: Ragnarok 3 [' films', ' like', ' Doctor', ' Strange']
+162 36 Secret identity of x -1 Secret identity of Doctor Strange Stephen Strange Doctor Strange [':', ' The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The'] ": The man who can do anything .
+
+ The man who can do anything .
+
+ The" False characters Doctor Strange and Marcus Daniels, 2 [' characters', ' Doctor', ' Strange']
+163 36 Secret identity of x -1 Secret identity of Doctor Strange Stephen Strange Doctor Strange [':', ' The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The'] ": The man who can do anything .
+
+ The man who can do anything .
+
+ The" False we ’ ve got Doctor Strange in November, two 6 [' we', ' �', '�', ' ve', ' got', ' Doctor', ' Strange']
+164 36 Secret identity of x -1 Secret identity of Doctor Strange Stephen Strange Doctor Strange [':', ' The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The'] ": The man who can do anything .
+
+ The man who can do anything .
+
+ The" False films like Doctor Strange and Thor: Ragnarok 3 [' films', ' like', ' Doctor', ' Strange']
+165 36 Secret identity of x -1 Secret identity of Doctor Strange Stephen Strange Doctor Strange [':', ' The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The', ' man', ' who', ' can', ' do', ' anything', '.', '\n', '\n', 'The'] ": The man who can do anything .
+
+ The man who can do anything .
+
+ The" False " was inspired by a Doctor Strange comic book ""which" 5 [' was', ' inspired', ' by', ' a', ' Doctor', ' Strange']
+166 37 Secret identity of x -1 Secret identity of Ant-Man Hank Pym Ant-Man [':', ' Scott', ' Lang', ' is', ' a', ' former', ' NASA', ' engineer', ' who', ' was', ' recruited', ' by', ' S', '.', 'H', '.', 'I', '.', 'E', '.'] : Scott Lang is a former NASA engineer who was recruited by S . H . I . E . False International, Wright said that Ant-Man would not fit in 7 [' International', ',', ' Wright', ' said', ' that', ' Ant', '-', 'Man']
+167 37 Secret identity of x -1 Secret identity of Ant-Man Hank Pym Ant-Man [':', ' Scott', ' Lang', ' is', ' a', ' former', ' NASA', ' engineer', ' who', ' was', ' recruited', ' by', ' S', '.', 'H', '.', 'I', '.', 'E', '.'] : Scott Lang is a former NASA engineer who was recruited by S . H . I . E . False difference between Ant-Man and other films 4 [' difference', ' between', ' Ant', '-', 'Man']
+168 37 Secret identity of x -1 Secret identity of Ant-Man Hank Pym Ant-Man [':', ' Scott', ' Lang', ' is', ' a', ' former', ' NASA', ' engineer', ' who', ' was', ' recruited', ' by', ' S', '.', 'H', '.', 'I', '.', 'E', '.'] : Scott Lang is a former NASA engineer who was recruited by S . H . I . E . False " Sun-Times said Ant-Man ""is a lightweight," 6 [' Sun', '-', 'Times', ' said', ' Ant', '-', 'Man']
+169 37 Secret identity of x -1 Secret identity of Ant-Man Hank Pym Ant-Man [':', ' Scott', ' Lang', ' is', ' a', ' former', ' NASA', ' engineer', ' who', ' was', ' recruited', ' by', ' S', '.', 'H', '.', 'I', '.', 'E', '.'] : Scott Lang is a former NASA engineer who was recruited by S . H . I . E . False being overtaken by Ant-Man in its third weekend 6 [' being', ' overt', 'aken', ' by', ' Ant', '-', 'Man']
+170 37 Secret identity of x -1 Secret identity of Ant-Man Hank Pym Ant-Man [':', ' Scott', ' Lang', ' is', ' a', ' former', ' NASA', ' engineer', ' who', ' was', ' recruited', ' by', ' S', '.', 'H', '.', 'I', '.', 'E', '.'] : Scott Lang is a former NASA engineer who was recruited by S . H . I . E . False became the original Ant-Man in 1963 after discovering 5 [' became', ' the', ' original', ' Ant', '-', 'Man']
+171 38 Secret identity of x -1 Secret identity of Wasp Janet Van Dyne Wasp ['\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The W asp is a fictional character , a superhero ine appearing in American comic books published by" False the schooner Wasp for a two-year 5 [' the', ' sch', 'oon', 'er', ' W', 'asp']
+172 38 Secret identity of x -1 Secret identity of Wasp Janet Van Dyne Wasp ['\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The W asp is a fictional character , a superhero ine appearing in American comic books published by" False " advanced"". Wasp turned north, from" 3 "[' advanced', '"".', ' W', 'asp']"
+173 38 Secret identity of x -1 Secret identity of Wasp Janet Van Dyne Wasp ['\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The W asp is a fictional character , a superhero ine appearing in American comic books published by" False 1 ['W', 'asp']
+174 38 Secret identity of x -1 Secret identity of Wasp Janet Van Dyne Wasp ['\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The W asp is a fictional character , a superhero ine appearing in American comic books published by" False R-1535 Twin Wasp Junior air-cooled 6 [' R', '-', '15', '35', ' Twin', ' W', 'asp']
+175 38 Secret identity of x -1 Secret identity of Wasp Janet Van Dyne Wasp ['\n', '\n', 'The', ' W', 'asp', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] "
+
+ The W asp is a fictional character , a superhero ine appearing in American comic books published by" False & Whitney R-4360 Wasp Major radial engine 7 [' &', ' Whitney', ' R', '-', '43', '60', ' W', 'asp']
+176 39 Secret identity of x -1 Secret identity of Blue Beetle Jaime Reyes Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False tie-in with Blue Beetle # 20. Part One, the 5 [' tie', '-', 'in', ' with', ' Blue', ' Beetle']
+177 39 Secret identity of x -1 Secret identity of Blue Beetle Jaime Reyes Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False one-shot specials and a Blue Beetle tie-in issue were 7 [' one', '-', 'shot', ' specials', ' and', ' a', ' Blue', ' Beetle']
+178 39 Secret identity of x -1 Secret identity of Blue Beetle Jaime Reyes Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False as one tie-in with Blue Beetle # 20. Part One, the 7 [' as', ' one', ' tie', '-', 'in', ' with', ' Blue', ' Beetle']
+179 39 Secret identity of x -1 Secret identity of Blue Beetle Jaime Reyes Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False specials and a Blue Beetle tie-in issue 4 [' specials', ' and', ' a', ' Blue', ' Beetle']
+180 39 Secret identity of x -1 Secret identity of Blue Beetle Jaime Reyes Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False one tie-in with Blue Beetle # 20. Part One, the 6 [' one', ' tie', '-', 'in', ' with', ' Blue', ' Beetle']
+181 40 Secret identity of x -1 Secret identity of Batgirl Cassandra Cain Batgirl ['\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat'] "
+
+ The Bat girl of Burn side
+
+ The Bat girl of Burn side
+
+ The Bat" False " Silverstone as Batgirl / Barbara Wilson
+" 4 [' Silver', 'stone', ' as', ' Bat', 'girl']
+182 40 Secret identity of x -1 Secret identity of Batgirl Cassandra Cain Batgirl ['\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat'] "
+
+ The Bat girl of Burn side
+
+ The Bat girl of Burn side
+
+ The Bat" False " Barbara Gordon as Batgirl (2012) Softcover
+" 4 [' Barbara', ' Gordon', ' as', ' Bat', 'girl']
+183 40 Secret identity of x -1 Secret identity of Batgirl Cassandra Cain Batgirl ['\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat'] "
+
+ The Bat girl of Burn side
+
+ The Bat girl of Burn side
+
+ The Bat" False them was a cover to Batgirl # 41 by artist 6 [' them', ' was', ' a', ' cover', ' to', ' Bat', 'girl']
+184 40 Secret identity of x -1 Secret identity of Batgirl Cassandra Cain Batgirl ['\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat'] "
+
+ The Bat girl of Burn side
+
+ The Bat girl of Burn side
+
+ The Bat" False the early 1970s, Batgirl had matured, using 6 [' the', ' early', ' 1970', 's', ',', ' Bat', 'girl']
+185 40 Secret identity of x -1 Secret identity of Batgirl Cassandra Cain Batgirl ['\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat', 'girl', ' of', ' Burn', 'side', '\n', '\n', 'The', ' Bat'] "
+
+ The Bat girl of Burn side
+
+ The Bat girl of Burn side
+
+ The Bat" False new modified Batgirl design that acts 3 [' new', ' modified', ' Bat', 'girl']
+186 41 Secret identity of x -1 Secret identity of Spoiler Stephanie Brown Spoiler [':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler'] ":
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler" False Brown, the former Spoiler and Robin. After 4 [' Brown', ',', ' the', ' former', ' Spoiler']
+187 41 Secret identity of x -1 Secret identity of Spoiler Stephanie Brown Spoiler [':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler'] ":
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler" False Brown, the former Spoiler and Robin. After 4 [' Brown', ',', ' the', ' former', ' Spoiler']
+188 41 Secret identity of x -1 Secret identity of Spoiler Stephanie Brown Spoiler [':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler'] ":
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler" False Brown, the former Spoiler and Robin. 4 [' Brown', ',', ' the', ' former', ' Spoiler']
+189 41 Secret identity of x -1 Secret identity of Spoiler Stephanie Brown Spoiler [':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler'] ":
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler" False " Spoiler embargo ===
+" 0 [' Spoiler']
+190 41 Secret identity of x -1 Secret identity of Spoiler Stephanie Brown Spoiler [':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler', ':', '\n', '\n', 'Spoiler'] ":
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler :
+
+ Spoiler" False the former Spoiler and Robin. After 2 [' the', ' former', ' Spoiler']
+191 42 Secret identity of x -1 Secret identity of Oracle Barbara Gordon Oracle "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False Microsoft and Oracle for software compatibility 2 [' Microsoft', ' and', ' Oracle']
+192 42 Secret identity of x -1 Secret identity of Oracle Barbara Gordon Oracle "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False building's roof, where Oracle is revealed to be alive, 5 "[' building', ""'s"", ' roof', ',', ' where', ' Oracle']"
+193 42 Secret identity of x -1 Secret identity of Oracle Barbara Gordon Oracle "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False one, based on Oracle of Seasons, was published 4 [' one', ',', ' based', ' on', ' Oracle']
+194 42 Secret identity of x -1 Secret identity of Oracle Barbara Gordon Oracle "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False Sam Weiss to the Oracle in The Matrix, which 4 [' Sam', ' Weiss', ' to', ' the', ' Oracle']
+195 42 Secret identity of x -1 Secret identity of Oracle Barbara Gordon Oracle "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False California, at the Oracle Arena. Cyrus later 4 [' California', ',', ' at', ' the', ' Oracle']
+196 44 Secret identity of x -1 Secret identity of Iron Fist Danny Rand Iron Fist [':', ' The', ' Immortal', ' Iron', ' F', 'ists', '\n', '\n', 'The', ' Immortal', ' Iron', ' F', 'ists', ' is', ' a', ' comic', ' book', ' series', ' published', ' by'] ": The Immortal Iron F ists
+
+ The Immortal Iron F ists is a comic book series published by" False connection to the Iron Fist antagonist Steel Serpent; 4 [' connection', ' to', ' the', ' Iron', ' Fist']
+197 44 Secret identity of x -1 Secret identity of Iron Fist Danny Rand Iron Fist [':', ' The', ' Immortal', ' Iron', ' F', 'ists', '\n', '\n', 'The', ' Immortal', ' Iron', ' F', 'ists', ' is', ' a', ' comic', ' book', ' series', ' published', ' by'] ": The Immortal Iron F ists
+
+ The Immortal Iron F ists is a comic book series published by" False with Calculator and Iron Fist also in the 4 [' with', ' Calculator', ' and', ' Iron', ' Fist']
+198 44 Secret identity of x -1 Secret identity of Iron Fist Danny Rand Iron Fist [':', ' The', ' Immortal', ' Iron', ' F', 'ists', '\n', '\n', 'The', ' Immortal', ' Iron', ' F', 'ists', ' is', ' a', ' comic', ' book', ' series', ' published', ' by'] ": The Immortal Iron F ists
+
+ The Immortal Iron F ists is a comic book series published by" False Calculator and Iron Fist also in the field. 3 [' Calculator', ' and', ' Iron', ' Fist']
+199 44 Secret identity of x -1 Secret identity of Iron Fist Danny Rand Iron Fist [':', ' The', ' Immortal', ' Iron', ' F', 'ists', '\n', '\n', 'The', ' Immortal', ' Iron', ' F', 'ists', ' is', ' a', ' comic', ' book', ' series', ' published', ' by'] ": The Immortal Iron F ists
+
+ The Immortal Iron F ists is a comic book series published by" False with Calculator and Iron Fist also in the field. 4 [' with', ' Calculator', ' and', ' Iron', ' Fist']
+200 44 Secret identity of x -1 Secret identity of Iron Fist Danny Rand Iron Fist [':', ' The', ' Immortal', ' Iron', ' F', 'ists', '\n', '\n', 'The', ' Immortal', ' Iron', ' F', 'ists', ' is', ' a', ' comic', ' book', ' series', ' published', ' by'] ": The Immortal Iron F ists
+
+ The Immortal Iron F ists is a comic book series published by" False art of Hung Ga Iron Fist technique, 5 [' art', ' of', ' Hung', ' Ga', ' Iron', ' Fist']
+201 45 Secret identity of x -1 Secret identity of Jewel Jessica Jones Jewel ['ry', '\n', '\n', 'I', ' am', ' a', ' jewelry', ' designer', ' and', ' I', ' have', ' been', ' designing', ' jewelry', ' for', ' over', ' 20', ' years', '.', ' I'] "ry
+
+ I am a jewelry designer and I have been designing jewelry for over 20 years . I" False address, called the Jewel Voice Broadcast 4 [' address', ',', ' called', ' the', ' Jewel']
+202 45 Secret identity of x -1 Secret identity of Jewel Jessica Jones Jewel ['ry', '\n', '\n', 'I', ' am', ' a', ' jewelry', ' designer', ' and', ' I', ' have', ' been', ' designing', ' jewelry', ' for', ' over', ' 20', ' years', '.', ' I'] "ry
+
+ I am a jewelry designer and I have been designing jewelry for over 20 years . I" False Monster Kingdom: Jewel Summoner is a turn-based 3 [' Monster', ' Kingdom', ':', ' Jewel']
+203 45 Secret identity of x -1 Secret identity of Jewel Jessica Jones Jewel ['ry', '\n', '\n', 'I', ' am', ' a', ' jewelry', ' designer', ' and', ' I', ' have', ' been', ' designing', ' jewelry', ' for', ' over', ' 20', ' years', '.', ' I'] "ry
+
+ I am a jewelry designer and I have been designing jewelry for over 20 years . I" False 1 ['Jew', 'el']
+204 45 Secret identity of x -1 Secret identity of Jewel Jessica Jones Jewel ['ry', '\n', '\n', 'I', ' am', ' a', ' jewelry', ' designer', ' and', ' I', ' have', ' been', ' designing', ' jewelry', ' for', ' over', ' 20', ' years', '.', ' I'] "ry
+
+ I am a jewelry designer and I have been designing jewelry for over 20 years . I" False Green Terrors, Jewel cichlids, Salvini, 4 [' Green', ' Ter', 'rors', ',', ' Jewel']
+205 45 Secret identity of x -1 Secret identity of Jewel Jessica Jones Jewel ['ry', '\n', '\n', 'I', ' am', ' a', ' jewelry', ' designer', ' and', ' I', ' have', ' been', ' designing', ' jewelry', ' for', ' over', ' 20', ' years', '.', ' I'] "ry
+
+ I am a jewelry designer and I have been designing jewelry for over 20 years . I" False " possession, ""the Jewel in the Crown"", and" 4 "[' possession', ',', ' ""', 'the', ' Jewel']"
+206 46 Secret identity of x -1 Secret identity of Starfire Koriand'r Starfire "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False of Rusty Pup. Starfire Studios were 6 [' of', ' Rusty', ' P', 'up', '.', ' Star', 'fire']
+207 46 Secret identity of x -1 Secret identity of Starfire Koriand'r Starfire "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False two companies, Starfire Publishing – which 4 [' two', ' companies', ',', ' Star', 'fire']
+208 46 Secret identity of x -1 Secret identity of Starfire Koriand'r Starfire "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False games at Tukwila's Starfire Sports stadium. The 8 "[' games', ' at', ' T', 'uk', 'w', 'ila', ""'s"", ' Star', 'fire']"
+209 46 Secret identity of x -1 Secret identity of Starfire Koriand'r Starfire "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False located at the Starfire Sports Complex in 4 [' located', ' at', ' the', ' Star', 'fire']
+210 46 Secret identity of x -1 Secret identity of Starfire Koriand'r Starfire "['\n', '\n', 'I', ""'m"", ' not', ' sure', ' if', ' this', ' is', ' the', ' right', ' place', ' to', ' ask', ' this', ',', ' but', ' I', ""'m"", ' not']" "
+
+ I 'm not sure if this is the right place to ask this , but I 'm not" False two companies, Starfire Publishing 4 [' two', ' companies', ',', ' Star', 'fire']
+211 47 Secret identity of x -1 Secret identity of Cyborg Victor Stone Cyborg [' 00', '9', '\n', '\n', 'The', ' Cyborg', ' 00', '9', ' series', ' is', ' a', ' Japanese', ' manga', ' series', ' written', ' and', ' illustrated', ' by', ' Shot', 'aro'] " 00 9
+
+ The Cyborg 00 9 series is a Japanese manga series written and illustrated by Shot aro" False fiction-themed Cyborg Commando, which 3 [' fiction', '-', 'themed', ' Cyborg']
+212 47 Secret identity of x -1 Secret identity of Cyborg Victor Stone Cyborg [' 00', '9', '\n', '\n', 'The', ' Cyborg', ' 00', '9', ' series', ' is', ' a', ' Japanese', ' manga', ' series', ' written', ' and', ' illustrated', ' by', ' Shot', 'aro'] " 00 9
+
+ The Cyborg 00 9 series is a Japanese manga series written and illustrated by Shot aro" False for Superman based on Cyborg Superman, released 4 [' for', ' Superman', ' based', ' on', ' Cyborg']
+213 47 Secret identity of x -1 Secret identity of Cyborg Victor Stone Cyborg [' 00', '9', '\n', '\n', 'The', ' Cyborg', ' 00', '9', ' series', ' is', ' a', ' Japanese', ' manga', ' series', ' written', ' and', ' illustrated', ' by', ' Shot', 'aro'] " 00 9
+
+ The Cyborg 00 9 series is a Japanese manga series written and illustrated by Shot aro" False Superman-Prime), Cyborg Superman and the 4 [' Superman', '-', 'Prime', '),', ' Cyborg']
+214 47 Secret identity of x -1 Secret identity of Cyborg Victor Stone Cyborg [' 00', '9', '\n', '\n', 'The', ' Cyborg', ' 00', '9', ' series', ' is', ' a', ' Japanese', ' manga', ' series', ' written', ' and', ' illustrated', ' by', ' Shot', 'aro'] " 00 9
+
+ The Cyborg 00 9 series is a Japanese manga series written and illustrated by Shot aro" False counterpart. Cyborg visits the 2 [' counterpart', '.', ' Cyborg']
+215 47 Secret identity of x -1 Secret identity of Cyborg Victor Stone Cyborg [' 00', '9', '\n', '\n', 'The', ' Cyborg', ' 00', '9', ' series', ' is', ' a', ' Japanese', ' manga', ' series', ' written', ' and', ' illustrated', ' by', ' Shot', 'aro'] " 00 9
+
+ The Cyborg 00 9 series is a Japanese manga series written and illustrated by Shot aro" False fiction-themed Cyborg Commando, which was 3 [' fiction', '-', 'themed', ' Cyborg']
+216 48 Secret identity of x -1 Secret identity of Raven Rachel Roth Raven [',', ' the', ' girl', ' who', ' is', ' the', ' main', ' character', ' of', ' the', ' story', '.', ' She', ' is', ' a', ' girl', ' who', ' is', ' a', ' member'] , the girl who is the main character of the story . She is a girl who is a member False the show was between Raven and Chris K. 4 [' the', ' show', ' was', ' between', ' Raven']
+217 48 Secret identity of x -1 Secret identity of Raven Rachel Roth Raven [',', ' the', ' girl', ' who', ' is', ' the', ' main', ' character', ' of', ' the', ' story', '.', ' She', ' is', ' a', ' girl', ' who', ' is', ' a', ' member'] , the girl who is the main character of the story . She is a girl who is a member False Squadron with the carriers Raven II, Anne, 4 [' Squadron', ' with', ' the', ' carriers', ' Raven']
+218 48 Secret identity of x -1 Secret identity of Raven Rachel Roth Raven [',', ' the', ' girl', ' who', ' is', ' the', ' main', ' character', ' of', ' the', ' story', '.', ' She', ' is', ' a', ' girl', ' who', ' is', ' a', ' member'] , the girl who is the main character of the story . She is a girl who is a member False " Brown by defeating Raven and Abyss.
+" 3 [' Brown', ' by', ' defeating', ' Raven']
+219 48 Secret identity of x -1 Secret identity of Raven Rachel Roth Raven [',', ' the', ' girl', ' who', ' is', ' the', ' main', ' character', ' of', ' the', ' story', '.', ' She', ' is', ' a', ' girl', ' who', ' is', ' a', ' member'] , the girl who is the main character of the story . She is a girl who is a member False " success of ""The Raven"" prompted Wiley and" 4 "[' success', ' of', ' ""', 'The', ' Raven']"
+220 48 Secret identity of x -1 Secret identity of Raven Rachel Roth Raven [',', ' the', ' girl', ' who', ' is', ' the', ' main', ' character', ' of', ' the', ' story', '.', ' She', ' is', ' a', ' girl', ' who', ' is', ' a', ' member'] , the girl who is the main character of the story . She is a girl who is a member False Championship match between Raven and Rhino was 3 [' Championship', ' match', ' between', ' Raven']
+221 49 Secret identity of x -1 Secret identity of Beast Boy Garfield Logan Beast Boy "[':', ' He', ""'s"", ' a', ' were', 'wolf', '.', '\n', '\n', 'The', ' Teen', ' Titans', ':', '\n', '\n', 'Robin', ':', ' He', ""'s"", ' a']" ": He 's a were wolf .
+
+ The Teen Titans :
+
+ Robin : He 's a" False During the Beast Boy miniseries, Flamebird 3 [' During', ' the', ' Beast', ' Boy']
+222 49 Secret identity of x -1 Secret identity of Beast Boy Garfield Logan Beast Boy "[':', ' He', ""'s"", ' a', ' were', 'wolf', '.', '\n', '\n', 'The', ' Teen', ' Titans', ':', '\n', '\n', 'Robin', ':', ' He', ""'s"", ' a']" ": He 's a were wolf .
+
+ The Teen Titans :
+
+ Robin : He 's a" False # 38. During the Beast Boy miniseries, 6 [' #', ' 38', '.', ' During', ' the', ' Beast', ' Boy']
+223 49 Secret identity of x -1 Secret identity of Beast Boy Garfield Logan Beast Boy "[':', ' He', ""'s"", ' a', ' were', 'wolf', '.', '\n', '\n', 'The', ' Teen', ' Titans', ':', '\n', '\n', 'Robin', ':', ' He', ""'s"", ' a']" ": He 's a were wolf .
+
+ The Teen Titans :
+
+ Robin : He 's a" False 38. During the Beast Boy miniseries, 5 [' 38', '.', ' During', ' the', ' Beast', ' Boy']
+224 49 Secret identity of x -1 Secret identity of Beast Boy Garfield Logan Beast Boy "[':', ' He', ""'s"", ' a', ' were', 'wolf', '.', '\n', '\n', 'The', ' Teen', ' Titans', ':', '\n', '\n', 'Robin', ':', ' He', ""'s"", ' a']" ": He 's a were wolf .
+
+ The Teen Titans :
+
+ Robin : He 's a" False During the Beast Boy miniseries, 3 [' During', ' the', ' Beast', ' Boy']
+225 50 Secret identity of x -1 Secret identity of Batwoman Kate Kane Batwoman [':', '\n', '\n', 'Bat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ":
+
+ Bat woman is a fictional character , a superhero ine appearing in American comic books published by" False father; serving as Batwoman's ally, whom 5 [' father', ';', ' serving', ' as', ' Bat', 'woman']
+226 50 Secret identity of x -1 Secret identity of Batwoman Kate Kane Batwoman [':', '\n', '\n', 'Bat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ":
+
+ Bat woman is a fictional character , a superhero ine appearing in American comic books published by" False request for Batwoman to appear 3 [' request', ' for', ' Bat', 'woman']
+227 50 Secret identity of x -1 Secret identity of Batwoman Kate Kane Batwoman [':', '\n', '\n', 'Bat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ":
+
+ Bat woman is a fictional character , a superhero ine appearing in American comic books published by" False concept of an older Batwoman in Brave and the Bold 5 [' concept', ' of', ' an', ' older', ' Bat', 'woman']
+228 50 Secret identity of x -1 Secret identity of Batwoman Kate Kane Batwoman [':', '\n', '\n', 'Bat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ":
+
+ Bat woman is a fictional character , a superhero ine appearing in American comic books published by" False Gods Among Us, Batwoman (Kate Kane) 5 [' Gods', ' Among', ' Us', ',', ' Bat', 'woman']
+229 50 Secret identity of x -1 Secret identity of Batwoman Kate Kane Batwoman [':', '\n', '\n', 'Bat', 'woman', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by'] ":
+
+ Bat woman is a fictional character , a superhero ine appearing in American comic books published by" False decision to make Batwoman a gay character 4 [' decision', ' to', ' make', ' Bat', 'woman']
+230 51 Secret identity of x -1 Secret identity of Arsenal Roy Harper Arsenal "[""'s"", ' new', ' signing', '\n', '\n', 'Arsenal', ' have', ' signed', ' French', ' midfielder', ' Y', 'aya', ' San', 'ogo', ' from', ' Aux', 'er', 're', ' for', ' an']" "'s new signing
+
+ Arsenal have signed French midfielder Y aya San ogo from Aux er re for an" False First Division side Arsenal 1 – 0 at home 3 [' First', ' Division', ' side', ' Arsenal']
+231 51 Secret identity of x -1 Secret identity of Arsenal Roy Harper Arsenal "[""'s"", ' new', ' signing', '\n', '\n', 'Arsenal', ' have', ' signed', ' French', ' midfielder', ' Y', 'aya', ' San', 'ogo', ' from', ' Aux', 'er', 're', ' for', ' an']" "'s new signing
+
+ Arsenal have signed French midfielder Y aya San ogo from Aux er re for an" False delight at joining Arsenal in his press conference: 3 [' delight', ' at', ' joining', ' Arsenal']
+232 51 Secret identity of x -1 Secret identity of Arsenal Roy Harper Arsenal "[""'s"", ' new', ' signing', '\n', '\n', 'Arsenal', ' have', ' signed', ' French', ' midfielder', ' Y', 'aya', ' San', 'ogo', ' from', ' Aux', 'er', 're', ' for', ' an']" "'s new signing
+
+ Arsenal have signed French midfielder Y aya San ogo from Aux er re for an" False the Yokosuka Naval Arsenal to a modified Yamato-class 5 [' the', ' Yok', 'os', 'uka', ' Naval', ' Arsenal']
+233 51 Secret identity of x -1 Secret identity of Arsenal Roy Harper Arsenal "[""'s"", ' new', ' signing', '\n', '\n', 'Arsenal', ' have', ' signed', ' French', ' midfielder', ' Y', 'aya', ' San', 'ogo', ' from', ' Aux', 'er', 're', ' for', ' an']" "'s new signing
+
+ Arsenal have signed French midfielder Y aya San ogo from Aux er re for an" False – 92 season, giving Arsenal the opportunity 5 [' –', ' 92', ' season', ',', ' giving', ' Arsenal']
+234 51 Secret identity of x -1 Secret identity of Arsenal Roy Harper Arsenal "[""'s"", ' new', ' signing', '\n', '\n', 'Arsenal', ' have', ' signed', ' French', ' midfielder', ' Y', 'aya', ' San', 'ogo', ' from', ' Aux', 'er', 're', ' for', ' an']" "'s new signing
+
+ Arsenal have signed French midfielder Y aya San ogo from Aux er re for an" False voted second by Arsenal fans behind Thierry 3 [' voted', ' second', ' by', ' Arsenal']
+235 52 Secret identity of x -1 Secret identity of Gambit Remy LeBeau Gambit [':', ' The', ' X', '-', 'Men', '\n', '\n', 'Gam', 'bit', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others'] ": The X - Men
+
+ Gam bit is a mutant with the ability to teleport himself and others" False to test the Rice Gambit. Lasker narrowly 5 [' to', ' test', ' the', ' Rice', ' Gam', 'bit']
+236 52 Secret identity of x -1 Secret identity of Gambit Remy LeBeau Gambit [':', ' The', ' X', '-', 'Men', '\n', '\n', 'Gam', 'bit', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others'] ": The X - Men
+
+ Gam bit is a mutant with the ability to teleport himself and others" False " finished second, behind Gambit Gaming.
+" 5 [' finished', ' second', ',', ' behind', ' Gam', 'bit']
+237 52 Secret identity of x -1 Secret identity of Gambit Remy LeBeau Gambit [':', ' The', ' X', '-', 'Men', '\n', '\n', 'Gam', 'bit', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others'] ": The X - Men
+
+ Gam bit is a mutant with the ability to teleport himself and others" False was known as Gambit at Thrill Valley 4 [' was', ' known', ' as', ' Gam', 'bit']
+238 52 Secret identity of x -1 Secret identity of Gambit Remy LeBeau Gambit [':', ' The', ' X', '-', 'Men', '\n', '\n', 'Gam', 'bit', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others'] ": The X - Men
+
+ Gam bit is a mutant with the ability to teleport himself and others" False to the Queen's Gambit, Lasker's Defense 5 "[' to', ' the', ' Queen', ""'s"", ' Gam', 'bit']"
+239 52 Secret identity of x -1 Secret identity of Gambit Remy LeBeau Gambit [':', ' The', ' X', '-', 'Men', '\n', '\n', 'Gam', 'bit', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others'] ": The X - Men
+
+ Gam bit is a mutant with the ability to teleport himself and others" False quiet Queen's Gambit Declined. After 19.Be4!, 4 "[' quiet', ' Queen', ""'s"", ' Gam', 'bit']"
+240 53 Secret identity of x -1 Secret identity of Iceman Bobby Drake Iceman "[':', ' I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.']" ": I 'm a superhero .
+
+ I 'm a superhero .
+
+ I 'm a superhero ." False featuring Wolverine, Iceman and Nightcrawler as 4 [' featuring', ' Wolverine', ',', ' Ice', 'man']
+241 53 Secret identity of x -1 Secret identity of Iceman Bobby Drake Iceman "[':', ' I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.']" ": I 'm a superhero .
+
+ I 'm a superhero .
+
+ I 'm a superhero ." False tattooing, and Ötzi the Iceman was found to have 9 [' tattoo', 'ing', ',', ' and', ' Ö', 't', 'zi', ' the', ' Ice', 'man']
+242 53 Secret identity of x -1 Secret identity of Iceman Bobby Drake Iceman "[':', ' I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.']" ": I 'm a superhero .
+
+ I 'm a superhero .
+
+ I 'm a superhero ." False Rogue should give Iceman a passionate 4 [' Rogue', ' should', ' give', ' Ice', 'man']
+243 53 Secret identity of x -1 Secret identity of Iceman Bobby Drake Iceman "[':', ' I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.']" ": I 'm a superhero .
+
+ I 'm a superhero .
+
+ I 'm a superhero ." False affection for Iceman further adds 3 [' affection', ' for', ' Ice', 'man']
+244 53 Secret identity of x -1 Secret identity of Iceman Bobby Drake Iceman "[':', ' I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.', '\n', '\n', 'I', ""'m"", ' a', ' superhero', '.']" ": I 'm a superhero .
+
+ I 'm a superhero .
+
+ I 'm a superhero ." False featuring Wolverine, Iceman and Nightcrawler 4 [' featuring', ' Wolverine', ',', ' Ice', 'man']
+245 55 Secret identity of x -1 Secret identity of Hellblazer John Constantine Hellblazer [':', ' The', ' original', ' John', ' Constantine', ',', ' the', ' one', ' who', ' was', ' the', ' first', ' to', ' wear', ' the', ' red', ' suit', ',', ' the', ' one'] : The original John Constantine , the one who was the first to wear the red suit , the one True protagonist of Hellblazer was first hinted 4 [' protagonist', ' of', ' Hell', 'bl', 'azer']
+246 55 Secret identity of x -1 Secret identity of Hellblazer John Constantine Hellblazer [':', ' The', ' original', ' John', ' Constantine', ',', ' the', ' one', ' who', ' was', ' the', ' first', ' to', ' wear', ' the', ' red', ' suit', ',', ' the', ' one'] : The original John Constantine , the one who was the first to wear the red suit , the one True John Constantine in Hellblazer No. 19 leading into 5 [' John', ' Constantine', ' in', ' Hell', 'bl', 'azer']
+247 55 Secret identity of x -1 Secret identity of Hellblazer John Constantine Hellblazer [':', ' The', ' original', ' John', ' Constantine', ',', ' the', ' one', ' who', ' was', ' the', ' first', ' to', ' wear', ' the', ' red', ' suit', ',', ' the', ' one'] : The original John Constantine , the one who was the first to wear the red suit , the one True series based on Hellblazer and that NBC 5 [' series', ' based', ' on', ' Hell', 'bl', 'azer']
+248 55 Secret identity of x -1 Secret identity of Hellblazer John Constantine Hellblazer [':', ' The', ' original', ' John', ' Constantine', ',', ' the', ' one', ' who', ' was', ' the', ' first', ' to', ' wear', ' the', ' red', ' suit', ',', ' the', ' one'] : The original John Constantine , the one who was the first to wear the red suit , the one True Peter Milligan. Hellblazer was one of the first 6 [' Peter', ' Mill', 'igan', '.', ' Hell', 'bl', 'azer']
+249 55 Secret identity of x -1 Secret identity of Hellblazer John Constantine Hellblazer [':', ' The', ' original', ' John', ' Constantine', ',', ' the', ' one', ' who', ' was', ' the', ' first', ' to', ' wear', ' the', ' red', ' suit', ',', ' the', ' one'] : The original John Constantine , the one who was the first to wear the red suit , the one True The final panel of Hellblazer reveals John, 6 [' The', ' final', ' panel', ' of', ' Hell', 'bl', 'azer']
+250 56 Secret identity of x -1 Secret identity of Green Lantern Alan Scott Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False Warner Bros. film, Green Lantern on Blu-ray 6 [' Warner', ' Bros', '.', ' film', ',', ' Green', ' Lantern']
+251 56 Secret identity of x -1 Secret identity of Green Lantern Alan Scott Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False 1 ['Green', ' Lantern']
+252 56 Secret identity of x -1 Secret identity of Green Lantern Alan Scott Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False killed. The alternate Green Lantern and Hawkgirl mistake 5 [' killed', '.', ' The', ' alternate', ' Green', ' Lantern']
+253 56 Secret identity of x -1 Secret identity of Green Lantern Alan Scott Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False possess him. In Green Lantern (vol. 4) # 5 [' possess', ' him', '.', ' In', ' Green', ' Lantern']
+254 56 Secret identity of x -1 Secret identity of Green Lantern Alan Scott Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False member of the Green Lantern Corps who teaches 4 [' member', ' of', ' the', ' Green', ' Lantern']
+255 57 Secret identity of x -1 Secret identity of The Flash Jay Garrick The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" False pack Portal: The Flash Version created by 4 [' pack', ' Portal', ':', ' The', ' Flash']
+256 57 Secret identity of x -1 Secret identity of The Flash Jay Garrick The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" False 1 ['The', ' Flash']
+257 57 Secret identity of x -1 Secret identity of The Flash Jay Garrick The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" False (since unlike The Flash or Green Lantern, Batman 4 [' (', 'since', ' unlike', ' The', ' Flash']
+258 57 Secret identity of x -1 Secret identity of The Flash Jay Garrick The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" False 1 ['The', ' Flash']
+259 57 Secret identity of x -1 Secret identity of The Flash Jay Garrick The Flash [':', ' Barry', ' Allen', '\n', '\n', 'The', ' Flash', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics'] ": Barry Allen
+
+ The Flash is a fictional superhero appearing in American comic books published by DC Comics" False complexity. The Flash version was much simpler 3 [' complexity', '.', ' The', ' Flash']
+260 58 Secret identity of x -1 Secret identity of Blue Beetle Ted Kord Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False one tie-in with Blue Beetle # 20. Part One, 6 [' one', ' tie', '-', 'in', ' with', ' Blue', ' Beetle']
+261 58 Secret identity of x -1 Secret identity of Blue Beetle Ted Kord Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False one tie-in with Blue Beetle # 20. Part One, 6 [' one', ' tie', '-', 'in', ' with', ' Blue', ' Beetle']
+262 58 Secret identity of x -1 Secret identity of Blue Beetle Ted Kord Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False one tie-in with Blue Beetle # 20. Part One, the 6 [' one', ' tie', '-', 'in', ' with', ' Blue', ' Beetle']
+263 58 Secret identity of x -1 Secret identity of Blue Beetle Ted Kord Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False printings as well. Blue Beetle # 20 saw much 6 [' print', 'ings', ' as', ' well', '.', ' Blue', ' Beetle']
+264 58 Secret identity of x -1 Secret identity of Blue Beetle Ted Kord Blue Beetle [' (', '2013', ')', '\n', '\n', 'The', ' Blue', ' Beetle', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books'] " ( 2013 )
+
+ The Blue Beetle is a fictional character , a superhero appearing in American comic books" False printings as well. Blue Beetle # 20 saw much higher 6 [' print', 'ings', ' as', ' well', '.', ' Blue', ' Beetle']
+265 59 Secret identity of x -1 Secret identity of Booster Gold Michael Jon Carter Booster Gold ['\n', '\n', 'Bo', 'oster', ' Gold', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' first'] "
+
+ Bo oster Gold is a fictional character , a superhero in the DC Comics universe . He first" False " (DC Comics, 2007)
+" 8 [' (', 'DC', ' Comics', ',', ' 2007', ')', 'Bo', 'oster', ' Gold']
+266 59 Secret identity of x -1 Secret identity of Booster Gold Michael Jon Carter Booster Gold ['\n', '\n', 'Bo', 'oster', ' Gold', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' first'] "
+
+ Bo oster Gold is a fictional character , a superhero in the DC Comics universe . He first" False in DC Comic's Booster Gold comic series, 5 "[' in', ' DC', ' Comic', ""'s"", ' Booster', ' Gold']"
+267 59 Secret identity of x -1 Secret identity of Booster Gold Michael Jon Carter Booster Gold ['\n', '\n', 'Bo', 'oster', ' Gold', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' first'] "
+
+ Bo oster Gold is a fictional character , a superhero in the DC Comics universe . He first" False (DC Comics, 5 [' (', 'DC', ' Comics', 'Bo', 'oster', ' Gold']
+268 59 Secret identity of x -1 Secret identity of Booster Gold Michael Jon Carter Booster Gold ['\n', '\n', 'Bo', 'oster', ' Gold', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' first'] "
+
+ Bo oster Gold is a fictional character , a superhero in the DC Comics universe . He first" False " (DC Comics, 2007)
+" 8 [' (', 'DC', ' Comics', ',', ' 2007', ')', 'Bo', 'oster', ' Gold']
+269 59 Secret identity of x -1 Secret identity of Booster Gold Michael Jon Carter Booster Gold ['\n', '\n', 'Bo', 'oster', ' Gold', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' in', ' the', ' DC', ' Comics', ' universe', '.', ' He', ' first'] "
+
+ Bo oster Gold is a fictional character , a superhero in the DC Comics universe . He first" False (DC 4 [' (', 'DB', 'o', 'oster', ' Gold']
+270 60 Secret identity of x -1 Secret identity of Huntress Helena Bertinelli Huntress [' (', 'Hel', 'ena', ' Wayne', ')', '\n', '\n', 'The', ' Hunt', 'ress', ' is', ' Helena', ' Wayne', ',', ' the', ' daughter', ' of', ' Bruce', ' Wayne', ' and'] " ( Hel ena Wayne )
+
+ The Hunt ress is Helena Wayne , the daughter of Bruce Wayne and" False primarily as the Huntress and as Batgirl for 4 [' primarily', ' as', ' the', ' Hunt', 'ress']
+271 60 Secret identity of x -1 Secret identity of Huntress Helena Bertinelli Huntress [' (', 'Hel', 'ena', ' Wayne', ')', '\n', '\n', 'The', ' Hunt', 'ress', ' is', ' Helena', ' Wayne', ',', ' the', ' daughter', ' of', ' Bruce', ' Wayne', ' and'] " ( Hel ena Wayne )
+
+ The Hunt ress is Helena Wayne , the daughter of Bruce Wayne and" False qualifier William Huntress in the first round 3 [' qualifier', ' William', ' Hunt', 'ress']
+272 60 Secret identity of x -1 Secret identity of Huntress Helena Bertinelli Huntress [' (', 'Hel', 'ena', ' Wayne', ')', '\n', '\n', 'The', ' Hunt', 'ress', ' is', ' Helena', ' Wayne', ',', ' the', ' daughter', ' of', ' Bruce', ' Wayne', ' and'] " ( Hel ena Wayne )
+
+ The Hunt ress is Helena Wayne , the daughter of Bruce Wayne and" False Banks' The Vampire Huntress Legend Series, 5 "[' Banks', ""'"", ' The', ' Vampire', ' Hunt', 'ress']"
+273 60 Secret identity of x -1 Secret identity of Huntress Helena Bertinelli Huntress [' (', 'Hel', 'ena', ' Wayne', ')', '\n', '\n', 'The', ' Hunt', 'ress', ' is', ' Helena', ' Wayne', ',', ' the', ' daughter', ' of', ' Bruce', ' Wayne', ' and'] " ( Hel ena Wayne )
+
+ The Hunt ress is Helena Wayne , the daughter of Bruce Wayne and" False television series The Huntress when Ettinger 4 [' television', ' series', ' The', ' Hunt', 'ress']
+274 60 Secret identity of x -1 Secret identity of Huntress Helena Bertinelli Huntress [' (', 'Hel', 'ena', ' Wayne', ')', '\n', '\n', 'The', ' Hunt', 'ress', ' is', ' Helena', ' Wayne', ',', ' the', ' daughter', ' of', ' Bruce', ' Wayne', ' and'] " ( Hel ena Wayne )
+
+ The Hunt ress is Helena Wayne , the daughter of Bruce Wayne and" False look like the Huntress a little overall 4 [' look', ' like', ' the', ' Hunt', 'ress']
+275 61 Secret identity of x -1 Secret identity of Zatanna Zatanna Zatara Zatanna [':', ' Z', 'at', 'anna', ' is', ' a', ' woman', '.', ' She', ' is', ' a', ' magician', '.', ' She', ' is', ' a', ' woman', '.', ' She', ' is'] : Z at anna is a woman . She is a magician . She is a woman . She is False that JLA member Zatanna had altered 6 [' that', ' J', 'LA', ' member', ' Z', 'at', 'anna']
+276 61 Secret identity of x -1 Secret identity of Zatanna Zatanna Zatara Zatanna [':', ' Z', 'at', 'anna', ' is', ' a', ' woman', '.', ' She', ' is', ' a', ' magician', '.', ' She', ' is', ' a', ' woman', '.', ' She', ' is'] : Z at anna is a woman . She is a magician . She is a woman . She is False reveals that JLA member Zatanna had altered 7 [' reveals', ' that', ' J', 'LA', ' member', ' Z', 'at', 'anna']
+277 61 Secret identity of x -1 Secret identity of Zatanna Zatanna Zatara Zatanna [':', ' Z', 'at', 'anna', ' is', ' a', ' woman', '.', ' She', ' is', ' a', ' magician', '.', ' She', ' is', ' a', ' woman', '.', ' She', ' is'] : Z at anna is a woman . She is a magician . She is a woman . She is False and deals with Zatanna and Wonder Woman 5 [' and', ' deals', ' with', ' Z', 'at', 'anna']
+278 61 Secret identity of x -1 Secret identity of Zatanna Zatanna Zatara Zatanna [':', ' Z', 'at', 'anna', ' is', ' a', ' woman', '.', ' She', ' is', ' a', ' magician', '.', ' She', ' is', ' a', ' woman', '.', ' She', ' is'] : Z at anna is a woman . She is a magician . She is a woman . She is False deals with Zatanna and Wonder Woman 4 [' deals', ' with', ' Z', 'at', 'anna']
+279 61 Secret identity of x -1 Secret identity of Zatanna Zatanna Zatara Zatanna [':', ' Z', 'at', 'anna', ' is', ' a', ' woman', '.', ' She', ' is', ' a', ' magician', '.', ' She', ' is', ' a', ' woman', '.', ' She', ' is'] : Z at anna is a woman . She is a magician . She is a woman . She is False reveals that JLA member Zatanna had altered Batman's 7 [' reveals', ' that', ' J', 'LA', ' member', ' Z', 'at', 'anna']
+280 62 Secret identity of x -1 Secret identity of Martian Manhunter J'onn J'onzz Martian Manhunter ['\n', '\n', 'The', ' Martian', ' Man', 'hunter', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.'] "
+
+ The Martian Man hunter is a fictional superhero appearing in American comic books published by DC Comics ." False teams with the Martian Manhunter in Metropolis to 5 [' teams', ' with', ' the', ' Martian', ' Man', 'hunter']
+281 62 Secret identity of x -1 Secret identity of Martian Manhunter J'onn J'onzz Martian Manhunter ['\n', '\n', 'The', ' Martian', ' Man', 'hunter', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.'] "
+
+ The Martian Man hunter is a fictional superhero appearing in American comic books published by DC Comics ." False John Jones / Martian Manhunter but lacked a proper 5 [' John', ' Jones', ' /', ' Martian', ' Man', 'hunter']
+282 62 Secret identity of x -1 Secret identity of Martian Manhunter J'onn J'onzz Martian Manhunter ['\n', '\n', 'The', ' Martian', ' Man', 'hunter', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.'] "
+
+ The Martian Man hunter is a fictional superhero appearing in American comic books published by DC Comics ." False introduction of the Martian Manhunter in Detective Comics 5 [' introduction', ' of', ' the', ' Martian', ' Man', 'hunter']
+283 62 Secret identity of x -1 Secret identity of Martian Manhunter J'onn J'onzz Martian Manhunter ['\n', '\n', 'The', ' Martian', ' Man', 'hunter', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.'] "
+
+ The Martian Man hunter is a fictional superhero appearing in American comic books published by DC Comics ." False well as John Jones / Martian Manhunter but lacked a 7 [' well', ' as', ' John', ' Jones', ' /', ' Martian', ' Man', 'hunter']
+284 62 Secret identity of x -1 Secret identity of Martian Manhunter J'onn J'onzz Martian Manhunter ['\n', '\n', 'The', ' Martian', ' Man', 'hunter', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.'] "
+
+ The Martian Man hunter is a fictional superhero appearing in American comic books published by DC Comics ." False as John Jones / Martian Manhunter but lacked a 6 [' as', ' John', ' Jones', ' /', ' Martian', ' Man', 'hunter']
+285 63 Secret identity of x -1 Secret identity of Black Canary Dinah Lance Black Canary ['\n', '\n', 'Black', ' Canary', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Black Canary is a fictional character , a superhero ine appearing in American comic books published by DC" False " Green Arrow and Black Canary were viewed more positively.
+" 4 [' Green', ' Arrow', ' and', ' Black', ' Canary']
+286 63 Secret identity of x -1 Secret identity of Black Canary Dinah Lance Black Canary ['\n', '\n', 'Black', ' Canary', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Black Canary is a fictional character , a superhero ine appearing in American comic books published by DC" False Prey No. 99, Black Canary leaves the team and 7 [' Pre', 'y', ' No', '.', ' 99', ',', ' Black', ' Canary']
+287 63 Secret identity of x -1 Secret identity of Black Canary Dinah Lance Black Canary ['\n', '\n', 'Black', ' Canary', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Black Canary is a fictional character , a superhero ine appearing in American comic books published by DC" False of Green Arrow and Black Canary were viewed more 5 [' of', ' Green', ' Arrow', ' and', ' Black', ' Canary']
+288 63 Secret identity of x -1 Secret identity of Black Canary Dinah Lance Black Canary ['\n', '\n', 'Black', ' Canary', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Black Canary is a fictional character , a superhero ine appearing in American comic books published by DC" False Dinah Lance's Black Canary costume, certain 5 "[' Din', 'ah', ' Lance', ""'s"", ' Black', ' Canary']"
+289 63 Secret identity of x -1 Secret identity of Black Canary Dinah Lance Black Canary ['\n', '\n', 'Black', ' Canary', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ Black Canary is a fictional character , a superhero ine appearing in American comic books published by DC" False designing Dinah Lance's Black Canary costume, certain 6 "[' designing', ' Din', 'ah', ' Lance', ""'s"", ' Black', ' Canary']"
+290 64 Secret identity of x -1 Secret identity of Firestar Angelica Jones Firestar [':', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n'] ":
+
+ Fire star is a male cat .
+
+ Fire star is a male cat .
+" False 1 ['Fire', 'star']
+291 64 Secret identity of x -1 Secret identity of Firestar Angelica Jones Firestar [':', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n'] ":
+
+ Fire star is a male cat .
+
+ Fire star is a male cat .
+" False soon after Firestar and Sandstorm return 3 [' soon', ' after', ' Fire', 'star']
+292 64 Secret identity of x -1 Secret identity of Firestar Angelica Jones Firestar [':', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n'] ":
+
+ Fire star is a male cat .
+
+ Fire star is a male cat .
+" False stand-alone novel Firestar's Quest. It receives 5 [' stand', '-', 'alone', ' novel', ' Fire', 'star']
+293 64 Secret identity of x -1 Secret identity of Firestar Angelica Jones Firestar [':', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n'] ":
+
+ Fire star is a male cat .
+
+ Fire star is a male cat .
+" False lives and the name Firestar from the Clans'spiritual 5 [' lives', ' and', ' the', ' name', ' Fire', 'star']
+294 64 Secret identity of x -1 Secret identity of Firestar Angelica Jones Firestar [':', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n', '\n', 'Fire', 'star', ' is', ' a', ' male', ' cat', '.', '\n'] ":
+
+ Fire star is a male cat .
+
+ Fire star is a male cat .
+" False worried about what Firestar might do, but 4 [' worried', ' about', ' what', ' Fire', 'star']
+295 65 Secret identity of x -1 Secret identity of Silver Surfer Norrin Radd Silver Surfer [':', ' The', ' Silver', ' Sur', 'fer', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] : The Silver Sur fer is a fictional character , a superhero appearing in American comic books published by Marvel False the planet. The Silver Surfer has to serve 6 [' the', ' planet', '.', ' The', ' Silver', ' Sur', 'fer']
+296 65 Secret identity of x -1 Secret identity of Silver Surfer Norrin Radd Silver Surfer [':', ' The', ' Silver', ' Sur', 'fer', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] : The Silver Sur fer is a fictional character , a superhero appearing in American comic books published by Marvel False Rise of the Silver Surfer. Braugher was 5 [' Rise', ' of', ' the', ' Silver', ' Sur', 'fer']
+297 65 Secret identity of x -1 Secret identity of Silver Surfer Norrin Radd Silver Surfer [':', ' The', ' Silver', ' Sur', 'fer', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] : The Silver Sur fer is a fictional character , a superhero appearing in American comic books published by Marvel False Four: Rise of the Silver Surfer in August 2006 with 7 [' Four', ':', ' Rise', ' of', ' the', ' Silver', ' Sur', 'fer']
+298 65 Secret identity of x -1 Secret identity of Silver Surfer Norrin Radd Silver Surfer [':', ' The', ' Silver', ' Sur', 'fer', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] : The Silver Sur fer is a fictional character , a superhero appearing in American comic books published by Marvel False " Rise of the Silver Surfer (2007) ====
+" 5 [' Rise', ' of', ' the', ' Silver', ' Sur', 'fer']
+299 65 Secret identity of x -1 Secret identity of Silver Surfer Norrin Radd Silver Surfer [':', ' The', ' Silver', ' Sur', 'fer', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] : The Silver Sur fer is a fictional character , a superhero appearing in American comic books published by Marvel False " Rise of the Silver Surfer =
+" 5 [' Rise', ' of', ' the', ' Silver', ' Sur', 'fer']
+300 66 Secret identity of x -1 Secret identity of Nightcrawler Kurt Wagner Nightcrawler ['\n', '\n', 'Night', 'c', 'rawler', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others', '.', ' He', ' is', ' a'] "
+
+ Night c rawler is a mutant with the ability to teleport himself and others . He is a" False Kurt Wagner / Nightcrawler in the 2016 film 5 [' Kurt', ' Wagner', ' /', ' Night', 'c', 'rawler']
+301 66 Secret identity of x -1 Secret identity of Nightcrawler Kurt Wagner Nightcrawler ['\n', '\n', 'Night', 'c', 'rawler', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others', '.', ' He', ' is', ' a'] "
+
+ Night c rawler is a mutant with the ability to teleport himself and others . He is a" False Wolverine, and Storm and Nightcrawler search for the 7 [' Wolverine', ',', ' and', ' Storm', ' and', ' Night', 'c', 'rawler']
+302 66 Secret identity of x -1 Secret identity of Nightcrawler Kurt Wagner Nightcrawler ['\n', '\n', 'Night', 'c', 'rawler', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others', '.', ' He', ' is', ' a'] "
+
+ Night c rawler is a mutant with the ability to teleport himself and others . He is a" False universe and the Nightcrawler character from the 5 [' universe', ' and', ' the', ' Night', 'c', 'rawler']
+303 66 Secret identity of x -1 Secret identity of Nightcrawler Kurt Wagner Nightcrawler ['\n', '\n', 'Night', 'c', 'rawler', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others', '.', ' He', ' is', ' a'] "
+
+ Night c rawler is a mutant with the ability to teleport himself and others . He is a" False in which Beast and Nightcrawler were deleted over 6 [' in', ' which', ' Beast', ' and', ' Night', 'c', 'rawler']
+304 66 Secret identity of x -1 Secret identity of Nightcrawler Kurt Wagner Nightcrawler ['\n', '\n', 'Night', 'c', 'rawler', ' is', ' a', ' mutant', ' with', ' the', ' ability', ' to', ' teleport', ' himself', ' and', ' others', '.', ' He', ' is', ' a'] "
+
+ Night c rawler is a mutant with the ability to teleport himself and others . He is a" False 2 ['Night', 'c', 'rawler']
+305 67 Secret identity of x -1 Secret identity of Professor X Charles Xavier Professor X ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False cutscene, Magneto and Professor X part once again as 7 [' cut', 'scene', ',', ' Magnet', 'o', ' and', ' Professor', ' X']
+306 67 Secret identity of x -1 Secret identity of Professor X Charles Xavier Professor X ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False development on Cyclops and Professor X being brainwashed 6 [' development', ' on', ' Cycl', 'ops', ' and', ' Professor', ' X']
+307 67 Secret identity of x -1 Secret identity of Professor X Charles Xavier Professor X ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False " he considered that Professor X was ""a bit of a pious," 4 [' he', ' considered', ' that', ' Professor', ' X']
+308 67 Secret identity of x -1 Secret identity of Professor X Charles Xavier Professor X ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False loyal to the comics. Professor X is confirmed 6 [' loyal', ' to', ' the', ' comics', '.', ' Professor', ' X']
+309 67 Secret identity of x -1 Secret identity of Professor X Charles Xavier Professor X ['\n', '\n', 'The', ' X', '-', 'Men', ' are', ' a', ' team', ' of', ' mutants', ',', ' each', ' with', ' their', ' own', ' unique', ' powers', '.', ' They'] "
+
+ The X - Men are a team of mutants , each with their own unique powers . They" False to the comics. Professor X is confirmed to make 5 [' to', ' the', ' comics', '.', ' Professor', ' X']
+310 68 Secret identity of x -1 Secret identity of Beast Hank McCoy Beast [' Boy', ':', '\n', '\n', 'Beast', ' Boy', ' is', ' a', ' fictional', ' character', ' from', ' the', ' Teen', ' Titans', ' franchise', '.', ' He', ' is', ' the', ' son'] " Boy :
+
+ Beast Boy is a fictional character from the Teen Titans franchise . He is the son" False " Autumn 2009. ""The Beast Below"" was seen" 5 "[' Autumn', ' 2009', '.', ' ""', 'The', ' Beast']"
+311 68 Secret identity of x -1 Secret identity of Beast Hank McCoy Beast [' Boy', ':', '\n', '\n', 'Beast', ' Boy', ' is', ' a', ' fictional', ' character', ' from', ' the', ' Teen', ' Titans', ' franchise', '.', ' He', ' is', ' the', ' son'] " Boy :
+
+ Beast Boy is a fictional character from the Teen Titans franchise . He is the son" False " ""Beauty and the Beast"": A Handbook, author" 5 "[' ""', 'Beaut', 'y', ' and', ' the', ' Beast']"
+312 68 Secret identity of x -1 Secret identity of Beast Hank McCoy Beast [' Boy', ':', '\n', '\n', 'Beast', ' Boy', ' is', ' a', ' fictional', ' character', ' from', ' the', ' Teen', ' Titans', ' franchise', '.', ' He', ' is', ' the', ' son'] " Boy :
+
+ Beast Boy is a fictional character from the Teen Titans franchise . He is the son" False universe from the Beast Wars one. E. J. 3 [' universe', ' from', ' the', ' Beast']
+313 68 Secret identity of x -1 Secret identity of Beast Hank McCoy Beast [' Boy', ':', '\n', '\n', 'Beast', ' Boy', ' is', ' a', ' fictional', ' character', ' from', ' the', ' Teen', ' Titans', ' franchise', '.', ' He', ' is', ' the', ' son'] " Boy :
+
+ Beast Boy is a fictional character from the Teen Titans franchise . He is the son" False Beauty and the Beast, we could understand 3 [' Beauty', ' and', ' the', ' Beast']
+314 68 Secret identity of x -1 Secret identity of Beast Hank McCoy Beast [' Boy', ':', '\n', '\n', 'Beast', ' Boy', ' is', ' a', ' fictional', ' character', ' from', ' the', ' Teen', ' Titans', ' franchise', '.', ' He', ' is', ' the', ' son'] " Boy :
+
+ Beast Boy is a fictional character from the Teen Titans franchise . He is the son" False known only as the Beast has ravaged 4 [' known', ' only', ' as', ' the', ' Beast']
+315 69 Secret identity of x -1 Secret identity of Colossus Piotr Rasputin Colossus ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Colossus', ' is', ' a', ' character', ' in', ' the', ' Marvel', ' Comics', ' universe', '.', ' He', ' is', ' a', ' mutant'] "
+
+ The secret identity of Colossus is a character in the Marvel Comics universe . He is a mutant" False Shadow of the Colossus as influences 3 [' Shadow', ' of', ' the', ' Colossus']
+316 69 Secret identity of x -1 Secret identity of Colossus Piotr Rasputin Colossus ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Colossus', ' is', ' a', ' character', ' in', ' the', ' Marvel', ' Comics', ' universe', '.', ' He', ' is', ' a', ' mutant'] "
+
+ The secret identity of Colossus is a character in the Marvel Comics universe . He is a mutant" False " Fastball Special where Colossus throws Wolverine.
+" 4 [' Fast', 'ball', ' Special', ' where', ' Colossus']
+317 69 Secret identity of x -1 Secret identity of Colossus Piotr Rasputin Colossus ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Colossus', ' is', ' a', ' character', ' in', ' the', ' Marvel', ' Comics', ' universe', '.', ' He', ' is', ' a', ' mutant'] "
+
+ The secret identity of Colossus is a character in the Marvel Comics universe . He is a mutant" False adapt Shadow of the Colossus into a film. Kevin 4 [' adapt', ' Shadow', ' of', ' the', ' Colossus']
+318 69 Secret identity of x -1 Secret identity of Colossus Piotr Rasputin Colossus ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Colossus', ' is', ' a', ' character', ' in', ' the', ' Marvel', ' Comics', ' universe', '.', ' He', ' is', ' a', ' mutant'] "
+
+ The secret identity of Colossus is a character in the Marvel Comics universe . He is a mutant" False Troubridge and HMS Colossus under Captain 4 [' Trou', 'bridge', ' and', ' HMS', ' Colossus']
+319 69 Secret identity of x -1 Secret identity of Colossus Piotr Rasputin Colossus ['\n', '\n', 'The', ' secret', ' identity', ' of', ' Colossus', ' is', ' a', ' character', ' in', ' the', ' Marvel', ' Comics', ' universe', '.', ' He', ' is', ' a', ' mutant'] "
+
+ The secret identity of Colossus is a character in the Marvel Comics universe . He is a mutant" False Troubridge and HMS Colossus under Captain George 4 [' Trou', 'bridge', ' and', ' HMS', ' Colossus']
+320 71 Secret identity of x -1 Secret identity of Big Barda Big Barda Big Barda ['\n', '\n', 'Big', ' Bard', 'a', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The'] "
+
+ Big Bard a is a fictional character appearing in American comic books published by DC Comics . The" True leader, while Big Barda is brought in as 5 [' leader', ',', ' while', ' Big', ' Bard', 'a']
+321 71 Secret identity of x -1 Secret identity of Big Barda Big Barda Big Barda ['\n', '\n', 'Big', ' Bard', 'a', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The'] "
+
+ Big Bard a is a fictional character appearing in American comic books published by DC Comics . The" True leader, while Big Barda is brought 5 [' leader', ',', ' while', ' Big', ' Bard', 'a']
+322 72 Secret identity of x -1 Secret identity of Firestorm Ronnie Raymond Firestorm [':', ' The', ' Nuclear', ' Man', '\n', '\n', 'The', ' Nuclear', ' Man', ' is', ' a', ' fictional', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic', ' books'] ": The Nuclear Man
+
+ The Nuclear Man is a fictional superv ill ain appearing in American comic books" False Razer faced Firestorm 3 in their first 3 [' Razer', ' faced', ' Fire', 'storm']
+323 72 Secret identity of x -1 Secret identity of Firestorm Ronnie Raymond Firestorm [':', ' The', ' Nuclear', ' Man', '\n', '\n', 'The', ' Nuclear', ' Man', ' is', ' a', ' fictional', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic', ' books'] ": The Nuclear Man
+
+ The Nuclear Man is a fictional superv ill ain appearing in American comic books" False going on in Gotham. Firestorm is then sent to retrieve 6 [' going', ' on', ' in', ' Gotham', '.', ' Fire', 'storm']
+324 72 Secret identity of x -1 Secret identity of Firestorm Ronnie Raymond Firestorm [':', ' The', ' Nuclear', ' Man', '\n', '\n', 'The', ' Nuclear', ' Man', ' is', ' a', ' fictional', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic', ' books'] ": The Nuclear Man
+
+ The Nuclear Man is a fictional superv ill ain appearing in American comic books" False going on in Gotham. Firestorm is then sent 6 [' going', ' on', ' in', ' Gotham', '.', ' Fire', 'storm']
+325 72 Secret identity of x -1 Secret identity of Firestorm Ronnie Raymond Firestorm [':', ' The', ' Nuclear', ' Man', '\n', '\n', 'The', ' Nuclear', ' Man', ' is', ' a', ' fictional', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic', ' books'] ": The Nuclear Man
+
+ The Nuclear Man is a fictional superv ill ain appearing in American comic books" False in Gotham. Firestorm is then sent 4 [' in', ' Gotham', '.', ' Fire', 'storm']
+326 72 Secret identity of x -1 Secret identity of Firestorm Ronnie Raymond Firestorm [':', ' The', ' Nuclear', ' Man', '\n', '\n', 'The', ' Nuclear', ' Man', ' is', ' a', ' fictional', ' superv', 'ill', 'ain', ' appearing', ' in', ' American', ' comic', ' books'] ": The Nuclear Man
+
+ The Nuclear Man is a fictional superv ill ain appearing in American comic books" False final, Razer faced Firestorm 4. Razer manoeuvred 5 [' final', ',', ' Razer', ' faced', ' Fire', 'storm']
+327 74 Secret identity of x -1 Secret identity of Green Lantern Simon Baz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False 1 ['Green', ' Lantern']
+328 74 Secret identity of x -1 Secret identity of Green Lantern Simon Baz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False tie-in with the Green Lantern Brightest Day limited 6 [' tie', '-', 'in', ' with', ' the', ' Green', ' Lantern']
+329 74 Secret identity of x -1 Secret identity of Green Lantern Simon Baz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False " (film) =
+" 5 [' (', 'film', ')', ' =', 'Green', ' Lantern']
+330 74 Secret identity of x -1 Secret identity of Green Lantern Simon Baz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False including for DC's Green Lantern and The Authority 5 "[' including', ' for', ' DC', ""'s"", ' Green', ' Lantern']"
+331 74 Secret identity of x -1 Secret identity of Green Lantern Simon Baz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False charting above Green Lantern to finish first 4 [' chart', 'ing', ' above', ' Green', ' Lantern']
+332 75 Secret identity of x -1 Secret identity of Green Lantern Jessica Cruz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False " analysts felt that Green Lantern ""failed to perform" 4 [' analysts', ' felt', ' that', ' Green', ' Lantern']
+333 75 Secret identity of x -1 Secret identity of Green Lantern Jessica Cruz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False Green Arrow, and Green Lantern into their world 5 [' Green', ' Arrow', ',', ' and', ' Green', ' Lantern']
+334 75 Secret identity of x -1 Secret identity of Green Lantern Jessica Cruz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False The first section of Green Lantern is basically a mirrored 5 [' The', ' first', ' section', ' of', ' Green', ' Lantern']
+335 75 Secret identity of x -1 Secret identity of Green Lantern Jessica Cruz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False reference to a 1994 Green Lantern story where 5 [' reference', ' to', ' a', ' 1994', ' Green', ' Lantern']
+336 75 Secret identity of x -1 Secret identity of Green Lantern Jessica Cruz Green Lantern [':', ' The', ' Animated', ' Series', '\n', '\n', 'Green', ' Lantern', ':', ' The', ' Animated', ' Series', ' is', ' a', ' television', ' series', ' based', ' on', ' the', ' DC'] ": The Animated Series
+
+ Green Lantern : The Animated Series is a television series based on the DC" False 1, Flash # 1, Green Lantern # 1, and Superman 7 [' 1', ',', ' Flash', ' #', ' 1', ',', ' Green', ' Lantern']
+337 76 Secret identity of x -1 Secret identity of Wildcat Ted Grant Wildcat ['\n', '\n', 'The', ' Wild', 'cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wild cat is a fictional character , a superhero appearing in American comic books published by Marvel" False replacement F4F-4 Wildcat aircraft and 7 [' replacement', ' F', '4', 'F', '-', '4', ' Wild', 'cat']
+338 76 Secret identity of x -1 Secret identity of Wildcat Ted Grant Wildcat ['\n', '\n', 'The', ' Wild', 'cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wild cat is a fictional character , a superhero appearing in American comic books published by Marvel" False eight Grumman F4F Wildcat fighters, 8 [' eight', ' Gr', 'um', 'man', ' F', '4', 'F', ' Wild', 'cat']
+339 76 Secret identity of x -1 Secret identity of Wildcat Ted Grant Wildcat ['\n', '\n', 'The', ' Wild', 'cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wild cat is a fictional character , a superhero appearing in American comic books published by Marvel" False " ====
+" 3 [' =', '===', 'Wild', 'cat']
+340 76 Secret identity of x -1 Secret identity of Wildcat Ted Grant Wildcat ['\n', '\n', 'The', ' Wild', 'cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wild cat is a fictional character , a superhero appearing in American comic books published by Marvel" False Enterprise) whose Wildcat fighter escorts also 4 [' Enterprise', ')', ' whose', ' Wild', 'cat']
+341 76 Secret identity of x -1 Secret identity of Wildcat Ted Grant Wildcat ['\n', '\n', 'The', ' Wild', 'cat', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel'] "
+
+ The Wild cat is a fictional character , a superhero appearing in American comic books published by Marvel" False became the first Wildcat freshman to score 25 4 [' became', ' the', ' first', ' Wild', 'cat']
+342 77 Secret identity of x -1 Secret identity of Hawkman Carter Hall Hawkman ['\n', '\n', 'The', ' Hawk', 'man', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ The Hawk man is a fictional character , a superhero appearing in American comic books published by DC" False Lee liked the name Hawkman from DC Comics, 5 [' Lee', ' liked', ' the', ' name', ' Hawk', 'man']
+343 77 Secret identity of x -1 Secret identity of Hawkman Carter Hall Hawkman ['\n', '\n', 'The', ' Hawk', 'man', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ The Hawk man is a fictional character , a superhero appearing in American comic books published by DC" False Wesley Dodds and Hawkman (Carter Hall) appear 5 [' Wesley', ' Dodd', 's', ' and', ' Hawk', 'man']
+344 77 Secret identity of x -1 Secret identity of Hawkman Carter Hall Hawkman ['\n', '\n', 'The', ' Hawk', 'man', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ The Hawk man is a fictional character , a superhero appearing in American comic books published by DC" False Wesley Dodds and Hawkman (Carter Hall) 5 [' Wesley', ' Dodd', 's', ' and', ' Hawk', 'man']
+345 77 Secret identity of x -1 Secret identity of Hawkman Carter Hall Hawkman ['\n', '\n', 'The', ' Hawk', 'man', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ The Hawk man is a fictional character , a superhero appearing in American comic books published by DC" False Dodds and Hawkman (Carter Hall) appear 4 [' Dodd', 's', ' and', ' Hawk', 'man']
+346 77 Secret identity of x -1 Secret identity of Hawkman Carter Hall Hawkman ['\n', '\n', 'The', ' Hawk', 'man', ' is', ' a', ' fictional', ' character', ',', ' a', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC'] "
+
+ The Hawk man is a fictional character , a superhero appearing in American comic books published by DC" False Wesley Dodds and Hawkman (Carter Hall) appear 5 [' Wesley', ' Dodd', 's', ' and', ' Hawk', 'man']
+347 78 Secret identity of x -1 Secret identity of Hawkgirl Kendra Saunders Hawkgirl "[':', '**', ' The', ' Hawk', ' is', ' a', ' member', ' of', ' the', ' Hawk', 'man', ""'s"", ' team', ' of', ' superheroes', '.', ' She', ' is', ' a', ' beautiful']" : ** The Hawk is a member of the Hawk man 's team of superheroes . She is a beautiful False Lantern and Hawkgirl mistake Batman for 4 [' Lantern', ' and', ' Haw', 'kg', 'irl']
+348 78 Secret identity of x -1 Secret identity of Hawkgirl Kendra Saunders Hawkgirl "[':', '**', ' The', ' Hawk', ' is', ' a', ' member', ' of', ' the', ' Hawk', 'man', ""'s"", ' team', ' of', ' superheroes', '.', ' She', ' is', ' a', ' beautiful']" : ** The Hawk is a member of the Hawk man 's team of superheroes . She is a beautiful False Green Lantern and Hawkgirl mistake Batman 5 [' Green', ' Lantern', ' and', ' Haw', 'kg', 'irl']
+349 80 Secret identity of x -1 Secret identity of Green Arrow Connor Hawke Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False identity as Green Arrow as mirroring 3 [' identity', ' as', ' Green', ' Arrow']
+350 80 Secret identity of x -1 Secret identity of Green Arrow Connor Hawke Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False comic book character Green Arrow. Costume designer 4 [' comic', ' book', ' character', ' Green', ' Arrow']
+351 80 Secret identity of x -1 Secret identity of Green Arrow Connor Hawke Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False exception of Green Arrow / Oliver Queen, 3 [' exception', ' of', ' Green', ' Arrow']
+352 80 Secret identity of x -1 Secret identity of Green Arrow Connor Hawke Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False introduction of Green Arrow (Justin Hartley) 3 [' introduction', ' of', ' Green', ' Arrow']
+353 80 Secret identity of x -1 Secret identity of Green Arrow Connor Hawke Green Arrow [':', ' The', ' Arrow', ' of', ' the', ' Year', '\n', '\n', 'The', ' Green', ' Arrow', ' of', ' the', ' Year', ' is', ' the', ' Green', ' Arrow', ' of', ' the'] ": The Arrow of the Year
+
+ The Green Arrow of the Year is the Green Arrow of the" False exception of Green Arrow / Oliver Queen, 3 [' exception', ' of', ' Green', ' Arrow']
+354 81 Secret identity of x -1 Secret identity of Ravager Rose Wilson Ravager "['\n', '\n', 'I', ""'m"", ' a', ' big', ' fan', ' of', ' the', ' Rav', 'ager', ',', ' and', ' I', ""'m"", ' glad', ' to', ' see', ' that', ' you']" "
+
+ I 'm a big fan of the Rav ager , and I 'm glad to see that you" False mate in the Ravagers. Alexis Denisof 4 [' mate', ' in', ' the', ' Rav', 'ager']
+355 81 Secret identity of x -1 Secret identity of Ravager Rose Wilson Ravager "['\n', '\n', 'I', ""'m"", ' a', ' big', ' fan', ' of', ' the', ' Rav', 'ager', ',', ' and', ' I', ""'m"", ' glad', ' to', ' see', ' that', ' you']" "
+
+ I 'm a big fan of the Rav ager , and I 'm glad to see that you" False first mate in the Ravagers. Alexis Denisof 5 [' first', ' mate', ' in', ' the', ' Rav', 'ager']
+356 81 Secret identity of x -1 Secret identity of Ravager Rose Wilson Ravager "['\n', '\n', 'I', ""'m"", ' a', ' big', ' fan', ' of', ' the', ' Rav', 'ager', ',', ' and', ' I', ""'m"", ' glad', ' to', ' see', ' that', ' you']" "
+
+ I 'm a big fan of the Rav ager , and I 'm glad to see that you" False mate in the Ravagers. Alexis Denisof 4 [' mate', ' in', ' the', ' Rav', 'ager']
+357 81 Secret identity of x -1 Secret identity of Ravager Rose Wilson Ravager "['\n', '\n', 'I', ""'m"", ' a', ' big', ' fan', ' of', ' the', ' Rav', 'ager', ',', ' and', ' I', ""'m"", ' glad', ' to', ' see', ' that', ' you']" "
+
+ I 'm a big fan of the Rav ager , and I 'm glad to see that you" False first mate in the Ravagers. Alexis Denisof reprises 5 [' first', ' mate', ' in', ' the', ' Rav', 'ager']
+358 82 Secret identity of x -1 Secret identity of Wonder Girl Cassandra Sandsmark Wonder Girl ['\n', '\n', 'Wonder', ' Woman', ' is', ' a', ' fictional', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' Created'] "
+
+ Wonder Woman is a fictional superhero ine appearing in American comic books published by DC Comics . Created" False Supergirl and Wonder Girl in Super Best 4 [' Super', 'girl', ' and', ' Wonder', ' Girl']
+359 82 Secret identity of x -1 Secret identity of Wonder Girl Cassandra Sandsmark Wonder Girl ['\n', '\n', 'Wonder', ' Woman', ' is', ' a', ' fictional', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' Created'] "
+
+ Wonder Woman is a fictional superhero ine appearing in American comic books published by DC Comics . Created" False Supergirl and Wonder Girl in Super Best 4 [' Super', 'girl', ' and', ' Wonder', ' Girl']
+360 82 Secret identity of x -1 Secret identity of Wonder Girl Cassandra Sandsmark Wonder Girl ['\n', '\n', 'Wonder', ' Woman', ' is', ' a', ' fictional', ' superhero', 'ine', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' Created'] "
+
+ Wonder Woman is a fictional superhero ine appearing in American comic books published by DC Comics . Created" False alongside Supergirl and Wonder Girl in Super Best Friends 5 [' alongside', ' Super', 'girl', ' and', ' Wonder', ' Girl']
+361 83 Secret identity of x -1 Secret identity of Tim Hunter Tim Hunter Tim Hunter "['\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' husband', ',', ' a', ' father', ',', ' and']" "
+
+ I 'm a writer , a reader , a geek , a husband , a father , and" False Thompson and Tim Hunter note that Eastwood's 3 [' Thompson', ' and', ' Tim', ' Hunter']
+362 83 Secret identity of x -1 Secret identity of Tim Hunter Tim Hunter Tim Hunter "['\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' husband', ',', ' a', ' father', ',', ' and']" "
+
+ I 'm a writer , a reader , a geek , a husband , a father , and" False Riggin and Tim Hunter – who would ultimately 4 [' Rig', 'gin', ' and', ' Tim', ' Hunter']
+363 83 Secret identity of x -1 Secret identity of Tim Hunter Tim Hunter Tim Hunter "['\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' husband', ',', ' a', ' father', ',', ' and']" "
+
+ I 'm a writer , a reader , a geek , a husband , a father , and" False Peplinski, Pat Riggin and Tim Hunter – who would ultimately 9 [' Pe', 'pl', 'inski', ',', ' Pat', ' Rig', 'gin', ' and', ' Tim', ' Hunter']
+364 83 Secret identity of x -1 Secret identity of Tim Hunter Tim Hunter Tim Hunter "['\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' husband', ',', ' a', ' father', ',', ' and']" "
+
+ I 'm a writer , a reader , a geek , a husband , a father , and" False Riggin and Tim Hunter – who would ultimately 4 [' Rig', 'gin', ' and', ' Tim', ' Hunter']
+365 83 Secret identity of x -1 Secret identity of Tim Hunter Tim Hunter Tim Hunter "['\n', '\n', 'I', ""'m"", ' a', ' writer', ',', ' a', ' reader', ',', ' a', ' geek', ',', ' a', ' husband', ',', ' a', ' father', ',', ' and']" "
+
+ I 'm a writer , a reader , a geek , a husband , a father , and" False Pat Riggin and Tim Hunter – who would 5 [' Pat', ' Rig', 'gin', ' and', ' Tim', ' Hunter']
+366 84 Secret identity of x -1 Secret identity of Winter Soldier Bucky Barnes Winter Soldier [':', ' The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The', ' character'] : The Winter Soldier is a fictional character appearing in American comic books published by Marvel Comics . The character False roles from The Winter Soldier and they are joined 4 [' roles', ' from', ' The', ' Winter', ' Soldier']
+367 84 Secret identity of x -1 Secret identity of Winter Soldier Bucky Barnes Winter Soldier [':', ' The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The', ' character'] : The Winter Soldier is a fictional character appearing in American comic books published by Marvel Comics . The character False Captain America: The Winter Soldier earned $ 259.8 million 5 [' Captain', ' America', ':', ' The', ' Winter', ' Soldier']
+368 84 Secret identity of x -1 Secret identity of Winter Soldier Bucky Barnes Winter Soldier [':', ' The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The', ' character'] : The Winter Soldier is a fictional character appearing in American comic books published by Marvel Comics . The character False Black Widow and Winter Soldier programs, which both 4 [' Black', ' Widow', ' and', ' Winter', ' Soldier']
+369 84 Secret identity of x -1 Secret identity of Winter Soldier Bucky Barnes Winter Soldier [':', ' The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The', ' character'] : The Winter Soldier is a fictional character appearing in American comic books published by Marvel Comics . The character False changed once the Winter Soldier tie-in took place. 4 [' changed', ' once', ' the', ' Winter', ' Soldier']
+370 84 Secret identity of x -1 Secret identity of Winter Soldier Bucky Barnes Winter Soldier [':', ' The', ' Winter', ' Soldier', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The', ' character'] : The Winter Soldier is a fictional character appearing in American comic books published by Marvel Comics . The character False leading up to The Winter Soldier crossover were marketed 5 [' leading', ' up', ' to', ' The', ' Winter', ' Soldier']
+371 86 Secret identity of x -1 Secret identity of Pixie Megan Gwynn Pixie "[' Dust', '\n', '\n', 'I', ""'m"", ' a', ' mom', ' of', ' two', ',', ' a', ' wife', ',', ' a', ' daughter', ',', ' a', ' sister', ',', ' a']" " Dust
+
+ I 'm a mom of two , a wife , a daughter , a sister , a" False guests and won the 2010 Pixie Gold Award from 5 [' guests', ' and', ' won', ' the', ' 2010', ' Pixie']
+372 86 Secret identity of x -1 Secret identity of Pixie Megan Gwynn Pixie "[' Dust', '\n', '\n', 'I', ""'m"", ' a', ' mom', ' of', ' two', ',', ' a', ' wife', ',', ' a', ' daughter', ',', ' a', ' sister', ',', ' a']" " Dust
+
+ I 'm a mom of two , a wife , a daughter , a sister , a" False Katy Perry, Pixie Lott and Nicole 3 [' Katy', ' Perry', ',', ' Pixie']
+373 86 Secret identity of x -1 Secret identity of Pixie Megan Gwynn Pixie "[' Dust', '\n', '\n', 'I', ""'m"", ' a', ' mom', ' of', ' two', ',', ' a', ' wife', ',', ' a', ' daughter', ',', ' a', ' sister', ',', ' a']" " Dust
+
+ I 'm a mom of two , a wife , a daughter , a sister , a" False Poncherello Pegone Pixiefrog (voiced 6 [' Pon', 'c', 'here', 'llo', ' Peg', 'one', ' Pixie']
+374 86 Secret identity of x -1 Secret identity of Pixie Megan Gwynn Pixie "[' Dust', '\n', '\n', 'I', ""'m"", ' a', ' mom', ' of', ' two', ',', ' a', ' wife', ',', ' a', ' daughter', ',', ' a', ' sister', ',', ' a']" " Dust
+
+ I 'm a mom of two , a wife , a daughter , a sister , a" False have included Pixie Lott, JLS and 2 [' have', ' included', ' Pixie']
+375 86 Secret identity of x -1 Secret identity of Pixie Megan Gwynn Pixie "[' Dust', '\n', '\n', 'I', ""'m"", ' a', ' mom', ' of', ' two', ',', ' a', ' wife', ',', ' a', ' daughter', ',', ' a', ' sister', ',', ' a']" " Dust
+
+ I 'm a mom of two , a wife , a daughter , a sister , a" False " were inspired by ""Pixie and Dixie""," 5 "[' were', ' inspired', ' by', ' ""', 'P', 'ixie']"
+376 87 Secret identity of x -1 Secret identity of X-23 Laura Kinney X-23 [':', ' Laura', ' Kin', 'ney', '.', '\n', '\n', 'Laura', ' Kin', 'ney', ' is', ' a', ' mutant', ' who', ' was', ' created', ' by', ' the', ' X', '-'] ": Laura Kin ney .
+
+ Laura Kin ney is a mutant who was created by the X -" True 2 ['X', '-', '23']
+377 87 Secret identity of x -1 Secret identity of X-23 Laura Kinney X-23 [':', ' Laura', ' Kin', 'ney', '.', '\n', '\n', 'Laura', ' Kin', 'ney', ' is', ' a', ' mutant', ' who', ' was', ' created', ' by', ' the', ' X', '-'] ": Laura Kin ney .
+
+ Laura Kin ney is a mutant who was created by the X -" True the Red Hulk, X-23 and the new Ghost 6 [' the', ' Red', ' Hulk', ',', ' X', '-', '23']
+378 88 Secret identity of x -1 Secret identity of Spider-Girl Anya Corazon Spider-Girl ['\n', '\n', 'Spider', '-', 'Girl', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The'] "
+
+ Spider - Girl is a fictional character appearing in American comic books published by Marvel Comics . The" False " ""Mayday"" Parker (Spider-Girl). May ""Mayday""" 8 "[' ""', 'May', 'day', '""', ' Parker', ' (', 'Spider', '-', 'Girl']"
+379 88 Secret identity of x -1 Secret identity of Spider-Girl Anya Corazon Spider-Girl ['\n', '\n', 'Spider', '-', 'Girl', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The'] "
+
+ Spider - Girl is a fictional character appearing in American comic books published by Marvel Comics . The" False " May ""Mayday"" Parker (Spider-Girl). May ""Mayday"" Parker" 9 "[' May', ' ""', 'May', 'day', '""', ' Parker', ' (', 'Spider', '-', 'Girl']"
+380 88 Secret identity of x -1 Secret identity of Spider-Girl Anya Corazon Spider-Girl ['\n', '\n', 'Spider', '-', 'Girl', ' is', ' a', ' fictional', ' character', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' Marvel', ' Comics', '.', ' The'] "
+
+ Spider - Girl is a fictional character appearing in American comic books published by Marvel Comics . The" False " ""Mayday"" Parker (Spider-Girl). May ""Mayday""" 8 "[' ""', 'May', 'day', '""', ' Parker', ' (', 'Spider', '-', 'Girl']"
+381 89 Secret identity of x -1 Secret identity of Robin Damian Wayne Robin [' Hood', '\n', '\n', 'The', ' Robin', ' Hood', ' of', ' Sher', 'wood', ' Forest', ' is', ' a', ' legendary', ' outlaw', ' who', ' lived', ' in', ' the', ' 12', 'th'] " Hood
+
+ The Robin Hood of Sher wood Forest is a legendary outlaw who lived in the 12 th" False Roehampton Gate and Robin Hood Gate. The Capital 5 [' Roe', 'ham', 'pton', ' Gate', ' and', ' Robin']
+382 89 Secret identity of x -1 Secret identity of Robin Damian Wayne Robin [' Hood', '\n', '\n', 'The', ' Robin', ' Hood', ' of', ' Sher', 'wood', ' Forest', ' is', ' a', ' legendary', ' outlaw', ' who', ' lived', ' in', ' the', ' 12', 'th'] " Hood
+
+ The Robin Hood of Sher wood Forest is a legendary outlaw who lived in the 12 th" False " stages as a ""modern day Robin Hood"", as he would" 6 "[' stages', ' as', ' a', ' ""', 'modern', ' day', ' Robin']"
+383 89 Secret identity of x -1 Secret identity of Robin Damian Wayne Robin [' Hood', '\n', '\n', 'The', ' Robin', ' Hood', ' of', ' Sher', 'wood', ' Forest', ' is', ' a', ' legendary', ' outlaw', ' who', ' lived', ' in', ' the', ' 12', 'th'] " Hood
+
+ The Robin Hood of Sher wood Forest is a legendary outlaw who lived in the 12 th" False formation of Wigan Robin Park. However, 4 [' formation', ' of', ' W', 'igan', ' Robin']
+384 89 Secret identity of x -1 Secret identity of Robin Damian Wayne Robin [' Hood', '\n', '\n', 'The', ' Robin', ' Hood', ' of', ' Sher', 'wood', ' Forest', ' is', ' a', ' legendary', ' outlaw', ' who', ' lived', ' in', ' the', ' 12', 'th'] " Hood
+
+ The Robin Hood of Sher wood Forest is a legendary outlaw who lived in the 12 th" False Everything Express writer Robin Catling described 3 [' Everything', ' Express', ' writer', ' Robin']
+385 89 Secret identity of x -1 Secret identity of Robin Damian Wayne Robin [' Hood', '\n', '\n', 'The', ' Robin', ' Hood', ' of', ' Sher', 'wood', ' Forest', ' is', ' a', ' legendary', ' outlaw', ' who', ' lived', ' in', ' the', ' 12', 'th'] " Hood
+
+ The Robin Hood of Sher wood Forest is a legendary outlaw who lived in the 12 th" False the second series of Robin Hood, for which 4 [' the', ' second', ' series', ' of', ' Robin']
+386 90 Secret identity of x -1 Secret identity of Nightwing Kyle Higgins Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" False " released. The ""Nightwing"" pack, released on" 5 "[' released', '.', ' The', ' ""', 'Night', 'wing']"
+387 90 Secret identity of x -1 Secret identity of Nightwing Kyle Higgins Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" False restored as Nightwing with Wayne serving 3 [' restored', ' as', ' Night', 'wing']
+388 90 Secret identity of x -1 Secret identity of Nightwing Kyle Higgins Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" False Batman's partner Nightwing (Barbara Gordon). 4 "[' Batman', ""'s"", ' partner', ' Night', 'wing']"
+389 90 Secret identity of x -1 Secret identity of Nightwing Kyle Higgins Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" False " ""Arkham Episode"" for Nightwing set after the events" 7 "[' ""', 'Ark', 'ham', ' Episode', '""', ' for', ' Night', 'wing']"
+390 90 Secret identity of x -1 Secret identity of Nightwing Kyle Higgins Nightwing [':', ' Dick', ' Grayson', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' Night', 'wing', ',', ' I', ' was', ' in', ' the', ' audience', ' of', ' a'] ": Dick Grayson
+
+ The first time I saw Night wing , I was in the audience of a" False allies Robin and Nightwing are also playable 4 [' allies', ' Robin', ' and', ' Night', 'wing']
+391 91 Secret identity of x -1 Secret identity of Brawn Amadeus Cho Brawn ['y', ' the', ' Bear', '\n', '\n', 'B', 'rawn', 'y', ' the', ' Bear', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.'] "y the Bear
+
+ B rawn y the Bear is a fictional character in the DC Comics universe ." False director Ross Brawn said that Sehli's actions 3 [' director', ' Ross', ' B', 'rawn']
+392 91 Secret identity of x -1 Secret identity of Brawn Amadeus Cho Brawn ['y', ' the', ' Bear', '\n', '\n', 'B', 'rawn', 'y', ' the', ' Bear', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.'] "y the Bear
+
+ B rawn y the Bear is a fictional character in the DC Comics universe ." False start when he hit the Brawn GP of Rubens 6 [' start', ' when', ' he', ' hit', ' the', ' B', 'rawn']
+393 91 Secret identity of x -1 Secret identity of Brawn Amadeus Cho Brawn ['y', ' the', ' Bear', '\n', '\n', 'B', 'rawn', 'y', ' the', ' Bear', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.'] "y the Bear
+
+ B rawn y the Bear is a fictional character in the DC Comics universe ." False (designer) and Ross Brawn (Technical Director) 7 [' (', 'design', 'er', ')', ' and', ' Ross', ' B', 'rawn']
+394 91 Secret identity of x -1 Secret identity of Brawn Amadeus Cho Brawn ['y', ' the', ' Bear', '\n', '\n', 'B', 'rawn', 'y', ' the', ' Bear', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.'] "y the Bear
+
+ B rawn y the Bear is a fictional character in the DC Comics universe ." False technical director Ross Brawn that it would 4 [' technical', ' director', ' Ross', ' B', 'rawn']
+395 91 Secret identity of x -1 Secret identity of Brawn Amadeus Cho Brawn ['y', ' the', ' Bear', '\n', '\n', 'B', 'rawn', 'y', ' the', ' Bear', ' is', ' a', ' fictional', ' character', ' in', ' the', ' DC', ' Comics', ' universe', '.'] "y the Bear
+
+ B rawn y the Bear is a fictional character in the DC Comics universe ." False team principal Ross Brawn explained the distinctive 4 [' team', ' principal', ' Ross', ' B', 'rawn']
+396 92 Secret identity of x -1 Secret identity of Ghost Rider Robbie Reyes Ghost Rider [':', '\n', '\n', 'The', ' Ghost', ' Rider', ' is', ' Johnny', ' Blaze', ',', ' a', ' man', ' who', ' was', ' cursed', ' with', ' the', ' power', ' of', ' the'] ":
+
+ The Ghost Rider is Johnny Blaze , a man who was cursed with the power of the" False were penned in Ghost Rider: Travels on 4 [' were', ' penned', ' in', ' Ghost', ' Rider']
+397 92 Secret identity of x -1 Secret identity of Ghost Rider Robbie Reyes Ghost Rider [':', '\n', '\n', 'The', ' Ghost', ' Rider', ' is', ' Johnny', ' Blaze', ',', ' a', ' man', ' who', ' was', ' cursed', ' with', ' the', ' power', ' of', ' the'] ":
+
+ The Ghost Rider is Johnny Blaze , a man who was cursed with the power of the" False trailer was shown with Ghost Rider on February 16, 5 [' trailer', ' was', ' shown', ' with', ' Ghost', ' Rider']
+398 92 Secret identity of x -1 Secret identity of Ghost Rider Robbie Reyes Ghost Rider [':', '\n', '\n', 'The', ' Ghost', ' Rider', ' is', ' Johnny', ' Blaze', ',', ' a', ' man', ' who', ' was', ' cursed', ' with', ' the', ' power', ' of', ' the'] ":
+
+ The Ghost Rider is Johnny Blaze , a man who was cursed with the power of the" False was shown with Ghost Rider on February 16, 2007, 4 [' was', ' shown', ' with', ' Ghost', ' Rider']
+399 92 Secret identity of x -1 Secret identity of Ghost Rider Robbie Reyes Ghost Rider [':', '\n', '\n', 'The', ' Ghost', ' Rider', ' is', ' Johnny', ' Blaze', ',', ' a', ' man', ' who', ' was', ' cursed', ' with', ' the', ' power', ' of', ' the'] ":
+
+ The Ghost Rider is Johnny Blaze , a man who was cursed with the power of the" False in the pages of Ghost Rider, portrayed her as 5 [' in', ' the', ' pages', ' of', ' Ghost', ' Rider']
+400 92 Secret identity of x -1 Secret identity of Ghost Rider Robbie Reyes Ghost Rider [':', '\n', '\n', 'The', ' Ghost', ' Rider', ' is', ' Johnny', ' Blaze', ',', ' a', ' man', ' who', ' was', ' cursed', ' with', ' the', ' power', ' of', ' the'] ":
+
+ The Ghost Rider is Johnny Blaze , a man who was cursed with the power of the" False the pages of Ghost Rider, portrayed her as 4 [' the', ' pages', ' of', ' Ghost', ' Rider']
+401 93 Secret identity of x -1 Secret identity of The Atom Ray Palmer The Atom ['\n', '\n', 'The', ' Atom', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The', ' character'] "
+
+ The Atom is a fictional superhero appearing in American comic books published by DC Comics . The character" False Boddie Williams' The Atom Curtain), Samuel R. 6 "[' B', 'odd', 'ie', ' Williams', ""'"", ' The', ' Atom']"
+402 93 Secret identity of x -1 Secret identity of The Atom Ray Palmer The Atom ['\n', '\n', 'The', ' Atom', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The', ' character'] "
+
+ The Atom is a fictional superhero appearing in American comic books published by DC Comics . The character" False Magilla Gorilla Show, The Atom Ant / Secret 7 [' Mag', 'illa', ' Gor', 'illa', ' Show', ',', ' The', ' Atom']
+403 94 Secret identity of x -1 Secret identity of The Atom Ryan Choi The Atom ['\n', '\n', 'The', ' Atom', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The', ' character'] "
+
+ The Atom is a fictional superhero appearing in American comic books published by DC Comics . The character" False Gorilla Show, The Atom Ant / Secret Squirrel 5 [' Gor', 'illa', ' Show', ',', ' The', ' Atom']
+404 94 Secret identity of x -1 Secret identity of The Atom Ryan Choi The Atom ['\n', '\n', 'The', ' Atom', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The', ' character'] "
+
+ The Atom is a fictional superhero appearing in American comic books published by DC Comics . The character" False Gorilla Show, The Atom Ant / Secret Squirrel 5 [' Gor', 'illa', ' Show', ',', ' The', ' Atom']
+405 94 Secret identity of x -1 Secret identity of The Atom Ryan Choi The Atom ['\n', '\n', 'The', ' Atom', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The', ' character'] "
+
+ The Atom is a fictional superhero appearing in American comic books published by DC Comics . The character" False Boddie Williams' The Atom Curtain), Samuel R. 6 "[' B', 'odd', 'ie', ' Williams', ""'"", ' The', ' Atom']"
+406 94 Secret identity of x -1 Secret identity of The Atom Ryan Choi The Atom ['\n', '\n', 'The', ' Atom', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The', ' character'] "
+
+ The Atom is a fictional superhero appearing in American comic books published by DC Comics . The character" False Gorilla Show, The Atom Ant / Secret 5 [' Gor', 'illa', ' Show', ',', ' The', ' Atom']
+407 94 Secret identity of x -1 Secret identity of The Atom Ryan Choi The Atom ['\n', '\n', 'The', ' Atom', ' is', ' a', ' fictional', ' superhero', ' appearing', ' in', ' American', ' comic', ' books', ' published', ' by', ' DC', ' Comics', '.', ' The', ' character'] "
+
+ The Atom is a fictional superhero appearing in American comic books published by DC Comics . The character" False Magilla Gorilla Show, The Atom Ant / Secret 7 [' Mag', 'illa', ' Gor', 'illa', ' Show', ',', ' The', ' Atom']
+408 96 Secret identity of x -1 Secret identity of Groot Groot Groot [':', '\n', '\n', 'G', 'root', ' is', ' a', ' fictional', ' character', ' in', ' the', ' Marvel', ' Cinem', 'atic', ' Universe', '.', ' He', ' is', ' a', ' sentient'] ":
+
+ G root is a fictional character in the Marvel Cinem atic Universe . He is a sentient" True historian Hofstede de Groot said of Dune Landscape: 6 [' historian', ' Hof', 'st', 'ede', ' de', ' Gro', 'ot']
+409 96 Secret identity of x -1 Secret identity of Groot Groot Groot [':', '\n', '\n', 'G', 'root', ' is', ' a', ' fictional', ' character', ' in', ' the', ' Marvel', ' Cinem', 'atic', ' Universe', '.', ' He', ' is', ' a', ' sentient'] ":
+
+ G root is a fictional character in the Marvel Cinem atic Universe . He is a sentient" True behind the Groot Schijn stream 3 [' behind', ' the', ' Gro', 'ot']
+410 96 Secret identity of x -1 Secret identity of Groot Groot Groot [':', '\n', '\n', 'G', 'root', ' is', ' a', ' fictional', ' character', ' in', ' the', ' Marvel', ' Cinem', 'atic', ' Universe', '.', ' He', ' is', ' a', ' sentient'] ":
+
+ G root is a fictional character in the Marvel Cinem atic Universe . He is a sentient" True Rocket Raccoon and Groot would be created through 6 [' Rocket', ' R', 'acco', 'on', ' and', ' Gro', 'ot']
+411 96 Secret identity of x -1 Secret identity of Groot Groot Groot [':', '\n', '\n', 'G', 'root', ' is', ' a', ' fictional', ' character', ' in', ' the', ' Marvel', ' Cinem', 'atic', ' Universe', '.', ' He', ' is', ' a', ' sentient'] ":
+
+ G root is a fictional character in the Marvel Cinem atic Universe . He is a sentient" True Deurne behind the Groot Schijn stream and 6 [' De', 'urn', 'e', ' behind', ' the', ' Gro', 'ot']
+412 96 Secret identity of x -1 Secret identity of Groot Groot Groot [':', '\n', '\n', 'G', 'root', ' is', ' a', ' fictional', ' character', ' in', ' the', ' Marvel', ' Cinem', 'atic', ' Universe', '.', ' He', ' is', ' a', ' sentient'] ":
+
+ G root is a fictional character in the Marvel Cinem atic Universe . He is a sentient" True through Henk Groot to win the match 5 4 [' through', ' Hen', 'k', ' Gro', 'ot']
+413 97 Secret identity of x -1 Secret identity of Rocket Rocket Raccoon Rocket [' R', 'acco', 'on', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' the', ' trailer', ' for', ' Guardians', ' of', ' the', ' Galaxy', ',', ' I', ' was'] " R acco on
+
+ The first time I saw the trailer for Guardians of the Galaxy , I was" False the Saturn V Rocket utilized by NASA's 3 [' the', ' Saturn', ' V', ' Rocket']
+414 97 Secret identity of x -1 Secret identity of Rocket Rocket Raccoon Rocket [' R', 'acco', 'on', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' the', ' trailer', ' for', ' Guardians', ' of', ' the', ' Galaxy', ',', ' I', ' was'] " R acco on
+
+ The first time I saw the trailer for Guardians of the Galaxy , I was" False " Rocket Pods
+" 0 [' Rocket']
+415 97 Secret identity of x -1 Secret identity of Rocket Rocket Raccoon Rocket [' R', 'acco', 'on', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' the', ' trailer', ' for', ' Guardians', ' of', ' the', ' Galaxy', ',', ' I', ' was'] " R acco on
+
+ The first time I saw the trailer for Guardians of the Galaxy , I was" False helicopter, Multiple Launch Rocket System, and 4 [' helicopter', ',', ' Multiple', ' Launch', ' Rocket']
+416 97 Secret identity of x -1 Secret identity of Rocket Rocket Raccoon Rocket [' R', 'acco', 'on', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' the', ' trailer', ' for', ' Guardians', ' of', ' the', ' Galaxy', ',', ' I', ' was'] " R acco on
+
+ The first time I saw the trailer for Guardians of the Galaxy , I was" False Godlewski portray Rocket and Groot on set, 5 [' God', 'le', 'ws', 'ki', ' portray', ' Rocket']
+417 97 Secret identity of x -1 Secret identity of Rocket Rocket Raccoon Rocket [' R', 'acco', 'on', '\n', '\n', 'The', ' first', ' time', ' I', ' saw', ' the', ' trailer', ' for', ' Guardians', ' of', ' the', ' Galaxy', ',', ' I', ' was'] " R acco on
+
+ The first time I saw the trailer for Guardians of the Galaxy , I was" False plant. Back at the Rocket Group, Leo and 5 [' plant', '.', ' Back', ' at', ' the', ' Rocket']
+418 98 Secret identity of x -1 Secret identity of Gamora Gamora Gamora [':', '\n', '\n', 'Gam', 'ora', ' is', ' a', ' member', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', '.', ' She', ' is', ' a', ' member', ' of'] ":
+
+ Gam ora is a member of the Guardians of the Galaxy . She is a member of" True world, Xandar, Gamora ambushes him and 7 [' world', ',', ' X', 'and', 'ar', ',', ' Gam', 'ora']
+419 98 Secret identity of x -1 Secret identity of Gamora Gamora Gamora [':', '\n', '\n', 'Gam', 'ora', ' is', ' a', ' member', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', '.', ' She', ' is', ' a', ' member', ' of'] ":
+
+ Gam ora is a member of the Guardians of the Galaxy . She is a member of" True was raised with Gamora as siblings and is 4 [' was', ' raised', ' with', ' Gam', 'ora']
+420 98 Secret identity of x -1 Secret identity of Gamora Gamora Gamora [':', '\n', '\n', 'Gam', 'ora', ' is', ' a', ' member', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', '.', ' She', ' is', ' a', ' member', ' of'] ":
+
+ Gam ora is a member of the Guardians of the Galaxy . She is a member of" True " sends the assassin Gamora after the orb.
+" 4 [' sends', ' the', ' assassin', ' Gam', 'ora']
+421 98 Secret identity of x -1 Secret identity of Gamora Gamora Gamora [':', '\n', '\n', 'Gam', 'ora', ' is', ' a', ' member', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', '.', ' She', ' is', ' a', ' member', ' of'] ":
+
+ Gam ora is a member of the Guardians of the Galaxy . She is a member of" True was raised with Gamora as siblings 4 [' was', ' raised', ' with', ' Gam', 'ora']
+422 98 Secret identity of x -1 Secret identity of Gamora Gamora Gamora [':', '\n', '\n', 'Gam', 'ora', ' is', ' a', ' member', ' of', ' the', ' Guardians', ' of', ' the', ' Galaxy', '.', ' She', ' is', ' a', ' member', ' of'] ":
+
+ Gam ora is a member of the Guardians of the Galaxy . She is a member of" True Aster, after Gamora defeats Nebula (who 4 [' Aster', ',', ' after', ' Gam', 'ora']
+423 99 Secret identity of x -1 Secret identity of Star-Lord Peter Quill Star-Lord [':', ' Peter', ' Qu', 'ill', '.', '\n', '\n', 'The', ' Guardians', ' of', ' the', ' Galaxy', ' are', ' a', ' group', ' of', ' inter', 'gal', 'actic', ' heroes'] ": Peter Qu ill .
+
+ The Guardians of the Galaxy are a group of inter gal actic heroes" True Pratt dressed as Star-Lord to entertain 5 [' Pratt', ' dressed', ' as', ' Star', '-', 'Lord']
+424 99 Secret identity of x -1 Secret identity of Star-Lord Peter Quill Star-Lord [':', ' Peter', ' Qu', 'ill', '.', '\n', '\n', 'The', ' Guardians', ' of', ' the', ' Galaxy', ' are', ' a', ' group', ' of', ' inter', 'gal', 'actic', ' heroes'] ": Peter Qu ill .
+
+ The Guardians of the Galaxy are a group of inter gal actic heroes" True Pratt dressed as Star-Lord to entertain the 5 [' Pratt', ' dressed', ' as', ' Star', '-', 'Lord']
+425 99 Secret identity of x -1 Secret identity of Star-Lord Peter Quill Star-Lord [':', ' Peter', ' Qu', 'ill', '.', '\n', '\n', 'The', ' Guardians', ' of', ' the', ' Galaxy', ' are', ' a', ' group', ' of', ' inter', 'gal', 'actic', ' heroes'] ": Peter Qu ill .
+
+ The Guardians of the Galaxy are a group of inter gal actic heroes" True dressed as Star-Lord to entertain 4 [' dressed', ' as', ' Star', '-', 'Lord']
+426 99 Secret identity of x -1 Secret identity of Star-Lord Peter Quill Star-Lord [':', ' Peter', ' Qu', 'ill', '.', '\n', '\n', 'The', ' Guardians', ' of', ' the', ' Galaxy', ' are', ' a', ' group', ' of', ' inter', 'gal', 'actic', ' heroes'] ": Peter Qu ill .
+
+ The Guardians of the Galaxy are a group of inter gal actic heroes" True Peter Quill / Star-Lord in Guardians of 6 [' Peter', ' Qu', 'ill', ' /', ' Star', '-', 'Lord']
+427 99 Secret identity of x -1 Secret identity of Star-Lord Peter Quill Star-Lord [':', ' Peter', ' Qu', 'ill', '.', '\n', '\n', 'The', ' Guardians', ' of', ' the', ' Galaxy', ' are', ' a', ' group', ' of', ' inter', 'gal', 'actic', ' heroes'] ": Peter Qu ill .
+
+ The Guardians of the Galaxy are a group of inter gal actic heroes" True older demographics; Star-Lord ’ s obsession 5 [' older', ' demographics', ';', ' Star', '-', 'Lord']
diff --git a/patchscopes/code/preprocessed_data/factual_multihop/multihop_CoT_vicuna-13b-v1.1.tsv b/patchscopes/code/preprocessed_data/factual_multihop/multihop_CoT_vicuna-13b-v1.1.tsv
new file mode 100644
index 00000000..137faa38
--- /dev/null
+++ b/patchscopes/code/preprocessed_data/factual_multihop/multihop_CoT_vicuna-13b-v1.1.tsv
@@ -0,0 +1,73601 @@
+ sample_id prompt_source prompt_target position_source position_target baseline_hop2 baseline_hop3 baseline_multihop3 hop1 hop2 hop3 layer_source layer_target
+0 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 0
+1 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 0
+2 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 0
+3 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 0
+4 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 0
+5 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 0
+6 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 0
+7 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 0
+8 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 0
+9 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 0
+10 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 0
+11 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 0
+12 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 0
+13 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 0
+14 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 0
+15 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 0
+16 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 0
+17 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 0
+18 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 0
+19 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 0
+20 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 0
+21 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 0
+22 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 0
+23 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 0
+24 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 0
+25 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 0
+26 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 0
+27 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 0
+28 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 0
+29 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 0
+30 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 0
+31 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 0
+32 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 0
+33 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 0
+34 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 0
+35 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 0
+36 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 0
+37 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 0
+38 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 0
+39 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 0
+40 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 0
+41 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 0
+42 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 0
+43 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 0
+44 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 0
+45 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 0
+46 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 1
+47 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 1
+48 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 1
+49 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 1
+50 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 1
+51 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 1
+52 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 1
+53 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 1
+54 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 1
+55 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 1
+56 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 1
+57 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 1
+58 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 1
+59 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 1
+60 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 1
+61 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 1
+62 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 1
+63 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 1
+64 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 1
+65 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 1
+66 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 1
+67 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 1
+68 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 1
+69 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 1
+70 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 1
+71 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 1
+72 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 1
+73 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 1
+74 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 1
+75 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 1
+76 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 1
+77 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 1
+78 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 1
+79 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 1
+80 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 1
+81 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 1
+82 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 1
+83 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 1
+84 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 1
+85 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 1
+86 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 1
+87 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 1
+88 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 1
+89 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 1
+90 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 1
+91 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 1
+92 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 2
+93 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 2
+94 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 2
+95 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 2
+96 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 2
+97 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 2
+98 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 2
+99 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 2
+100 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 2
+101 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 2
+102 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 2
+103 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 2
+104 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 2
+105 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 2
+106 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 2
+107 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 2
+108 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 2
+109 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 2
+110 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 2
+111 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 2
+112 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 2
+113 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 2
+114 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 2
+115 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 2
+116 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 2
+117 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 2
+118 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 2
+119 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 2
+120 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 2
+121 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 2
+122 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 2
+123 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 2
+124 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 2
+125 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 2
+126 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 2
+127 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 2
+128 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 2
+129 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 2
+130 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 2
+131 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 2
+132 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 2
+133 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 2
+134 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 2
+135 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 2
+136 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 2
+137 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 2
+138 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 3
+139 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 3
+140 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 3
+141 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 3
+142 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 3
+143 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 3
+144 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 3
+145 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 3
+146 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 3
+147 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 3
+148 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 3
+149 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 3
+150 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 3
+151 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 3
+152 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 3
+153 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 3
+154 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 3
+155 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 3
+156 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 3
+157 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 3
+158 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 3
+159 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 3
+160 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 3
+161 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 3
+162 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 3
+163 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 3
+164 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 3
+165 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 3
+166 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 3
+167 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 3
+168 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 3
+169 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 3
+170 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 3
+171 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 3
+172 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 3
+173 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 3
+174 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 3
+175 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 3
+176 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 3
+177 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 3
+178 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 3
+179 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 3
+180 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 3
+181 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 3
+182 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 3
+183 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 3
+184 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 4
+185 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 4
+186 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 4
+187 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 4
+188 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 4
+189 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 4
+190 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 4
+191 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 4
+192 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 4
+193 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 4
+194 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 4
+195 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 4
+196 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 4
+197 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 4
+198 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 4
+199 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 4
+200 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 4
+201 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 4
+202 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 4
+203 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 4
+204 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 4
+205 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 4
+206 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 4
+207 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 4
+208 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 4
+209 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 4
+210 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 4
+211 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 4
+212 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 4
+213 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 4
+214 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 4
+215 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 4
+216 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 4
+217 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 4
+218 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 4
+219 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 4
+220 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 4
+221 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 4
+222 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 4
+223 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 4
+224 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 4
+225 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 4
+226 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 4
+227 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 4
+228 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 4
+229 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 4
+230 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 5
+231 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 5
+232 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 5
+233 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 5
+234 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 5
+235 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 5
+236 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 5
+237 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 5
+238 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 5
+239 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 5
+240 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 5
+241 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 5
+242 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 5
+243 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 5
+244 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 5
+245 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 5
+246 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 5
+247 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 5
+248 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 5
+249 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 5
+250 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 5
+251 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 5
+252 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 5
+253 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 5
+254 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 5
+255 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 5
+256 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 5
+257 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 5
+258 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 5
+259 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 5
+260 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 5
+261 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 5
+262 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 5
+263 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 5
+264 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 5
+265 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 5
+266 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 5
+267 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 5
+268 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 5
+269 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 5
+270 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 5
+271 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 5
+272 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 5
+273 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 5
+274 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 5
+275 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 5
+276 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 6
+277 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 6
+278 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 6
+279 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 6
+280 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 6
+281 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 6
+282 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 6
+283 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 6
+284 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 6
+285 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 6
+286 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 6
+287 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 6
+288 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 6
+289 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 6
+290 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 6
+291 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 6
+292 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 6
+293 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 6
+294 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 6
+295 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 6
+296 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 6
+297 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 6
+298 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 6
+299 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 6
+300 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 6
+301 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 6
+302 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 6
+303 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 6
+304 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 6
+305 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 6
+306 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 6
+307 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 6
+308 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 6
+309 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 6
+310 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 6
+311 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 6
+312 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 6
+313 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 6
+314 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 6
+315 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 6
+316 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 6
+317 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 6
+318 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 6
+319 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 6
+320 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 6
+321 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 6
+322 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 7
+323 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 7
+324 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 7
+325 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 7
+326 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 7
+327 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 7
+328 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 7
+329 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 7
+330 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 7
+331 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 7
+332 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 7
+333 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 7
+334 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 7
+335 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 7
+336 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 7
+337 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 7
+338 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 7
+339 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 7
+340 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 7
+341 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 7
+342 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 7
+343 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 7
+344 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 7
+345 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 7
+346 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 7
+347 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 7
+348 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 7
+349 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 7
+350 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 7
+351 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 7
+352 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 7
+353 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 7
+354 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 7
+355 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 7
+356 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 7
+357 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 7
+358 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 7
+359 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 7
+360 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 7
+361 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 7
+362 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 7
+363 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 7
+364 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 7
+365 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 7
+366 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 7
+367 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 7
+368 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 8
+369 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 8
+370 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 8
+371 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 8
+372 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 8
+373 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 8
+374 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 8
+375 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 8
+376 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 8
+377 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 8
+378 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 8
+379 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 8
+380 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 8
+381 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 8
+382 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 8
+383 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 8
+384 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 8
+385 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 8
+386 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 8
+387 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 8
+388 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 8
+389 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 8
+390 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 8
+391 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 8
+392 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 8
+393 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 8
+394 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 8
+395 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 8
+396 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 8
+397 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 8
+398 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 8
+399 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 8
+400 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 8
+401 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 8
+402 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 8
+403 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 8
+404 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 8
+405 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 8
+406 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 8
+407 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 8
+408 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 8
+409 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 8
+410 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 8
+411 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 8
+412 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 8
+413 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 8
+414 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 9
+415 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 9
+416 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 9
+417 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 9
+418 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 9
+419 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 9
+420 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 9
+421 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 9
+422 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 9
+423 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 9
+424 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 9
+425 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 9
+426 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 9
+427 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 9
+428 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 9
+429 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 9
+430 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 9
+431 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 9
+432 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 9
+433 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 9
+434 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 9
+435 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 9
+436 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 9
+437 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 9
+438 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 9
+439 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 9
+440 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 9
+441 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 9
+442 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 9
+443 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 9
+444 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 9
+445 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 9
+446 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 9
+447 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 9
+448 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 9
+449 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 9
+450 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 9
+451 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 9
+452 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 9
+453 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 9
+454 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 9
+455 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 9
+456 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 9
+457 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 9
+458 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 9
+459 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 9
+460 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 10
+461 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 10
+462 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 10
+463 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 10
+464 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 10
+465 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 10
+466 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 10
+467 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 10
+468 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 10
+469 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 10
+470 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 10
+471 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 10
+472 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 10
+473 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 10
+474 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 10
+475 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 10
+476 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 10
+477 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 10
+478 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 10
+479 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 10
+480 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 10
+481 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 10
+482 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 10
+483 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 10
+484 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 10
+485 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 10
+486 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 10
+487 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 10
+488 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 10
+489 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 10
+490 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 10
+491 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 10
+492 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 10
+493 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 10
+494 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 10
+495 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 10
+496 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 10
+497 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 10
+498 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 10
+499 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 10
+500 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 10
+501 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 10
+502 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 10
+503 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 10
+504 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 10
+505 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 10
+506 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 11
+507 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 11
+508 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 11
+509 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 11
+510 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 11
+511 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 11
+512 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 11
+513 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 11
+514 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 11
+515 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 11
+516 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 11
+517 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 11
+518 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 11
+519 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 11
+520 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 11
+521 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 11
+522 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 11
+523 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 11
+524 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 11
+525 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 11
+526 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 11
+527 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 11
+528 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 11
+529 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 11
+530 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 11
+531 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 11
+532 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 11
+533 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 11
+534 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 11
+535 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 11
+536 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 11
+537 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 11
+538 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 11
+539 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 11
+540 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 11
+541 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 11
+542 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 11
+543 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 11
+544 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 11
+545 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 11
+546 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 11
+547 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 11
+548 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 11
+549 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 11
+550 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 11
+551 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 11
+552 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 12
+553 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 12
+554 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 12
+555 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 12
+556 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 12
+557 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 12
+558 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 12
+559 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 12
+560 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 12
+561 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 12
+562 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 12
+563 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 12
+564 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 12
+565 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 12
+566 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 12
+567 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 12
+568 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 12
+569 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 12
+570 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 12
+571 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 12
+572 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 12
+573 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 12
+574 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 12
+575 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 12
+576 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 12
+577 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 12
+578 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 12
+579 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 12
+580 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 12
+581 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 12
+582 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 12
+583 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 12
+584 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 12
+585 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 12
+586 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 12
+587 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 12
+588 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 12
+589 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 12
+590 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 12
+591 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 12
+592 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 12
+593 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 12
+594 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 12
+595 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 12
+596 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 12
+597 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 12
+598 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 13
+599 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 13
+600 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 13
+601 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 13
+602 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 13
+603 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 13
+604 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 13
+605 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 13
+606 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 13
+607 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 13
+608 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 13
+609 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 13
+610 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 13
+611 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 13
+612 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 13
+613 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 13
+614 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 13
+615 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 13
+616 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 13
+617 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 13
+618 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 13
+619 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 13
+620 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 13
+621 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 13
+622 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 13
+623 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 13
+624 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 13
+625 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 13
+626 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 13
+627 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 13
+628 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 13
+629 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 13
+630 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 13
+631 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 13
+632 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 13
+633 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 13
+634 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 13
+635 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 13
+636 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 13
+637 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 13
+638 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 13
+639 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 13
+640 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 13
+641 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 13
+642 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 13
+643 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 13
+644 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 14
+645 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 14
+646 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 14
+647 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 14
+648 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 14
+649 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 14
+650 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 14
+651 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 14
+652 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 14
+653 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 14
+654 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 14
+655 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 14
+656 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 14
+657 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 14
+658 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 14
+659 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 14
+660 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 14
+661 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 14
+662 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 14
+663 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 14
+664 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 14
+665 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 14
+666 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 14
+667 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 14
+668 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 14
+669 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 14
+670 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 14
+671 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 14
+672 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 14
+673 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 14
+674 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 14
+675 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 14
+676 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 14
+677 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 14
+678 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 14
+679 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 14
+680 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 14
+681 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 14
+682 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 14
+683 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 14
+684 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 14
+685 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 14
+686 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 14
+687 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 14
+688 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 14
+689 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 14
+690 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 15
+691 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 15
+692 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 15
+693 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 15
+694 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 15
+695 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 15
+696 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 15
+697 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 15
+698 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 15
+699 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 15
+700 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 15
+701 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 15
+702 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 15
+703 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 15
+704 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 15
+705 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 15
+706 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 15
+707 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 15
+708 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 15
+709 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 15
+710 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 15
+711 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 15
+712 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 15
+713 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 15
+714 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 15
+715 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 15
+716 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 15
+717 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 15
+718 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 15
+719 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 15
+720 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 15
+721 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 15
+722 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 15
+723 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 15
+724 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 15
+725 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 15
+726 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 15
+727 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 15
+728 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 15
+729 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 15
+730 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 15
+731 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 15
+732 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 15
+733 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 15
+734 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 15
+735 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 15
+736 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 16
+737 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 16
+738 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 16
+739 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 16
+740 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 16
+741 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 16
+742 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 16
+743 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 16
+744 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 16
+745 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 16
+746 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 16
+747 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 16
+748 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 16
+749 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 16
+750 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 16
+751 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 16
+752 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 16
+753 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 16
+754 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 16
+755 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 16
+756 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 16
+757 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 16
+758 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 16
+759 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 16
+760 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 16
+761 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 16
+762 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 16
+763 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 16
+764 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 16
+765 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 16
+766 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 16
+767 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 16
+768 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 16
+769 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 16
+770 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 16
+771 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 16
+772 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 16
+773 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 16
+774 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 16
+775 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 16
+776 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 16
+777 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 16
+778 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 16
+779 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 16
+780 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 16
+781 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 16
+782 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 17
+783 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 17
+784 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 17
+785 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 17
+786 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 17
+787 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 17
+788 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 17
+789 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 17
+790 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 17
+791 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 17
+792 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 17
+793 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 17
+794 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 17
+795 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 17
+796 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 17
+797 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 17
+798 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 17
+799 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 17
+800 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 17
+801 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 17
+802 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 17
+803 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 17
+804 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 17
+805 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 17
+806 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 17
+807 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 17
+808 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 17
+809 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 17
+810 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 17
+811 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 17
+812 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 17
+813 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 17
+814 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 17
+815 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 17
+816 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 17
+817 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 17
+818 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 17
+819 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 17
+820 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 17
+821 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 17
+822 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 17
+823 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 17
+824 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 17
+825 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 17
+826 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 17
+827 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 17
+828 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 18
+829 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 18
+830 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 18
+831 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 18
+832 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 18
+833 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 18
+834 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 18
+835 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 18
+836 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 18
+837 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 18
+838 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 18
+839 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 18
+840 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 18
+841 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 18
+842 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 18
+843 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 18
+844 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 18
+845 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 18
+846 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 18
+847 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 18
+848 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 18
+849 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 18
+850 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 18
+851 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 18
+852 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 18
+853 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 18
+854 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 18
+855 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 18
+856 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 18
+857 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 18
+858 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 18
+859 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 18
+860 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 18
+861 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 18
+862 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 18
+863 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 18
+864 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 18
+865 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 18
+866 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 18
+867 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 18
+868 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 18
+869 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 18
+870 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 18
+871 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 18
+872 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 18
+873 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 18
+874 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 19
+875 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 19
+876 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 19
+877 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 19
+878 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 19
+879 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 19
+880 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 19
+881 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 19
+882 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 19
+883 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 19
+884 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 19
+885 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 19
+886 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 19
+887 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 19
+888 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 19
+889 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 19
+890 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 19
+891 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 19
+892 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 19
+893 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 19
+894 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 19
+895 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 19
+896 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 19
+897 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 19
+898 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 19
+899 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 19
+900 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 19
+901 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 19
+902 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 19
+903 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 19
+904 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 19
+905 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 19
+906 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 19
+907 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 19
+908 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 19
+909 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 19
+910 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 19
+911 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 19
+912 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 19
+913 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 19
+914 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 19
+915 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 19
+916 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 19
+917 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 19
+918 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 19
+919 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 19
+920 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 20
+921 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 20
+922 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 20
+923 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 20
+924 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 20
+925 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 20
+926 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 20
+927 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 20
+928 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 20
+929 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 20
+930 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 20
+931 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 20
+932 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 20
+933 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 20
+934 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 20
+935 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 20
+936 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 20
+937 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 20
+938 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 20
+939 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 20
+940 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 20
+941 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 20
+942 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 20
+943 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 20
+944 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 20
+945 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 20
+946 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 20
+947 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 20
+948 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 20
+949 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 20
+950 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 20
+951 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 20
+952 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 20
+953 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 20
+954 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 20
+955 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 20
+956 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 20
+957 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 20
+958 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 20
+959 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 20
+960 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 20
+961 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 20
+962 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 20
+963 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 20
+964 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 20
+965 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 20
+966 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 21
+967 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 21
+968 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 21
+969 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 21
+970 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 21
+971 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 21
+972 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 21
+973 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 21
+974 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 21
+975 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 21
+976 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 21
+977 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 21
+978 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 21
+979 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 21
+980 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 21
+981 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 21
+982 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 21
+983 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 21
+984 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 21
+985 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 21
+986 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 21
+987 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 21
+988 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 21
+989 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 21
+990 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 21
+991 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 21
+992 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 21
+993 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 21
+994 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 21
+995 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 21
+996 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 21
+997 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 21
+998 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 21
+999 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 21
+1000 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 21
+1001 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 21
+1002 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 21
+1003 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 21
+1004 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 21
+1005 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 21
+1006 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 21
+1007 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 21
+1008 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 21
+1009 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 21
+1010 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 21
+1011 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 21
+1012 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 22
+1013 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 22
+1014 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 22
+1015 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 22
+1016 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 22
+1017 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 22
+1018 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 22
+1019 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 22
+1020 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 22
+1021 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 22
+1022 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 22
+1023 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 22
+1024 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 22
+1025 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 22
+1026 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 22
+1027 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 22
+1028 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 22
+1029 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 22
+1030 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 22
+1031 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 22
+1032 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 22
+1033 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 22
+1034 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 22
+1035 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 22
+1036 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 22
+1037 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 22
+1038 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 22
+1039 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 22
+1040 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 22
+1041 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 22
+1042 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 22
+1043 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 22
+1044 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 22
+1045 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 22
+1046 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 22
+1047 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 22
+1048 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 22
+1049 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 22
+1050 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 22
+1051 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 22
+1052 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 22
+1053 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 22
+1054 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 22
+1055 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 22
+1056 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 22
+1057 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 22
+1058 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 23
+1059 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 23
+1060 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 23
+1061 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 23
+1062 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 23
+1063 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 23
+1064 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 23
+1065 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 23
+1066 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 23
+1067 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 23
+1068 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 23
+1069 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 23
+1070 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 23
+1071 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 23
+1072 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 23
+1073 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 23
+1074 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 23
+1075 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 23
+1076 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 23
+1077 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 23
+1078 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 23
+1079 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 23
+1080 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 23
+1081 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 23
+1082 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 23
+1083 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 23
+1084 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 23
+1085 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 23
+1086 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 23
+1087 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 23
+1088 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 23
+1089 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 23
+1090 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 23
+1091 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 23
+1092 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 23
+1093 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 23
+1094 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 23
+1095 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 23
+1096 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 23
+1097 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 23
+1098 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 23
+1099 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 23
+1100 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 23
+1101 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 23
+1102 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 23
+1103 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 23
+1104 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 24
+1105 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 24
+1106 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 24
+1107 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 24
+1108 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 24
+1109 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 24
+1110 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 24
+1111 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 24
+1112 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 24
+1113 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 24
+1114 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 24
+1115 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 24
+1116 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 24
+1117 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 24
+1118 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 24
+1119 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 24
+1120 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 24
+1121 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 24
+1122 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 24
+1123 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 24
+1124 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 24
+1125 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 24
+1126 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 24
+1127 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 24
+1128 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 24
+1129 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 24
+1130 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 24
+1131 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 24
+1132 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 24
+1133 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 24
+1134 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 24
+1135 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 24
+1136 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 24
+1137 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 24
+1138 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 24
+1139 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 24
+1140 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 24
+1141 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 24
+1142 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 24
+1143 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 24
+1144 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 24
+1145 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 24
+1146 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 24
+1147 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 24
+1148 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 24
+1149 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 24
+1150 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 25
+1151 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 25
+1152 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 25
+1153 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 25
+1154 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 25
+1155 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 25
+1156 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 25
+1157 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 25
+1158 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 25
+1159 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 25
+1160 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 25
+1161 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 25
+1162 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 25
+1163 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 25
+1164 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 25
+1165 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 25
+1166 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 25
+1167 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 25
+1168 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 25
+1169 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 25
+1170 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 25
+1171 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 25
+1172 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 25
+1173 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 25
+1174 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 25
+1175 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 25
+1176 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 25
+1177 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 25
+1178 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 25
+1179 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 25
+1180 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 25
+1181 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 25
+1182 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 25
+1183 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 25
+1184 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 25
+1185 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 25
+1186 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 25
+1187 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 25
+1188 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 25
+1189 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 25
+1190 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 25
+1191 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 25
+1192 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 25
+1193 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 25
+1194 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 25
+1195 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 25
+1196 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 26
+1197 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 26
+1198 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 26
+1199 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 26
+1200 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 26
+1201 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 26
+1202 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 26
+1203 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 26
+1204 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 26
+1205 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 26
+1206 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 26
+1207 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 26
+1208 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 26
+1209 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 26
+1210 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 26
+1211 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 26
+1212 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 26
+1213 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 26
+1214 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 26
+1215 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 26
+1216 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 26
+1217 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 26
+1218 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 26
+1219 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 26
+1220 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 26
+1221 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 26
+1222 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 26
+1223 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 26
+1224 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 26
+1225 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 26
+1226 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 26
+1227 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 26
+1228 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 26
+1229 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 26
+1230 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 26
+1231 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 26
+1232 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 26
+1233 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 26
+1234 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 26
+1235 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 26
+1236 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 26
+1237 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 26
+1238 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 26
+1239 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 26
+1240 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 26
+1241 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 26
+1242 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 27
+1243 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 27
+1244 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 27
+1245 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 27
+1246 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 27
+1247 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 27
+1248 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 27
+1249 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 27
+1250 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 27
+1251 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 27
+1252 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 27
+1253 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 27
+1254 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 27
+1255 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 27
+1256 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 27
+1257 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 27
+1258 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 27
+1259 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 27
+1260 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 27
+1261 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 27
+1262 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 27
+1263 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 27
+1264 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 27
+1265 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 27
+1266 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 27
+1267 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 27
+1268 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 27
+1269 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 27
+1270 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 27
+1271 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 27
+1272 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 27
+1273 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 27
+1274 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 27
+1275 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 27
+1276 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 27
+1277 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 27
+1278 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 27
+1279 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 27
+1280 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 27
+1281 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 27
+1282 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 27
+1283 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 27
+1284 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 27
+1285 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 27
+1286 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 27
+1287 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 27
+1288 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 28
+1289 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 28
+1290 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 28
+1291 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 28
+1292 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 28
+1293 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 28
+1294 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 28
+1295 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 28
+1296 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 28
+1297 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 28
+1298 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 28
+1299 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 28
+1300 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 28
+1301 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 28
+1302 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 28
+1303 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 28
+1304 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 28
+1305 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 28
+1306 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 28
+1307 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 28
+1308 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 28
+1309 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 28
+1310 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 28
+1311 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 28
+1312 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 28
+1313 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 28
+1314 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 28
+1315 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 28
+1316 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 28
+1317 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 28
+1318 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 28
+1319 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 28
+1320 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 28
+1321 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 28
+1322 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 28
+1323 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 28
+1324 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 28
+1325 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 28
+1326 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 28
+1327 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 28
+1328 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 28
+1329 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 28
+1330 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 28
+1331 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 28
+1332 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 28
+1333 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 28
+1334 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 29
+1335 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 29
+1336 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 29
+1337 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 29
+1338 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 29
+1339 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 29
+1340 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 29
+1341 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 29
+1342 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 29
+1343 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 29
+1344 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 29
+1345 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 29
+1346 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 29
+1347 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 29
+1348 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 29
+1349 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 29
+1350 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 29
+1351 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 29
+1352 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 29
+1353 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 29
+1354 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 29
+1355 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 29
+1356 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 29
+1357 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 29
+1358 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 29
+1359 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 29
+1360 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 29
+1361 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 29
+1362 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 29
+1363 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 29
+1364 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 29
+1365 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 29
+1366 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 29
+1367 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 29
+1368 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 29
+1369 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 29
+1370 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 29
+1371 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 29
+1372 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 29
+1373 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 29
+1374 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 29
+1375 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 29
+1376 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 29
+1377 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 29
+1378 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 29
+1379 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 29
+1380 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 30
+1381 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 30
+1382 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 30
+1383 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 30
+1384 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 30
+1385 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 30
+1386 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 30
+1387 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 30
+1388 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 30
+1389 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 30
+1390 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 30
+1391 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 30
+1392 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 30
+1393 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 30
+1394 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 30
+1395 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 30
+1396 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 30
+1397 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 30
+1398 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 30
+1399 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 30
+1400 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 30
+1401 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 30
+1402 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 30
+1403 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 30
+1404 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 30
+1405 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 30
+1406 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 30
+1407 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 30
+1408 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 30
+1409 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 30
+1410 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 30
+1411 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 30
+1412 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 30
+1413 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 30
+1414 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 30
+1415 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 30
+1416 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 30
+1417 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 30
+1418 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 30
+1419 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 30
+1420 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 30
+1421 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 30
+1422 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 30
+1423 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 30
+1424 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 30
+1425 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 30
+1426 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 31
+1427 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 31
+1428 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 31
+1429 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 31
+1430 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 31
+1431 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 31
+1432 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 31
+1433 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 31
+1434 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 31
+1435 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 31
+1436 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 31
+1437 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 31
+1438 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 31
+1439 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 31
+1440 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 31
+1441 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 31
+1442 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 31
+1443 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 31
+1444 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 31
+1445 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 31
+1446 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 31
+1447 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 31
+1448 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 31
+1449 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 31
+1450 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 31
+1451 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 31
+1452 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 31
+1453 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 31
+1454 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 31
+1455 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 31
+1456 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 31
+1457 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 31
+1458 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 31
+1459 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 31
+1460 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 31
+1461 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 31
+1462 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 31
+1463 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 31
+1464 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 31
+1465 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 31
+1466 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 31
+1467 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 31
+1468 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 31
+1469 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 31
+1470 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 31
+1471 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 31
+1472 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 32
+1473 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 32
+1474 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 32
+1475 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 32
+1476 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 32
+1477 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 32
+1478 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 32
+1479 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 32
+1480 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 32
+1481 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 32
+1482 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 32
+1483 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 32
+1484 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 32
+1485 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 32
+1486 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 32
+1487 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 32
+1488 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 32
+1489 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 32
+1490 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 32
+1491 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 32
+1492 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 32
+1493 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 32
+1494 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 32
+1495 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 32
+1496 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 32
+1497 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 32
+1498 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 32
+1499 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 32
+1500 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 32
+1501 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 32
+1502 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 32
+1503 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 32
+1504 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 32
+1505 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 32
+1506 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 32
+1507 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 32
+1508 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 32
+1509 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 32
+1510 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 32
+1511 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 32
+1512 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 32
+1513 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 32
+1514 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 32
+1515 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 32
+1516 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 32
+1517 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 32
+1518 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 33
+1519 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 33
+1520 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 33
+1521 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 33
+1522 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 33
+1523 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 33
+1524 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 33
+1525 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 33
+1526 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 33
+1527 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 33
+1528 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 33
+1529 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 33
+1530 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 33
+1531 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 33
+1532 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 33
+1533 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 33
+1534 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 33
+1535 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 33
+1536 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 33
+1537 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 33
+1538 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 33
+1539 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 33
+1540 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 33
+1541 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 33
+1542 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 33
+1543 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 33
+1544 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 33
+1545 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 33
+1546 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 33
+1547 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 33
+1548 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 33
+1549 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 33
+1550 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 33
+1551 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 33
+1552 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 33
+1553 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 33
+1554 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 33
+1555 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 33
+1556 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 33
+1557 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 33
+1558 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 33
+1559 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 33
+1560 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 33
+1561 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 33
+1562 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 33
+1563 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 33
+1564 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 34
+1565 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 34
+1566 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 34
+1567 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 34
+1568 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 34
+1569 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 34
+1570 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 34
+1571 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 34
+1572 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 34
+1573 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 34
+1574 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 34
+1575 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 34
+1576 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 34
+1577 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 34
+1578 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 34
+1579 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 34
+1580 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 34
+1581 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 34
+1582 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 34
+1583 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 34
+1584 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 34
+1585 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 34
+1586 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 34
+1587 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 34
+1588 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 34
+1589 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 34
+1590 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 34
+1591 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 34
+1592 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 34
+1593 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 34
+1594 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 34
+1595 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 34
+1596 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 34
+1597 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 34
+1598 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 34
+1599 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 34
+1600 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 34
+1601 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 34
+1602 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 34
+1603 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 34
+1604 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 34
+1605 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 34
+1606 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 34
+1607 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 34
+1608 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 34
+1609 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 34
+1610 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 35
+1611 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 35
+1612 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 35
+1613 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 35
+1614 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 35
+1615 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 35
+1616 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 35
+1617 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 35
+1618 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 35
+1619 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 35
+1620 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 35
+1621 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 35
+1622 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 35
+1623 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 35
+1624 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 35
+1625 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 35
+1626 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 35
+1627 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 35
+1628 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 35
+1629 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 35
+1630 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 35
+1631 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 35
+1632 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 35
+1633 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 35
+1634 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 35
+1635 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 35
+1636 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 35
+1637 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 35
+1638 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 35
+1639 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 35
+1640 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 35
+1641 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 35
+1642 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 35
+1643 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 35
+1644 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 35
+1645 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 35
+1646 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 35
+1647 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 35
+1648 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 35
+1649 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 35
+1650 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 35
+1651 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 35
+1652 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 35
+1653 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 35
+1654 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 35
+1655 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 35
+1656 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 36
+1657 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 36
+1658 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 36
+1659 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 36
+1660 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 36
+1661 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 36
+1662 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 36
+1663 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 36
+1664 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 36
+1665 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 36
+1666 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 36
+1667 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 36
+1668 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 36
+1669 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 36
+1670 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 36
+1671 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 36
+1672 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 36
+1673 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 36
+1674 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 36
+1675 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 36
+1676 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 36
+1677 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 36
+1678 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 36
+1679 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 36
+1680 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 36
+1681 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 36
+1682 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 36
+1683 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 36
+1684 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 36
+1685 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 36
+1686 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 36
+1687 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 36
+1688 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 36
+1689 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 36
+1690 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 36
+1691 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 36
+1692 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 36
+1693 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 36
+1694 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 36
+1695 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 36
+1696 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 36
+1697 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 36
+1698 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 36
+1699 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 36
+1700 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 36
+1701 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 36
+1702 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 37
+1703 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 37
+1704 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 37
+1705 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 37
+1706 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 37
+1707 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 37
+1708 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 37
+1709 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 37
+1710 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 37
+1711 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 37
+1712 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 37
+1713 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 37
+1714 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 37
+1715 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 37
+1716 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 37
+1717 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 37
+1718 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 37
+1719 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 37
+1720 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 37
+1721 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 37
+1722 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 37
+1723 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 37
+1724 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 37
+1725 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 37
+1726 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 37
+1727 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 37
+1728 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 37
+1729 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 37
+1730 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 37
+1731 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 37
+1732 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 37
+1733 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 37
+1734 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 37
+1735 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 37
+1736 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 37
+1737 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 37
+1738 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 37
+1739 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 37
+1740 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 37
+1741 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 37
+1742 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 37
+1743 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 37
+1744 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 37
+1745 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 37
+1746 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 37
+1747 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 37
+1748 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 38
+1749 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 38
+1750 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 38
+1751 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 38
+1752 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 38
+1753 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 38
+1754 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 38
+1755 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 38
+1756 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 38
+1757 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 38
+1758 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 38
+1759 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 38
+1760 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 38
+1761 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 38
+1762 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 38
+1763 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 38
+1764 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 38
+1765 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 38
+1766 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 38
+1767 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 38
+1768 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 38
+1769 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 38
+1770 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 38
+1771 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 38
+1772 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 38
+1773 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 38
+1774 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 38
+1775 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 38
+1776 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 38
+1777 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 38
+1778 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 38
+1779 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 38
+1780 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 38
+1781 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 38
+1782 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 38
+1783 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 38
+1784 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 38
+1785 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 38
+1786 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 38
+1787 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 38
+1788 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 38
+1789 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 38
+1790 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 38
+1791 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 38
+1792 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 38
+1793 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 38
+1794 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 0 39
+1795 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 0 39
+1796 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 0 39
+1797 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 0 39
+1798 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 0 39
+1799 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 0 39
+1800 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 0 39
+1801 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 39
+1802 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 0 39
+1803 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 0 39
+1804 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 0 39
+1805 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 0 39
+1806 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 0 39
+1807 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 0 39
+1808 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 0 39
+1809 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 0 39
+1810 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 0 39
+1811 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 0 39
+1812 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 0 39
+1813 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 0 39
+1814 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 0 39
+1815 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 0 39
+1816 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 0 39
+1817 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 0 39
+1818 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 0 39
+1819 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 0 39
+1820 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 0 39
+1821 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 0 39
+1822 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 0 39
+1823 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 0 39
+1824 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 0 39
+1825 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 0 39
+1826 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 0 39
+1827 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 0 39
+1828 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 0 39
+1829 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 0 39
+1830 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 0 39
+1831 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 0 39
+1832 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 0 39
+1833 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 0 39
+1834 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 39
+1835 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 0 39
+1836 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 0 39
+1837 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 0 39
+1838 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 39
+1839 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 0 39
+1840 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 0
+1841 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 0
+1842 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 0
+1843 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 0
+1844 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 0
+1845 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 0
+1846 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 0
+1847 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 0
+1848 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 0
+1849 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 0
+1850 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 0
+1851 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 0
+1852 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 0
+1853 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 0
+1854 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 0
+1855 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 0
+1856 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 0
+1857 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 0
+1858 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 0
+1859 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 0
+1860 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 0
+1861 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 0
+1862 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 0
+1863 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 0
+1864 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 0
+1865 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 0
+1866 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 0
+1867 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 0
+1868 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 0
+1869 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 0
+1870 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 0
+1871 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 0
+1872 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 0
+1873 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 0
+1874 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 0
+1875 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 0
+1876 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 0
+1877 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 0
+1878 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 0
+1879 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 0
+1880 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 0
+1881 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 0
+1882 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 0
+1883 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 0
+1884 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 0
+1885 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 0
+1886 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 1
+1887 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 1
+1888 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 1
+1889 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 1
+1890 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 1
+1891 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 1
+1892 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 1
+1893 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 1
+1894 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 1
+1895 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 1
+1896 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 1
+1897 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 1
+1898 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 1
+1899 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 1
+1900 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 1
+1901 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 1
+1902 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 1
+1903 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 1
+1904 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 1
+1905 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 1
+1906 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 1
+1907 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 1
+1908 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 1
+1909 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 1
+1910 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 1
+1911 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 1
+1912 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 1
+1913 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 1
+1914 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 1
+1915 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 1
+1916 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 1
+1917 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 1
+1918 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 1
+1919 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 1
+1920 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 1
+1921 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 1
+1922 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 1
+1923 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 1
+1924 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 1
+1925 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 1
+1926 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 1
+1927 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 1
+1928 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 1
+1929 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 1
+1930 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 1
+1931 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 1
+1932 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 2
+1933 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 2
+1934 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 2
+1935 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 2
+1936 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 2
+1937 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 2
+1938 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 2
+1939 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 2
+1940 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 2
+1941 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 2
+1942 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 2
+1943 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 2
+1944 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 2
+1945 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 2
+1946 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 2
+1947 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 2
+1948 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 2
+1949 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 2
+1950 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 2
+1951 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 2
+1952 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 2
+1953 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 2
+1954 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 2
+1955 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 2
+1956 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 2
+1957 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 2
+1958 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 2
+1959 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 2
+1960 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 2
+1961 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 2
+1962 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 2
+1963 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 2
+1964 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 2
+1965 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 2
+1966 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 2
+1967 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 2
+1968 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 2
+1969 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 2
+1970 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 2
+1971 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 2
+1972 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 2
+1973 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 2
+1974 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 2
+1975 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 2
+1976 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 2
+1977 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 2
+1978 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 3
+1979 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 3
+1980 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 3
+1981 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 3
+1982 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 3
+1983 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 3
+1984 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 3
+1985 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 3
+1986 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 3
+1987 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 3
+1988 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 3
+1989 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 3
+1990 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 3
+1991 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 3
+1992 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 3
+1993 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 3
+1994 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 3
+1995 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 3
+1996 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 3
+1997 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 3
+1998 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 3
+1999 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 3
+2000 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 3
+2001 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 3
+2002 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 3
+2003 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 3
+2004 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 3
+2005 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 3
+2006 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 3
+2007 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 3
+2008 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 3
+2009 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 3
+2010 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 3
+2011 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 3
+2012 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 3
+2013 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 3
+2014 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 3
+2015 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 3
+2016 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 3
+2017 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 3
+2018 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 3
+2019 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 3
+2020 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 3
+2021 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 3
+2022 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 3
+2023 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 3
+2024 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 4
+2025 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 4
+2026 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 4
+2027 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 4
+2028 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 4
+2029 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 4
+2030 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 4
+2031 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 4
+2032 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 4
+2033 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 4
+2034 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 4
+2035 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 4
+2036 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 4
+2037 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 4
+2038 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 4
+2039 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 4
+2040 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 4
+2041 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 4
+2042 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 4
+2043 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 4
+2044 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 4
+2045 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 4
+2046 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 4
+2047 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 4
+2048 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 4
+2049 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 4
+2050 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 4
+2051 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 4
+2052 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 4
+2053 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 4
+2054 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 4
+2055 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 4
+2056 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 4
+2057 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 4
+2058 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 4
+2059 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 4
+2060 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 4
+2061 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 4
+2062 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 4
+2063 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 4
+2064 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 4
+2065 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 4
+2066 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 4
+2067 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 4
+2068 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 4
+2069 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 4
+2070 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 5
+2071 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 5
+2072 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 5
+2073 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 5
+2074 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 5
+2075 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 5
+2076 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 5
+2077 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 5
+2078 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 5
+2079 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 5
+2080 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 5
+2081 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 5
+2082 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 5
+2083 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 5
+2084 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 5
+2085 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 5
+2086 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 5
+2087 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 5
+2088 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 5
+2089 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 5
+2090 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 5
+2091 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 5
+2092 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 5
+2093 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 5
+2094 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 5
+2095 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 5
+2096 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 5
+2097 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 5
+2098 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 5
+2099 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 5
+2100 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 5
+2101 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 5
+2102 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 5
+2103 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 5
+2104 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 5
+2105 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 5
+2106 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 5
+2107 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 5
+2108 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 5
+2109 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 5
+2110 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 5
+2111 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 5
+2112 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 5
+2113 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 5
+2114 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 5
+2115 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 5
+2116 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 6
+2117 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 6
+2118 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 6
+2119 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 6
+2120 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 6
+2121 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 6
+2122 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 6
+2123 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 6
+2124 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 6
+2125 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 6
+2126 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 6
+2127 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 6
+2128 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 6
+2129 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 6
+2130 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 6
+2131 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 6
+2132 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 6
+2133 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 6
+2134 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 6
+2135 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 6
+2136 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 6
+2137 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 6
+2138 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 6
+2139 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 6
+2140 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 6
+2141 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 6
+2142 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 6
+2143 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 6
+2144 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 6
+2145 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 6
+2146 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 6
+2147 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 6
+2148 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 6
+2149 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 6
+2150 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 6
+2151 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 6
+2152 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 6
+2153 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 6
+2154 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 6
+2155 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 6
+2156 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 6
+2157 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 6
+2158 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 6
+2159 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 6
+2160 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 6
+2161 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 6
+2162 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 7
+2163 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 7
+2164 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 7
+2165 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 7
+2166 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 7
+2167 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 7
+2168 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 7
+2169 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 7
+2170 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 7
+2171 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 7
+2172 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 7
+2173 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 7
+2174 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 7
+2175 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 7
+2176 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 7
+2177 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 7
+2178 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 7
+2179 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 7
+2180 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 7
+2181 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 7
+2182 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 7
+2183 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 7
+2184 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 7
+2185 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 7
+2186 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 7
+2187 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 7
+2188 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 7
+2189 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 7
+2190 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 7
+2191 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 7
+2192 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 7
+2193 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 7
+2194 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 7
+2195 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 7
+2196 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 7
+2197 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 7
+2198 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 7
+2199 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 7
+2200 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 7
+2201 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 7
+2202 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 7
+2203 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 7
+2204 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 7
+2205 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 7
+2206 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 7
+2207 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 7
+2208 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 8
+2209 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 8
+2210 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 8
+2211 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 8
+2212 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 8
+2213 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 8
+2214 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 8
+2215 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 8
+2216 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 8
+2217 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 8
+2218 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 8
+2219 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 8
+2220 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 8
+2221 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 8
+2222 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 8
+2223 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 8
+2224 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 8
+2225 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 8
+2226 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 8
+2227 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 8
+2228 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 8
+2229 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 8
+2230 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 8
+2231 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 8
+2232 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 8
+2233 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 8
+2234 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 8
+2235 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 8
+2236 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 8
+2237 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 8
+2238 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 8
+2239 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 8
+2240 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 8
+2241 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 8
+2242 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 8
+2243 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 8
+2244 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 8
+2245 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 8
+2246 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 8
+2247 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 8
+2248 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 8
+2249 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 8
+2250 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 8
+2251 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 8
+2252 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 8
+2253 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 8
+2254 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 9
+2255 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 9
+2256 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 9
+2257 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 9
+2258 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 9
+2259 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 9
+2260 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 9
+2261 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 9
+2262 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 9
+2263 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 9
+2264 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 9
+2265 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 9
+2266 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 9
+2267 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 9
+2268 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 9
+2269 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 9
+2270 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 9
+2271 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 9
+2272 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 9
+2273 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 9
+2274 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 9
+2275 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 9
+2276 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 9
+2277 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 9
+2278 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 9
+2279 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 9
+2280 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 9
+2281 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 9
+2282 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 9
+2283 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 9
+2284 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 9
+2285 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 9
+2286 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 9
+2287 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 9
+2288 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 9
+2289 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 9
+2290 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 9
+2291 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 9
+2292 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 9
+2293 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 9
+2294 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 9
+2295 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 9
+2296 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 9
+2297 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 9
+2298 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 9
+2299 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 9
+2300 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 10
+2301 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 10
+2302 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 10
+2303 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 10
+2304 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 10
+2305 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 10
+2306 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 10
+2307 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 10
+2308 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 10
+2309 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 10
+2310 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 10
+2311 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 10
+2312 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 10
+2313 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 10
+2314 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 10
+2315 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 10
+2316 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 10
+2317 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 10
+2318 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 10
+2319 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 10
+2320 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 10
+2321 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 10
+2322 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 10
+2323 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 10
+2324 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 10
+2325 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 10
+2326 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 10
+2327 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 10
+2328 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 10
+2329 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 10
+2330 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 10
+2331 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 10
+2332 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 10
+2333 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 10
+2334 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 10
+2335 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 10
+2336 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 10
+2337 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 10
+2338 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 10
+2339 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 10
+2340 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 10
+2341 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 10
+2342 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 10
+2343 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 10
+2344 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 10
+2345 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 10
+2346 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 11
+2347 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 11
+2348 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 11
+2349 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 11
+2350 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 11
+2351 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 11
+2352 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 11
+2353 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 11
+2354 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 11
+2355 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 11
+2356 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 11
+2357 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 11
+2358 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 11
+2359 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 11
+2360 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 11
+2361 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 11
+2362 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 11
+2363 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 11
+2364 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 11
+2365 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 11
+2366 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 11
+2367 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 11
+2368 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 11
+2369 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 11
+2370 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 11
+2371 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 11
+2372 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 11
+2373 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 11
+2374 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 11
+2375 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 11
+2376 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 11
+2377 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 11
+2378 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 11
+2379 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 11
+2380 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 11
+2381 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 11
+2382 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 11
+2383 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 11
+2384 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 11
+2385 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 11
+2386 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 11
+2387 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 11
+2388 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 11
+2389 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 11
+2390 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 11
+2391 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 11
+2392 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 12
+2393 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 12
+2394 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 12
+2395 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 12
+2396 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 12
+2397 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 12
+2398 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 12
+2399 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 12
+2400 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 12
+2401 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 12
+2402 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 12
+2403 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 12
+2404 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 12
+2405 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 12
+2406 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 12
+2407 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 12
+2408 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 12
+2409 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 12
+2410 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 12
+2411 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 12
+2412 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 12
+2413 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 12
+2414 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 12
+2415 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 12
+2416 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 12
+2417 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 12
+2418 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 12
+2419 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 12
+2420 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 12
+2421 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 12
+2422 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 12
+2423 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 12
+2424 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 12
+2425 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 12
+2426 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 12
+2427 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 12
+2428 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 12
+2429 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 12
+2430 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 12
+2431 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 12
+2432 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 12
+2433 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 12
+2434 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 12
+2435 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 12
+2436 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 12
+2437 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 12
+2438 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 13
+2439 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 13
+2440 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 13
+2441 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 13
+2442 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 13
+2443 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 13
+2444 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 13
+2445 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 13
+2446 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 13
+2447 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 13
+2448 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 13
+2449 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 13
+2450 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 13
+2451 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 13
+2452 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 13
+2453 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 13
+2454 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 13
+2455 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 13
+2456 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 13
+2457 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 13
+2458 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 13
+2459 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 13
+2460 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 13
+2461 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 13
+2462 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 13
+2463 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 13
+2464 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 13
+2465 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 13
+2466 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 13
+2467 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 13
+2468 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 13
+2469 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 13
+2470 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 13
+2471 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 13
+2472 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 13
+2473 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 13
+2474 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 13
+2475 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 13
+2476 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 13
+2477 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 13
+2478 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 13
+2479 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 13
+2480 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 13
+2481 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 13
+2482 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 13
+2483 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 13
+2484 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 14
+2485 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 14
+2486 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 14
+2487 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 14
+2488 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 14
+2489 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 14
+2490 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 14
+2491 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 14
+2492 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 14
+2493 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 14
+2494 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 14
+2495 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 14
+2496 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 14
+2497 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 14
+2498 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 14
+2499 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 14
+2500 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 14
+2501 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 14
+2502 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 14
+2503 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 14
+2504 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 14
+2505 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 14
+2506 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 14
+2507 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 14
+2508 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 14
+2509 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 14
+2510 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 14
+2511 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 14
+2512 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 14
+2513 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 14
+2514 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 14
+2515 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 14
+2516 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 14
+2517 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 14
+2518 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 14
+2519 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 14
+2520 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 14
+2521 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 14
+2522 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 14
+2523 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 14
+2524 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 14
+2525 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 14
+2526 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 14
+2527 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 14
+2528 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 14
+2529 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 14
+2530 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 15
+2531 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 15
+2532 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 15
+2533 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 15
+2534 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 15
+2535 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 15
+2536 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 15
+2537 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 15
+2538 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 15
+2539 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 15
+2540 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 15
+2541 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 15
+2542 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 15
+2543 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 15
+2544 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 15
+2545 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 15
+2546 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 15
+2547 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 15
+2548 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 15
+2549 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 15
+2550 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 15
+2551 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 15
+2552 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 15
+2553 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 15
+2554 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 15
+2555 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 15
+2556 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 15
+2557 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 15
+2558 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 15
+2559 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 15
+2560 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 15
+2561 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 15
+2562 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 15
+2563 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 15
+2564 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 15
+2565 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 15
+2566 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 15
+2567 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 15
+2568 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 15
+2569 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 15
+2570 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 15
+2571 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 15
+2572 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 15
+2573 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 15
+2574 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 15
+2575 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 15
+2576 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 16
+2577 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 16
+2578 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 16
+2579 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 16
+2580 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 16
+2581 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 16
+2582 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 16
+2583 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 16
+2584 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 16
+2585 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 16
+2586 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 16
+2587 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 16
+2588 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 16
+2589 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 16
+2590 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 16
+2591 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 16
+2592 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 16
+2593 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 16
+2594 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 16
+2595 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 16
+2596 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 16
+2597 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 16
+2598 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 16
+2599 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 16
+2600 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 16
+2601 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 16
+2602 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 16
+2603 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 16
+2604 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 16
+2605 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 16
+2606 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 16
+2607 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 16
+2608 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 16
+2609 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 16
+2610 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 16
+2611 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 16
+2612 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 16
+2613 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 16
+2614 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 16
+2615 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 16
+2616 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 16
+2617 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 16
+2618 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 16
+2619 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 16
+2620 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 16
+2621 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 16
+2622 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 17
+2623 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 17
+2624 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 17
+2625 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 17
+2626 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 17
+2627 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 17
+2628 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 17
+2629 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 17
+2630 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 17
+2631 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 17
+2632 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 17
+2633 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 17
+2634 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 17
+2635 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 17
+2636 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 17
+2637 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 17
+2638 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 17
+2639 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 17
+2640 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 17
+2641 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 17
+2642 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 17
+2643 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 17
+2644 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 17
+2645 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 17
+2646 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 17
+2647 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 17
+2648 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 17
+2649 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 17
+2650 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 17
+2651 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 17
+2652 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 17
+2653 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 17
+2654 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 17
+2655 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 17
+2656 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 17
+2657 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 17
+2658 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 17
+2659 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 17
+2660 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 17
+2661 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 17
+2662 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 17
+2663 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 17
+2664 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 17
+2665 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 17
+2666 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 17
+2667 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 17
+2668 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 18
+2669 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 18
+2670 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 18
+2671 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 18
+2672 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 18
+2673 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 18
+2674 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 18
+2675 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 18
+2676 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 18
+2677 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 18
+2678 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 18
+2679 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 18
+2680 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 18
+2681 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 18
+2682 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 18
+2683 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 18
+2684 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 18
+2685 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 18
+2686 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 18
+2687 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 18
+2688 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 18
+2689 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 18
+2690 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 18
+2691 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 18
+2692 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 18
+2693 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 18
+2694 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 18
+2695 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 18
+2696 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 18
+2697 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 18
+2698 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 18
+2699 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 18
+2700 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 18
+2701 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 18
+2702 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 18
+2703 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 18
+2704 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 18
+2705 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 18
+2706 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 18
+2707 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 18
+2708 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 18
+2709 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 18
+2710 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 18
+2711 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 18
+2712 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 18
+2713 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 18
+2714 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 19
+2715 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 19
+2716 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 19
+2717 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 19
+2718 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 19
+2719 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 19
+2720 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 19
+2721 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 19
+2722 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 19
+2723 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 19
+2724 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 19
+2725 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 19
+2726 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 19
+2727 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 19
+2728 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 19
+2729 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 19
+2730 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 19
+2731 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 19
+2732 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 19
+2733 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 19
+2734 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 19
+2735 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 19
+2736 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 19
+2737 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 19
+2738 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 19
+2739 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 19
+2740 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 19
+2741 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 19
+2742 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 19
+2743 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 19
+2744 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 19
+2745 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 19
+2746 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 19
+2747 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 19
+2748 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 19
+2749 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 19
+2750 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 19
+2751 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 19
+2752 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 19
+2753 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 19
+2754 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 19
+2755 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 19
+2756 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 19
+2757 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 19
+2758 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 19
+2759 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 19
+2760 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 20
+2761 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 20
+2762 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 20
+2763 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 20
+2764 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 20
+2765 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 20
+2766 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 20
+2767 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 20
+2768 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 20
+2769 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 20
+2770 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 20
+2771 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 20
+2772 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 20
+2773 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 20
+2774 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 20
+2775 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 20
+2776 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 20
+2777 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 20
+2778 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 20
+2779 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 20
+2780 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 20
+2781 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 20
+2782 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 20
+2783 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 20
+2784 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 20
+2785 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 20
+2786 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 20
+2787 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 20
+2788 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 20
+2789 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 20
+2790 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 20
+2791 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 20
+2792 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 20
+2793 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 20
+2794 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 20
+2795 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 20
+2796 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 20
+2797 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 20
+2798 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 20
+2799 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 20
+2800 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 20
+2801 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 20
+2802 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 20
+2803 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 20
+2804 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 20
+2805 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 20
+2806 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 21
+2807 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 21
+2808 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 21
+2809 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 21
+2810 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 21
+2811 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 21
+2812 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 21
+2813 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 21
+2814 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 21
+2815 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 21
+2816 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 21
+2817 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 21
+2818 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 21
+2819 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 21
+2820 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 21
+2821 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 21
+2822 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 21
+2823 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 21
+2824 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 21
+2825 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 21
+2826 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 21
+2827 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 21
+2828 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 21
+2829 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 21
+2830 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 21
+2831 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 21
+2832 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 21
+2833 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 21
+2834 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 21
+2835 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 21
+2836 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 21
+2837 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 21
+2838 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 21
+2839 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 21
+2840 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 21
+2841 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 21
+2842 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 21
+2843 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 21
+2844 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 21
+2845 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 21
+2846 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 21
+2847 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 21
+2848 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 21
+2849 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 21
+2850 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 21
+2851 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 21
+2852 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 22
+2853 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 22
+2854 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 22
+2855 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 22
+2856 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 22
+2857 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 22
+2858 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 22
+2859 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 22
+2860 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 22
+2861 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 22
+2862 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 22
+2863 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 22
+2864 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 22
+2865 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 22
+2866 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 22
+2867 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 22
+2868 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 22
+2869 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 22
+2870 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 22
+2871 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 22
+2872 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 22
+2873 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 22
+2874 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 22
+2875 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 22
+2876 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 22
+2877 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 22
+2878 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 22
+2879 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 22
+2880 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 22
+2881 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 22
+2882 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 22
+2883 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 22
+2884 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 22
+2885 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 22
+2886 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 22
+2887 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 22
+2888 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 22
+2889 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 22
+2890 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 22
+2891 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 22
+2892 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 22
+2893 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 22
+2894 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 22
+2895 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 22
+2896 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 22
+2897 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 22
+2898 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 23
+2899 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 23
+2900 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 23
+2901 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 23
+2902 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 23
+2903 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 23
+2904 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 23
+2905 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 23
+2906 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 23
+2907 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 23
+2908 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 23
+2909 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 23
+2910 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 23
+2911 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 23
+2912 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 23
+2913 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 23
+2914 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 23
+2915 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 23
+2916 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 23
+2917 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 23
+2918 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 23
+2919 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 23
+2920 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 23
+2921 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 23
+2922 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 23
+2923 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 23
+2924 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 23
+2925 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 23
+2926 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 23
+2927 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 23
+2928 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 23
+2929 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 23
+2930 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 23
+2931 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 23
+2932 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 23
+2933 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 23
+2934 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 23
+2935 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 23
+2936 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 23
+2937 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 23
+2938 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 23
+2939 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 23
+2940 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 23
+2941 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 23
+2942 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 23
+2943 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 23
+2944 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 24
+2945 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 24
+2946 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 24
+2947 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 24
+2948 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 24
+2949 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 24
+2950 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 24
+2951 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 24
+2952 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 24
+2953 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 24
+2954 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 24
+2955 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 24
+2956 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 24
+2957 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 24
+2958 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 24
+2959 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 24
+2960 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 24
+2961 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 24
+2962 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 24
+2963 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 24
+2964 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 24
+2965 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 24
+2966 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 24
+2967 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 24
+2968 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 24
+2969 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 24
+2970 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 24
+2971 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 24
+2972 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 24
+2973 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 24
+2974 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 24
+2975 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 24
+2976 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 24
+2977 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 24
+2978 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 24
+2979 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 24
+2980 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 24
+2981 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 24
+2982 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 24
+2983 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 24
+2984 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 24
+2985 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 24
+2986 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 24
+2987 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 24
+2988 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 24
+2989 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 24
+2990 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 25
+2991 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 25
+2992 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 25
+2993 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 25
+2994 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 25
+2995 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 25
+2996 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 25
+2997 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 25
+2998 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 25
+2999 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 25
+3000 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 25
+3001 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 25
+3002 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 25
+3003 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 25
+3004 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 25
+3005 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 25
+3006 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 25
+3007 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 25
+3008 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 25
+3009 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 25
+3010 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 25
+3011 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 25
+3012 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 25
+3013 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 25
+3014 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 25
+3015 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 25
+3016 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 25
+3017 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 25
+3018 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 25
+3019 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 25
+3020 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 25
+3021 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 25
+3022 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 25
+3023 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 25
+3024 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 25
+3025 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 25
+3026 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 25
+3027 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 25
+3028 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 25
+3029 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 25
+3030 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 25
+3031 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 25
+3032 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 25
+3033 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 25
+3034 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 25
+3035 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 25
+3036 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 26
+3037 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 26
+3038 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 26
+3039 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 26
+3040 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 26
+3041 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 26
+3042 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 26
+3043 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 26
+3044 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 26
+3045 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 26
+3046 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 26
+3047 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 26
+3048 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 26
+3049 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 26
+3050 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 26
+3051 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 26
+3052 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 26
+3053 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 26
+3054 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 26
+3055 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 26
+3056 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 26
+3057 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 26
+3058 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 26
+3059 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 26
+3060 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 26
+3061 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 26
+3062 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 26
+3063 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 26
+3064 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 26
+3065 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 26
+3066 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 26
+3067 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 26
+3068 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 26
+3069 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 26
+3070 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 26
+3071 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 26
+3072 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 26
+3073 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 26
+3074 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 26
+3075 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 26
+3076 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 26
+3077 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 26
+3078 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 26
+3079 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 26
+3080 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 26
+3081 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 26
+3082 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 27
+3083 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 27
+3084 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 27
+3085 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 27
+3086 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 27
+3087 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 27
+3088 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 27
+3089 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 27
+3090 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 27
+3091 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 27
+3092 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 27
+3093 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 27
+3094 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 27
+3095 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 27
+3096 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 27
+3097 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 27
+3098 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 27
+3099 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 27
+3100 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 27
+3101 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 27
+3102 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 27
+3103 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 27
+3104 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 27
+3105 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 27
+3106 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 27
+3107 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 27
+3108 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 27
+3109 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 27
+3110 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 27
+3111 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 27
+3112 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 27
+3113 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 27
+3114 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 27
+3115 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 27
+3116 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 27
+3117 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 27
+3118 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 27
+3119 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 27
+3120 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 27
+3121 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 27
+3122 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 27
+3123 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 27
+3124 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 27
+3125 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 27
+3126 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 27
+3127 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 27
+3128 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 28
+3129 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 28
+3130 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 28
+3131 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 28
+3132 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 28
+3133 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 28
+3134 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 28
+3135 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 28
+3136 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 28
+3137 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 28
+3138 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 28
+3139 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 28
+3140 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 28
+3141 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 28
+3142 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 28
+3143 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 28
+3144 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 28
+3145 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 28
+3146 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 28
+3147 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 28
+3148 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 28
+3149 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 28
+3150 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 28
+3151 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 28
+3152 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 28
+3153 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 28
+3154 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 28
+3155 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 28
+3156 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 28
+3157 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 28
+3158 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 28
+3159 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 28
+3160 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 28
+3161 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 28
+3162 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 28
+3163 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 28
+3164 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 28
+3165 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 28
+3166 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 28
+3167 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 28
+3168 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 28
+3169 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 28
+3170 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 28
+3171 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 28
+3172 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 28
+3173 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 28
+3174 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 29
+3175 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 29
+3176 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 29
+3177 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 29
+3178 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 29
+3179 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 29
+3180 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 29
+3181 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 29
+3182 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 29
+3183 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 29
+3184 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 29
+3185 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 29
+3186 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 29
+3187 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 29
+3188 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 29
+3189 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 29
+3190 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 29
+3191 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 29
+3192 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 29
+3193 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 29
+3194 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 29
+3195 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 29
+3196 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 29
+3197 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 29
+3198 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 29
+3199 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 29
+3200 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 29
+3201 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 29
+3202 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 29
+3203 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 29
+3204 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 29
+3205 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 29
+3206 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 29
+3207 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 29
+3208 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 29
+3209 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 29
+3210 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 29
+3211 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 29
+3212 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 29
+3213 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 29
+3214 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 29
+3215 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 29
+3216 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 29
+3217 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 29
+3218 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 29
+3219 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 29
+3220 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 30
+3221 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 30
+3222 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 30
+3223 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 30
+3224 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 30
+3225 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 30
+3226 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 30
+3227 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 30
+3228 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 30
+3229 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 30
+3230 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 30
+3231 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 30
+3232 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 30
+3233 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 30
+3234 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 30
+3235 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 30
+3236 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 30
+3237 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 30
+3238 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 30
+3239 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 30
+3240 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 30
+3241 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 30
+3242 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 30
+3243 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 30
+3244 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 30
+3245 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 30
+3246 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 30
+3247 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 30
+3248 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 30
+3249 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 30
+3250 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 30
+3251 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 30
+3252 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 30
+3253 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 30
+3254 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 30
+3255 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 30
+3256 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 30
+3257 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 30
+3258 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 30
+3259 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 30
+3260 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 30
+3261 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 30
+3262 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 30
+3263 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 30
+3264 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 30
+3265 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 30
+3266 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 31
+3267 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 31
+3268 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 31
+3269 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 31
+3270 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 31
+3271 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 31
+3272 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 31
+3273 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 31
+3274 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 31
+3275 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 31
+3276 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 31
+3277 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 31
+3278 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 31
+3279 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 31
+3280 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 31
+3281 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 31
+3282 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 31
+3283 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 31
+3284 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 31
+3285 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 31
+3286 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 31
+3287 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 31
+3288 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 31
+3289 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 31
+3290 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 31
+3291 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 31
+3292 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 31
+3293 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 31
+3294 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 31
+3295 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 31
+3296 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 31
+3297 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 31
+3298 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 31
+3299 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 31
+3300 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 31
+3301 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 31
+3302 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 31
+3303 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 31
+3304 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 31
+3305 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 31
+3306 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 31
+3307 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 31
+3308 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 31
+3309 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 31
+3310 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 31
+3311 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 31
+3312 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 32
+3313 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 32
+3314 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 32
+3315 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 32
+3316 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 32
+3317 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 32
+3318 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 32
+3319 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 32
+3320 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 32
+3321 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 32
+3322 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 32
+3323 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 32
+3324 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 32
+3325 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 32
+3326 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 32
+3327 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 32
+3328 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 32
+3329 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 32
+3330 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 32
+3331 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 32
+3332 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 32
+3333 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 32
+3334 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 32
+3335 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 32
+3336 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 32
+3337 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 32
+3338 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 32
+3339 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 32
+3340 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 32
+3341 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 32
+3342 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 32
+3343 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 32
+3344 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 32
+3345 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 32
+3346 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 32
+3347 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 32
+3348 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 32
+3349 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 32
+3350 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 32
+3351 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 32
+3352 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 32
+3353 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 32
+3354 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 32
+3355 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 32
+3356 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 32
+3357 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 32
+3358 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 33
+3359 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 33
+3360 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 33
+3361 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 33
+3362 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 33
+3363 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 33
+3364 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 33
+3365 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 33
+3366 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 33
+3367 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 33
+3368 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 33
+3369 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 33
+3370 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 33
+3371 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 33
+3372 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 33
+3373 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 33
+3374 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 33
+3375 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 33
+3376 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 33
+3377 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 33
+3378 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 33
+3379 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 33
+3380 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 33
+3381 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 33
+3382 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 33
+3383 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 33
+3384 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 33
+3385 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 33
+3386 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 33
+3387 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 33
+3388 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 33
+3389 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 33
+3390 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 33
+3391 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 33
+3392 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 33
+3393 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 33
+3394 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 33
+3395 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 33
+3396 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 33
+3397 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 33
+3398 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 33
+3399 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 33
+3400 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 33
+3401 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 33
+3402 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 33
+3403 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 33
+3404 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 34
+3405 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 34
+3406 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 34
+3407 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 34
+3408 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 34
+3409 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 34
+3410 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 34
+3411 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 34
+3412 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 34
+3413 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 34
+3414 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 34
+3415 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 34
+3416 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 34
+3417 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 34
+3418 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 34
+3419 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 34
+3420 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 34
+3421 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 34
+3422 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 34
+3423 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 34
+3424 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 34
+3425 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 34
+3426 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 34
+3427 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 34
+3428 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 34
+3429 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 34
+3430 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 34
+3431 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 34
+3432 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 34
+3433 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 34
+3434 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 34
+3435 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 34
+3436 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 34
+3437 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 34
+3438 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 34
+3439 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 34
+3440 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 34
+3441 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 34
+3442 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 34
+3443 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 34
+3444 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 34
+3445 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 34
+3446 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 34
+3447 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 34
+3448 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 34
+3449 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 34
+3450 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 35
+3451 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 35
+3452 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 35
+3453 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 35
+3454 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 35
+3455 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 35
+3456 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 35
+3457 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 35
+3458 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 35
+3459 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 35
+3460 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 35
+3461 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 35
+3462 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 35
+3463 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 35
+3464 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 35
+3465 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 35
+3466 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 35
+3467 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 35
+3468 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 35
+3469 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 35
+3470 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 35
+3471 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 35
+3472 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 35
+3473 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 35
+3474 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 35
+3475 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 35
+3476 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 35
+3477 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 35
+3478 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 35
+3479 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 35
+3480 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 35
+3481 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 35
+3482 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 35
+3483 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 35
+3484 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 35
+3485 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 35
+3486 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 35
+3487 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 35
+3488 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 35
+3489 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 35
+3490 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 35
+3491 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 35
+3492 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 35
+3493 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 35
+3494 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 35
+3495 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 35
+3496 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 36
+3497 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 36
+3498 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 36
+3499 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 36
+3500 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 36
+3501 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 36
+3502 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 36
+3503 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 36
+3504 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 36
+3505 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 36
+3506 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 36
+3507 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 36
+3508 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 36
+3509 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 36
+3510 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 36
+3511 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 36
+3512 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 36
+3513 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 36
+3514 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 36
+3515 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 36
+3516 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 36
+3517 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 36
+3518 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 36
+3519 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 36
+3520 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 36
+3521 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 36
+3522 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 36
+3523 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 36
+3524 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 36
+3525 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 36
+3526 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 36
+3527 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 36
+3528 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 36
+3529 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 36
+3530 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 36
+3531 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 36
+3532 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 36
+3533 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 36
+3534 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 36
+3535 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 36
+3536 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 36
+3537 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 36
+3538 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 36
+3539 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 36
+3540 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 36
+3541 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 36
+3542 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 37
+3543 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 37
+3544 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 37
+3545 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 37
+3546 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 37
+3547 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 37
+3548 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 37
+3549 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 37
+3550 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 37
+3551 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 37
+3552 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 37
+3553 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 37
+3554 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 37
+3555 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 37
+3556 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 37
+3557 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 37
+3558 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 37
+3559 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 37
+3560 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 37
+3561 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 37
+3562 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 37
+3563 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 37
+3564 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 37
+3565 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 37
+3566 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 37
+3567 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 37
+3568 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 37
+3569 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 37
+3570 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 37
+3571 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 37
+3572 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 37
+3573 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 37
+3574 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 37
+3575 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 37
+3576 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 37
+3577 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 37
+3578 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 37
+3579 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 37
+3580 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 37
+3581 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 37
+3582 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 37
+3583 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 37
+3584 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 37
+3585 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 37
+3586 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 37
+3587 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 37
+3588 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 38
+3589 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 38
+3590 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 38
+3591 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 38
+3592 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 38
+3593 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 38
+3594 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 38
+3595 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 38
+3596 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 38
+3597 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 38
+3598 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 38
+3599 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 38
+3600 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 38
+3601 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 38
+3602 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 38
+3603 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 38
+3604 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 38
+3605 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 38
+3606 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 38
+3607 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 38
+3608 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 38
+3609 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 38
+3610 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 38
+3611 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 38
+3612 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 38
+3613 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 38
+3614 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 38
+3615 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 38
+3616 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 38
+3617 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 38
+3618 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 38
+3619 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 38
+3620 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 38
+3621 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 38
+3622 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 38
+3623 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 38
+3624 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 38
+3625 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 38
+3626 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 38
+3627 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 38
+3628 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 38
+3629 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 38
+3630 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 38
+3631 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 38
+3632 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 38
+3633 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 38
+3634 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 1 39
+3635 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 1 39
+3636 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 1 39
+3637 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 1 39
+3638 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 1 39
+3639 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 1 39
+3640 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 1 39
+3641 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 39
+3642 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 1 39
+3643 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 1 39
+3644 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 1 39
+3645 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 1 39
+3646 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 1 39
+3647 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 1 39
+3648 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 1 39
+3649 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 1 39
+3650 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 1 39
+3651 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 1 39
+3652 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 1 39
+3653 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 1 39
+3654 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 1 39
+3655 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 1 39
+3656 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 1 39
+3657 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 1 39
+3658 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 1 39
+3659 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 1 39
+3660 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 1 39
+3661 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 1 39
+3662 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 1 39
+3663 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 1 39
+3664 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 1 39
+3665 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 1 39
+3666 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 1 39
+3667 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 1 39
+3668 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 1 39
+3669 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 1 39
+3670 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 1 39
+3671 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 1 39
+3672 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 1 39
+3673 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 1 39
+3674 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 39
+3675 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 1 39
+3676 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 1 39
+3677 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 1 39
+3678 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 39
+3679 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 1 39
+3680 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 0
+3681 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 0
+3682 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 0
+3683 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 0
+3684 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 0
+3685 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 0
+3686 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 0
+3687 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 0
+3688 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 0
+3689 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 0
+3690 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 0
+3691 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 0
+3692 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 0
+3693 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 0
+3694 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 0
+3695 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 0
+3696 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 0
+3697 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 0
+3698 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 0
+3699 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 0
+3700 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 0
+3701 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 0
+3702 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 0
+3703 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 0
+3704 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 0
+3705 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 0
+3706 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 0
+3707 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 0
+3708 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 0
+3709 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 0
+3710 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 0
+3711 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 0
+3712 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 0
+3713 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 0
+3714 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 0
+3715 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 0
+3716 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 0
+3717 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 0
+3718 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 0
+3719 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 0
+3720 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 0
+3721 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 0
+3722 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 0
+3723 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 0
+3724 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 0
+3725 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 0
+3726 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 1
+3727 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 1
+3728 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 1
+3729 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 1
+3730 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 1
+3731 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 1
+3732 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 1
+3733 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 1
+3734 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 1
+3735 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 1
+3736 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 1
+3737 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 1
+3738 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 1
+3739 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 1
+3740 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 1
+3741 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 1
+3742 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 1
+3743 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 1
+3744 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 1
+3745 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 1
+3746 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 1
+3747 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 1
+3748 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 1
+3749 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 1
+3750 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 1
+3751 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 1
+3752 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 1
+3753 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 1
+3754 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 1
+3755 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 1
+3756 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 1
+3757 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 1
+3758 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 1
+3759 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 1
+3760 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 1
+3761 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 1
+3762 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 1
+3763 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 1
+3764 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 1
+3765 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 1
+3766 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 1
+3767 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 1
+3768 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 1
+3769 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 1
+3770 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 1
+3771 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 1
+3772 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 2
+3773 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 2
+3774 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 2
+3775 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 2
+3776 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 2
+3777 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 2
+3778 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 2
+3779 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 2
+3780 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 2
+3781 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 2
+3782 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 2
+3783 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 2
+3784 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 2
+3785 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 2
+3786 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 2
+3787 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 2
+3788 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 2
+3789 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 2
+3790 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 2
+3791 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 2
+3792 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 2
+3793 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 2
+3794 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 2
+3795 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 2
+3796 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 2
+3797 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 2
+3798 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 2
+3799 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 2
+3800 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 2
+3801 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 2
+3802 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 2
+3803 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 2
+3804 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 2
+3805 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 2
+3806 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 2
+3807 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 2
+3808 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 2
+3809 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 2
+3810 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 2
+3811 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 2
+3812 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 2
+3813 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 2
+3814 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 2
+3815 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 2
+3816 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 2
+3817 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 2
+3818 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 3
+3819 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 3
+3820 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 3
+3821 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 3
+3822 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 3
+3823 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 3
+3824 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 3
+3825 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 3
+3826 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 3
+3827 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 3
+3828 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 3
+3829 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 3
+3830 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 3
+3831 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 3
+3832 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 3
+3833 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 3
+3834 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 3
+3835 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 3
+3836 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 3
+3837 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 3
+3838 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 3
+3839 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 3
+3840 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 3
+3841 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 3
+3842 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 3
+3843 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 3
+3844 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 3
+3845 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 3
+3846 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 3
+3847 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 3
+3848 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 3
+3849 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 3
+3850 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 3
+3851 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 3
+3852 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 3
+3853 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 3
+3854 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 3
+3855 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 3
+3856 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 3
+3857 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 3
+3858 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 3
+3859 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 3
+3860 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 3
+3861 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 3
+3862 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 3
+3863 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 3
+3864 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 4
+3865 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 4
+3866 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 4
+3867 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 4
+3868 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 4
+3869 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 4
+3870 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 4
+3871 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 4
+3872 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 4
+3873 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 4
+3874 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 4
+3875 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 4
+3876 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 4
+3877 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 4
+3878 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 4
+3879 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 4
+3880 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 4
+3881 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 4
+3882 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 4
+3883 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 4
+3884 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 4
+3885 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 4
+3886 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 4
+3887 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 4
+3888 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 4
+3889 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 4
+3890 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 4
+3891 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 4
+3892 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 4
+3893 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 4
+3894 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 4
+3895 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 4
+3896 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 4
+3897 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 4
+3898 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 4
+3899 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 4
+3900 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 4
+3901 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 4
+3902 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 4
+3903 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 4
+3904 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 4
+3905 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 4
+3906 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 4
+3907 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 4
+3908 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 4
+3909 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 4
+3910 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 5
+3911 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 5
+3912 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 5
+3913 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 5
+3914 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 5
+3915 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 5
+3916 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 5
+3917 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 5
+3918 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 5
+3919 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 5
+3920 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 5
+3921 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 5
+3922 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 5
+3923 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 5
+3924 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 5
+3925 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 5
+3926 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 5
+3927 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 5
+3928 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 5
+3929 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 5
+3930 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 5
+3931 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 5
+3932 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 5
+3933 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 5
+3934 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 5
+3935 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 5
+3936 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 5
+3937 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 5
+3938 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 5
+3939 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 5
+3940 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 5
+3941 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 5
+3942 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 5
+3943 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 5
+3944 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 5
+3945 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 5
+3946 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 5
+3947 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 5
+3948 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 5
+3949 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 5
+3950 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 5
+3951 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 5
+3952 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 5
+3953 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 5
+3954 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 5
+3955 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 5
+3956 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 6
+3957 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 6
+3958 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 6
+3959 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 6
+3960 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 6
+3961 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 6
+3962 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 6
+3963 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 6
+3964 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 6
+3965 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 6
+3966 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 6
+3967 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 6
+3968 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 6
+3969 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 6
+3970 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 6
+3971 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 6
+3972 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 6
+3973 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 6
+3974 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 6
+3975 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 6
+3976 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 6
+3977 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 6
+3978 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 6
+3979 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 6
+3980 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 6
+3981 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 6
+3982 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 6
+3983 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 6
+3984 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 6
+3985 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 6
+3986 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 6
+3987 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 6
+3988 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 6
+3989 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 6
+3990 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 6
+3991 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 6
+3992 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 6
+3993 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 6
+3994 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 6
+3995 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 6
+3996 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 6
+3997 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 6
+3998 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 6
+3999 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 6
+4000 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 6
+4001 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 6
+4002 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 7
+4003 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 7
+4004 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 7
+4005 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 7
+4006 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 7
+4007 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 7
+4008 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 7
+4009 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 7
+4010 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 7
+4011 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 7
+4012 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 7
+4013 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 7
+4014 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 7
+4015 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 7
+4016 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 7
+4017 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 7
+4018 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 7
+4019 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 7
+4020 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 7
+4021 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 7
+4022 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 7
+4023 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 7
+4024 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 7
+4025 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 7
+4026 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 7
+4027 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 7
+4028 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 7
+4029 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 7
+4030 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 7
+4031 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 7
+4032 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 7
+4033 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 7
+4034 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 7
+4035 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 7
+4036 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 7
+4037 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 7
+4038 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 7
+4039 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 7
+4040 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 7
+4041 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 7
+4042 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 7
+4043 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 7
+4044 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 7
+4045 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 7
+4046 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 7
+4047 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 7
+4048 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 8
+4049 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 8
+4050 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 8
+4051 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 8
+4052 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 8
+4053 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 8
+4054 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 8
+4055 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 8
+4056 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 8
+4057 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 8
+4058 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 8
+4059 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 8
+4060 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 8
+4061 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 8
+4062 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 8
+4063 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 8
+4064 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 8
+4065 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 8
+4066 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 8
+4067 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 8
+4068 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 8
+4069 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 8
+4070 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 8
+4071 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 8
+4072 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 8
+4073 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 8
+4074 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 8
+4075 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 8
+4076 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 8
+4077 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 8
+4078 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 8
+4079 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 8
+4080 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 8
+4081 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 8
+4082 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 8
+4083 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 8
+4084 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 8
+4085 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 8
+4086 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 8
+4087 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 8
+4088 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 8
+4089 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 8
+4090 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 8
+4091 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 8
+4092 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 8
+4093 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 8
+4094 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 9
+4095 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 9
+4096 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 9
+4097 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 9
+4098 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 9
+4099 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 9
+4100 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 9
+4101 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 9
+4102 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 9
+4103 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 9
+4104 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 9
+4105 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 9
+4106 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 9
+4107 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 9
+4108 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 9
+4109 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 9
+4110 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 9
+4111 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 9
+4112 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 9
+4113 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 9
+4114 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 9
+4115 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 9
+4116 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 9
+4117 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 9
+4118 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 9
+4119 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 9
+4120 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 9
+4121 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 9
+4122 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 9
+4123 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 9
+4124 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 9
+4125 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 9
+4126 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 9
+4127 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 9
+4128 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 9
+4129 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 9
+4130 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 9
+4131 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 9
+4132 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 9
+4133 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 9
+4134 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 9
+4135 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 9
+4136 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 9
+4137 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 9
+4138 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 9
+4139 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 9
+4140 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 10
+4141 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 10
+4142 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 10
+4143 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 10
+4144 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 10
+4145 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 10
+4146 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 10
+4147 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 10
+4148 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 10
+4149 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 10
+4150 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 10
+4151 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 10
+4152 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 10
+4153 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 10
+4154 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 10
+4155 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 10
+4156 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 10
+4157 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 10
+4158 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 10
+4159 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 10
+4160 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 10
+4161 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 10
+4162 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 10
+4163 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 10
+4164 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 10
+4165 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 10
+4166 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 10
+4167 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 10
+4168 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 10
+4169 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 10
+4170 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 10
+4171 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 10
+4172 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 10
+4173 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 10
+4174 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 10
+4175 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 10
+4176 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 10
+4177 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 10
+4178 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 10
+4179 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 10
+4180 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 10
+4181 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 10
+4182 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 10
+4183 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 10
+4184 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 10
+4185 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 10
+4186 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 11
+4187 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 11
+4188 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 11
+4189 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 11
+4190 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 11
+4191 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 11
+4192 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 11
+4193 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 11
+4194 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 11
+4195 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 11
+4196 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 11
+4197 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 11
+4198 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 11
+4199 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 11
+4200 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 11
+4201 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 11
+4202 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 11
+4203 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 11
+4204 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 11
+4205 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 11
+4206 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 11
+4207 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 11
+4208 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 11
+4209 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 11
+4210 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 11
+4211 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 11
+4212 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 11
+4213 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 11
+4214 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 11
+4215 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 11
+4216 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 11
+4217 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 11
+4218 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 11
+4219 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 11
+4220 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 11
+4221 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 11
+4222 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 11
+4223 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 11
+4224 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 11
+4225 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 11
+4226 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 11
+4227 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 11
+4228 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 11
+4229 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 11
+4230 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 11
+4231 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 11
+4232 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 12
+4233 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 12
+4234 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 12
+4235 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 12
+4236 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 12
+4237 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 12
+4238 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 12
+4239 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 12
+4240 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 12
+4241 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 12
+4242 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 12
+4243 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 12
+4244 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 12
+4245 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 12
+4246 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 12
+4247 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 12
+4248 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 12
+4249 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 12
+4250 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 12
+4251 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 12
+4252 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 12
+4253 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 12
+4254 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 12
+4255 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 12
+4256 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 12
+4257 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 12
+4258 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 12
+4259 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 12
+4260 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 12
+4261 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 12
+4262 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 12
+4263 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 12
+4264 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 12
+4265 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 12
+4266 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 12
+4267 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 12
+4268 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 12
+4269 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 12
+4270 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 12
+4271 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 12
+4272 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 12
+4273 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 12
+4274 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 12
+4275 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 12
+4276 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 12
+4277 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 12
+4278 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 13
+4279 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 13
+4280 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 13
+4281 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 13
+4282 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 13
+4283 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 13
+4284 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 13
+4285 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 13
+4286 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 13
+4287 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 13
+4288 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 13
+4289 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 13
+4290 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 13
+4291 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 13
+4292 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 13
+4293 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 13
+4294 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 13
+4295 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 13
+4296 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 13
+4297 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 13
+4298 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 13
+4299 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 13
+4300 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 13
+4301 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 13
+4302 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 13
+4303 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 13
+4304 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 13
+4305 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 13
+4306 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 13
+4307 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 13
+4308 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 13
+4309 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 13
+4310 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 13
+4311 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 13
+4312 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 13
+4313 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 13
+4314 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 13
+4315 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 13
+4316 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 13
+4317 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 13
+4318 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 13
+4319 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 13
+4320 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 13
+4321 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 13
+4322 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 13
+4323 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 13
+4324 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 14
+4325 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 14
+4326 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 14
+4327 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 14
+4328 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 14
+4329 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 14
+4330 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 14
+4331 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 14
+4332 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 14
+4333 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 14
+4334 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 14
+4335 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 14
+4336 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 14
+4337 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 14
+4338 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 14
+4339 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 14
+4340 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 14
+4341 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 14
+4342 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 14
+4343 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 14
+4344 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 14
+4345 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 14
+4346 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 14
+4347 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 14
+4348 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 14
+4349 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 14
+4350 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 14
+4351 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 14
+4352 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 14
+4353 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 14
+4354 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 14
+4355 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 14
+4356 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 14
+4357 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 14
+4358 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 14
+4359 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 14
+4360 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 14
+4361 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 14
+4362 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 14
+4363 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 14
+4364 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 14
+4365 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 14
+4366 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 14
+4367 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 14
+4368 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 14
+4369 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 14
+4370 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 15
+4371 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 15
+4372 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 15
+4373 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 15
+4374 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 15
+4375 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 15
+4376 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 15
+4377 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 15
+4378 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 15
+4379 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 15
+4380 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 15
+4381 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 15
+4382 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 15
+4383 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 15
+4384 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 15
+4385 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 15
+4386 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 15
+4387 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 15
+4388 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 15
+4389 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 15
+4390 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 15
+4391 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 15
+4392 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 15
+4393 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 15
+4394 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 15
+4395 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 15
+4396 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 15
+4397 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 15
+4398 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 15
+4399 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 15
+4400 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 15
+4401 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 15
+4402 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 15
+4403 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 15
+4404 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 15
+4405 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 15
+4406 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 15
+4407 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 15
+4408 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 15
+4409 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 15
+4410 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 15
+4411 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 15
+4412 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 15
+4413 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 15
+4414 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 15
+4415 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 15
+4416 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 16
+4417 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 16
+4418 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 16
+4419 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 16
+4420 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 16
+4421 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 16
+4422 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 16
+4423 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 16
+4424 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 16
+4425 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 16
+4426 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 16
+4427 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 16
+4428 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 16
+4429 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 16
+4430 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 16
+4431 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 16
+4432 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 16
+4433 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 16
+4434 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 16
+4435 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 16
+4436 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 16
+4437 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 16
+4438 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 16
+4439 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 16
+4440 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 16
+4441 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 16
+4442 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 16
+4443 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 16
+4444 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 16
+4445 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 16
+4446 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 16
+4447 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 16
+4448 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 16
+4449 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 16
+4450 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 16
+4451 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 16
+4452 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 16
+4453 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 16
+4454 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 16
+4455 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 16
+4456 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 16
+4457 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 16
+4458 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 16
+4459 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 16
+4460 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 16
+4461 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 16
+4462 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 17
+4463 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 17
+4464 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 17
+4465 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 17
+4466 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 17
+4467 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 17
+4468 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 17
+4469 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 17
+4470 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 17
+4471 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 17
+4472 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 17
+4473 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 17
+4474 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 17
+4475 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 17
+4476 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 17
+4477 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 17
+4478 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 17
+4479 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 17
+4480 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 17
+4481 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 17
+4482 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 17
+4483 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 17
+4484 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 17
+4485 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 17
+4486 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 17
+4487 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 17
+4488 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 17
+4489 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 17
+4490 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 17
+4491 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 17
+4492 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 17
+4493 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 17
+4494 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 17
+4495 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 17
+4496 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 17
+4497 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 17
+4498 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 17
+4499 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 17
+4500 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 17
+4501 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 17
+4502 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 17
+4503 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 17
+4504 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 17
+4505 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 17
+4506 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 17
+4507 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 17
+4508 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 18
+4509 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 18
+4510 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 18
+4511 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 18
+4512 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 18
+4513 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 18
+4514 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 18
+4515 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 18
+4516 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 18
+4517 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 18
+4518 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 18
+4519 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 18
+4520 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 18
+4521 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 18
+4522 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 18
+4523 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 18
+4524 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 18
+4525 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 18
+4526 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 18
+4527 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 18
+4528 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 18
+4529 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 18
+4530 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 18
+4531 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 18
+4532 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 18
+4533 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 18
+4534 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 18
+4535 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 18
+4536 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 18
+4537 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 18
+4538 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 18
+4539 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 18
+4540 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 18
+4541 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 18
+4542 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 18
+4543 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 18
+4544 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 18
+4545 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 18
+4546 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 18
+4547 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 18
+4548 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 18
+4549 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 18
+4550 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 18
+4551 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 18
+4552 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 18
+4553 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 18
+4554 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 19
+4555 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 19
+4556 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 19
+4557 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 19
+4558 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 19
+4559 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 19
+4560 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 19
+4561 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 19
+4562 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 19
+4563 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 19
+4564 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 19
+4565 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 19
+4566 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 19
+4567 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 19
+4568 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 19
+4569 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 19
+4570 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 19
+4571 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 19
+4572 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 19
+4573 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 19
+4574 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 19
+4575 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 19
+4576 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 19
+4577 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 19
+4578 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 19
+4579 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 19
+4580 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 19
+4581 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 19
+4582 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 19
+4583 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 19
+4584 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 19
+4585 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 19
+4586 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 19
+4587 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 19
+4588 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 19
+4589 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 19
+4590 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 19
+4591 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 19
+4592 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 19
+4593 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 19
+4594 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 19
+4595 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 19
+4596 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 19
+4597 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 19
+4598 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 19
+4599 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 19
+4600 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 20
+4601 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 20
+4602 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 20
+4603 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 20
+4604 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 20
+4605 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 20
+4606 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 20
+4607 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 20
+4608 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 20
+4609 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 20
+4610 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 20
+4611 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 20
+4612 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 20
+4613 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 20
+4614 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 20
+4615 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 20
+4616 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 20
+4617 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 20
+4618 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 20
+4619 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 20
+4620 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 20
+4621 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 20
+4622 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 20
+4623 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 20
+4624 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 20
+4625 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 20
+4626 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 20
+4627 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 20
+4628 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 20
+4629 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 20
+4630 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 20
+4631 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 20
+4632 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 20
+4633 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 20
+4634 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 20
+4635 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 20
+4636 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 20
+4637 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 20
+4638 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 20
+4639 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 20
+4640 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 20
+4641 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 20
+4642 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 20
+4643 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 20
+4644 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 20
+4645 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 20
+4646 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 21
+4647 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 21
+4648 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 21
+4649 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 21
+4650 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 21
+4651 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 21
+4652 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 21
+4653 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 21
+4654 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 21
+4655 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 21
+4656 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 21
+4657 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 21
+4658 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 21
+4659 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 21
+4660 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 21
+4661 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 21
+4662 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 21
+4663 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 21
+4664 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 21
+4665 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 21
+4666 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 21
+4667 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 21
+4668 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 21
+4669 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 21
+4670 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 21
+4671 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 21
+4672 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 21
+4673 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 21
+4674 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 21
+4675 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 21
+4676 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 21
+4677 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 21
+4678 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 21
+4679 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 21
+4680 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 21
+4681 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 21
+4682 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 21
+4683 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 21
+4684 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 21
+4685 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 21
+4686 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 21
+4687 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 21
+4688 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 21
+4689 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 21
+4690 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 21
+4691 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 21
+4692 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 22
+4693 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 22
+4694 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 22
+4695 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 22
+4696 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 22
+4697 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 22
+4698 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 22
+4699 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 22
+4700 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 22
+4701 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 22
+4702 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 22
+4703 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 22
+4704 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 22
+4705 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 22
+4706 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 22
+4707 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 22
+4708 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 22
+4709 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 22
+4710 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 22
+4711 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 22
+4712 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 22
+4713 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 22
+4714 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 22
+4715 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 22
+4716 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 22
+4717 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 22
+4718 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 22
+4719 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 22
+4720 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 22
+4721 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 22
+4722 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 22
+4723 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 22
+4724 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 22
+4725 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 22
+4726 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 22
+4727 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 22
+4728 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 22
+4729 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 22
+4730 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 22
+4731 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 22
+4732 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 22
+4733 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 22
+4734 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 22
+4735 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 22
+4736 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 22
+4737 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 22
+4738 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 23
+4739 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 23
+4740 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 23
+4741 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 23
+4742 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 23
+4743 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 23
+4744 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 23
+4745 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 23
+4746 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 23
+4747 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 23
+4748 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 23
+4749 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 23
+4750 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 23
+4751 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 23
+4752 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 23
+4753 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 23
+4754 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 23
+4755 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 23
+4756 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 23
+4757 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 23
+4758 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 23
+4759 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 23
+4760 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 23
+4761 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 23
+4762 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 23
+4763 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 23
+4764 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 23
+4765 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 23
+4766 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 23
+4767 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 23
+4768 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 23
+4769 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 23
+4770 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 23
+4771 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 23
+4772 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 23
+4773 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 23
+4774 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 23
+4775 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 23
+4776 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 23
+4777 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 23
+4778 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 23
+4779 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 23
+4780 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 23
+4781 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 23
+4782 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 23
+4783 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 23
+4784 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 24
+4785 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 24
+4786 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 24
+4787 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 24
+4788 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 24
+4789 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 24
+4790 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 24
+4791 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 24
+4792 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 24
+4793 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 24
+4794 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 24
+4795 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 24
+4796 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 24
+4797 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 24
+4798 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 24
+4799 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 24
+4800 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 24
+4801 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 24
+4802 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 24
+4803 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 24
+4804 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 24
+4805 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 24
+4806 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 24
+4807 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 24
+4808 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 24
+4809 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 24
+4810 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 24
+4811 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 24
+4812 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 24
+4813 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 24
+4814 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 24
+4815 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 24
+4816 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 24
+4817 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 24
+4818 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 24
+4819 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 24
+4820 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 24
+4821 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 24
+4822 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 24
+4823 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 24
+4824 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 24
+4825 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 24
+4826 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 24
+4827 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 24
+4828 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 24
+4829 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 24
+4830 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 25
+4831 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 25
+4832 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 25
+4833 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 25
+4834 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 25
+4835 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 25
+4836 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 25
+4837 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 25
+4838 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 25
+4839 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 25
+4840 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 25
+4841 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 25
+4842 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 25
+4843 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 25
+4844 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 25
+4845 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 25
+4846 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 25
+4847 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 25
+4848 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 25
+4849 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 25
+4850 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 25
+4851 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 25
+4852 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 25
+4853 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 25
+4854 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 25
+4855 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 25
+4856 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 25
+4857 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 25
+4858 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 25
+4859 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 25
+4860 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 25
+4861 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 25
+4862 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 25
+4863 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 25
+4864 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 25
+4865 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 25
+4866 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 25
+4867 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 25
+4868 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 25
+4869 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 25
+4870 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 25
+4871 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 25
+4872 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 25
+4873 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 25
+4874 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 25
+4875 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 25
+4876 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 26
+4877 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 26
+4878 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 26
+4879 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 26
+4880 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 26
+4881 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 26
+4882 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 26
+4883 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 26
+4884 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 26
+4885 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 26
+4886 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 26
+4887 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 26
+4888 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 26
+4889 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 26
+4890 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 26
+4891 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 26
+4892 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 26
+4893 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 26
+4894 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 26
+4895 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 26
+4896 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 26
+4897 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 26
+4898 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 26
+4899 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 26
+4900 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 26
+4901 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 26
+4902 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 26
+4903 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 26
+4904 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 26
+4905 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 26
+4906 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 26
+4907 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 26
+4908 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 26
+4909 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 26
+4910 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 26
+4911 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 26
+4912 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 26
+4913 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 26
+4914 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 26
+4915 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 26
+4916 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 26
+4917 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 26
+4918 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 26
+4919 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 26
+4920 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 26
+4921 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 26
+4922 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 27
+4923 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 27
+4924 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 27
+4925 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 27
+4926 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 27
+4927 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 27
+4928 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 27
+4929 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 27
+4930 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 27
+4931 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 27
+4932 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 27
+4933 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 27
+4934 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 27
+4935 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 27
+4936 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 27
+4937 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 27
+4938 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 27
+4939 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 27
+4940 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 27
+4941 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 27
+4942 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 27
+4943 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 27
+4944 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 27
+4945 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 27
+4946 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 27
+4947 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 27
+4948 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 27
+4949 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 27
+4950 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 27
+4951 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 27
+4952 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 27
+4953 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 27
+4954 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 27
+4955 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 27
+4956 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 27
+4957 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 27
+4958 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 27
+4959 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 27
+4960 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 27
+4961 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 27
+4962 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 27
+4963 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 27
+4964 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 27
+4965 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 27
+4966 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 27
+4967 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 27
+4968 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 28
+4969 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 28
+4970 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 28
+4971 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 28
+4972 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 28
+4973 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 28
+4974 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 28
+4975 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 28
+4976 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 28
+4977 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 28
+4978 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 28
+4979 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 28
+4980 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 28
+4981 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 28
+4982 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 28
+4983 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 28
+4984 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 28
+4985 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 28
+4986 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 28
+4987 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 28
+4988 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 28
+4989 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 28
+4990 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 28
+4991 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 28
+4992 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 28
+4993 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 28
+4994 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 28
+4995 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 28
+4996 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 28
+4997 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 28
+4998 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 28
+4999 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 28
+5000 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 28
+5001 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 28
+5002 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 28
+5003 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 28
+5004 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 28
+5005 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 28
+5006 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 28
+5007 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 28
+5008 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 28
+5009 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 28
+5010 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 28
+5011 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 28
+5012 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 28
+5013 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 28
+5014 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 29
+5015 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 29
+5016 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 29
+5017 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 29
+5018 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 29
+5019 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 29
+5020 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 29
+5021 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 29
+5022 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 29
+5023 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 29
+5024 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 29
+5025 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 29
+5026 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 29
+5027 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 29
+5028 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 29
+5029 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 29
+5030 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 29
+5031 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 29
+5032 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 29
+5033 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 29
+5034 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 29
+5035 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 29
+5036 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 29
+5037 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 29
+5038 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 29
+5039 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 29
+5040 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 29
+5041 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 29
+5042 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 29
+5043 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 29
+5044 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 29
+5045 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 29
+5046 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 29
+5047 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 29
+5048 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 29
+5049 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 29
+5050 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 29
+5051 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 29
+5052 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 29
+5053 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 29
+5054 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 29
+5055 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 29
+5056 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 29
+5057 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 29
+5058 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 29
+5059 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 29
+5060 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 30
+5061 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 30
+5062 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 30
+5063 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 30
+5064 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 30
+5065 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 30
+5066 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 30
+5067 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 30
+5068 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 30
+5069 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 30
+5070 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 30
+5071 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 30
+5072 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 30
+5073 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 30
+5074 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 30
+5075 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 30
+5076 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 30
+5077 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 30
+5078 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 30
+5079 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 30
+5080 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 30
+5081 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 30
+5082 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 30
+5083 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 30
+5084 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 30
+5085 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 30
+5086 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 30
+5087 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 30
+5088 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 30
+5089 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 30
+5090 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 30
+5091 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 30
+5092 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 30
+5093 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 30
+5094 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 30
+5095 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 30
+5096 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 30
+5097 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 30
+5098 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 30
+5099 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 30
+5100 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 30
+5101 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 30
+5102 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 30
+5103 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 30
+5104 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 30
+5105 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 30
+5106 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 31
+5107 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 31
+5108 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 31
+5109 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 31
+5110 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 31
+5111 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 31
+5112 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 31
+5113 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 31
+5114 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 31
+5115 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 31
+5116 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 31
+5117 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 31
+5118 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 31
+5119 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 31
+5120 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 31
+5121 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 31
+5122 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 31
+5123 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 31
+5124 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 31
+5125 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 31
+5126 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 31
+5127 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 31
+5128 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 31
+5129 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 31
+5130 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 31
+5131 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 31
+5132 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 31
+5133 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 31
+5134 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 31
+5135 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 31
+5136 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 31
+5137 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 31
+5138 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 31
+5139 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 31
+5140 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 31
+5141 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 31
+5142 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 31
+5143 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 31
+5144 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 31
+5145 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 31
+5146 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 31
+5147 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 31
+5148 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 31
+5149 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 31
+5150 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 31
+5151 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 31
+5152 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 32
+5153 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 32
+5154 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 32
+5155 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 32
+5156 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 32
+5157 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 32
+5158 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 32
+5159 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 32
+5160 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 32
+5161 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 32
+5162 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 32
+5163 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 32
+5164 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 32
+5165 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 32
+5166 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 32
+5167 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 32
+5168 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 32
+5169 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 32
+5170 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 32
+5171 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 32
+5172 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 32
+5173 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 32
+5174 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 32
+5175 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 32
+5176 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 32
+5177 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 32
+5178 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 32
+5179 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 32
+5180 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 32
+5181 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 32
+5182 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 32
+5183 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 32
+5184 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 32
+5185 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 32
+5186 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 32
+5187 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 32
+5188 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 32
+5189 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 32
+5190 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 32
+5191 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 32
+5192 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 32
+5193 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 32
+5194 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 32
+5195 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 32
+5196 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 32
+5197 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 32
+5198 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 33
+5199 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 33
+5200 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 33
+5201 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 33
+5202 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 33
+5203 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 33
+5204 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 33
+5205 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 33
+5206 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 33
+5207 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 33
+5208 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 33
+5209 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 33
+5210 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 33
+5211 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 33
+5212 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 33
+5213 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 33
+5214 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 33
+5215 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 33
+5216 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 33
+5217 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 33
+5218 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 33
+5219 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 33
+5220 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 33
+5221 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 33
+5222 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 33
+5223 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 33
+5224 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 33
+5225 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 33
+5226 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 33
+5227 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 33
+5228 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 33
+5229 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 33
+5230 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 33
+5231 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 33
+5232 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 33
+5233 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 33
+5234 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 33
+5235 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 33
+5236 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 33
+5237 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 33
+5238 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 33
+5239 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 33
+5240 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 33
+5241 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 33
+5242 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 33
+5243 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 33
+5244 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 34
+5245 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 34
+5246 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 34
+5247 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 34
+5248 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 34
+5249 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 34
+5250 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 34
+5251 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 34
+5252 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 34
+5253 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 34
+5254 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 34
+5255 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 34
+5256 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 34
+5257 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 34
+5258 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 34
+5259 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 34
+5260 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 34
+5261 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 34
+5262 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 34
+5263 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 34
+5264 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 34
+5265 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 34
+5266 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 34
+5267 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 34
+5268 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 34
+5269 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 34
+5270 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 34
+5271 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 34
+5272 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 34
+5273 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 34
+5274 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 34
+5275 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 34
+5276 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 34
+5277 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 34
+5278 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 34
+5279 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 34
+5280 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 34
+5281 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 34
+5282 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 34
+5283 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 34
+5284 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 34
+5285 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 34
+5286 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 34
+5287 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 34
+5288 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 34
+5289 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 34
+5290 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 35
+5291 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 35
+5292 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 35
+5293 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 35
+5294 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 35
+5295 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 35
+5296 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 35
+5297 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 35
+5298 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 35
+5299 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 35
+5300 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 35
+5301 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 35
+5302 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 35
+5303 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 35
+5304 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 35
+5305 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 35
+5306 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 35
+5307 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 35
+5308 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 35
+5309 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 35
+5310 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 35
+5311 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 35
+5312 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 35
+5313 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 35
+5314 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 35
+5315 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 35
+5316 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 35
+5317 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 35
+5318 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 35
+5319 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 35
+5320 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 35
+5321 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 35
+5322 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 35
+5323 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 35
+5324 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 35
+5325 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 35
+5326 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 35
+5327 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 35
+5328 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 35
+5329 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 35
+5330 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 35
+5331 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 35
+5332 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 35
+5333 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 35
+5334 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 35
+5335 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 35
+5336 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 36
+5337 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 36
+5338 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 36
+5339 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 36
+5340 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 36
+5341 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 36
+5342 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 36
+5343 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 36
+5344 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 36
+5345 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 36
+5346 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 36
+5347 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 36
+5348 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 36
+5349 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 36
+5350 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 36
+5351 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 36
+5352 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 36
+5353 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 36
+5354 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 36
+5355 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 36
+5356 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 36
+5357 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 36
+5358 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 36
+5359 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 36
+5360 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 36
+5361 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 36
+5362 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 36
+5363 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 36
+5364 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 36
+5365 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 36
+5366 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 36
+5367 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 36
+5368 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 36
+5369 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 36
+5370 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 36
+5371 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 36
+5372 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 36
+5373 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 36
+5374 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 36
+5375 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 36
+5376 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 36
+5377 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 36
+5378 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 36
+5379 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 36
+5380 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 36
+5381 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 36
+5382 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 37
+5383 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 37
+5384 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 37
+5385 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 37
+5386 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 37
+5387 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 37
+5388 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 37
+5389 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 37
+5390 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 37
+5391 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 37
+5392 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 37
+5393 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 37
+5394 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 37
+5395 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 37
+5396 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 37
+5397 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 37
+5398 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 37
+5399 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 37
+5400 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 37
+5401 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 37
+5402 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 37
+5403 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 37
+5404 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 37
+5405 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 37
+5406 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 37
+5407 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 37
+5408 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 37
+5409 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 37
+5410 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 37
+5411 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 37
+5412 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 37
+5413 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 37
+5414 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 37
+5415 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 37
+5416 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 37
+5417 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 37
+5418 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 37
+5419 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 37
+5420 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 37
+5421 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 37
+5422 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 37
+5423 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 37
+5424 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 37
+5425 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 37
+5426 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 37
+5427 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 37
+5428 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 38
+5429 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 38
+5430 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 38
+5431 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 38
+5432 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 38
+5433 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 38
+5434 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 38
+5435 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 38
+5436 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 38
+5437 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 38
+5438 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 38
+5439 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 38
+5440 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 38
+5441 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 38
+5442 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 38
+5443 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 38
+5444 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 38
+5445 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 38
+5446 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 38
+5447 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 38
+5448 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 38
+5449 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 38
+5450 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 38
+5451 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 38
+5452 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 38
+5453 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 38
+5454 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 38
+5455 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 38
+5456 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 38
+5457 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 38
+5458 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 38
+5459 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 38
+5460 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 38
+5461 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 38
+5462 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 38
+5463 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 38
+5464 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 38
+5465 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 38
+5466 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 38
+5467 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 38
+5468 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 38
+5469 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 38
+5470 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 38
+5471 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 38
+5472 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 38
+5473 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 38
+5474 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 2 39
+5475 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 2 39
+5476 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 2 39
+5477 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 2 39
+5478 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 2 39
+5479 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 2 39
+5480 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 2 39
+5481 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 39
+5482 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 2 39
+5483 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 2 39
+5484 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 2 39
+5485 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 2 39
+5486 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 2 39
+5487 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 2 39
+5488 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 2 39
+5489 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 2 39
+5490 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 2 39
+5491 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 2 39
+5492 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 2 39
+5493 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 2 39
+5494 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 2 39
+5495 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 2 39
+5496 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 2 39
+5497 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 2 39
+5498 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 2 39
+5499 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 2 39
+5500 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 2 39
+5501 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 2 39
+5502 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 2 39
+5503 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 2 39
+5504 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 2 39
+5505 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 2 39
+5506 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 2 39
+5507 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 2 39
+5508 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 2 39
+5509 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 2 39
+5510 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 2 39
+5511 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 2 39
+5512 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 2 39
+5513 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 2 39
+5514 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 39
+5515 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 2 39
+5516 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 2 39
+5517 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 2 39
+5518 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 39
+5519 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 2 39
+5520 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 0
+5521 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 0
+5522 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 0
+5523 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 0
+5524 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 0
+5525 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 0
+5526 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 0
+5527 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 0
+5528 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 0
+5529 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 0
+5530 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 0
+5531 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 0
+5532 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 0
+5533 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 0
+5534 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 0
+5535 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 0
+5536 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 0
+5537 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 0
+5538 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 0
+5539 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 0
+5540 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 0
+5541 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 0
+5542 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 0
+5543 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 0
+5544 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 0
+5545 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 0
+5546 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 0
+5547 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 0
+5548 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 0
+5549 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 0
+5550 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 0
+5551 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 0
+5552 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 0
+5553 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 0
+5554 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 0
+5555 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 0
+5556 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 0
+5557 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 0
+5558 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 0
+5559 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 0
+5560 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 0
+5561 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 0
+5562 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 0
+5563 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 0
+5564 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 0
+5565 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 0
+5566 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 1
+5567 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 1
+5568 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 1
+5569 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 1
+5570 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 1
+5571 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 1
+5572 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 1
+5573 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 1
+5574 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 1
+5575 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 1
+5576 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 1
+5577 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 1
+5578 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 1
+5579 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 1
+5580 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 1
+5581 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 1
+5582 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 1
+5583 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 1
+5584 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 1
+5585 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 1
+5586 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 1
+5587 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 1
+5588 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 1
+5589 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 1
+5590 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 1
+5591 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 1
+5592 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 1
+5593 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 1
+5594 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 1
+5595 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 1
+5596 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 1
+5597 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 1
+5598 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 1
+5599 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 1
+5600 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 1
+5601 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 1
+5602 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 1
+5603 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 1
+5604 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 1
+5605 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 1
+5606 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 1
+5607 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 1
+5608 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 1
+5609 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 1
+5610 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 1
+5611 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 1
+5612 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 2
+5613 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 2
+5614 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 2
+5615 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 2
+5616 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 2
+5617 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 2
+5618 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 2
+5619 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 2
+5620 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 2
+5621 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 2
+5622 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 2
+5623 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 2
+5624 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 2
+5625 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 2
+5626 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 2
+5627 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 2
+5628 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 2
+5629 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 2
+5630 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 2
+5631 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 2
+5632 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 2
+5633 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 2
+5634 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 2
+5635 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 2
+5636 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 2
+5637 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 2
+5638 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 2
+5639 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 2
+5640 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 2
+5641 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 2
+5642 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 2
+5643 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 2
+5644 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 2
+5645 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 2
+5646 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 2
+5647 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 2
+5648 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 2
+5649 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 2
+5650 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 2
+5651 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 2
+5652 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 2
+5653 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 2
+5654 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 2
+5655 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 2
+5656 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 2
+5657 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 2
+5658 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 3
+5659 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 3
+5660 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 3
+5661 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 3
+5662 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 3
+5663 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 3
+5664 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 3
+5665 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 3
+5666 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 3
+5667 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 3
+5668 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 3
+5669 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 3
+5670 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 3
+5671 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 3
+5672 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 3
+5673 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 3
+5674 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 3
+5675 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 3
+5676 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 3
+5677 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 3
+5678 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 3
+5679 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 3
+5680 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 3
+5681 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 3
+5682 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 3
+5683 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 3
+5684 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 3
+5685 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 3
+5686 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 3
+5687 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 3
+5688 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 3
+5689 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 3
+5690 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 3
+5691 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 3
+5692 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 3
+5693 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 3
+5694 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 3
+5695 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 3
+5696 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 3
+5697 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 3
+5698 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 3
+5699 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 3
+5700 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 3
+5701 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 3
+5702 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 3
+5703 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 3
+5704 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 4
+5705 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 4
+5706 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 4
+5707 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 4
+5708 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 4
+5709 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 4
+5710 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 4
+5711 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 4
+5712 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 4
+5713 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 4
+5714 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 4
+5715 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 4
+5716 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 4
+5717 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 4
+5718 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 4
+5719 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 4
+5720 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 4
+5721 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 4
+5722 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 4
+5723 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 4
+5724 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 4
+5725 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 4
+5726 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 4
+5727 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 4
+5728 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 4
+5729 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 4
+5730 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 4
+5731 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 4
+5732 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 4
+5733 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 4
+5734 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 4
+5735 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 4
+5736 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 4
+5737 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 4
+5738 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 4
+5739 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 4
+5740 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 4
+5741 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 4
+5742 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 4
+5743 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 4
+5744 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 4
+5745 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 4
+5746 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 4
+5747 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 4
+5748 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 4
+5749 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 4
+5750 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 5
+5751 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 5
+5752 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 5
+5753 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 5
+5754 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 5
+5755 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 5
+5756 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 5
+5757 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 5
+5758 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 5
+5759 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 5
+5760 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 5
+5761 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 5
+5762 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 5
+5763 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 5
+5764 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 5
+5765 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 5
+5766 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 5
+5767 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 5
+5768 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 5
+5769 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 5
+5770 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 5
+5771 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 5
+5772 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 5
+5773 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 5
+5774 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 5
+5775 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 5
+5776 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 5
+5777 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 5
+5778 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 5
+5779 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 5
+5780 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 5
+5781 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 5
+5782 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 5
+5783 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 5
+5784 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 5
+5785 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 5
+5786 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 5
+5787 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 5
+5788 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 5
+5789 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 5
+5790 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 5
+5791 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 5
+5792 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 5
+5793 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 5
+5794 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 5
+5795 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 5
+5796 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 6
+5797 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 6
+5798 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 6
+5799 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 6
+5800 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 6
+5801 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 6
+5802 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 6
+5803 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 6
+5804 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 6
+5805 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 6
+5806 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 6
+5807 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 6
+5808 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 6
+5809 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 6
+5810 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 6
+5811 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 6
+5812 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 6
+5813 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 6
+5814 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 6
+5815 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 6
+5816 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 6
+5817 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 6
+5818 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 6
+5819 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 6
+5820 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 6
+5821 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 6
+5822 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 6
+5823 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 6
+5824 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 6
+5825 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 6
+5826 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 6
+5827 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 6
+5828 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 6
+5829 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 6
+5830 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 6
+5831 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 6
+5832 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 6
+5833 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 6
+5834 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 6
+5835 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 6
+5836 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 6
+5837 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 6
+5838 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 6
+5839 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 6
+5840 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 6
+5841 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 6
+5842 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 7
+5843 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 7
+5844 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 7
+5845 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 7
+5846 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 7
+5847 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 7
+5848 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 7
+5849 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 7
+5850 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 7
+5851 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 7
+5852 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 7
+5853 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 7
+5854 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 7
+5855 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 7
+5856 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 7
+5857 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 7
+5858 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 7
+5859 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 7
+5860 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 7
+5861 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 7
+5862 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 7
+5863 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 7
+5864 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 7
+5865 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 7
+5866 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 7
+5867 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 7
+5868 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 7
+5869 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 7
+5870 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 7
+5871 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 7
+5872 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 7
+5873 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 7
+5874 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 7
+5875 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 7
+5876 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 7
+5877 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 7
+5878 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 7
+5879 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 7
+5880 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 7
+5881 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 7
+5882 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 7
+5883 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 7
+5884 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 7
+5885 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 7
+5886 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 7
+5887 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 7
+5888 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 8
+5889 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 8
+5890 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 8
+5891 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 8
+5892 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 8
+5893 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 8
+5894 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 8
+5895 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 8
+5896 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 8
+5897 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 8
+5898 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 8
+5899 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 8
+5900 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 8
+5901 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 8
+5902 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 8
+5903 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 8
+5904 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 8
+5905 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 8
+5906 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 8
+5907 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 8
+5908 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 8
+5909 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 8
+5910 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 8
+5911 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 8
+5912 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 8
+5913 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 8
+5914 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 8
+5915 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 8
+5916 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 8
+5917 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 8
+5918 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 8
+5919 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 8
+5920 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 8
+5921 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 8
+5922 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 8
+5923 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 8
+5924 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 8
+5925 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 8
+5926 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 8
+5927 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 8
+5928 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 8
+5929 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 8
+5930 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 8
+5931 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 8
+5932 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 8
+5933 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 8
+5934 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 9
+5935 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 9
+5936 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 9
+5937 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 9
+5938 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 9
+5939 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 9
+5940 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 9
+5941 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 9
+5942 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 9
+5943 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 9
+5944 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 9
+5945 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 9
+5946 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 9
+5947 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 9
+5948 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 9
+5949 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 9
+5950 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 9
+5951 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 9
+5952 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 9
+5953 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 9
+5954 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 9
+5955 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 9
+5956 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 9
+5957 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 9
+5958 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 9
+5959 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 9
+5960 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 9
+5961 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 9
+5962 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 9
+5963 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 9
+5964 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 9
+5965 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 9
+5966 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 9
+5967 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 9
+5968 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 9
+5969 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 9
+5970 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 9
+5971 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 9
+5972 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 9
+5973 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 9
+5974 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 9
+5975 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 9
+5976 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 9
+5977 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 9
+5978 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 9
+5979 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 9
+5980 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 10
+5981 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 10
+5982 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 10
+5983 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 10
+5984 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 10
+5985 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 10
+5986 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 10
+5987 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 10
+5988 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 10
+5989 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 10
+5990 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 10
+5991 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 10
+5992 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 10
+5993 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 10
+5994 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 10
+5995 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 10
+5996 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 10
+5997 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 10
+5998 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 10
+5999 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 10
+6000 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 10
+6001 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 10
+6002 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 10
+6003 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 10
+6004 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 10
+6005 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 10
+6006 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 10
+6007 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 10
+6008 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 10
+6009 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 10
+6010 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 10
+6011 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 10
+6012 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 10
+6013 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 10
+6014 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 10
+6015 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 10
+6016 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 10
+6017 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 10
+6018 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 10
+6019 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 10
+6020 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 10
+6021 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 10
+6022 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 10
+6023 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 10
+6024 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 10
+6025 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 10
+6026 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 11
+6027 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 11
+6028 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 11
+6029 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 11
+6030 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 11
+6031 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 11
+6032 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 11
+6033 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 11
+6034 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 11
+6035 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 11
+6036 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 11
+6037 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 11
+6038 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 11
+6039 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 11
+6040 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 11
+6041 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 11
+6042 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 11
+6043 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 11
+6044 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 11
+6045 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 11
+6046 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 11
+6047 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 11
+6048 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 11
+6049 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 11
+6050 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 11
+6051 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 11
+6052 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 11
+6053 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 11
+6054 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 11
+6055 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 11
+6056 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 11
+6057 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 11
+6058 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 11
+6059 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 11
+6060 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 11
+6061 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 11
+6062 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 11
+6063 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 11
+6064 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 11
+6065 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 11
+6066 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 11
+6067 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 11
+6068 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 11
+6069 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 11
+6070 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 11
+6071 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 11
+6072 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 12
+6073 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 12
+6074 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 12
+6075 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 12
+6076 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 12
+6077 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 12
+6078 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 12
+6079 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 12
+6080 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 12
+6081 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 12
+6082 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 12
+6083 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 12
+6084 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 12
+6085 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 12
+6086 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 12
+6087 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 12
+6088 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 12
+6089 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 12
+6090 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 12
+6091 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 12
+6092 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 12
+6093 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 12
+6094 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 12
+6095 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 12
+6096 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 12
+6097 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 12
+6098 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 12
+6099 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 12
+6100 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 12
+6101 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 12
+6102 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 12
+6103 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 12
+6104 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 12
+6105 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 12
+6106 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 12
+6107 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 12
+6108 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 12
+6109 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 12
+6110 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 12
+6111 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 12
+6112 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 12
+6113 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 12
+6114 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 12
+6115 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 12
+6116 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 12
+6117 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 12
+6118 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 13
+6119 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 13
+6120 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 13
+6121 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 13
+6122 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 13
+6123 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 13
+6124 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 13
+6125 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 13
+6126 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 13
+6127 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 13
+6128 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 13
+6129 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 13
+6130 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 13
+6131 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 13
+6132 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 13
+6133 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 13
+6134 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 13
+6135 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 13
+6136 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 13
+6137 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 13
+6138 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 13
+6139 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 13
+6140 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 13
+6141 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 13
+6142 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 13
+6143 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 13
+6144 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 13
+6145 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 13
+6146 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 13
+6147 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 13
+6148 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 13
+6149 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 13
+6150 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 13
+6151 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 13
+6152 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 13
+6153 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 13
+6154 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 13
+6155 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 13
+6156 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 13
+6157 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 13
+6158 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 13
+6159 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 13
+6160 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 13
+6161 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 13
+6162 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 13
+6163 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 13
+6164 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 14
+6165 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 14
+6166 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 14
+6167 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 14
+6168 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 14
+6169 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 14
+6170 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 14
+6171 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 14
+6172 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 14
+6173 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 14
+6174 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 14
+6175 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 14
+6176 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 14
+6177 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 14
+6178 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 14
+6179 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 14
+6180 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 14
+6181 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 14
+6182 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 14
+6183 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 14
+6184 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 14
+6185 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 14
+6186 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 14
+6187 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 14
+6188 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 14
+6189 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 14
+6190 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 14
+6191 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 14
+6192 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 14
+6193 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 14
+6194 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 14
+6195 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 14
+6196 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 14
+6197 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 14
+6198 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 14
+6199 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 14
+6200 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 14
+6201 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 14
+6202 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 14
+6203 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 14
+6204 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 14
+6205 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 14
+6206 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 14
+6207 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 14
+6208 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 14
+6209 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 14
+6210 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 15
+6211 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 15
+6212 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 15
+6213 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 15
+6214 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 15
+6215 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 15
+6216 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 15
+6217 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 15
+6218 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 15
+6219 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 15
+6220 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 15
+6221 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 15
+6222 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 15
+6223 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 15
+6224 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 15
+6225 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 15
+6226 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 15
+6227 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 15
+6228 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 15
+6229 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 15
+6230 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 15
+6231 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 15
+6232 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 15
+6233 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 15
+6234 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 15
+6235 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 15
+6236 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 15
+6237 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 15
+6238 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 15
+6239 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 15
+6240 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 15
+6241 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 15
+6242 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 15
+6243 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 15
+6244 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 15
+6245 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 15
+6246 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 15
+6247 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 15
+6248 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 15
+6249 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 15
+6250 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 15
+6251 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 15
+6252 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 15
+6253 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 15
+6254 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 15
+6255 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 15
+6256 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 16
+6257 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 16
+6258 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 16
+6259 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 16
+6260 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 16
+6261 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 16
+6262 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 16
+6263 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 16
+6264 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 16
+6265 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 16
+6266 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 16
+6267 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 16
+6268 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 16
+6269 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 16
+6270 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 16
+6271 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 16
+6272 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 16
+6273 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 16
+6274 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 16
+6275 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 16
+6276 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 16
+6277 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 16
+6278 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 16
+6279 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 16
+6280 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 16
+6281 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 16
+6282 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 16
+6283 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 16
+6284 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 16
+6285 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 16
+6286 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 16
+6287 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 16
+6288 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 16
+6289 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 16
+6290 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 16
+6291 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 16
+6292 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 16
+6293 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 16
+6294 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 16
+6295 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 16
+6296 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 16
+6297 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 16
+6298 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 16
+6299 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 16
+6300 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 16
+6301 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 16
+6302 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 17
+6303 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 17
+6304 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 17
+6305 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 17
+6306 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 17
+6307 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 17
+6308 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 17
+6309 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 17
+6310 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 17
+6311 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 17
+6312 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 17
+6313 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 17
+6314 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 17
+6315 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 17
+6316 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 17
+6317 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 17
+6318 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 17
+6319 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 17
+6320 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 17
+6321 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 17
+6322 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 17
+6323 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 17
+6324 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 17
+6325 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 17
+6326 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 17
+6327 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 17
+6328 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 17
+6329 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 17
+6330 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 17
+6331 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 17
+6332 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 17
+6333 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 17
+6334 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 17
+6335 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 17
+6336 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 17
+6337 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 17
+6338 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 17
+6339 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 17
+6340 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 17
+6341 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 17
+6342 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 17
+6343 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 17
+6344 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 17
+6345 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 17
+6346 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 17
+6347 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 17
+6348 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 18
+6349 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 18
+6350 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 18
+6351 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 18
+6352 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 18
+6353 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 18
+6354 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 18
+6355 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 18
+6356 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 18
+6357 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 18
+6358 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 18
+6359 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 18
+6360 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 18
+6361 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 18
+6362 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 18
+6363 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 18
+6364 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 18
+6365 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 18
+6366 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 18
+6367 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 18
+6368 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 18
+6369 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 18
+6370 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 18
+6371 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 18
+6372 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 18
+6373 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 18
+6374 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 18
+6375 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 18
+6376 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 18
+6377 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 18
+6378 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 18
+6379 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 18
+6380 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 18
+6381 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 18
+6382 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 18
+6383 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 18
+6384 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 18
+6385 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 18
+6386 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 18
+6387 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 18
+6388 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 18
+6389 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 18
+6390 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 18
+6391 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 18
+6392 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 18
+6393 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 18
+6394 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 19
+6395 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 19
+6396 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 19
+6397 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 19
+6398 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 19
+6399 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 19
+6400 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 19
+6401 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 19
+6402 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 19
+6403 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 19
+6404 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 19
+6405 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 19
+6406 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 19
+6407 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 19
+6408 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 19
+6409 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 19
+6410 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 19
+6411 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 19
+6412 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 19
+6413 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 19
+6414 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 19
+6415 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 19
+6416 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 19
+6417 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 19
+6418 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 19
+6419 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 19
+6420 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 19
+6421 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 19
+6422 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 19
+6423 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 19
+6424 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 19
+6425 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 19
+6426 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 19
+6427 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 19
+6428 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 19
+6429 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 19
+6430 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 19
+6431 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 19
+6432 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 19
+6433 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 19
+6434 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 19
+6435 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 19
+6436 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 19
+6437 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 19
+6438 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 19
+6439 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 19
+6440 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 20
+6441 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 20
+6442 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 20
+6443 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 20
+6444 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 20
+6445 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 20
+6446 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 20
+6447 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 20
+6448 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 20
+6449 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 20
+6450 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 20
+6451 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 20
+6452 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 20
+6453 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 20
+6454 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 20
+6455 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 20
+6456 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 20
+6457 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 20
+6458 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 20
+6459 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 20
+6460 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 20
+6461 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 20
+6462 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 20
+6463 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 20
+6464 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 20
+6465 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 20
+6466 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 20
+6467 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 20
+6468 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 20
+6469 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 20
+6470 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 20
+6471 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 20
+6472 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 20
+6473 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 20
+6474 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 20
+6475 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 20
+6476 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 20
+6477 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 20
+6478 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 20
+6479 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 20
+6480 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 20
+6481 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 20
+6482 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 20
+6483 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 20
+6484 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 20
+6485 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 20
+6486 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 21
+6487 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 21
+6488 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 21
+6489 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 21
+6490 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 21
+6491 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 21
+6492 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 21
+6493 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 21
+6494 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 21
+6495 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 21
+6496 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 21
+6497 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 21
+6498 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 21
+6499 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 21
+6500 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 21
+6501 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 21
+6502 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 21
+6503 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 21
+6504 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 21
+6505 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 21
+6506 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 21
+6507 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 21
+6508 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 21
+6509 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 21
+6510 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 21
+6511 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 21
+6512 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 21
+6513 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 21
+6514 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 21
+6515 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 21
+6516 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 21
+6517 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 21
+6518 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 21
+6519 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 21
+6520 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 21
+6521 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 21
+6522 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 21
+6523 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 21
+6524 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 21
+6525 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 21
+6526 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 21
+6527 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 21
+6528 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 21
+6529 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 21
+6530 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 21
+6531 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 21
+6532 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 22
+6533 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 22
+6534 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 22
+6535 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 22
+6536 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 22
+6537 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 22
+6538 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 22
+6539 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 22
+6540 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 22
+6541 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 22
+6542 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 22
+6543 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 22
+6544 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 22
+6545 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 22
+6546 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 22
+6547 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 22
+6548 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 22
+6549 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 22
+6550 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 22
+6551 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 22
+6552 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 22
+6553 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 22
+6554 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 22
+6555 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 22
+6556 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 22
+6557 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 22
+6558 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 22
+6559 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 22
+6560 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 22
+6561 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 22
+6562 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 22
+6563 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 22
+6564 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 22
+6565 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 22
+6566 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 22
+6567 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 22
+6568 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 22
+6569 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 22
+6570 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 22
+6571 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 22
+6572 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 22
+6573 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 22
+6574 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 22
+6575 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 22
+6576 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 22
+6577 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 22
+6578 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 23
+6579 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 23
+6580 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 23
+6581 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 23
+6582 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 23
+6583 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 23
+6584 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 23
+6585 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 23
+6586 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 23
+6587 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 23
+6588 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 23
+6589 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 23
+6590 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 23
+6591 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 23
+6592 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 23
+6593 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 23
+6594 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 23
+6595 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 23
+6596 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 23
+6597 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 23
+6598 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 23
+6599 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 23
+6600 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 23
+6601 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 23
+6602 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 23
+6603 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 23
+6604 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 23
+6605 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 23
+6606 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 23
+6607 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 23
+6608 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 23
+6609 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 23
+6610 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 23
+6611 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 23
+6612 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 23
+6613 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 23
+6614 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 23
+6615 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 23
+6616 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 23
+6617 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 23
+6618 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 23
+6619 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 23
+6620 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 23
+6621 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 23
+6622 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 23
+6623 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 23
+6624 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 24
+6625 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 24
+6626 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 24
+6627 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 24
+6628 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 24
+6629 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 24
+6630 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 24
+6631 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 24
+6632 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 24
+6633 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 24
+6634 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 24
+6635 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 24
+6636 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 24
+6637 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 24
+6638 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 24
+6639 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 24
+6640 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 24
+6641 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 24
+6642 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 24
+6643 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 24
+6644 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 24
+6645 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 24
+6646 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 24
+6647 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 24
+6648 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 24
+6649 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 24
+6650 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 24
+6651 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 24
+6652 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 24
+6653 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 24
+6654 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 24
+6655 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 24
+6656 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 24
+6657 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 24
+6658 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 24
+6659 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 24
+6660 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 24
+6661 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 24
+6662 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 24
+6663 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 24
+6664 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 24
+6665 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 24
+6666 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 24
+6667 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 24
+6668 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 24
+6669 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 24
+6670 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 25
+6671 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 25
+6672 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 25
+6673 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 25
+6674 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 25
+6675 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 25
+6676 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 25
+6677 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 25
+6678 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 25
+6679 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 25
+6680 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 25
+6681 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 25
+6682 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 25
+6683 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 25
+6684 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 25
+6685 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 25
+6686 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 25
+6687 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 25
+6688 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 25
+6689 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 25
+6690 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 25
+6691 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 25
+6692 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 25
+6693 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 25
+6694 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 25
+6695 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 25
+6696 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 25
+6697 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 25
+6698 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 25
+6699 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 25
+6700 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 25
+6701 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 25
+6702 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 25
+6703 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 25
+6704 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 25
+6705 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 25
+6706 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 25
+6707 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 25
+6708 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 25
+6709 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 25
+6710 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 25
+6711 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 25
+6712 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 25
+6713 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 25
+6714 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 25
+6715 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 25
+6716 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 26
+6717 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 26
+6718 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 26
+6719 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 26
+6720 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 26
+6721 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 26
+6722 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 26
+6723 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 26
+6724 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 26
+6725 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 26
+6726 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 26
+6727 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 26
+6728 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 26
+6729 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 26
+6730 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 26
+6731 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 26
+6732 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 26
+6733 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 26
+6734 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 26
+6735 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 26
+6736 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 26
+6737 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 26
+6738 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 26
+6739 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 26
+6740 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 26
+6741 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 26
+6742 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 26
+6743 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 26
+6744 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 26
+6745 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 26
+6746 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 26
+6747 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 26
+6748 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 26
+6749 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 26
+6750 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 26
+6751 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 26
+6752 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 26
+6753 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 26
+6754 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 26
+6755 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 26
+6756 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 26
+6757 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 26
+6758 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 26
+6759 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 26
+6760 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 26
+6761 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 26
+6762 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 27
+6763 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 27
+6764 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 27
+6765 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 27
+6766 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 27
+6767 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 27
+6768 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 27
+6769 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 27
+6770 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 27
+6771 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 27
+6772 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 27
+6773 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 27
+6774 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 27
+6775 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 27
+6776 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 27
+6777 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 27
+6778 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 27
+6779 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 27
+6780 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 27
+6781 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 27
+6782 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 27
+6783 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 27
+6784 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 27
+6785 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 27
+6786 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 27
+6787 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 27
+6788 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 27
+6789 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 27
+6790 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 27
+6791 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 27
+6792 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 27
+6793 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 27
+6794 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 27
+6795 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 27
+6796 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 27
+6797 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 27
+6798 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 27
+6799 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 27
+6800 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 27
+6801 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 27
+6802 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 27
+6803 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 27
+6804 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 27
+6805 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 27
+6806 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 27
+6807 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 27
+6808 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 28
+6809 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 28
+6810 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 28
+6811 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 28
+6812 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 28
+6813 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 28
+6814 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 28
+6815 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 28
+6816 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 28
+6817 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 28
+6818 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 28
+6819 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 28
+6820 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 28
+6821 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 28
+6822 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 28
+6823 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 28
+6824 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 28
+6825 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 28
+6826 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 28
+6827 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 28
+6828 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 28
+6829 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 28
+6830 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 28
+6831 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 28
+6832 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 28
+6833 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 28
+6834 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 28
+6835 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 28
+6836 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 28
+6837 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 28
+6838 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 28
+6839 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 28
+6840 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 28
+6841 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 28
+6842 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 28
+6843 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 28
+6844 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 28
+6845 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 28
+6846 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 28
+6847 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 28
+6848 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 28
+6849 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 28
+6850 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 28
+6851 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 28
+6852 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 28
+6853 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 28
+6854 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 29
+6855 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 29
+6856 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 29
+6857 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 29
+6858 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 29
+6859 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 29
+6860 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 29
+6861 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 29
+6862 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 29
+6863 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 29
+6864 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 29
+6865 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 29
+6866 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 29
+6867 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 29
+6868 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 29
+6869 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 29
+6870 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 29
+6871 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 29
+6872 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 29
+6873 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 29
+6874 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 29
+6875 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 29
+6876 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 29
+6877 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 29
+6878 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 29
+6879 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 29
+6880 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 29
+6881 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 29
+6882 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 29
+6883 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 29
+6884 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 29
+6885 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 29
+6886 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 29
+6887 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 29
+6888 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 29
+6889 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 29
+6890 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 29
+6891 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 29
+6892 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 29
+6893 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 29
+6894 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 29
+6895 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 29
+6896 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 29
+6897 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 29
+6898 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 29
+6899 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 29
+6900 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 30
+6901 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 30
+6902 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 30
+6903 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 30
+6904 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 30
+6905 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 30
+6906 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 30
+6907 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 30
+6908 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 30
+6909 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 30
+6910 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 30
+6911 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 30
+6912 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 30
+6913 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 30
+6914 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 30
+6915 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 30
+6916 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 30
+6917 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 30
+6918 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 30
+6919 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 30
+6920 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 30
+6921 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 30
+6922 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 30
+6923 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 30
+6924 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 30
+6925 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 30
+6926 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 30
+6927 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 30
+6928 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 30
+6929 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 30
+6930 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 30
+6931 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 30
+6932 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 30
+6933 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 30
+6934 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 30
+6935 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 30
+6936 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 30
+6937 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 30
+6938 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 30
+6939 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 30
+6940 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 30
+6941 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 30
+6942 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 30
+6943 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 30
+6944 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 30
+6945 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 30
+6946 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 31
+6947 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 31
+6948 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 31
+6949 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 31
+6950 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 31
+6951 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 31
+6952 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 31
+6953 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 31
+6954 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 31
+6955 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 31
+6956 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 31
+6957 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 31
+6958 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 31
+6959 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 31
+6960 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 31
+6961 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 31
+6962 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 31
+6963 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 31
+6964 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 31
+6965 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 31
+6966 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 31
+6967 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 31
+6968 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 31
+6969 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 31
+6970 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 31
+6971 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 31
+6972 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 31
+6973 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 31
+6974 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 31
+6975 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 31
+6976 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 31
+6977 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 31
+6978 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 31
+6979 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 31
+6980 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 31
+6981 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 31
+6982 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 31
+6983 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 31
+6984 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 31
+6985 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 31
+6986 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 31
+6987 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 31
+6988 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 31
+6989 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 31
+6990 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 31
+6991 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 31
+6992 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 32
+6993 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 32
+6994 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 32
+6995 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 32
+6996 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 32
+6997 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 32
+6998 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 32
+6999 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 32
+7000 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 32
+7001 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 32
+7002 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 32
+7003 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 32
+7004 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 32
+7005 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 32
+7006 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 32
+7007 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 32
+7008 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 32
+7009 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 32
+7010 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 32
+7011 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 32
+7012 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 32
+7013 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 32
+7014 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 32
+7015 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 32
+7016 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 32
+7017 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 32
+7018 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 32
+7019 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 32
+7020 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 32
+7021 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 32
+7022 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 32
+7023 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 32
+7024 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 32
+7025 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 32
+7026 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 32
+7027 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 32
+7028 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 32
+7029 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 32
+7030 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 32
+7031 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 32
+7032 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 32
+7033 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 32
+7034 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 32
+7035 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 32
+7036 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 32
+7037 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 32
+7038 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 33
+7039 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 33
+7040 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 33
+7041 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 33
+7042 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 33
+7043 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 33
+7044 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 33
+7045 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 33
+7046 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 33
+7047 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 33
+7048 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 33
+7049 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 33
+7050 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 33
+7051 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 33
+7052 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 33
+7053 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 33
+7054 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 33
+7055 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 33
+7056 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 33
+7057 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 33
+7058 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 33
+7059 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 33
+7060 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 33
+7061 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 33
+7062 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 33
+7063 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 33
+7064 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 33
+7065 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 33
+7066 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 33
+7067 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 33
+7068 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 33
+7069 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 33
+7070 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 33
+7071 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 33
+7072 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 33
+7073 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 33
+7074 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 33
+7075 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 33
+7076 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 33
+7077 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 33
+7078 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 33
+7079 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 33
+7080 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 33
+7081 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 33
+7082 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 33
+7083 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 33
+7084 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 34
+7085 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 34
+7086 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 34
+7087 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 34
+7088 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 34
+7089 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 34
+7090 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 34
+7091 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 34
+7092 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 34
+7093 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 34
+7094 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 34
+7095 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 34
+7096 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 34
+7097 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 34
+7098 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 34
+7099 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 34
+7100 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 34
+7101 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 34
+7102 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 34
+7103 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 34
+7104 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 34
+7105 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 34
+7106 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 34
+7107 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 34
+7108 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 34
+7109 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 34
+7110 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 34
+7111 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 34
+7112 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 34
+7113 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 34
+7114 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 34
+7115 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 34
+7116 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 34
+7117 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 34
+7118 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 34
+7119 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 34
+7120 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 34
+7121 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 34
+7122 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 34
+7123 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 34
+7124 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 34
+7125 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 34
+7126 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 34
+7127 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 34
+7128 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 34
+7129 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 34
+7130 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 35
+7131 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 35
+7132 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 35
+7133 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 35
+7134 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 35
+7135 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 35
+7136 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 35
+7137 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 35
+7138 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 35
+7139 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 35
+7140 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 35
+7141 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 35
+7142 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 35
+7143 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 35
+7144 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 35
+7145 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 35
+7146 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 35
+7147 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 35
+7148 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 35
+7149 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 35
+7150 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 35
+7151 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 35
+7152 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 35
+7153 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 35
+7154 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 35
+7155 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 35
+7156 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 35
+7157 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 35
+7158 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 35
+7159 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 35
+7160 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 35
+7161 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 35
+7162 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 35
+7163 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 35
+7164 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 35
+7165 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 35
+7166 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 35
+7167 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 35
+7168 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 35
+7169 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 35
+7170 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 35
+7171 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 35
+7172 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 35
+7173 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 35
+7174 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 35
+7175 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 35
+7176 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 36
+7177 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 36
+7178 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 36
+7179 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 36
+7180 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 36
+7181 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 36
+7182 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 36
+7183 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 36
+7184 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 36
+7185 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 36
+7186 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 36
+7187 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 36
+7188 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 36
+7189 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 36
+7190 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 36
+7191 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 36
+7192 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 36
+7193 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 36
+7194 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 36
+7195 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 36
+7196 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 36
+7197 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 36
+7198 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 36
+7199 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 36
+7200 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 36
+7201 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 36
+7202 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 36
+7203 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 36
+7204 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 36
+7205 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 36
+7206 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 36
+7207 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 36
+7208 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 36
+7209 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 36
+7210 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 36
+7211 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 36
+7212 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 36
+7213 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 36
+7214 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 36
+7215 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 36
+7216 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 36
+7217 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 36
+7218 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 36
+7219 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 36
+7220 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 36
+7221 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 36
+7222 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 37
+7223 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 37
+7224 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 37
+7225 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 37
+7226 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 37
+7227 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 37
+7228 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 37
+7229 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 37
+7230 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 37
+7231 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 37
+7232 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 37
+7233 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 37
+7234 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 37
+7235 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 37
+7236 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 37
+7237 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 37
+7238 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 37
+7239 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 37
+7240 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 37
+7241 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 37
+7242 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 37
+7243 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 37
+7244 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 37
+7245 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 37
+7246 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 37
+7247 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 37
+7248 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 37
+7249 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 37
+7250 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 37
+7251 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 37
+7252 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 37
+7253 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 37
+7254 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 37
+7255 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 37
+7256 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 37
+7257 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 37
+7258 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 37
+7259 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 37
+7260 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 37
+7261 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 37
+7262 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 37
+7263 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 37
+7264 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 37
+7265 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 37
+7266 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 37
+7267 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 37
+7268 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 38
+7269 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 38
+7270 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 38
+7271 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 38
+7272 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 38
+7273 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 38
+7274 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 38
+7275 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 38
+7276 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 38
+7277 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 38
+7278 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 38
+7279 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 38
+7280 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 38
+7281 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 38
+7282 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 38
+7283 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 38
+7284 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 38
+7285 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 38
+7286 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 38
+7287 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 38
+7288 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 38
+7289 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 38
+7290 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 38
+7291 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 38
+7292 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 38
+7293 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 38
+7294 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 38
+7295 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 38
+7296 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 38
+7297 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 38
+7298 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 38
+7299 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 38
+7300 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 38
+7301 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 38
+7302 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 38
+7303 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 38
+7304 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 38
+7305 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 38
+7306 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 38
+7307 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 38
+7308 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 38
+7309 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 38
+7310 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 38
+7311 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 38
+7312 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 38
+7313 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 38
+7314 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 3 39
+7315 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 3 39
+7316 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 3 39
+7317 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 3 39
+7318 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 3 39
+7319 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 3 39
+7320 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 3 39
+7321 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 39
+7322 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 3 39
+7323 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 3 39
+7324 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 3 39
+7325 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 3 39
+7326 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 3 39
+7327 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 3 39
+7328 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 3 39
+7329 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 3 39
+7330 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 3 39
+7331 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 3 39
+7332 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 3 39
+7333 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 3 39
+7334 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 3 39
+7335 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 3 39
+7336 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 3 39
+7337 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 3 39
+7338 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 3 39
+7339 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 3 39
+7340 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 3 39
+7341 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 3 39
+7342 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 3 39
+7343 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 3 39
+7344 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 3 39
+7345 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 3 39
+7346 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 3 39
+7347 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 3 39
+7348 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 3 39
+7349 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 3 39
+7350 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 3 39
+7351 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 3 39
+7352 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 3 39
+7353 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 3 39
+7354 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 39
+7355 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 3 39
+7356 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 3 39
+7357 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 3 39
+7358 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 39
+7359 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 3 39
+7360 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 0
+7361 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 0
+7362 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 0
+7363 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 0
+7364 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 0
+7365 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 0
+7366 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 0
+7367 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 0
+7368 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 0
+7369 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 0
+7370 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 0
+7371 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 0
+7372 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 0
+7373 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 0
+7374 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 0
+7375 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 0
+7376 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 0
+7377 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 0
+7378 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 0
+7379 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 0
+7380 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 0
+7381 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 0
+7382 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 0
+7383 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 0
+7384 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 0
+7385 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 0
+7386 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 0
+7387 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 0
+7388 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 0
+7389 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 0
+7390 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 0
+7391 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 0
+7392 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 0
+7393 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 0
+7394 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 0
+7395 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 0
+7396 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 0
+7397 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 0
+7398 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 0
+7399 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 0
+7400 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 0
+7401 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 0
+7402 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 0
+7403 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 0
+7404 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 0
+7405 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 0
+7406 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 1
+7407 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 1
+7408 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 1
+7409 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 1
+7410 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 1
+7411 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 1
+7412 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 1
+7413 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 1
+7414 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 1
+7415 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 1
+7416 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 1
+7417 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 1
+7418 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 1
+7419 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 1
+7420 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 1
+7421 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 1
+7422 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 1
+7423 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 1
+7424 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 1
+7425 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 1
+7426 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 1
+7427 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 1
+7428 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 1
+7429 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 1
+7430 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 1
+7431 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 1
+7432 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 1
+7433 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 1
+7434 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 1
+7435 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 1
+7436 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 1
+7437 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 1
+7438 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 1
+7439 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 1
+7440 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 1
+7441 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 1
+7442 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 1
+7443 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 1
+7444 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 1
+7445 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 1
+7446 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 1
+7447 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 1
+7448 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 1
+7449 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 1
+7450 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 1
+7451 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 1
+7452 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 2
+7453 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 2
+7454 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 2
+7455 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 2
+7456 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 2
+7457 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 2
+7458 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 2
+7459 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 2
+7460 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 2
+7461 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 2
+7462 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 2
+7463 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 2
+7464 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 2
+7465 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 2
+7466 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 2
+7467 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 2
+7468 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 2
+7469 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 2
+7470 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 2
+7471 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 2
+7472 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 2
+7473 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 2
+7474 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 2
+7475 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 2
+7476 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 2
+7477 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 2
+7478 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 2
+7479 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 2
+7480 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 2
+7481 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 2
+7482 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 2
+7483 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 2
+7484 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 2
+7485 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 2
+7486 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 2
+7487 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 2
+7488 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 2
+7489 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 2
+7490 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 2
+7491 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 2
+7492 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 2
+7493 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 2
+7494 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 2
+7495 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 2
+7496 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 2
+7497 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 2
+7498 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 3
+7499 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 3
+7500 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 3
+7501 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 3
+7502 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 3
+7503 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 3
+7504 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 3
+7505 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 3
+7506 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 3
+7507 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 3
+7508 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 3
+7509 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 3
+7510 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 3
+7511 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 3
+7512 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 3
+7513 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 3
+7514 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 3
+7515 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 3
+7516 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 3
+7517 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 3
+7518 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 3
+7519 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 3
+7520 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 3
+7521 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 3
+7522 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 3
+7523 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 3
+7524 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 3
+7525 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 3
+7526 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 3
+7527 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 3
+7528 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 3
+7529 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 3
+7530 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 3
+7531 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 3
+7532 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 3
+7533 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 3
+7534 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 3
+7535 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 3
+7536 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 3
+7537 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 3
+7538 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 3
+7539 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 3
+7540 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 3
+7541 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 3
+7542 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 3
+7543 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 3
+7544 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 4
+7545 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 4
+7546 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 4
+7547 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 4
+7548 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 4
+7549 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 4
+7550 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 4
+7551 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 4
+7552 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 4
+7553 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 4
+7554 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 4
+7555 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 4
+7556 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 4
+7557 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 4
+7558 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 4
+7559 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 4
+7560 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 4
+7561 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 4
+7562 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 4
+7563 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 4
+7564 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 4
+7565 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 4
+7566 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 4
+7567 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 4
+7568 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 4
+7569 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 4
+7570 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 4
+7571 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 4
+7572 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 4
+7573 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 4
+7574 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 4
+7575 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 4
+7576 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 4
+7577 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 4
+7578 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 4
+7579 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 4
+7580 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 4
+7581 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 4
+7582 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 4
+7583 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 4
+7584 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 4
+7585 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 4
+7586 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 4
+7587 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 4
+7588 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 4
+7589 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 4
+7590 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 5
+7591 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 5
+7592 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 5
+7593 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 5
+7594 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 5
+7595 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 5
+7596 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 5
+7597 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 5
+7598 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 5
+7599 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 5
+7600 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 5
+7601 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 5
+7602 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 5
+7603 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 5
+7604 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 5
+7605 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 5
+7606 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 5
+7607 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 5
+7608 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 5
+7609 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 5
+7610 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 5
+7611 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 5
+7612 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 5
+7613 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 5
+7614 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 5
+7615 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 5
+7616 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 5
+7617 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 5
+7618 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 5
+7619 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 5
+7620 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 5
+7621 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 5
+7622 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 5
+7623 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 5
+7624 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 5
+7625 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 5
+7626 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 5
+7627 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 5
+7628 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 5
+7629 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 5
+7630 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 5
+7631 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 5
+7632 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 5
+7633 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 5
+7634 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 5
+7635 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 5
+7636 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 6
+7637 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 6
+7638 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 6
+7639 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 6
+7640 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 6
+7641 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 6
+7642 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 6
+7643 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 6
+7644 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 6
+7645 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 6
+7646 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 6
+7647 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 6
+7648 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 6
+7649 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 6
+7650 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 6
+7651 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 6
+7652 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 6
+7653 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 6
+7654 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 6
+7655 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 6
+7656 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 6
+7657 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 6
+7658 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 6
+7659 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 6
+7660 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 6
+7661 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 6
+7662 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 6
+7663 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 6
+7664 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 6
+7665 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 6
+7666 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 6
+7667 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 6
+7668 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 6
+7669 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 6
+7670 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 6
+7671 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 6
+7672 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 6
+7673 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 6
+7674 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 6
+7675 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 6
+7676 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 6
+7677 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 6
+7678 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 6
+7679 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 6
+7680 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 6
+7681 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 6
+7682 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 7
+7683 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 7
+7684 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 7
+7685 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 7
+7686 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 7
+7687 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 7
+7688 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 7
+7689 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 7
+7690 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 7
+7691 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 7
+7692 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 7
+7693 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 7
+7694 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 7
+7695 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 7
+7696 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 7
+7697 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 7
+7698 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 7
+7699 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 7
+7700 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 7
+7701 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 7
+7702 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 7
+7703 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 7
+7704 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 7
+7705 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 7
+7706 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 7
+7707 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 7
+7708 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 7
+7709 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 7
+7710 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 7
+7711 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 7
+7712 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 7
+7713 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 7
+7714 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 7
+7715 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 7
+7716 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 7
+7717 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 7
+7718 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 7
+7719 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 7
+7720 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 7
+7721 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 7
+7722 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 7
+7723 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 7
+7724 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 7
+7725 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 7
+7726 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 7
+7727 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 7
+7728 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 8
+7729 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 8
+7730 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 8
+7731 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 8
+7732 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 8
+7733 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 8
+7734 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 8
+7735 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 8
+7736 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 8
+7737 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 8
+7738 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 8
+7739 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 8
+7740 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 8
+7741 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 8
+7742 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 8
+7743 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 8
+7744 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 8
+7745 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 8
+7746 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 8
+7747 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 8
+7748 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 8
+7749 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 8
+7750 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 8
+7751 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 8
+7752 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 8
+7753 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 8
+7754 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 8
+7755 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 8
+7756 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 8
+7757 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 8
+7758 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 8
+7759 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 8
+7760 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 8
+7761 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 8
+7762 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 8
+7763 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 8
+7764 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 8
+7765 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 8
+7766 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 8
+7767 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 8
+7768 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 8
+7769 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 8
+7770 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 8
+7771 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 8
+7772 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 8
+7773 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 8
+7774 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 9
+7775 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 9
+7776 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 9
+7777 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 9
+7778 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 9
+7779 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 9
+7780 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 9
+7781 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 9
+7782 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 9
+7783 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 9
+7784 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 9
+7785 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 9
+7786 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 9
+7787 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 9
+7788 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 9
+7789 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 9
+7790 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 9
+7791 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 9
+7792 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 9
+7793 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 9
+7794 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 9
+7795 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 9
+7796 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 9
+7797 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 9
+7798 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 9
+7799 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 9
+7800 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 9
+7801 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 9
+7802 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 9
+7803 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 9
+7804 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 9
+7805 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 9
+7806 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 9
+7807 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 9
+7808 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 9
+7809 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 9
+7810 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 9
+7811 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 9
+7812 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 9
+7813 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 9
+7814 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 9
+7815 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 9
+7816 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 9
+7817 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 9
+7818 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 9
+7819 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 9
+7820 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 10
+7821 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 10
+7822 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 10
+7823 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 10
+7824 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 10
+7825 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 10
+7826 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 10
+7827 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 10
+7828 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 10
+7829 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 10
+7830 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 10
+7831 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 10
+7832 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 10
+7833 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 10
+7834 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 10
+7835 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 10
+7836 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 10
+7837 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 10
+7838 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 10
+7839 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 10
+7840 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 10
+7841 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 10
+7842 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 10
+7843 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 10
+7844 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 10
+7845 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 10
+7846 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 10
+7847 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 10
+7848 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 10
+7849 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 10
+7850 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 10
+7851 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 10
+7852 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 10
+7853 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 10
+7854 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 10
+7855 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 10
+7856 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 10
+7857 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 10
+7858 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 10
+7859 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 10
+7860 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 10
+7861 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 10
+7862 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 10
+7863 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 10
+7864 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 10
+7865 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 10
+7866 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 11
+7867 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 11
+7868 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 11
+7869 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 11
+7870 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 11
+7871 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 11
+7872 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 11
+7873 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 11
+7874 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 11
+7875 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 11
+7876 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 11
+7877 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 11
+7878 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 11
+7879 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 11
+7880 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 11
+7881 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 11
+7882 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 11
+7883 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 11
+7884 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 11
+7885 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 11
+7886 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 11
+7887 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 11
+7888 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 11
+7889 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 11
+7890 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 11
+7891 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 11
+7892 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 11
+7893 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 11
+7894 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 11
+7895 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 11
+7896 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 11
+7897 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 11
+7898 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 11
+7899 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 11
+7900 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 11
+7901 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 11
+7902 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 11
+7903 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 11
+7904 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 11
+7905 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 11
+7906 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 11
+7907 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 11
+7908 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 11
+7909 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 11
+7910 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 11
+7911 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 11
+7912 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 12
+7913 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 12
+7914 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 12
+7915 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 12
+7916 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 12
+7917 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 12
+7918 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 12
+7919 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 12
+7920 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 12
+7921 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 12
+7922 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 12
+7923 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 12
+7924 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 12
+7925 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 12
+7926 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 12
+7927 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 12
+7928 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 12
+7929 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 12
+7930 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 12
+7931 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 12
+7932 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 12
+7933 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 12
+7934 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 12
+7935 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 12
+7936 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 12
+7937 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 12
+7938 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 12
+7939 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 12
+7940 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 12
+7941 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 12
+7942 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 12
+7943 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 12
+7944 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 12
+7945 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 12
+7946 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 12
+7947 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 12
+7948 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 12
+7949 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 12
+7950 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 12
+7951 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 12
+7952 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 12
+7953 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 12
+7954 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 12
+7955 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 12
+7956 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 12
+7957 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 12
+7958 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 13
+7959 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 13
+7960 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 13
+7961 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 13
+7962 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 13
+7963 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 13
+7964 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 13
+7965 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 13
+7966 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 13
+7967 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 13
+7968 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 13
+7969 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 13
+7970 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 13
+7971 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 13
+7972 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 13
+7973 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 13
+7974 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 13
+7975 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 13
+7976 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 13
+7977 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 13
+7978 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 13
+7979 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 13
+7980 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 13
+7981 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 13
+7982 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 13
+7983 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 13
+7984 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 13
+7985 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 13
+7986 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 13
+7987 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 13
+7988 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 13
+7989 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 13
+7990 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 13
+7991 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 13
+7992 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 13
+7993 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 13
+7994 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 13
+7995 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 13
+7996 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 13
+7997 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 13
+7998 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 13
+7999 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 13
+8000 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 13
+8001 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 13
+8002 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 13
+8003 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 13
+8004 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 14
+8005 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 14
+8006 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 14
+8007 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 14
+8008 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 14
+8009 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 14
+8010 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 14
+8011 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 14
+8012 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 14
+8013 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 14
+8014 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 14
+8015 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 14
+8016 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 14
+8017 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 14
+8018 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 14
+8019 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 14
+8020 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 14
+8021 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 14
+8022 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 14
+8023 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 14
+8024 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 14
+8025 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 14
+8026 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 14
+8027 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 14
+8028 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 14
+8029 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 14
+8030 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 14
+8031 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 14
+8032 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 14
+8033 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 14
+8034 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 14
+8035 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 14
+8036 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 14
+8037 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 14
+8038 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 14
+8039 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 14
+8040 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 14
+8041 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 14
+8042 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 14
+8043 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 14
+8044 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 14
+8045 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 14
+8046 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 14
+8047 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 14
+8048 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 14
+8049 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 14
+8050 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 15
+8051 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 15
+8052 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 15
+8053 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 15
+8054 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 15
+8055 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 15
+8056 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 15
+8057 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 15
+8058 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 15
+8059 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 15
+8060 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 15
+8061 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 15
+8062 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 15
+8063 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 15
+8064 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 15
+8065 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 15
+8066 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 15
+8067 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 15
+8068 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 15
+8069 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 15
+8070 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 15
+8071 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 15
+8072 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 15
+8073 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 15
+8074 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 15
+8075 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 15
+8076 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 15
+8077 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 15
+8078 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 15
+8079 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 15
+8080 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 15
+8081 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 15
+8082 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 15
+8083 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 15
+8084 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 15
+8085 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 15
+8086 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 15
+8087 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 15
+8088 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 15
+8089 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 15
+8090 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 15
+8091 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 15
+8092 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 15
+8093 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 15
+8094 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 15
+8095 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 15
+8096 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 16
+8097 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 16
+8098 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 16
+8099 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 16
+8100 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 16
+8101 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 16
+8102 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 16
+8103 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 16
+8104 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 16
+8105 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 16
+8106 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 16
+8107 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 16
+8108 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 16
+8109 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 16
+8110 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 16
+8111 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 16
+8112 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 16
+8113 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 16
+8114 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 16
+8115 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 16
+8116 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 16
+8117 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 16
+8118 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 16
+8119 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 16
+8120 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 16
+8121 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 16
+8122 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 16
+8123 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 16
+8124 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 16
+8125 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 16
+8126 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 16
+8127 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 16
+8128 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 16
+8129 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 16
+8130 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 16
+8131 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 16
+8132 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 16
+8133 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 16
+8134 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 16
+8135 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 16
+8136 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 16
+8137 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 16
+8138 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 16
+8139 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 16
+8140 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 16
+8141 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 16
+8142 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 17
+8143 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 17
+8144 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 17
+8145 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 17
+8146 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 17
+8147 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 17
+8148 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 17
+8149 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 17
+8150 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 17
+8151 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 17
+8152 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 17
+8153 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 17
+8154 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 17
+8155 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 17
+8156 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 17
+8157 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 17
+8158 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 17
+8159 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 17
+8160 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 17
+8161 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 17
+8162 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 17
+8163 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 17
+8164 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 17
+8165 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 17
+8166 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 17
+8167 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 17
+8168 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 17
+8169 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 17
+8170 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 17
+8171 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 17
+8172 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 17
+8173 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 17
+8174 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 17
+8175 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 17
+8176 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 17
+8177 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 17
+8178 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 17
+8179 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 17
+8180 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 17
+8181 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 17
+8182 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 17
+8183 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 17
+8184 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 17
+8185 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 17
+8186 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 17
+8187 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 17
+8188 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 18
+8189 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 18
+8190 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 18
+8191 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 18
+8192 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 18
+8193 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 18
+8194 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 18
+8195 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 18
+8196 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 18
+8197 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 18
+8198 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 18
+8199 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 18
+8200 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 18
+8201 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 18
+8202 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 18
+8203 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 18
+8204 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 18
+8205 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 18
+8206 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 18
+8207 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 18
+8208 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 18
+8209 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 18
+8210 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 18
+8211 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 18
+8212 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 18
+8213 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 18
+8214 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 18
+8215 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 18
+8216 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 18
+8217 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 18
+8218 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 18
+8219 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 18
+8220 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 18
+8221 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 18
+8222 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 18
+8223 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 18
+8224 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 18
+8225 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 18
+8226 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 18
+8227 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 18
+8228 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 18
+8229 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 18
+8230 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 18
+8231 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 18
+8232 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 18
+8233 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 18
+8234 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 19
+8235 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 19
+8236 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 19
+8237 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 19
+8238 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 19
+8239 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 19
+8240 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 19
+8241 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 19
+8242 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 19
+8243 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 19
+8244 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 19
+8245 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 19
+8246 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 19
+8247 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 19
+8248 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 19
+8249 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 19
+8250 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 19
+8251 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 19
+8252 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 19
+8253 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 19
+8254 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 19
+8255 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 19
+8256 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 19
+8257 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 19
+8258 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 19
+8259 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 19
+8260 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 19
+8261 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 19
+8262 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 19
+8263 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 19
+8264 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 19
+8265 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 19
+8266 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 19
+8267 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 19
+8268 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 19
+8269 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 19
+8270 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 19
+8271 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 19
+8272 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 19
+8273 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 19
+8274 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 19
+8275 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 19
+8276 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 19
+8277 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 19
+8278 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 19
+8279 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 19
+8280 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 20
+8281 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 20
+8282 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 20
+8283 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 20
+8284 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 20
+8285 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 20
+8286 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 20
+8287 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 20
+8288 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 20
+8289 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 20
+8290 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 20
+8291 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 20
+8292 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 20
+8293 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 20
+8294 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 20
+8295 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 20
+8296 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 20
+8297 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 20
+8298 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 20
+8299 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 20
+8300 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 20
+8301 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 20
+8302 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 20
+8303 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 20
+8304 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 20
+8305 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 20
+8306 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 20
+8307 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 20
+8308 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 20
+8309 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 20
+8310 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 20
+8311 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 20
+8312 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 20
+8313 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 20
+8314 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 20
+8315 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 20
+8316 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 20
+8317 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 20
+8318 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 20
+8319 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 20
+8320 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 20
+8321 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 20
+8322 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 20
+8323 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 20
+8324 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 20
+8325 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 20
+8326 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 21
+8327 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 21
+8328 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 21
+8329 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 21
+8330 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 21
+8331 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 21
+8332 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 21
+8333 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 21
+8334 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 21
+8335 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 21
+8336 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 21
+8337 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 21
+8338 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 21
+8339 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 21
+8340 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 21
+8341 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 21
+8342 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 21
+8343 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 21
+8344 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 21
+8345 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 21
+8346 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 21
+8347 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 21
+8348 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 21
+8349 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 21
+8350 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 21
+8351 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 21
+8352 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 21
+8353 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 21
+8354 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 21
+8355 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 21
+8356 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 21
+8357 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 21
+8358 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 21
+8359 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 21
+8360 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 21
+8361 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 21
+8362 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 21
+8363 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 21
+8364 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 21
+8365 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 21
+8366 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 21
+8367 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 21
+8368 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 21
+8369 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 21
+8370 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 21
+8371 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 21
+8372 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 22
+8373 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 22
+8374 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 22
+8375 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 22
+8376 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 22
+8377 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 22
+8378 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 22
+8379 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 22
+8380 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 22
+8381 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 22
+8382 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 22
+8383 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 22
+8384 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 22
+8385 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 22
+8386 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 22
+8387 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 22
+8388 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 22
+8389 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 22
+8390 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 22
+8391 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 22
+8392 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 22
+8393 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 22
+8394 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 22
+8395 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 22
+8396 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 22
+8397 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 22
+8398 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 22
+8399 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 22
+8400 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 22
+8401 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 22
+8402 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 22
+8403 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 22
+8404 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 22
+8405 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 22
+8406 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 22
+8407 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 22
+8408 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 22
+8409 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 22
+8410 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 22
+8411 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 22
+8412 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 22
+8413 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 22
+8414 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 22
+8415 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 22
+8416 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 22
+8417 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 22
+8418 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 23
+8419 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 23
+8420 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 23
+8421 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 23
+8422 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 23
+8423 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 23
+8424 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 23
+8425 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 23
+8426 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 23
+8427 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 23
+8428 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 23
+8429 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 23
+8430 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 23
+8431 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 23
+8432 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 23
+8433 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 23
+8434 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 23
+8435 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 23
+8436 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 23
+8437 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 23
+8438 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 23
+8439 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 23
+8440 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 23
+8441 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 23
+8442 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 23
+8443 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 23
+8444 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 23
+8445 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 23
+8446 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 23
+8447 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 23
+8448 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 23
+8449 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 23
+8450 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 23
+8451 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 23
+8452 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 23
+8453 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 23
+8454 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 23
+8455 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 23
+8456 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 23
+8457 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 23
+8458 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 23
+8459 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 23
+8460 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 23
+8461 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 23
+8462 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 23
+8463 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 23
+8464 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 24
+8465 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 24
+8466 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 24
+8467 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 24
+8468 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 24
+8469 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 24
+8470 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 24
+8471 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 24
+8472 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 24
+8473 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 24
+8474 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 24
+8475 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 24
+8476 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 24
+8477 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 24
+8478 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 24
+8479 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 24
+8480 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 24
+8481 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 24
+8482 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 24
+8483 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 24
+8484 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 24
+8485 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 24
+8486 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 24
+8487 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 24
+8488 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 24
+8489 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 24
+8490 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 24
+8491 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 24
+8492 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 24
+8493 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 24
+8494 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 24
+8495 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 24
+8496 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 24
+8497 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 24
+8498 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 24
+8499 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 24
+8500 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 24
+8501 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 24
+8502 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 24
+8503 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 24
+8504 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 24
+8505 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 24
+8506 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 24
+8507 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 24
+8508 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 24
+8509 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 24
+8510 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 25
+8511 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 25
+8512 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 25
+8513 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 25
+8514 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 25
+8515 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 25
+8516 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 25
+8517 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 25
+8518 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 25
+8519 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 25
+8520 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 25
+8521 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 25
+8522 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 25
+8523 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 25
+8524 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 25
+8525 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 25
+8526 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 25
+8527 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 25
+8528 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 25
+8529 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 25
+8530 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 25
+8531 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 25
+8532 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 25
+8533 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 25
+8534 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 25
+8535 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 25
+8536 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 25
+8537 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 25
+8538 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 25
+8539 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 25
+8540 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 25
+8541 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 25
+8542 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 25
+8543 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 25
+8544 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 25
+8545 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 25
+8546 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 25
+8547 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 25
+8548 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 25
+8549 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 25
+8550 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 25
+8551 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 25
+8552 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 25
+8553 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 25
+8554 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 25
+8555 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 25
+8556 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 26
+8557 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 26
+8558 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 26
+8559 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 26
+8560 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 26
+8561 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 26
+8562 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 26
+8563 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 26
+8564 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 26
+8565 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 26
+8566 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 26
+8567 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 26
+8568 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 26
+8569 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 26
+8570 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 26
+8571 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 26
+8572 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 26
+8573 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 26
+8574 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 26
+8575 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 26
+8576 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 26
+8577 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 26
+8578 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 26
+8579 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 26
+8580 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 26
+8581 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 26
+8582 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 26
+8583 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 26
+8584 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 26
+8585 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 26
+8586 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 26
+8587 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 26
+8588 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 26
+8589 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 26
+8590 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 26
+8591 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 26
+8592 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 26
+8593 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 26
+8594 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 26
+8595 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 26
+8596 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 26
+8597 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 26
+8598 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 26
+8599 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 26
+8600 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 26
+8601 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 26
+8602 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 27
+8603 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 27
+8604 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 27
+8605 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 27
+8606 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 27
+8607 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 27
+8608 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 27
+8609 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 27
+8610 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 27
+8611 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 27
+8612 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 27
+8613 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 27
+8614 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 27
+8615 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 27
+8616 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 27
+8617 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 27
+8618 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 27
+8619 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 27
+8620 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 27
+8621 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 27
+8622 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 27
+8623 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 27
+8624 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 27
+8625 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 27
+8626 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 27
+8627 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 27
+8628 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 27
+8629 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 27
+8630 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 27
+8631 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 27
+8632 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 27
+8633 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 27
+8634 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 27
+8635 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 27
+8636 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 27
+8637 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 27
+8638 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 27
+8639 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 27
+8640 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 27
+8641 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 27
+8642 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 27
+8643 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 27
+8644 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 27
+8645 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 27
+8646 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 27
+8647 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 27
+8648 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 28
+8649 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 28
+8650 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 28
+8651 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 28
+8652 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 28
+8653 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 28
+8654 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 28
+8655 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 28
+8656 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 28
+8657 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 28
+8658 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 28
+8659 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 28
+8660 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 28
+8661 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 28
+8662 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 28
+8663 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 28
+8664 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 28
+8665 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 28
+8666 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 28
+8667 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 28
+8668 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 28
+8669 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 28
+8670 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 28
+8671 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 28
+8672 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 28
+8673 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 28
+8674 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 28
+8675 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 28
+8676 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 28
+8677 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 28
+8678 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 28
+8679 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 28
+8680 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 28
+8681 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 28
+8682 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 28
+8683 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 28
+8684 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 28
+8685 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 28
+8686 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 28
+8687 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 28
+8688 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 28
+8689 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 28
+8690 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 28
+8691 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 28
+8692 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 28
+8693 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 28
+8694 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 29
+8695 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 29
+8696 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 29
+8697 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 29
+8698 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 29
+8699 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 29
+8700 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 29
+8701 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 29
+8702 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 29
+8703 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 29
+8704 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 29
+8705 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 29
+8706 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 29
+8707 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 29
+8708 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 29
+8709 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 29
+8710 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 29
+8711 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 29
+8712 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 29
+8713 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 29
+8714 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 29
+8715 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 29
+8716 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 29
+8717 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 29
+8718 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 29
+8719 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 29
+8720 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 29
+8721 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 29
+8722 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 29
+8723 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 29
+8724 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 29
+8725 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 29
+8726 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 29
+8727 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 29
+8728 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 29
+8729 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 29
+8730 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 29
+8731 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 29
+8732 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 29
+8733 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 29
+8734 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 29
+8735 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 29
+8736 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 29
+8737 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 29
+8738 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 29
+8739 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 29
+8740 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 30
+8741 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 30
+8742 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 30
+8743 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 30
+8744 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 30
+8745 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 30
+8746 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 30
+8747 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 30
+8748 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 30
+8749 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 30
+8750 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 30
+8751 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 30
+8752 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 30
+8753 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 30
+8754 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 30
+8755 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 30
+8756 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 30
+8757 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 30
+8758 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 30
+8759 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 30
+8760 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 30
+8761 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 30
+8762 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 30
+8763 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 30
+8764 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 30
+8765 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 30
+8766 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 30
+8767 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 30
+8768 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 30
+8769 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 30
+8770 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 30
+8771 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 30
+8772 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 30
+8773 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 30
+8774 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 30
+8775 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 30
+8776 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 30
+8777 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 30
+8778 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 30
+8779 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 30
+8780 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 30
+8781 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 30
+8782 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 30
+8783 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 30
+8784 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 30
+8785 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 30
+8786 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 31
+8787 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 31
+8788 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 31
+8789 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 31
+8790 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 31
+8791 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 31
+8792 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 31
+8793 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 31
+8794 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 31
+8795 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 31
+8796 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 31
+8797 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 31
+8798 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 31
+8799 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 31
+8800 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 31
+8801 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 31
+8802 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 31
+8803 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 31
+8804 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 31
+8805 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 31
+8806 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 31
+8807 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 31
+8808 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 31
+8809 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 31
+8810 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 31
+8811 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 31
+8812 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 31
+8813 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 31
+8814 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 31
+8815 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 31
+8816 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 31
+8817 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 31
+8818 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 31
+8819 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 31
+8820 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 31
+8821 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 31
+8822 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 31
+8823 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 31
+8824 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 31
+8825 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 31
+8826 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 31
+8827 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 31
+8828 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 31
+8829 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 31
+8830 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 31
+8831 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 31
+8832 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 32
+8833 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 32
+8834 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 32
+8835 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 32
+8836 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 32
+8837 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 32
+8838 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 32
+8839 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 32
+8840 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 32
+8841 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 32
+8842 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 32
+8843 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 32
+8844 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 32
+8845 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 32
+8846 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 32
+8847 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 32
+8848 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 32
+8849 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 32
+8850 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 32
+8851 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 32
+8852 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 32
+8853 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 32
+8854 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 32
+8855 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 32
+8856 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 32
+8857 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 32
+8858 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 32
+8859 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 32
+8860 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 32
+8861 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 32
+8862 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 32
+8863 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 32
+8864 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 32
+8865 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 32
+8866 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 32
+8867 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 32
+8868 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 32
+8869 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 32
+8870 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 32
+8871 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 32
+8872 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 32
+8873 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 32
+8874 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 32
+8875 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 32
+8876 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 32
+8877 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 32
+8878 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 33
+8879 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 33
+8880 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 33
+8881 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 33
+8882 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 33
+8883 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 33
+8884 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 33
+8885 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 33
+8886 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 33
+8887 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 33
+8888 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 33
+8889 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 33
+8890 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 33
+8891 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 33
+8892 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 33
+8893 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 33
+8894 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 33
+8895 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 33
+8896 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 33
+8897 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 33
+8898 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 33
+8899 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 33
+8900 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 33
+8901 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 33
+8902 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 33
+8903 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 33
+8904 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 33
+8905 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 33
+8906 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 33
+8907 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 33
+8908 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 33
+8909 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 33
+8910 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 33
+8911 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 33
+8912 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 33
+8913 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 33
+8914 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 33
+8915 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 33
+8916 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 33
+8917 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 33
+8918 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 33
+8919 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 33
+8920 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 33
+8921 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 33
+8922 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 33
+8923 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 33
+8924 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 34
+8925 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 34
+8926 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 34
+8927 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 34
+8928 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 34
+8929 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 34
+8930 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 34
+8931 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 34
+8932 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 34
+8933 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 34
+8934 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 34
+8935 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 34
+8936 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 34
+8937 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 34
+8938 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 34
+8939 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 34
+8940 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 34
+8941 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 34
+8942 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 34
+8943 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 34
+8944 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 34
+8945 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 34
+8946 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 34
+8947 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 34
+8948 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 34
+8949 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 34
+8950 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 34
+8951 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 34
+8952 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 34
+8953 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 34
+8954 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 34
+8955 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 34
+8956 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 34
+8957 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 34
+8958 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 34
+8959 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 34
+8960 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 34
+8961 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 34
+8962 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 34
+8963 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 34
+8964 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 34
+8965 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 34
+8966 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 34
+8967 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 34
+8968 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 34
+8969 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 34
+8970 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 35
+8971 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 35
+8972 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 35
+8973 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 35
+8974 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 35
+8975 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 35
+8976 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 35
+8977 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 35
+8978 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 35
+8979 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 35
+8980 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 35
+8981 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 35
+8982 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 35
+8983 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 35
+8984 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 35
+8985 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 35
+8986 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 35
+8987 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 35
+8988 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 35
+8989 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 35
+8990 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 35
+8991 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 35
+8992 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 35
+8993 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 35
+8994 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 35
+8995 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 35
+8996 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 35
+8997 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 35
+8998 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 35
+8999 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 35
+9000 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 35
+9001 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 35
+9002 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 35
+9003 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 35
+9004 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 35
+9005 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 35
+9006 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 35
+9007 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 35
+9008 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 35
+9009 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 35
+9010 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 35
+9011 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 35
+9012 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 35
+9013 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 35
+9014 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 35
+9015 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 35
+9016 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 36
+9017 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 36
+9018 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 36
+9019 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 36
+9020 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 36
+9021 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 36
+9022 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 36
+9023 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 36
+9024 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 36
+9025 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 36
+9026 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 36
+9027 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 36
+9028 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 36
+9029 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 36
+9030 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 36
+9031 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 36
+9032 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 36
+9033 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 36
+9034 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 36
+9035 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 36
+9036 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 36
+9037 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 36
+9038 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 36
+9039 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 36
+9040 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 36
+9041 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 36
+9042 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 36
+9043 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 36
+9044 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 36
+9045 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 36
+9046 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 36
+9047 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 36
+9048 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 36
+9049 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 36
+9050 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 36
+9051 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 36
+9052 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 36
+9053 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 36
+9054 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 36
+9055 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 36
+9056 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 36
+9057 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 36
+9058 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 36
+9059 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 36
+9060 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 36
+9061 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 36
+9062 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 37
+9063 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 37
+9064 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 37
+9065 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 37
+9066 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 37
+9067 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 37
+9068 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 37
+9069 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 37
+9070 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 37
+9071 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 37
+9072 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 37
+9073 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 37
+9074 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 37
+9075 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 37
+9076 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 37
+9077 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 37
+9078 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 37
+9079 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 37
+9080 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 37
+9081 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 37
+9082 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 37
+9083 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 37
+9084 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 37
+9085 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 37
+9086 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 37
+9087 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 37
+9088 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 37
+9089 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 37
+9090 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 37
+9091 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 37
+9092 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 37
+9093 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 37
+9094 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 37
+9095 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 37
+9096 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 37
+9097 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 37
+9098 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 37
+9099 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 37
+9100 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 37
+9101 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 37
+9102 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 37
+9103 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 37
+9104 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 37
+9105 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 37
+9106 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 37
+9107 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 37
+9108 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 38
+9109 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 38
+9110 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 38
+9111 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 38
+9112 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 38
+9113 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 38
+9114 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 38
+9115 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 38
+9116 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 38
+9117 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 38
+9118 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 38
+9119 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 38
+9120 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 38
+9121 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 38
+9122 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 38
+9123 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 38
+9124 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 38
+9125 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 38
+9126 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 38
+9127 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 38
+9128 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 38
+9129 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 38
+9130 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 38
+9131 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 38
+9132 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 38
+9133 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 38
+9134 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 38
+9135 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 38
+9136 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 38
+9137 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 38
+9138 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 38
+9139 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 38
+9140 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 38
+9141 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 38
+9142 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 38
+9143 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 38
+9144 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 38
+9145 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 38
+9146 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 38
+9147 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 38
+9148 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 38
+9149 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 38
+9150 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 38
+9151 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 38
+9152 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 38
+9153 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 38
+9154 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 4 39
+9155 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 4 39
+9156 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 4 39
+9157 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 4 39
+9158 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 4 39
+9159 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 4 39
+9160 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 4 39
+9161 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 39
+9162 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 4 39
+9163 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 4 39
+9164 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 4 39
+9165 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 4 39
+9166 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 4 39
+9167 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 4 39
+9168 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 4 39
+9169 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 4 39
+9170 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 4 39
+9171 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 4 39
+9172 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 4 39
+9173 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 4 39
+9174 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 4 39
+9175 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 4 39
+9176 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 4 39
+9177 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 4 39
+9178 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 4 39
+9179 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 4 39
+9180 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 4 39
+9181 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 4 39
+9182 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 4 39
+9183 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 4 39
+9184 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 4 39
+9185 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 4 39
+9186 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 4 39
+9187 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 4 39
+9188 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 4 39
+9189 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 4 39
+9190 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 4 39
+9191 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 4 39
+9192 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 4 39
+9193 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 4 39
+9194 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 39
+9195 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 4 39
+9196 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 4 39
+9197 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 4 39
+9198 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 39
+9199 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 4 39
+9200 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 0
+9201 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 0
+9202 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 0
+9203 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 0
+9204 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 0
+9205 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 0
+9206 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 0
+9207 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 0
+9208 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 0
+9209 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 0
+9210 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 0
+9211 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 0
+9212 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 0
+9213 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 0
+9214 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 0
+9215 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 0
+9216 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 0
+9217 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 0
+9218 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 0
+9219 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 0
+9220 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 0
+9221 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 0
+9222 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 0
+9223 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 0
+9224 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 0
+9225 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 0
+9226 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 0
+9227 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 0
+9228 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 0
+9229 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 0
+9230 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 0
+9231 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 0
+9232 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 0
+9233 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 0
+9234 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 0
+9235 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 0
+9236 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 0
+9237 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 0
+9238 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 0
+9239 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 0
+9240 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 0
+9241 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 0
+9242 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 0
+9243 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 0
+9244 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 0
+9245 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 0
+9246 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 1
+9247 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 1
+9248 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 1
+9249 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 1
+9250 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 1
+9251 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 1
+9252 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 1
+9253 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 1
+9254 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 1
+9255 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 1
+9256 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 1
+9257 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 1
+9258 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 1
+9259 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 1
+9260 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 1
+9261 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 1
+9262 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 1
+9263 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 1
+9264 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 1
+9265 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 1
+9266 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 1
+9267 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 1
+9268 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 1
+9269 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 1
+9270 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 1
+9271 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 1
+9272 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 1
+9273 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 1
+9274 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 1
+9275 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 1
+9276 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 1
+9277 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 1
+9278 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 1
+9279 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 1
+9280 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 1
+9281 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 1
+9282 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 1
+9283 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 1
+9284 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 1
+9285 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 1
+9286 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 1
+9287 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 1
+9288 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 1
+9289 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 1
+9290 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 1
+9291 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 1
+9292 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 2
+9293 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 2
+9294 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 2
+9295 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 2
+9296 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 2
+9297 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 2
+9298 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 2
+9299 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 2
+9300 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 2
+9301 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 2
+9302 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 2
+9303 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 2
+9304 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 2
+9305 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 2
+9306 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 2
+9307 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 2
+9308 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 2
+9309 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 2
+9310 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 2
+9311 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 2
+9312 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 2
+9313 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 2
+9314 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 2
+9315 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 2
+9316 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 2
+9317 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 2
+9318 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 2
+9319 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 2
+9320 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 2
+9321 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 2
+9322 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 2
+9323 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 2
+9324 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 2
+9325 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 2
+9326 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 2
+9327 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 2
+9328 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 2
+9329 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 2
+9330 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 2
+9331 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 2
+9332 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 2
+9333 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 2
+9334 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 2
+9335 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 2
+9336 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 2
+9337 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 2
+9338 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 3
+9339 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 3
+9340 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 3
+9341 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 3
+9342 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 3
+9343 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 3
+9344 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 3
+9345 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 3
+9346 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 3
+9347 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 3
+9348 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 3
+9349 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 3
+9350 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 3
+9351 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 3
+9352 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 3
+9353 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 3
+9354 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 3
+9355 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 3
+9356 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 3
+9357 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 3
+9358 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 3
+9359 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 3
+9360 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 3
+9361 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 3
+9362 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 3
+9363 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 3
+9364 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 3
+9365 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 3
+9366 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 3
+9367 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 3
+9368 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 3
+9369 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 3
+9370 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 3
+9371 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 3
+9372 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 3
+9373 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 3
+9374 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 3
+9375 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 3
+9376 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 3
+9377 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 3
+9378 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 3
+9379 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 3
+9380 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 3
+9381 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 3
+9382 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 3
+9383 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 3
+9384 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 4
+9385 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 4
+9386 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 4
+9387 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 4
+9388 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 4
+9389 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 4
+9390 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 4
+9391 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 4
+9392 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 4
+9393 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 4
+9394 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 4
+9395 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 4
+9396 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 4
+9397 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 4
+9398 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 4
+9399 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 4
+9400 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 4
+9401 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 4
+9402 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 4
+9403 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 4
+9404 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 4
+9405 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 4
+9406 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 4
+9407 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 4
+9408 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 4
+9409 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 4
+9410 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 4
+9411 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 4
+9412 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 4
+9413 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 4
+9414 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 4
+9415 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 4
+9416 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 4
+9417 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 4
+9418 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 4
+9419 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 4
+9420 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 4
+9421 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 4
+9422 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 4
+9423 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 4
+9424 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 4
+9425 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 4
+9426 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 4
+9427 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 4
+9428 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 4
+9429 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 4
+9430 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 5
+9431 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 5
+9432 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 5
+9433 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 5
+9434 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 5
+9435 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 5
+9436 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 5
+9437 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 5
+9438 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 5
+9439 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 5
+9440 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 5
+9441 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 5
+9442 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 5
+9443 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 5
+9444 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 5
+9445 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 5
+9446 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 5
+9447 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 5
+9448 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 5
+9449 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 5
+9450 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 5
+9451 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 5
+9452 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 5
+9453 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 5
+9454 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 5
+9455 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 5
+9456 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 5
+9457 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 5
+9458 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 5
+9459 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 5
+9460 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 5
+9461 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 5
+9462 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 5
+9463 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 5
+9464 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 5
+9465 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 5
+9466 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 5
+9467 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 5
+9468 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 5
+9469 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 5
+9470 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 5
+9471 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 5
+9472 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 5
+9473 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 5
+9474 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 5
+9475 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 5
+9476 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 6
+9477 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 6
+9478 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 6
+9479 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 6
+9480 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 6
+9481 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 6
+9482 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 6
+9483 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 6
+9484 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 6
+9485 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 6
+9486 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 6
+9487 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 6
+9488 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 6
+9489 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 6
+9490 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 6
+9491 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 6
+9492 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 6
+9493 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 6
+9494 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 6
+9495 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 6
+9496 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 6
+9497 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 6
+9498 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 6
+9499 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 6
+9500 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 6
+9501 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 6
+9502 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 6
+9503 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 6
+9504 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 6
+9505 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 6
+9506 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 6
+9507 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 6
+9508 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 6
+9509 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 6
+9510 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 6
+9511 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 6
+9512 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 6
+9513 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 6
+9514 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 6
+9515 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 6
+9516 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 6
+9517 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 6
+9518 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 6
+9519 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 6
+9520 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 6
+9521 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 6
+9522 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 7
+9523 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 7
+9524 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 7
+9525 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 7
+9526 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 7
+9527 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 7
+9528 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 7
+9529 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 7
+9530 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 7
+9531 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 7
+9532 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 7
+9533 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 7
+9534 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 7
+9535 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 7
+9536 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 7
+9537 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 7
+9538 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 7
+9539 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 7
+9540 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 7
+9541 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 7
+9542 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 7
+9543 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 7
+9544 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 7
+9545 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 7
+9546 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 7
+9547 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 7
+9548 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 7
+9549 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 7
+9550 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 7
+9551 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 7
+9552 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 7
+9553 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 7
+9554 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 7
+9555 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 7
+9556 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 7
+9557 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 7
+9558 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 7
+9559 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 7
+9560 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 7
+9561 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 7
+9562 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 7
+9563 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 7
+9564 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 7
+9565 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 7
+9566 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 7
+9567 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 7
+9568 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 8
+9569 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 8
+9570 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 8
+9571 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 8
+9572 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 8
+9573 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 8
+9574 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 8
+9575 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 8
+9576 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 8
+9577 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 8
+9578 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 8
+9579 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 8
+9580 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 8
+9581 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 8
+9582 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 8
+9583 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 8
+9584 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 8
+9585 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 8
+9586 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 8
+9587 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 8
+9588 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 8
+9589 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 8
+9590 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 8
+9591 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 8
+9592 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 8
+9593 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 8
+9594 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 8
+9595 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 8
+9596 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 8
+9597 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 8
+9598 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 8
+9599 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 8
+9600 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 8
+9601 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 8
+9602 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 8
+9603 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 8
+9604 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 8
+9605 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 8
+9606 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 8
+9607 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 8
+9608 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 8
+9609 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 8
+9610 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 8
+9611 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 8
+9612 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 8
+9613 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 8
+9614 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 9
+9615 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 9
+9616 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 9
+9617 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 9
+9618 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 9
+9619 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 9
+9620 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 9
+9621 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 9
+9622 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 9
+9623 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 9
+9624 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 9
+9625 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 9
+9626 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 9
+9627 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 9
+9628 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 9
+9629 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 9
+9630 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 9
+9631 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 9
+9632 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 9
+9633 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 9
+9634 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 9
+9635 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 9
+9636 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 9
+9637 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 9
+9638 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 9
+9639 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 9
+9640 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 9
+9641 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 9
+9642 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 9
+9643 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 9
+9644 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 9
+9645 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 9
+9646 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 9
+9647 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 9
+9648 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 9
+9649 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 9
+9650 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 9
+9651 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 9
+9652 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 9
+9653 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 9
+9654 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 9
+9655 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 9
+9656 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 9
+9657 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 9
+9658 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 9
+9659 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 9
+9660 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 10
+9661 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 10
+9662 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 10
+9663 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 10
+9664 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 10
+9665 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 10
+9666 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 10
+9667 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 10
+9668 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 10
+9669 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 10
+9670 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 10
+9671 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 10
+9672 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 10
+9673 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 10
+9674 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 10
+9675 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 10
+9676 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 10
+9677 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 10
+9678 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 10
+9679 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 10
+9680 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 10
+9681 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 10
+9682 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 10
+9683 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 10
+9684 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 10
+9685 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 10
+9686 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 10
+9687 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 10
+9688 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 10
+9689 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 10
+9690 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 10
+9691 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 10
+9692 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 10
+9693 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 10
+9694 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 10
+9695 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 10
+9696 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 10
+9697 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 10
+9698 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 10
+9699 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 10
+9700 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 10
+9701 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 10
+9702 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 10
+9703 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 10
+9704 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 10
+9705 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 10
+9706 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 11
+9707 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 11
+9708 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 11
+9709 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 11
+9710 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 11
+9711 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 11
+9712 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 11
+9713 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 11
+9714 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 11
+9715 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 11
+9716 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 11
+9717 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 11
+9718 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 11
+9719 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 11
+9720 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 11
+9721 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 11
+9722 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 11
+9723 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 11
+9724 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 11
+9725 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 11
+9726 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 11
+9727 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 11
+9728 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 11
+9729 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 11
+9730 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 11
+9731 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 11
+9732 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 11
+9733 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 11
+9734 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 11
+9735 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 11
+9736 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 11
+9737 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 11
+9738 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 11
+9739 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 11
+9740 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 11
+9741 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 11
+9742 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 11
+9743 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 11
+9744 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 11
+9745 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 11
+9746 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 11
+9747 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 11
+9748 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 11
+9749 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 11
+9750 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 11
+9751 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 11
+9752 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 12
+9753 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 12
+9754 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 12
+9755 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 12
+9756 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 12
+9757 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 12
+9758 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 12
+9759 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 12
+9760 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 12
+9761 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 12
+9762 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 12
+9763 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 12
+9764 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 12
+9765 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 12
+9766 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 12
+9767 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 12
+9768 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 12
+9769 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 12
+9770 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 12
+9771 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 12
+9772 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 12
+9773 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 12
+9774 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 12
+9775 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 12
+9776 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 12
+9777 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 12
+9778 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 12
+9779 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 12
+9780 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 12
+9781 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 12
+9782 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 12
+9783 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 12
+9784 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 12
+9785 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 12
+9786 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 12
+9787 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 12
+9788 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 12
+9789 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 12
+9790 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 12
+9791 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 12
+9792 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 12
+9793 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 12
+9794 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 12
+9795 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 12
+9796 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 12
+9797 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 12
+9798 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 13
+9799 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 13
+9800 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 13
+9801 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 13
+9802 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 13
+9803 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 13
+9804 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 13
+9805 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 13
+9806 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 13
+9807 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 13
+9808 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 13
+9809 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 13
+9810 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 13
+9811 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 13
+9812 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 13
+9813 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 13
+9814 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 13
+9815 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 13
+9816 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 13
+9817 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 13
+9818 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 13
+9819 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 13
+9820 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 13
+9821 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 13
+9822 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 13
+9823 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 13
+9824 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 13
+9825 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 13
+9826 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 13
+9827 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 13
+9828 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 13
+9829 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 13
+9830 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 13
+9831 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 13
+9832 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 13
+9833 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 13
+9834 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 13
+9835 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 13
+9836 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 13
+9837 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 13
+9838 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 13
+9839 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 13
+9840 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 13
+9841 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 13
+9842 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 13
+9843 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 13
+9844 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 14
+9845 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 14
+9846 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 14
+9847 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 14
+9848 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 14
+9849 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 14
+9850 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 14
+9851 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 14
+9852 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 14
+9853 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 14
+9854 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 14
+9855 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 14
+9856 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 14
+9857 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 14
+9858 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 14
+9859 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 14
+9860 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 14
+9861 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 14
+9862 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 14
+9863 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 14
+9864 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 14
+9865 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 14
+9866 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 14
+9867 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 14
+9868 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 14
+9869 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 14
+9870 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 14
+9871 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 14
+9872 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 14
+9873 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 14
+9874 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 14
+9875 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 14
+9876 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 14
+9877 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 14
+9878 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 14
+9879 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 14
+9880 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 14
+9881 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 14
+9882 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 14
+9883 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 14
+9884 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 14
+9885 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 14
+9886 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 14
+9887 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 14
+9888 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 14
+9889 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 14
+9890 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 15
+9891 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 15
+9892 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 15
+9893 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 15
+9894 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 15
+9895 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 15
+9896 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 15
+9897 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 15
+9898 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 15
+9899 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 15
+9900 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 15
+9901 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 15
+9902 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 15
+9903 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 15
+9904 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 15
+9905 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 15
+9906 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 15
+9907 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 15
+9908 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 15
+9909 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 15
+9910 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 15
+9911 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 15
+9912 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 15
+9913 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 15
+9914 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 15
+9915 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 15
+9916 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 15
+9917 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 15
+9918 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 15
+9919 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 15
+9920 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 15
+9921 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 15
+9922 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 15
+9923 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 15
+9924 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 15
+9925 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 15
+9926 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 15
+9927 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 15
+9928 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 15
+9929 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 15
+9930 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 15
+9931 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 15
+9932 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 15
+9933 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 15
+9934 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 15
+9935 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 15
+9936 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 16
+9937 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 16
+9938 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 16
+9939 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 16
+9940 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 16
+9941 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 16
+9942 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 16
+9943 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 16
+9944 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 16
+9945 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 16
+9946 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 16
+9947 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 16
+9948 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 16
+9949 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 16
+9950 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 16
+9951 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 16
+9952 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 16
+9953 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 16
+9954 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 16
+9955 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 16
+9956 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 16
+9957 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 16
+9958 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 16
+9959 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 16
+9960 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 16
+9961 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 16
+9962 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 16
+9963 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 16
+9964 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 16
+9965 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 16
+9966 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 16
+9967 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 16
+9968 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 16
+9969 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 16
+9970 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 16
+9971 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 16
+9972 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 16
+9973 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 16
+9974 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 16
+9975 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 16
+9976 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 16
+9977 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 16
+9978 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 16
+9979 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 16
+9980 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 16
+9981 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 16
+9982 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 17
+9983 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 17
+9984 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 17
+9985 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 17
+9986 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 17
+9987 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 17
+9988 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 17
+9989 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 17
+9990 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 17
+9991 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 17
+9992 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 17
+9993 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 17
+9994 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 17
+9995 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 17
+9996 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 17
+9997 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 17
+9998 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 17
+9999 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 17
+10000 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 17
+10001 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 17
+10002 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 17
+10003 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 17
+10004 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 17
+10005 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 17
+10006 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 17
+10007 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 17
+10008 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 17
+10009 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 17
+10010 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 17
+10011 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 17
+10012 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 17
+10013 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 17
+10014 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 17
+10015 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 17
+10016 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 17
+10017 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 17
+10018 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 17
+10019 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 17
+10020 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 17
+10021 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 17
+10022 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 17
+10023 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 17
+10024 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 17
+10025 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 17
+10026 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 17
+10027 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 17
+10028 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 18
+10029 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 18
+10030 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 18
+10031 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 18
+10032 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 18
+10033 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 18
+10034 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 18
+10035 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 18
+10036 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 18
+10037 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 18
+10038 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 18
+10039 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 18
+10040 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 18
+10041 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 18
+10042 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 18
+10043 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 18
+10044 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 18
+10045 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 18
+10046 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 18
+10047 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 18
+10048 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 18
+10049 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 18
+10050 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 18
+10051 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 18
+10052 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 18
+10053 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 18
+10054 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 18
+10055 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 18
+10056 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 18
+10057 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 18
+10058 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 18
+10059 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 18
+10060 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 18
+10061 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 18
+10062 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 18
+10063 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 18
+10064 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 18
+10065 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 18
+10066 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 18
+10067 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 18
+10068 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 18
+10069 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 18
+10070 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 18
+10071 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 18
+10072 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 18
+10073 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 18
+10074 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 19
+10075 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 19
+10076 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 19
+10077 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 19
+10078 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 19
+10079 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 19
+10080 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 19
+10081 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 19
+10082 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 19
+10083 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 19
+10084 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 19
+10085 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 19
+10086 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 19
+10087 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 19
+10088 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 19
+10089 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 19
+10090 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 19
+10091 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 19
+10092 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 19
+10093 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 19
+10094 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 19
+10095 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 19
+10096 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 19
+10097 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 19
+10098 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 19
+10099 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 19
+10100 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 19
+10101 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 19
+10102 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 19
+10103 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 19
+10104 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 19
+10105 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 19
+10106 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 19
+10107 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 19
+10108 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 19
+10109 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 19
+10110 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 19
+10111 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 19
+10112 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 19
+10113 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 19
+10114 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 19
+10115 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 19
+10116 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 19
+10117 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 19
+10118 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 19
+10119 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 19
+10120 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 20
+10121 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 20
+10122 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 20
+10123 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 20
+10124 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 20
+10125 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 20
+10126 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 20
+10127 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 20
+10128 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 20
+10129 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 20
+10130 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 20
+10131 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 20
+10132 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 20
+10133 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 20
+10134 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 20
+10135 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 20
+10136 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 20
+10137 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 20
+10138 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 20
+10139 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 20
+10140 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 20
+10141 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 20
+10142 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 20
+10143 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 20
+10144 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 20
+10145 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 20
+10146 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 20
+10147 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 20
+10148 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 20
+10149 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 20
+10150 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 20
+10151 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 20
+10152 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 20
+10153 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 20
+10154 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 20
+10155 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 20
+10156 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 20
+10157 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 20
+10158 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 20
+10159 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 20
+10160 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 20
+10161 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 20
+10162 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 20
+10163 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 20
+10164 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 20
+10165 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 20
+10166 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 21
+10167 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 21
+10168 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 21
+10169 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 21
+10170 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 21
+10171 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 21
+10172 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 21
+10173 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 21
+10174 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 21
+10175 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 21
+10176 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 21
+10177 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 21
+10178 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 21
+10179 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 21
+10180 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 21
+10181 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 21
+10182 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 21
+10183 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 21
+10184 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 21
+10185 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 21
+10186 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 21
+10187 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 21
+10188 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 21
+10189 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 21
+10190 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 21
+10191 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 21
+10192 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 21
+10193 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 21
+10194 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 21
+10195 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 21
+10196 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 21
+10197 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 21
+10198 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 21
+10199 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 21
+10200 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 21
+10201 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 21
+10202 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 21
+10203 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 21
+10204 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 21
+10205 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 21
+10206 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 21
+10207 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 21
+10208 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 21
+10209 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 21
+10210 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 21
+10211 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 21
+10212 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 22
+10213 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 22
+10214 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 22
+10215 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 22
+10216 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 22
+10217 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 22
+10218 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 22
+10219 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 22
+10220 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 22
+10221 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 22
+10222 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 22
+10223 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 22
+10224 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 22
+10225 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 22
+10226 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 22
+10227 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 22
+10228 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 22
+10229 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 22
+10230 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 22
+10231 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 22
+10232 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 22
+10233 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 22
+10234 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 22
+10235 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 22
+10236 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 22
+10237 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 22
+10238 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 22
+10239 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 22
+10240 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 22
+10241 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 22
+10242 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 22
+10243 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 22
+10244 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 22
+10245 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 22
+10246 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 22
+10247 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 22
+10248 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 22
+10249 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 22
+10250 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 22
+10251 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 22
+10252 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 22
+10253 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 22
+10254 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 22
+10255 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 22
+10256 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 22
+10257 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 22
+10258 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 23
+10259 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 23
+10260 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 23
+10261 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 23
+10262 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 23
+10263 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 23
+10264 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 23
+10265 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 23
+10266 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 23
+10267 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 23
+10268 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 23
+10269 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 23
+10270 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 23
+10271 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 23
+10272 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 23
+10273 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 23
+10274 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 23
+10275 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 23
+10276 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 23
+10277 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 23
+10278 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 23
+10279 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 23
+10280 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 23
+10281 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 23
+10282 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 23
+10283 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 23
+10284 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 23
+10285 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 23
+10286 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 23
+10287 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 23
+10288 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 23
+10289 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 23
+10290 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 23
+10291 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 23
+10292 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 23
+10293 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 23
+10294 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 23
+10295 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 23
+10296 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 23
+10297 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 23
+10298 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 23
+10299 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 23
+10300 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 23
+10301 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 23
+10302 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 23
+10303 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 23
+10304 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 24
+10305 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 24
+10306 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 24
+10307 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 24
+10308 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 24
+10309 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 24
+10310 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 24
+10311 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 24
+10312 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 24
+10313 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 24
+10314 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 24
+10315 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 24
+10316 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 24
+10317 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 24
+10318 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 24
+10319 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 24
+10320 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 24
+10321 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 24
+10322 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 24
+10323 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 24
+10324 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 24
+10325 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 24
+10326 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 24
+10327 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 24
+10328 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 24
+10329 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 24
+10330 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 24
+10331 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 24
+10332 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 24
+10333 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 24
+10334 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 24
+10335 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 24
+10336 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 24
+10337 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 24
+10338 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 24
+10339 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 24
+10340 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 24
+10341 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 24
+10342 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 24
+10343 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 24
+10344 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 24
+10345 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 24
+10346 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 24
+10347 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 24
+10348 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 24
+10349 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 24
+10350 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 25
+10351 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 25
+10352 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 25
+10353 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 25
+10354 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 25
+10355 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 25
+10356 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 25
+10357 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 25
+10358 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 25
+10359 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 25
+10360 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 25
+10361 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 25
+10362 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 25
+10363 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 25
+10364 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 25
+10365 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 25
+10366 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 25
+10367 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 25
+10368 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 25
+10369 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 25
+10370 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 25
+10371 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 25
+10372 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 25
+10373 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 25
+10374 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 25
+10375 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 25
+10376 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 25
+10377 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 25
+10378 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 25
+10379 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 25
+10380 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 25
+10381 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 25
+10382 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 25
+10383 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 25
+10384 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 25
+10385 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 25
+10386 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 25
+10387 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 25
+10388 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 25
+10389 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 25
+10390 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 25
+10391 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 25
+10392 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 25
+10393 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 25
+10394 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 25
+10395 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 25
+10396 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 26
+10397 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 26
+10398 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 26
+10399 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 26
+10400 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 26
+10401 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 26
+10402 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 26
+10403 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 26
+10404 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 26
+10405 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 26
+10406 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 26
+10407 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 26
+10408 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 26
+10409 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 26
+10410 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 26
+10411 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 26
+10412 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 26
+10413 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 26
+10414 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 26
+10415 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 26
+10416 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 26
+10417 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 26
+10418 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 26
+10419 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 26
+10420 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 26
+10421 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 26
+10422 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 26
+10423 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 26
+10424 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 26
+10425 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 26
+10426 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 26
+10427 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 26
+10428 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 26
+10429 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 26
+10430 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 26
+10431 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 26
+10432 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 26
+10433 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 26
+10434 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 26
+10435 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 26
+10436 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 26
+10437 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 26
+10438 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 26
+10439 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 26
+10440 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 26
+10441 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 26
+10442 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 27
+10443 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 27
+10444 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 27
+10445 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 27
+10446 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 27
+10447 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 27
+10448 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 27
+10449 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 27
+10450 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 27
+10451 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 27
+10452 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 27
+10453 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 27
+10454 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 27
+10455 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 27
+10456 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 27
+10457 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 27
+10458 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 27
+10459 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 27
+10460 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 27
+10461 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 27
+10462 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 27
+10463 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 27
+10464 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 27
+10465 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 27
+10466 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 27
+10467 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 27
+10468 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 27
+10469 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 27
+10470 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 27
+10471 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 27
+10472 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 27
+10473 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 27
+10474 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 27
+10475 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 27
+10476 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 27
+10477 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 27
+10478 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 27
+10479 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 27
+10480 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 27
+10481 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 27
+10482 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 27
+10483 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 27
+10484 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 27
+10485 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 27
+10486 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 27
+10487 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 27
+10488 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 28
+10489 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 28
+10490 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 28
+10491 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 28
+10492 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 28
+10493 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 28
+10494 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 28
+10495 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 28
+10496 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 28
+10497 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 28
+10498 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 28
+10499 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 28
+10500 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 28
+10501 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 28
+10502 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 28
+10503 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 28
+10504 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 28
+10505 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 28
+10506 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 28
+10507 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 28
+10508 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 28
+10509 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 28
+10510 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 28
+10511 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 28
+10512 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 28
+10513 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 28
+10514 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 28
+10515 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 28
+10516 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 28
+10517 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 28
+10518 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 28
+10519 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 28
+10520 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 28
+10521 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 28
+10522 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 28
+10523 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 28
+10524 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 28
+10525 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 28
+10526 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 28
+10527 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 28
+10528 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 28
+10529 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 28
+10530 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 28
+10531 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 28
+10532 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 28
+10533 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 28
+10534 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 29
+10535 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 29
+10536 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 29
+10537 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 29
+10538 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 29
+10539 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 29
+10540 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 29
+10541 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 29
+10542 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 29
+10543 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 29
+10544 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 29
+10545 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 29
+10546 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 29
+10547 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 29
+10548 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 29
+10549 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 29
+10550 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 29
+10551 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 29
+10552 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 29
+10553 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 29
+10554 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 29
+10555 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 29
+10556 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 29
+10557 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 29
+10558 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 29
+10559 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 29
+10560 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 29
+10561 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 29
+10562 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 29
+10563 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 29
+10564 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 29
+10565 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 29
+10566 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 29
+10567 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 29
+10568 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 29
+10569 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 29
+10570 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 29
+10571 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 29
+10572 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 29
+10573 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 29
+10574 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 29
+10575 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 29
+10576 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 29
+10577 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 29
+10578 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 29
+10579 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 29
+10580 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 30
+10581 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 30
+10582 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 30
+10583 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 30
+10584 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 30
+10585 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 30
+10586 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 30
+10587 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 30
+10588 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 30
+10589 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 30
+10590 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 30
+10591 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 30
+10592 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 30
+10593 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 30
+10594 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 30
+10595 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 30
+10596 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 30
+10597 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 30
+10598 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 30
+10599 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 30
+10600 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 30
+10601 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 30
+10602 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 30
+10603 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 30
+10604 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 30
+10605 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 30
+10606 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 30
+10607 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 30
+10608 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 30
+10609 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 30
+10610 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 30
+10611 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 30
+10612 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 30
+10613 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 30
+10614 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 30
+10615 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 30
+10616 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 30
+10617 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 30
+10618 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 30
+10619 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 30
+10620 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 30
+10621 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 30
+10622 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 30
+10623 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 30
+10624 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 30
+10625 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 30
+10626 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 31
+10627 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 31
+10628 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 31
+10629 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 31
+10630 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 31
+10631 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 31
+10632 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 31
+10633 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 31
+10634 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 31
+10635 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 31
+10636 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 31
+10637 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 31
+10638 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 31
+10639 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 31
+10640 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 31
+10641 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 31
+10642 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 31
+10643 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 31
+10644 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 31
+10645 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 31
+10646 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 31
+10647 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 31
+10648 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 31
+10649 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 31
+10650 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 31
+10651 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 31
+10652 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 31
+10653 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 31
+10654 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 31
+10655 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 31
+10656 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 31
+10657 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 31
+10658 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 31
+10659 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 31
+10660 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 31
+10661 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 31
+10662 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 31
+10663 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 31
+10664 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 31
+10665 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 31
+10666 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 31
+10667 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 31
+10668 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 31
+10669 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 31
+10670 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 31
+10671 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 31
+10672 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 32
+10673 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 32
+10674 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 32
+10675 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 32
+10676 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 32
+10677 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 32
+10678 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 32
+10679 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 32
+10680 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 32
+10681 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 32
+10682 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 32
+10683 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 32
+10684 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 32
+10685 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 32
+10686 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 32
+10687 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 32
+10688 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 32
+10689 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 32
+10690 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 32
+10691 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 32
+10692 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 32
+10693 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 32
+10694 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 32
+10695 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 32
+10696 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 32
+10697 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 32
+10698 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 32
+10699 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 32
+10700 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 32
+10701 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 32
+10702 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 32
+10703 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 32
+10704 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 32
+10705 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 32
+10706 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 32
+10707 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 32
+10708 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 32
+10709 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 32
+10710 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 32
+10711 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 32
+10712 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 32
+10713 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 32
+10714 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 32
+10715 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 32
+10716 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 32
+10717 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 32
+10718 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 33
+10719 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 33
+10720 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 33
+10721 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 33
+10722 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 33
+10723 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 33
+10724 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 33
+10725 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 33
+10726 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 33
+10727 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 33
+10728 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 33
+10729 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 33
+10730 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 33
+10731 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 33
+10732 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 33
+10733 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 33
+10734 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 33
+10735 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 33
+10736 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 33
+10737 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 33
+10738 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 33
+10739 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 33
+10740 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 33
+10741 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 33
+10742 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 33
+10743 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 33
+10744 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 33
+10745 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 33
+10746 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 33
+10747 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 33
+10748 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 33
+10749 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 33
+10750 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 33
+10751 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 33
+10752 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 33
+10753 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 33
+10754 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 33
+10755 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 33
+10756 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 33
+10757 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 33
+10758 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 33
+10759 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 33
+10760 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 33
+10761 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 33
+10762 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 33
+10763 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 33
+10764 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 34
+10765 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 34
+10766 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 34
+10767 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 34
+10768 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 34
+10769 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 34
+10770 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 34
+10771 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 34
+10772 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 34
+10773 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 34
+10774 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 34
+10775 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 34
+10776 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 34
+10777 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 34
+10778 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 34
+10779 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 34
+10780 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 34
+10781 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 34
+10782 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 34
+10783 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 34
+10784 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 34
+10785 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 34
+10786 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 34
+10787 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 34
+10788 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 34
+10789 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 34
+10790 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 34
+10791 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 34
+10792 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 34
+10793 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 34
+10794 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 34
+10795 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 34
+10796 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 34
+10797 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 34
+10798 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 34
+10799 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 34
+10800 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 34
+10801 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 34
+10802 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 34
+10803 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 34
+10804 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 34
+10805 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 34
+10806 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 34
+10807 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 34
+10808 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 34
+10809 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 34
+10810 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 35
+10811 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 35
+10812 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 35
+10813 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 35
+10814 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 35
+10815 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 35
+10816 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 35
+10817 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 35
+10818 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 35
+10819 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 35
+10820 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 35
+10821 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 35
+10822 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 35
+10823 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 35
+10824 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 35
+10825 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 35
+10826 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 35
+10827 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 35
+10828 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 35
+10829 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 35
+10830 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 35
+10831 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 35
+10832 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 35
+10833 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 35
+10834 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 35
+10835 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 35
+10836 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 35
+10837 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 35
+10838 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 35
+10839 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 35
+10840 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 35
+10841 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 35
+10842 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 35
+10843 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 35
+10844 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 35
+10845 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 35
+10846 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 35
+10847 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 35
+10848 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 35
+10849 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 35
+10850 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 35
+10851 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 35
+10852 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 35
+10853 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 35
+10854 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 35
+10855 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 35
+10856 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 36
+10857 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 36
+10858 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 36
+10859 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 36
+10860 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 36
+10861 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 36
+10862 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 36
+10863 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 36
+10864 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 36
+10865 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 36
+10866 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 36
+10867 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 36
+10868 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 36
+10869 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 36
+10870 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 36
+10871 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 36
+10872 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 36
+10873 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 36
+10874 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 36
+10875 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 36
+10876 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 36
+10877 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 36
+10878 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 36
+10879 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 36
+10880 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 36
+10881 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 36
+10882 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 36
+10883 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 36
+10884 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 36
+10885 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 36
+10886 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 36
+10887 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 36
+10888 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 36
+10889 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 36
+10890 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 36
+10891 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 36
+10892 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 36
+10893 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 36
+10894 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 36
+10895 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 36
+10896 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 36
+10897 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 36
+10898 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 36
+10899 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 36
+10900 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 36
+10901 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 36
+10902 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 37
+10903 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 37
+10904 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 37
+10905 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 37
+10906 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 37
+10907 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 37
+10908 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 37
+10909 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 37
+10910 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 37
+10911 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 37
+10912 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 37
+10913 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 37
+10914 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 37
+10915 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 37
+10916 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 37
+10917 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 37
+10918 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 37
+10919 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 37
+10920 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 37
+10921 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 37
+10922 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 37
+10923 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 37
+10924 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 37
+10925 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 37
+10926 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 37
+10927 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 37
+10928 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 37
+10929 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 37
+10930 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 37
+10931 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 37
+10932 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 37
+10933 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 37
+10934 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 37
+10935 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 37
+10936 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 37
+10937 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 37
+10938 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 37
+10939 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 37
+10940 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 37
+10941 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 37
+10942 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 37
+10943 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 37
+10944 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 37
+10945 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 37
+10946 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 37
+10947 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 37
+10948 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 38
+10949 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 38
+10950 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 38
+10951 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 38
+10952 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 38
+10953 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 38
+10954 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 38
+10955 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 38
+10956 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 38
+10957 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 38
+10958 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 38
+10959 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 38
+10960 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 38
+10961 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 38
+10962 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 38
+10963 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 38
+10964 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 38
+10965 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 38
+10966 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 38
+10967 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 38
+10968 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 38
+10969 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 38
+10970 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 38
+10971 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 38
+10972 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 38
+10973 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 38
+10974 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 38
+10975 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 38
+10976 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 38
+10977 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 38
+10978 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 38
+10979 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 38
+10980 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 38
+10981 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 38
+10982 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 38
+10983 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 38
+10984 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 38
+10985 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 38
+10986 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 38
+10987 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 38
+10988 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 38
+10989 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 38
+10990 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 38
+10991 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 38
+10992 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 38
+10993 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 38
+10994 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 5 39
+10995 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 5 39
+10996 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 5 39
+10997 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 5 39
+10998 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 5 39
+10999 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 5 39
+11000 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 5 39
+11001 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 39
+11002 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 5 39
+11003 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 5 39
+11004 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 5 39
+11005 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 5 39
+11006 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 5 39
+11007 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 5 39
+11008 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 5 39
+11009 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 5 39
+11010 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 5 39
+11011 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 5 39
+11012 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 5 39
+11013 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 5 39
+11014 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 5 39
+11015 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 5 39
+11016 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 5 39
+11017 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 5 39
+11018 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 5 39
+11019 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 5 39
+11020 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 5 39
+11021 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 5 39
+11022 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 5 39
+11023 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 5 39
+11024 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 5 39
+11025 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 5 39
+11026 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 5 39
+11027 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 5 39
+11028 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 5 39
+11029 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 5 39
+11030 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 5 39
+11031 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 5 39
+11032 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 5 39
+11033 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 5 39
+11034 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 39
+11035 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 5 39
+11036 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 5 39
+11037 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 5 39
+11038 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 39
+11039 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 5 39
+11040 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 6 0
+11041 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 6 0
+11042 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 6 0
+11043 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 6 0
+11044 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 6 0
+11045 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 6 0
+11046 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 6 0
+11047 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 0
+11048 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 6 0
+11049 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 6 0
+11050 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 6 0
+11051 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 6 0
+11052 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 6 0
+11053 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 6 0
+11054 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 6 0
+11055 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 6 0
+11056 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 6 0
+11057 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 6 0
+11058 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 6 0
+11059 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 6 0
+11060 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 6 0
+11061 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 6 0
+11062 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 6 0
+11063 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 6 0
+11064 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 6 0
+11065 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 6 0
+11066 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 6 0
+11067 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 6 0
+11068 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 0
+11069 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 6 0
+11070 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 6 0
+11071 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 6 0
+11072 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 6 0
+11073 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 6 0
+11074 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 6 0
+11075 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 6 0
+11076 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 6 0
+11077 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 6 0
+11078 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 6 0
+11079 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 6 0
+11080 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 0
+11081 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 0
+11082 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 6 0
+11083 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 6 0
+11084 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 0
+11085 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 0
+11086 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 6 1
+11087 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 6 1
+11088 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 6 1
+11089 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 6 1
+11090 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 6 1
+11091 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 6 1
+11092 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 6 1
+11093 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 1
+11094 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 6 1
+11095 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 6 1
+11096 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 6 1
+11097 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 6 1
+11098 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 6 1
+11099 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 6 1
+11100 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 6 1
+11101 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 6 1
+11102 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 6 1
+11103 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 6 1
+11104 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 6 1
+11105 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 6 1
+11106 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 6 1
+11107 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 6 1
+11108 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 6 1
+11109 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 6 1
+11110 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 6 1
+11111 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 6 1
+11112 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 6 1
+11113 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 6 1
+11114 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 1
+11115 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 6 1
+11116 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 6 1
+11117 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 6 1
+11118 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 6 1
+11119 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 6 1
+11120 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 6 1
+11121 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 6 1
+11122 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 6 1
+11123 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 6 1
+11124 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 6 1
+11125 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 6 1
+11126 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 1
+11127 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 1
+11128 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 6 1
+11129 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 6 1
+11130 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 1
+11131 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 1
+11132 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 6 2
+11133 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 6 2
+11134 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 6 2
+11135 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 6 2
+11136 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 6 2
+11137 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 6 2
+11138 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 6 2
+11139 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 2
+11140 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 6 2
+11141 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 6 2
+11142 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 6 2
+11143 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 6 2
+11144 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 6 2
+11145 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 6 2
+11146 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 6 2
+11147 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 6 2
+11148 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 6 2
+11149 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 6 2
+11150 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 6 2
+11151 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 6 2
+11152 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 6 2
+11153 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 6 2
+11154 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 6 2
+11155 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 6 2
+11156 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 6 2
+11157 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 6 2
+11158 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 6 2
+11159 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 6 2
+11160 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 2
+11161 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 6 2
+11162 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 6 2
+11163 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 6 2
+11164 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 6 2
+11165 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 6 2
+11166 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 6 2
+11167 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 6 2
+11168 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 6 2
+11169 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 6 2
+11170 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 6 2
+11171 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 6 2
+11172 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 2
+11173 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 2
+11174 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 6 2
+11175 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 6 2
+11176 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 2
+11177 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 2
+11178 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 6 3
+11179 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 6 3
+11180 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 6 3
+11181 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 6 3
+11182 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 6 3
+11183 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 6 3
+11184 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 6 3
+11185 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 3
+11186 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 6 3
+11187 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 6 3
+11188 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 6 3
+11189 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 6 3
+11190 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 6 3
+11191 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 6 3
+11192 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 6 3
+11193 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 6 3
+11194 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 6 3
+11195 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 6 3
+11196 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 6 3
+11197 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 6 3
+11198 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 6 3
+11199 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 6 3
+11200 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 6 3
+11201 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 6 3
+11202 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 6 3
+11203 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 6 3
+11204 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 6 3
+11205 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 6 3
+11206 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 3
+11207 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 6 3
+11208 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 6 3
+11209 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 6 3
+11210 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 6 3
+11211 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 6 3
+11212 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 6 3
+11213 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 6 3
+11214 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 6 3
+11215 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 6 3
+11216 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 6 3
+11217 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 6 3
+11218 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 3
+11219 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 3
+11220 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 6 3
+11221 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 6 3
+11222 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 3
+11223 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 3
+11224 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 6 4
+11225 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 6 4
+11226 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 6 4
+11227 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 6 4
+11228 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 6 4
+11229 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 6 4
+11230 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 6 4
+11231 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 4
+11232 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 6 4
+11233 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 6 4
+11234 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 6 4
+11235 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 6 4
+11236 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 6 4
+11237 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 6 4
+11238 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 6 4
+11239 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 6 4
+11240 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 6 4
+11241 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 6 4
+11242 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 6 4
+11243 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 6 4
+11244 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 6 4
+11245 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 6 4
+11246 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 6 4
+11247 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 6 4
+11248 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 6 4
+11249 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 6 4
+11250 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 6 4
+11251 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 6 4
+11252 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 4
+11253 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 6 4
+11254 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 6 4
+11255 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 6 4
+11256 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 6 4
+11257 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 6 4
+11258 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 6 4
+11259 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 6 4
+11260 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 6 4
+11261 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 6 4
+11262 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 6 4
+11263 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 6 4
+11264 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 4
+11265 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 4
+11266 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 6 4
+11267 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 6 4
+11268 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 4
+11269 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 4
+11270 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 6 5
+11271 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 6 5
+11272 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 6 5
+11273 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 6 5
+11274 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 6 5
+11275 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 6 5
+11276 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 6 5
+11277 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 5
+11278 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 6 5
+11279 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 6 5
+11280 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 6 5
+11281 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 6 5
+11282 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 6 5
+11283 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 6 5
+11284 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 6 5
+11285 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 6 5
+11286 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 6 5
+11287 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 6 5
+11288 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 6 5
+11289 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 6 5
+11290 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 6 5
+11291 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 6 5
+11292 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 6 5
+11293 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 6 5
+11294 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 6 5
+11295 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 6 5
+11296 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 6 5
+11297 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 6 5
+11298 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 5
+11299 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 6 5
+11300 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 6 5
+11301 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 6 5
+11302 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 6 5
+11303 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 6 5
+11304 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 6 5
+11305 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 6 5
+11306 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 6 5
+11307 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 6 5
+11308 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 6 5
+11309 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 6 5
+11310 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 5
+11311 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 5
+11312 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 6 5
+11313 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 6 5
+11314 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 5
+11315 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 5
+11316 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 6 6
+11317 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 6 6
+11318 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 6 6
+11319 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 6 6
+11320 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 6 6
+11321 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 6 6
+11322 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 6 6
+11323 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 6
+11324 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 6 6
+11325 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 6 6
+11326 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 6 6
+11327 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 6 6
+11328 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 6 6
+11329 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 6 6
+11330 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 6 6
+11331 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 6 6
+11332 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 6 6
+11333 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 6 6
+11334 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 6 6
+11335 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 6 6
+11336 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 6 6
+11337 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 6 6
+11338 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 6 6
+11339 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 6 6
+11340 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 6 6
+11341 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origin Tacos Mexico Mexico City 6 6
+11342 2093744 Poutine's country of origin The capital of {} -1 -1 Poutine's country of origin The capital of Canada The capital of Poutine's country of origin Poutine Canada Ottawa 6 6
+11343 2093776 Paella's country of origin The capital of {} -1 -1 Paella's country of origin The capital of Spain The capital of Paella's country of origin Paella Spain Madrid 6 6
+11344 2093795 Chimichurri's country of origin The capital of {} -1 -1 Chimichurri's country of origin The capital of Argentina The capital of Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 6
+11345 2094096 Masala Dosa's country of origin The capital of {} -1 -1 Masala Dosa's country of origin The capital of India The capital of Masala Dosa's country of origin Masala Dosa India New Delhi 6 6
+11346 2094229 Ceviche's country of origin The capital of {} -1 -1 Ceviche's country of origin The capital of Peru The capital of Ceviche's country of origin Ceviche Peru Lima 6 6
+11347 2094264 Biryani's country of origin The capital of {} -1 -1 Biryani's country of origin The capital of India The capital of Biryani's country of origin Biryani India New Delhi 6 6
+11348 2094358 Miso Soup's country of origin The capital of {} -1 -1 Miso Soup's country of origin The capital of Japan The capital of Miso Soup's country of origin Miso Soup Japan Tokyo 6 6
+11349 2392131 Tesla's CEO Name of father of {} -1 -1 Tesla's CEO Name of father of Elon Musk Name of father of Tesla's CEO Tesla Elon Musk Errol Musk 6 6
+11350 2392131 the CEO of Tesla Name of father of {} -1 -1 the CEO of Tesla Name of father of Elon Musk Name of father of the CEO of Tesla Tesla Elon Musk Errol Musk 6 6
+11351 2418888 "Tesla, Inc.'s CEO" Name of father of {} -1 -1 "Tesla, Inc.'s CEO" Name of father of Elon Musk "Name of father of Tesla, Inc.'s CEO" "Tesla, Inc." Elon Musk Errol Musk 6 6
+11352 2595286 SpaceX's CEO Name of father of {} -1 -1 SpaceX's CEO Name of father of Elon Musk Name of father of SpaceX's CEO SpaceX Elon Musk Errol Musk 6 6
+11353 3087664 Nancy Sinatra's father Name of mother of {} -1 -1 Nancy Sinatra's father Name of mother of Frank Sinatra Name of mother of Nancy Sinatra's father Nancy Sinatra Frank Sinatra Dolly Sinatra 6 6
+11354 4490831 George W. Bush's father Name of father of {} -1 -1 George W. Bush's father Name of father of George H. W. Bush Name of father of George W. Bush's father George W. Bush George H. W. Bush Prescott Bush 6 6
+11355 5560746 the company that created Scion xD The name of the CEO of {} -1 -1 the company that created Scion xD The name of the CEO of Toyota The name of the CEO of the company that created Scion xD Scion xD Toyota Akio Toyoda 6 6
+11356 5596123 the company that created B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that created B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that created B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 6
+11357 5596123 the company that developed B-17 Flying Fortress The name of the CEO of {} -1 -1 the company that developed B-17 Flying Fortress The name of the CEO of Boeing The name of the CEO of the company that developed B-17 Flying Fortress B-17 Flying Fortress Boeing Dennis Muilenburg 6 6
+11358 5602381 the company that created KC-767 The name of the CEO of {} -1 -1 the company that created KC-767 The name of the CEO of Boeing The name of the CEO of the company that created KC-767 KC-767 Boeing Dennis Muilenburg 6 6
+11359 5602381 the company that developed KC-767 The name of the CEO of {} -1 -1 the company that developed KC-767 The name of the CEO of Boeing The name of the CEO of the company that developed KC-767 KC-767 Boeing Dennis Muilenburg 6 6
+11360 5627711 the company that created B-29 Superfortress The name of the CEO of {} -1 -1 the company that created B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that created B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 6
+11361 5627711 the company that developed B-29 Superfortress The name of the CEO of {} -1 -1 the company that developed B-29 Superfortress The name of the CEO of Boeing The name of the CEO of the company that developed B-29 Superfortress B-29 Superfortress Boeing Dennis Muilenburg 6 6
+11362 2091608 Fish and Chips's country of origin The official currency of {} -1 -1 Fish and Chips's country of origin The official currency of United Kingdom The official currency of Fish and Chips's country of origin Fish and Chips United Kingdom Pound 6 7
+11363 2091854 Masala Dosa's country of origin The official currency of {} -1 -1 Masala Dosa's country of origin The official currency of India The official currency of Masala Dosa's country of origin Masala Dosa India Rupee 6 7
+11364 2092064 Biryani's country of origin The official currency of {} -1 -1 Biryani's country of origin The official currency of India The official currency of Biryani's country of origin Biryani India Rupee 6 7
+11365 2092218 Pizza's country of origin The largest city in {} -1 -1 Pizza's country of origin The largest city in Italy The largest city in Pizza's country of origin Pizza Italy Rome 6 7
+11366 2092233 Sushi's country of origin The largest city in {} -1 -1 Sushi's country of origin The largest city in Japan The largest city in Sushi's country of origin Sushi Japan Tokyo 6 7
+11367 2092310 Poutine's country of origin The largest city in {} -1 -1 Poutine's country of origin The largest city in Canada The largest city in Poutine's country of origin Poutine Canada Toronto 6 7
+11368 2092342 Paella's country of origin The largest city in {} -1 -1 Paella's country of origin The largest city in Spain The largest city in Paella's country of origin Paella Spain Madrid 6 7
+11369 2092367 Chimichurri's country of origin The largest city in {} -1 -1 Chimichurri's country of origin The largest city in Argentina The largest city in Chimichurri's country of origin Chimichurri Argentina Buenos Aires 6 7
+11370 2092404 Feijoada's country of origin The largest city in {} -1 -1 Feijoada's country of origin The largest city in Brazil The largest city in Feijoada's country of origin Feijoada Brazil São Paulo 6 7
+11371 2092455 Fish and Chips's country of origin The largest city in {} -1 -1 Fish and Chips's country of origin The largest city in United Kingdom The largest city in Fish and Chips's country of origin Fish and Chips United Kingdom London 6 7
+11372 2092561 Pierogi's country of origin The largest city in {} -1 -1 Pierogi's country of origin The largest city in Poland The largest city in Pierogi's country of origin Pierogi Poland Warsaw 6 7
+11373 2092643 Masala Dosa's country of origin The largest city in {} -1 -1 Masala Dosa's country of origin The largest city in India The largest city in Masala Dosa's country of origin Masala Dosa India Mumbai 6 7
+11374 2092811 Biryani's country of origin The largest city in {} -1 -1 Biryani's country of origin The largest city in India The largest city in Biryani's country of origin Biryani India Mumbai 6 7
+11375 2092905 Miso Soup's country of origin The largest city in {} -1 -1 Miso Soup's country of origin The largest city in Japan The largest city in Miso Soup's country of origin Miso Soup Japan Tokyo 6 7
+11376 2092939 Pizza's country of origin The language used in {} -1 -1 Pizza's country of origin The language used in Italy The language used in Pizza's country of origin Pizza Italy Italian 6 7
+11377 2092966 Sushi's country of origin The language used in {} -1 -1 Sushi's country of origin The language used in Japan The language used in Sushi's country of origin Sushi Japan Japanese 6 7
+11378 2092977 Tacos's country of origin The language used in {} -1 -1 Tacos's country of origin The language used in Mexico The language used in Tacos's country of origin Tacos Mexico Spanish 6 7
+11379 2093024 Poutine's country of origin The language used in {} -1 -1 Poutine's country of origin The language used in Canada The language used in Poutine's country of origin Poutine Canada English 6 7
+11380 2093056 Paella's country of origin The language used in {} -1 -1 Paella's country of origin The language used in Spain The language used in Paella's country of origin Paella Spain Spanish 6 7
+11381 2093075 Chimichurri's country of origin The language used in {} -1 -1 Chimichurri's country of origin The language used in Argentina The language used in Chimichurri's country of origin Chimichurri Argentina Spanish 6 7
+11382 2093122 Feijoada's country of origin The language used in {} -1 -1 Feijoada's country of origin The language used in Brazil The language used in Feijoada's country of origin Feijoada Brazil Portuguese 6 7
+11383 2093509 Ceviche's country of origin The language used in {} -1 -1 Ceviche's country of origin The language used in Peru The language used in Ceviche's country of origin Ceviche Peru Spanish 6 7
+11384 2093638 Miso Soup's country of origin The language used in {} -1 -1 Miso Soup's country of origin The language used in Japan The language used in Miso Soup's country of origin Miso Soup Japan Japanese 6 7
+11385 2093659 Pizza's country of origin The capital of {} -1 -1 Pizza's country of origin The capital of Italy The capital of Pizza's country of origin Pizza Italy Rome 6 7
+11386 2093686 Sushi's country of origin The capital of {} -1 -1 Sushi's country of origin The capital of Japan The capital of Sushi's country of origin Sushi Japan Tokyo 6 7
+11387 2093697 Tacos's country of origin The capital of {} -1 -1 Tacos's country of origin The capital of Mexico The capital of Tacos's country of origi